# RotoWire

In [None]:
import json
import time
import datetime
import math

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from more_itertools import collapse
from torch import optim
from tqdm import tqdm

from evals import BleuScore
from utils import *
from beam_search import beam_search
from early_stopping import EarlyStopping

# root directory
root_dir = './rotowire'


## data preprocessing

In [None]:
# data preprocessing
# reference: https://github.com/KaijuML/data-to-text-hierarchical/blob/master/data/make-dataset.py

DELIMITER = "￨"
ENTITY_SIZE = 24  # at most 24 elements in an entity

# box_score keys
bs_keys = ['START_POSITION', 'MIN', 'PTS', 'FGM', 'FGA', 'FG_PCT', 'FG3M',
           'FG3A', 'FG3_PCT', 'FTM', 'FTA', 'FT_PCT', 'OREB', 'DREB', 'REB',
           'AST', 'TO', 'STL', 'BLK', 'PF', 'FIRST_NAME', 'SECOND_NAME']
# line_score keys
ls_keys = ['PTS_QTR1', 'PTS_QTR2', 'PTS_QTR3', 'PTS_QTR4', 'PTS', 'FG_PCT',
           'FG3_PCT', 'FT_PCT', 'REB', 'AST', 'TOV', 'WINS', 'LOSSES',
           'CITY', 'NAME']
ls_keys = [f'TEAM-{key}' for key in ls_keys]


def build_home(entry):
    """The team who hosted the game"""
    records = [DELIMITER.join(['<ent>', '<ent>'])]
    for key in ls_keys:
        records.append(DELIMITER.join([
            entry['home_line'][key].replace(' ', '_'),
            key
        ]))

    # Contrary to previous work, IS_HOME is now a unique token at the end
    records.append(DELIMITER.join(['yes', 'IS_HOME']))

    # We pad the entity to size ENT_SIZE with OpenNMT <blank> token
    records.extend([DELIMITER.join(['<blank>', '<blank>'])]
                   * (ENTITY_SIZE - len(records)))
    return records


def build_vis(entry):
    """The visiting team"""
    records = [DELIMITER.join(['<ent>', '<ent>'])]
    for key in ls_keys:
        records.append(DELIMITER.join([
            entry['vis_line'][key].replace(' ', '_'),
            key
        ]))

    # Contrary to previous work, IS_HOME is now a unique token at the end
    records.append(DELIMITER.join(['no', 'IS_HOME']))

    # We pad the entity to size ENT_SIZE with OpenNMT <blank> token
    records.extend([DELIMITER.join(['<blank>', '<blank>'])]
                   * (ENTITY_SIZE - len(records)))
    return records


def get_player_idxs(entry):
    # In 4 instances the Clippers play against the Lakers
    # Both are from LA... We simply devide in half the players
    # In all 4, there are 26 players so we return 13-25 & 0-12
    # as it is always visiting first and home second.
    if entry['home_city'] == entry['vis_city']:
        assert entry['home_city'] == 'Los Angeles'
        return ([str(idx) for idx in range(13, 26)],
                [str(idx) for idx in range(13)])

    nplayers = len(entry['box_score']['PTS'])
    home_players, vis_players = list(), list()
    for i in range(nplayers):
        player_city = entry['box_score']['TEAM_CITY'][str(i)]
        if player_city == entry['home_city']:
            home_players.append(str(i))
        else:
            vis_players.append(str(i))
    return home_players, vis_players


def box_preprocess(entry, remove_na=True):
    home_players, vis_players = get_player_idxs(entry)

    all_entities = list()  # will contain all records of the input table

    for is_home, player_idxs in enumerate([vis_players, home_players]):
        for player_idx in player_idxs:
            player = [DELIMITER.join(['<ent>', '<ent>'])]
            for key in bs_keys:
                val = entry['box_score'][key][player_idx]
                if remove_na and val == 'N/A':
                    continue
                player.append(DELIMITER.join([
                    val.replace(' ', '_'),
                    key
                ]))
            is_home_str = 'yes' if is_home else 'no'
            player.append(DELIMITER.join([is_home_str, 'IS_HOME']))

            # We pad the entity to size ENT_SIZE with OpenNMT <blank> token
            player.extend([DELIMITER.join(['<blank>', '<blank>'])]
                          * (ENTITY_SIZE - len(player)))
            all_entities.append(player)

    all_entities.append(build_home(entry))
    all_entities.append(build_vis(entry))
    return list(collapse(all_entities))


def clean_summary(summary, tokens):
    """
    In here, we slightly help the copy mechanism
    When we built the source sequence, we took all multi-words value
    and repalaced spaces by underscores. We replace those as well in
    the summaries, so that the copy mechanism knows it was a copy.
    It only happens with city names like "Los Angeles".
    """
    summary = ' '.join(summary)
    for token in tokens:
        val = token.split(DELIMITER)[0]
        if '_' in val:
            val_no_underscore = val.replace('_', ' ')
            summary = summary.replace(val_no_underscore, val)
    return summary


## build vocab

In [None]:
inputs, summaries = [], []
for setname in ['train', 'valid', 'test']:
    filename = f'{root_dir}/data/{setname}.json'
    with open(filename, encoding='utf8', mode='r') as f:
        data = json.load(f)

    for entry in data:
        input = box_preprocess(entry)
        inputs.append(' '.join(input))
        summaries.append(clean_summary(entry['summary'], input))

keys = ls_keys + bs_keys
keys.extend(['<blank>', '<ent>', 'IS_HOME'])
keys.sort()

field2ind = dict(zip(keys, [i for i in range(1, len(keys) + 1)]))
field2ind['<pad>'] = 0

word2ind = {}
word2ind['<pad>'] = 0
word2ind['<unk>'] = 1
word2ind['<sos>'] = 2
word2ind['<eos>'] = 3
idx = 4

# dataset distribution
max_field, max_len = 0, 0
text_len_dict = {idx: 0 for idx in range(0, 850)}

for input in inputs:
    items = input.split()
    max_field = max(max_field, len(items))
    for item in items:
        w = item.split(DELIMITER)[0]
        if w not in word2ind:
            word2ind[w] = idx
            idx += 1

for summary in summaries:
    _words = summary.split()
    max_len = max(max_len, len(_words))
    text_len_dict[len(_words)] += 1
    for w in _words:
        if w not in word2ind:
            word2ind[w] = idx
            idx += 1

ind2word = {key: value for (value, key) in word2ind.items()}
ind2field = {key: value for (value, key) in field2ind.items()}
vocab = {'word2ind': word2ind, 'field2ind': field2ind}
with open(f'{root_dir}/data/vocab.json', 'w', encoding='utf8') as f:
    json.dump(vocab, f)

print(f'max_cnt of field: {max_field}')  # 768
print(f'max_len of summary: {max_len}')  # 813

# get distribution of the length of summary

plt.bar(list(text_len_dict.keys()), text_len_dict.values())
plt.xticks([idx for idx in range(200, 850, 50)])
plt.title('distribution of summary length')
plt.show()


## data processor

In [None]:
word_vocab_size = len(word2ind)  # 11473
field_vocab_size = len(field2ind)  # 41

max_field, max_len = 770, 700

PAD_TOKEN = 0
UNK_TOKEN = 1
SOS_TOKEN = 2
EOS_TOKEN = 3


def _load_data(path):
    with open(path, 'r', encoding='utf8') as f:
        data = json.load(f)
    inputs, summaries = [], []

    for entry in data:
        input = box_preprocess(entry)
        inputs.append(' '.join(input))
        summaries.append(clean_summary(entry['summary'], input))
    samples = []
    for input, summary in zip(inputs, summaries):
        samples.append({'input': input, 'summary': summary})
    return samples


def _process_data(data):
    """
    process data to model-friendly format, i.e.
    input: (max_field, 2) output: (max_len, 1)
    """
    seq_info = np.zeros((len(data), max_field, 2))  # PAD
    seq_target = np.zeros((len(data), max_len))  # PAD
    for data_index, data_item in enumerate(data):
        input_items = data_item['input'].split()
        for idx, input_item in enumerate(input_items):
            v, f = input_item.split(DELIMITER)
            seq_info[data_index, idx, 0] = field2ind[f]
            seq_info[data_index, idx,
                     1] = word2ind[v] if v in word2ind else UNK_TOKEN
        tokens_text = data_item['summary'].strip().split()
        seq_target[data_index, 0] = 2  # SOS
        for idx, token in enumerate(tokens_text):
            if (idx + 1) < max_len:
                seq_target[data_index, idx +
                           1] = word2ind[token] if token in word2ind else UNK_TOKEN
            else:
                break
    return seq_info, seq_target


def process_one_data(idx_data, tag: str):
    """process data to model-friendly format one-by-one for test set."""
    seq_info = np.zeros((max_field, 2))  # PAD
    if tag == 'test':
        input_items = test_data[idx_data]['input'].split()
    elif tag == 'dev':
        input_items = dev_data[idx_data]['input'].split()
    else:
        ValueError('illegal tag: ', tag)
    for idx, input_item in enumerate(input_items):
        v, f = input_item.split(DELIMITER)
        seq_info[idx, 0] = field2ind[f]
        seq_info[idx, 1] = word2ind[v] if v in word2ind else UNK_TOKEN
    return seq_info


def get_data_loader(seq_info, seq_target, batch_size, shuffle, device):
    seq_info_tensor = torch.tensor(seq_info, dtype=torch.long, device=device)
    seq_target_tensor = torch.tensor(
        seq_target, dtype=torch.long, device=device)
    data_loader = Data.DataLoader(dataset=Data.TensorDataset(seq_info_tensor, seq_target_tensor),
                                  batch_size=batch_size,
                                  shuffle=shuffle)
    return data_loader


def get_refs():
    """ get gold summaries """
    list_refs = []
    for data_item in test_data:
        list_refs.append(data_item['summary'])
    return list_refs


def translate(list_seq):
    """ translate sequence-in-numbers to real sentence """
    list_token = []
    for index in list_seq:
        if index == UNK_TOKEN:
            continue
        elif index == SOS_TOKEN:
            continue
        elif index == PAD_TOKEN:
            continue
        elif index == EOS_TOKEN:
            break
        else:
            list_token.append(ind2word[index])
    return ' '.join(list_token)


def translate_with_copy(list_seq, attn_score, data_idx):
    """ translate sequence-in-numbers to real sentence with copy mechanism """
    list_token = []
    for index, attn in zip(list_seq, attn_score):
        if index == UNK_TOKEN:
            attn_max = attn.max(0)[1].item()
            if attn_max < len(test_data[data_idx]['input'].split()):
                list_token.append(test_data[data_idx]['input'].split()[
                                  attn_max].split(DELIMITER)[0])
            else:
                list_token.append('<unk>')
        elif index == SOS_TOKEN:
            continue
        elif index == PAD_TOKEN:
            continue
        elif index == EOS_TOKEN:
            break
        else:
            list_token.append(ind2word[index])
    return ' '.join(list_token)


seq_info_train, seq_target_train = _process_data(
    _load_data(f'{root_dir}/data/train.json'))
seq_info_dev, seq_target_dev = _process_data(
    _load_data(f'{root_dir}/data/valid.json'))
dev_data = _load_data(f'{root_dir}/data/valid.json')
test_data = _load_data(f'{root_dir}/data/test.json')


## model

In [None]:
class EncoderAttn(nn.Module):
    def __init__(self, field_vocab_size, word_vocab_size, embed_dim_field, embed_dim_word, hidden_dim, dropout_p):
        super(EncoderAttn, self).__init__()
        self.embed_dim_field = embed_dim_field
        self.embed_dim_word = embed_dim_word
        self.hidden_dim = hidden_dim

        self.field_embedding = nn.Embedding(field_vocab_size, embed_dim_field)
        self.value_embedding = nn.Embedding(word_vocab_size, embed_dim_word)
        self.dropout = nn.Dropout(p=dropout_p)
        self.bi_lstm = nn.LSTM(input_size=embed_dim_field + embed_dim_word,
                               hidden_size=hidden_dim,
                               batch_first=True,
                               bias=True,
                               bidirectional=True)
        self.fc_lstm = nn.Linear(hidden_dim * 2, hidden_dim)
        self.weight_reply = nn.Parameter(torch.randn(
            hidden_dim, hidden_dim), requires_grad=True)
        self.fc_gate = nn.Linear(hidden_dim * 2, hidden_dim)
        self.layernorm_gate = nn.LayerNorm(hidden_dim, elementwise_affine=True)
        self.layernorm_lstm = nn.LayerNorm(hidden_dim, elementwise_affine=True)

    def _content_selection_gate(self, reply_encoder):
        # reply_encoder(batch, record_num, hid_dim), reply_post(batch, hid_dim, record_num)
        reply_post = reply_encoder.permute(0, 2, 1)
        alpha = F.softmax(torch.matmul(
            torch.matmul(reply_encoder, self.weight_reply), reply_post), dim=2)
        reply_c = torch.bmm(alpha, reply_encoder)
        attn_gate = torch.sigmoid(self.layernorm_gate(
            self.fc_gate(torch.cat((reply_encoder, reply_c), dim=2))))
        reply_cs = attn_gate * reply_encoder
        return reply_cs

    def forward(self, encoder_input):
        field_embed = self.field_embedding(encoder_input[:, :, 0])
        value_embed = self.value_embedding(encoder_input[:, :, 1])
        field_value = self.dropout(
            torch.cat((field_embed, value_embed), dim=2))

        encoder_output, (h_n, c_n) = self.bi_lstm(field_value)
        encoder_output = self.layernorm_lstm((self.fc_lstm(encoder_output)))
        h_n = self.layernorm_lstm(self.fc_lstm(
            torch.cat((h_n[-2], h_n[-1]), dim=1)))
        c_n = self.layernorm_lstm(self.fc_lstm(
            torch.cat((c_n[-2], c_n[-1]), dim=1)))
        content_selection = self._content_selection_gate(encoder_output)

        return content_selection, (h_n, c_n)


class DecoderAttn(nn.Module):
    def __init__(self, word_vocab_size, embed_dim, hidden_dim, dropout_p):
        super(DecoderAttn, self).__init__()
        self.output_dim = word_vocab_size
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(word_vocab_size, embed_dim)
        self.dropout = nn.Dropout(p=dropout_p)
        self.lstm = nn.LSTMCell(embed_dim + hidden_dim, hidden_dim, bias=True)
        self.fc_out = nn.Linear(hidden_dim, word_vocab_size)

    @staticmethod
    def _weighted_encoder_rep(decoder_hidden, content_selection):
        energy = (decoder_hidden.unsqueeze(1) * content_selection) / \
            math.sqrt(decoder_hidden.size(-1))
        attn_score = F.softmax(torch.sum(energy, dim=2), dim=1)
        attn_with_selector = attn_score.unsqueeze(dim=2) * content_selection
        return torch.sum(attn_with_selector, dim=1), attn_score

    def forward(self, decoder_input, decoder_hidden, content_selection):
        embed = self.dropout(self.embedding(decoder_input))
        attn_vector, attn_score = self._weighted_encoder_rep(
            decoder_hidden[0], content_selection)
        emb_attn_combine = torch.cat((embed, attn_vector), dim=1)
        h_n, c_n = self.lstm(emb_attn_combine, decoder_hidden)
        decoder_output = F.log_softmax(self.fc_out(h_n), dim=1)

        return decoder_output, (h_n, c_n), attn_score


class Table2Text(nn.Module):
    def __init__(self, encoder, decoder, beam_width, max_len, max_field):
        super(Table2Text, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.beam_width = beam_width
        self.max_len = max_len
        self.max_field = max_field

    def forward(self, seq_input, seq_target, train_mode):
        # encoder
        content_selection, decoder_hidden = self.encoder(seq_input)

        if train_mode:
            batch_size = seq_target.size(0)
            seq_output = torch.zeros(
                (batch_size, self.max_len, self.decoder.output_dim)).cuda()
            for timeStep in range(self.max_len):
                decoder_input = seq_target[:, timeStep]
                decoder_output, decoder_hidden, _ = self.decoder(
                    decoder_input, decoder_hidden, content_selection)
                seq_output[:, timeStep, :] = decoder_output
            return seq_output
        else:
            if self.beam_width == 1:  # beam search with beam_width=1 equals to greedy search
                attn_map = torch.zeros((self.max_len, self.max_field)).cuda()
                seq_output = torch.zeros(self.max_len).cuda()
                decoder_input = seq_target  # first token: SOS_TOKEN
                for timeStep in range(self.max_len):
                    decoder_output, decoder_hidden, attn_score = self.decoder(
                        decoder_input, decoder_hidden, content_selection)
                    decoder_input = decoder_output.max(1)[1]
                    seq_output[timeStep] = decoder_input.squeeze()
                    attn_map[timeStep] = attn_score.squeeze()
                    if decoder_input.item() == 3:  # EOS_TOKEN
                        attn_map = attn_map[:timeStep]
                        break
            else:  # beam search
                seq_output, attn_map = beam_search(max_len=self.max_len,
                                                   max_field=self.max_field,
                                                   beam_width=self.beam_width,
                                                   decoder=self.decoder,
                                                   decoder_input=seq_target,
                                                   decoder_hidden=decoder_hidden,
                                                   content_selection=content_selection)
            return seq_output, attn_map


## utils

In [None]:
def save_checkpoint(experiment_time, model, optimizer):
    check_file_exist(f'{root_dir}/results/checkpoints')
    checkpoint_path = f'{root_dir}/results/checkpoints/' + \
        experiment_time + '.pth'
    checkpoint = {
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict()
    }
    torch.save(checkpoint, checkpoint_path)


def load_checkpoint(latest, file_name=None):
    """ load the latest checkpoint """
    checkpoints_dir = f'{root_dir}/results/checkpoints'
    if latest:
        file_list = os.listdir(checkpoints_dir)
        file_list.sort(key=lambda fn: os.path.getmtime(
            checkpoints_dir + '/' + fn))
        checkpoint = torch.load(checkpoints_dir + '/' + file_list[-1])
        return checkpoint, str(file_list[-1])
    else:
        if file_name is None:
            raise ValueError('checkpoint_path cannot be empty!')
        checkpoint = torch.load(checkpoints_dir + '/' + file_name)
        return checkpoint, file_name


## training

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

batch_size = 32
max_epoch = 50
lr = 3e-4
field_emb_dim = 32
word_emb_dim = 300
hidden_dim = 512
dropout = 0.1
random_seed = 1
beam_width = 1
train = True
resume = False
copy = False

fix_seed(random_seed)

cur_time = datetime.datetime.now().strftime('%Y_%m_%d_%H_%M')
logger = get_logger(f'{root_dir}/results/logs/' + cur_time + '.log')

start_time = time.time()

train_data_loader = get_data_loader(
    seq_info_train, seq_target_train, batch_size, True, device)
dev_data_loader = get_data_loader(
    seq_info_dev, seq_target_dev, batch_size, False, device)

logger.info(f'data processing consumes: {(time.time() - start_time):.2f}s')

encoder = EncoderAttn(field_vocab_size, word_vocab_size,
                      field_emb_dim, word_emb_dim, hidden_dim, dropout)
decoder = DecoderAttn(word_vocab_size, word_emb_dim, hidden_dim, dropout)

model = Table2Text(encoder, decoder, beam_width, max_len, max_field).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.NLLLoss()
early_stop = EarlyStopping(mode='min', min_delta=0.001, patience=5)

if resume:
    checkpoint, cp_name = load_checkpoint(latest=True)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    logger.info(f'load checkpoint: [{cp_name}]')


def train(data_loader):
    model.train()
    epoch_loss = 0
    for train_input, train_target in tqdm(data_loader):
        optimizer.zero_grad()
        train_output = model(train_input, train_target, train_mode=True)
        train_output = train_output[:, :-1].reshape(-1, train_output.size(-1))
        loss = criterion(train_output, train_target[:, 1:].reshape(-1))
        loss.backward()
        epoch_loss += loss.item()
        optimizer.step()
    return epoch_loss / len(data_loader)


def validate(data_loader):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for dev_input, dev_target in data_loader:
            dev_output = model(dev_input, dev_target, train_mode=True)
            dev_output = dev_output[:, :-1].reshape(-1, dev_output.size(-1))
            loss = criterion(dev_output, dev_target[:, 1:].reshape(-1))
            epoch_loss += loss.item()
    return epoch_loss / len(data_loader)


def evaluate(infer_input):
    model.eval()
    with torch.no_grad():
        infer_target = torch.tensor(
            [SOS_TOKEN], dtype=torch.long, device=device)
        eval_output, attn = model(infer_input, infer_target, train_mode=False)
    return eval_output, attn


loss_dict_train, loss_dict_dev = [], []
for epoch in range(1, int(max_epoch + 1)):
    start_time = time.time()
    train_loss = train(train_data_loader)
    dev_loss = validate(dev_data_loader)
    loss_dict_train.append(train_loss)
    loss_dict_dev.append(dev_loss)

    epoch_min, epoch_sec = record_time(start_time, time.time())
    logger.info(
        f'epoch: [{epoch:02}/{max_epoch}]  train_loss={train_loss:.3f}  valid_loss={dev_loss:.3f}  '
        f'duration: {epoch_min}m {epoch_sec}s')

    if early_stop.step(dev_loss):
        logger.info(f'early stop at [{epoch:02}/{max_epoch}]')
        break

if max_epoch > 0:
    save_checkpoint(experiment_time=cur_time, model=model, optimizer=optimizer)


## evaluate

In [None]:
"""Use the standard script provided by nltk."""
bleu_scorer = BleuScore()
bleu_scorer.set_refs(get_refs())
ie_metrics_list = []

for idx_data in range(len(test_data)):
    seq_input = torch.tensor(process_one_data(
        idx_data, 'test'), dtype=torch.long, device=device).unsqueeze(0)
    seq_output, attn_map = evaluate(seq_input)
    list_seq = seq_output.squeeze().tolist()
    if not copy:
        text_gen = translate(list_seq)
    else:
        text_gen = translate_with_copy(
            list_seq=list_seq, attn_score=attn_map, data_idx=idx_data)
    bleu_scorer.add_gen(text_gen)
bleu_score = bleu_scorer.calculate()
logger.info(f'bleu score: {bleu_score:.2f}')

for idx_data in range(len(dev_data)):
    seq_input = torch.tensor(process_one_data(
        idx_data, 'dev'), dtype=torch.long, device=device).unsqueeze(0)
    seq_output, attn_map = evaluate(seq_input)
    list_seq = seq_output.squeeze().tolist()
    if not copy:
        text_gen = translate(list_seq)
    else:
        text_gen = translate_with_copy(
            list_seq=list_seq, attn_score=attn_map, data_idx=idx_data)
    ie_metrics_list.append(text_gen)

# generate summaries
with open(f'{root_dir}/results/roto_cc-beam5_gens.txt', 'w', encoding='utf-8') as f:
    for text in ie_metrics_list:
        f.write(text + '\n')

"""
Next step is the *Information/Relation Extraction.*

Reference: [harvardnlp/data2text](https://github.com/harvardnlp/data2text)

There are some minor bugs in the `data_utils.py` running with python3, but fix will be easy.
"""
