In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchtext.vocab import vocab as torch_vocab

from torch.utils.data import DataLoader
from random import random
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from rouge import Rouge

In [2]:
from model import PGen
from utils import (get_counters, encode, encode_ext_abstract, encode_ext_article,
                   decode, preporocess_text, SummDataset, PointerDataPoint,
                  SOS, EOS, PAD, OOV)

from decode import BeamSearchNode, beam_decode

In [3]:
from datasets import load_dataset
dataset = load_dataset('IlyaGusev/gazeta', revision="v2.0")


specials = [SOS, EOS, PAD, OOV]


train_df = dataset['train']
val_df = dataset['validation']
test_df = dataset['test']

No config specified, defaulting to: gazeta/default
Found cached dataset gazeta (/home/goncharovglebig/.cache/huggingface/datasets/IlyaGusev___gazeta/default/2.0.0/c329f0fc1c22ab6e43e0045ee659d0d43c647492baa2a6ab3a5ea7dac98cd552)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
#Creatae vocab
# src_counts = get_counters(train_df['text'], train_df['summary'])


  0%|          | 0/60964 [00:00<?, ?it/s]

39807

In [6]:
# vocab = torch_vocab(src_counts, min_freq=70, specials=specials)
# vocab.set_default_index(vocab[OOV])
# vocab = vocab
# len(vocab)

43663

In [6]:
# torch.save(vocab, 'gazeta_voc_43.pth')
vocab = torch.load('gazeta_voc_43.pth')
len(vocab)

43663

In [7]:
train_dataset = SummDataset(train_df['text'], train_df['summary'])
val_dataset = SummDataset(val_df['text'], val_df['summary'])
test_dataset = SummDataset(test_df['text'], test_df['summary'])

  0%|          | 0/60964 [00:00<?, ?it/s]

  0%|          | 0/6369 [00:00<?, ?it/s]

  0%|          | 0/6793 [00:00<?, ?it/s]

In [8]:
import gc


del train_df
del val_df
del test_df

gc.collect()

41

In [9]:
def collate_batch(batch):
    """
    Collects batch for model
    Returns:
        enc_input_padded - padded art_idxs tensor
        enc_input_ext_padded - padded art_exq_idxs tensor
        enc_padding_mask - pad mask for enc_input_padded
        extra_zeros - zeros for extented vocab
        dec_input_padded - padded abs_idxs tensor
        target_padded - padded abs_idxs extra tensor
        target_padding_mask - pad mask for target_padded
        target_lens - lens of decoder input
        
    """

    enc_list, enc_ext_list, dec_inp_list, target_list, oovs_len_list, target_lens_list, oovs = [], [], [], [], [], [], []
    batch_size = len(batch)
    for article, abstract in batch:
        data_point = PointerDataPoint(article, abstract, vocab)
        
        enc_input = torch.tensor(np.array(data_point.art_idxs))
        enc_input_ext = torch.tensor(np.array(data_point.art_ext_idxs))
        dec_input = torch.tensor(np.array(data_point.abs_idxs))
        target = torch.tensor(np.array(data_point.abs_ext_idxs))
        
        oovs.append(data_point.art_oovs)
        oovs_len_list.append(len(data_point.art_oovs))
        enc_list.append(enc_input)
        dec_inp_list.append(dec_input)
        enc_ext_list.append(enc_input_ext)
        target_list.append(target)
        target_lens_list.append(len(target))
        

    
    enc_input_padded = pad_sequence(enc_list, padding_value=vocab[PAD]).T
    enc_input_ext_padded = pad_sequence(enc_ext_list, padding_value=vocab[PAD]).T
    target_padded = pad_sequence(target_list, padding_value=vocab[PAD]).T
    dec_inp_padded = pad_sequence(dec_inp_list, padding_value=vocab[PAD]).T
    
    enc_padding_mask = enc_input_padded.ne(vocab[PAD]).long()
    target_padding_mask = target_padded.ne(vocab[PAD]).long()
    
    target_lens = torch.tensor(target_lens_list)
    
    max_oovs = max(oovs_len_list)
    extra_zeros = None
    if max_oovs > 0:
        extra_zeros = torch.zeros((batch_size, max_oovs), requires_grad=True)

    return (enc_input_padded.to(device), enc_input_ext_padded.to(device),
            enc_padding_mask.to(device), extra_zeros.to(device),
            dec_inp_padded.to(device), target_padded.to(device),
            target_padding_mask.to(device), target_lens.to(device),
            oovs
           )



In [10]:
train_loader = DataLoader(train_dataset,
                          batch_size=32,
                          collate_fn=collate_batch)
val_loader = DataLoader(val_dataset,
                          batch_size=32,
                          collate_fn=collate_batch)
test_loader = DataLoader(test_dataset,
                          batch_size=32,
                          collate_fn=collate_batch)

In [None]:
# Training part
model = PGen(
    vocab_size=len(vocab),
    emb_dim=128, 
    hid_dim=256)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model.to(device)

params = list(model.parameters())
optimizer = torch.optim.Adagrad(params, lr=0.15, initial_accumulator_value=0.1)


epoch_num = 30
train_loss_list, val_loss_list = [], []

for ep in tqdm(range(epoch_num)):
    for train_batch in tqdm(train_loader):
        src, src_ext, src_mask, extra_zeros, dec_inp, trg, trg_mask, target_lens, _ = train_batch

        encoder_outputs, encoder_feature, s_t_1 = model.encoder(src)
        step_losses = []
        batch_size = src.shape[0]
        max_dec_len = dec_inp.shape[1]
        
        
        # For first input
        c_t_1 = torch.zeros((batch_size, 2 * s_t_1[0].shape[-1]), requires_grad=True).to(device)
        coverage = torch.zeros((src.shape), requires_grad=True).to(device)
        y_t_1 = dec_inp[:, 0]
        
        for di in range(1, max_dec_len):
            final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = model.decoder(y_t_1, s_t_1,
                                                                                       encoder_outputs,
                                                                                       encoder_feature,
                                                                                       src_mask,
                                                                                       c_t_1,
                                                                                       extra_zeros,
                                                                                       src_ext,
                                                                                       coverage,
                                                                                      )
            target = trg[:, di]
            gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
            step_loss = -torch.log(gold_probs)

            step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
            step_loss = step_loss + 0.1 * step_coverage_loss
            coverage = next_coverage

            step_mask = trg_mask[:, di]
            step_loss = step_loss * step_mask
            step_losses.append(step_loss)
            
            # Next token
            y_t_1 = dec_inp[:, di]  # Teacher forcing
            
        sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
        batch_avg_loss = sum_losses / target_lens
        loss = torch.mean(batch_avg_loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss = loss.detach().item()
        train_loss_list.append(train_loss)


    #Validation loop
    with torch.no_grad():
        for val_batch in val_loader:
            src, src_ext, src_mask, extra_zeros, dec_inp, trg, trg_mask, target_lens, _ = val_batch
            encoder_outputs, encoder_feature, s_t_1 = model.encoder(src)
            step_losses = []
            batch_size = src.shape[0]
            max_dec_len = dec_inp.shape[1]
            
            # For first input
            c_t_1 = torch.zeros((batch_size, 2 * s_t_1[0].shape[-1]), requires_grad=True).to(device)
            coverage = torch.zeros((src.shape), requires_grad=True).to(device)
            y_t_1 = dec_inp[:, 0]
            for di in range(1, max_dec_len):
                final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = model.decoder(y_t_1, s_t_1,
                                                                                           encoder_outputs,
                                                                                           encoder_feature,
                                                                                           src_mask,
                                                                                           c_t_1,
                                                                                           extra_zeros,
                                                                                           src_ext,
                                                                                           coverage,
                                                                                          )
                target = trg[:, di]
                gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
                step_loss = -torch.log(gold_probs)

                step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
                step_loss = step_loss + 0.1 * step_coverage_loss
                coverage = next_coverage

                step_mask = trg_mask[:, di]
                step_loss = step_loss * step_mask
                step_losses.append(step_loss)
                
                y_t_1 = dec_inp[:, di]  # Teacher forcing

            sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
            batch_avg_loss = sum_losses / target_lens
            loss = torch.mean(batch_avg_loss)
            val_loss = loss.item()
            val_loss_list.append(val_loss)
            torch.save(model, f'pointer_gazeta_{ep}.pth')
        
        print(f'For epoch #{ep} train loss {np.mean(train_loss_list[-250_000:])}, val loss {np.mean(val_loss_list[-10_000:])}')

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/1906 [00:00<?, ?it/s]

For epoch #0 train loss 6.209452417189527, val loss 5.728891501426697


  0%|          | 0/1906 [00:00<?, ?it/s]

For epoch #1 train loss 5.8731587773730345, val loss 5.497052632570266


  0%|          | 0/1906 [00:00<?, ?it/s]

For epoch #2 train loss 5.660578015551946, val loss 5.343588787714641


  0%|          | 0/1906 [00:00<?, ?it/s]

For epoch #3 train loss 5.5086054940789095, val loss 5.230483644008636


  0%|          | 0/1906 [00:00<?, ?it/s]

For epoch #4 train loss 5.388398936312696, val loss 5.142677486419678


  0%|          | 0/1906 [00:00<?, ?it/s]

For epoch #5 train loss 5.2871017945775955, val loss 5.070227198203405


  0%|          | 0/1906 [00:00<?, ?it/s]

For epoch #6 train loss 5.199694815130181, val loss 5.0089155595643176


  0%|          | 0/1906 [00:00<?, ?it/s]

For epoch #7 train loss 5.122664967729688, val loss 4.955806299448013


  0%|          | 0/1906 [00:00<?, ?it/s]

In [None]:
plt.plot(train_loss_list, label='train');
plt.legend();

In [None]:
plt.plot(val_loss_list, label='val');
plt.legend();

In [None]:
# device = 'cpu'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# model = torch.load('ponter_gazeta_0').to(device)
point = PointerDataPoint(val_dataset[3][0],
                         val_dataset[3][1],
                         vocab)

In [None]:

src = torch.tensor(np.array(point.art_idxs)).unsqueeze(0).to(device)
src_mask = torch.ones(len(point.art_idxs)).unsqueeze(0).to(device)
src_ext = torch.tensor(np.array(point.art_ext_idxs)).unsqueeze(0).to(device)
extra_zeros = torch.zeros(len(point.art_oovs)).unsqueeze(0).to(device)

model.encoder.to(device)

In [None]:
preds = []
with torch.no_grad():
    encoder_outputs, encoder_feature, s_t_1 = model.encoder(src)
    batch_size = src.shape[0]
    max_dec_len = 100

    # For first input
    c_t_1 = torch.zeros((batch_size, 2 * s_t_1[0].shape[-1]), requires_grad=True).to(device)
    coverage = torch.zeros((src.shape), requires_grad=True).to(device)
    y_t_1 = src[:, 0]
    for di in range(1, 100):
        final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = model.decoder(y_t_1, s_t_1,
                                                                                   encoder_outputs,
                                                                                   encoder_feature,
                                                                                   src_mask,
                                                                                   c_t_1,
                                                                                   extra_zeros,
                                                                                   src_ext,
                                                                                   coverage,
                                                                                  )

        coverage = next_coverage
        y_t_1 = final_dist.argmax(1)
        preds.append(y_t_1.item())
        y_t_1 = y_t_1 if y_t_1 < len(vocab) else torch.LongTensor([vocab[OOV]]).to(device)

In [None]:
print(' '.join(decode(point.abs_ext_idxs, vocab, point.art_oovs)))

In [None]:
print(' ' .join(decode(preds, vocab, point.art_oovs)))

In [None]:
# Let's predict in a greedy way
targets_list, predicts_list, oovs_list = [], [], []

with torch.no_grad():
        for val_batch in tqdm(val_loader):
            src, src_ext, src_mask, extra_zeros, dec_inp, trg, trg_mask, target_lens, oov = val_batch
            encoder_outputs, encoder_feature, s_t_1 = model.encoder(src)
            step_losses = []
            batch_size = src.shape[0]
            max_dec_len = dec_inp.shape[1]
            predicts = np.zeros((batch_size, max_dec_len), dtype='int')
            
            # For first input
            c_t_1 = torch.zeros((batch_size, 2 * s_t_1[0].shape[-1]), requires_grad=True).to(device)
            coverage = torch.zeros((src.shape), requires_grad=True).to(device)
            y_t_1 = dec_inp[:, 0]
            for di in range(1, max_dec_len):
                final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = model.decoder(y_t_1, s_t_1,
                                                                                           encoder_outputs,
                                                                                           encoder_feature,
                                                                                           src_mask,
                                                                                           c_t_1,
                                                                                           extra_zeros,
                                                                                           src_ext,
                                                                                           coverage,
                                                                                          )

                y_t_1 = final_dist.argmax(1)
                predicts[:, di] = y_t_1.cpu().numpy()
                y_t_1[y_t_1 >= len(vocab)] = vocab[OOV]
            
            targets_list = targets_list + trg.cpu().numpy().tolist()
            predicts_list = predicts_list + predicts.tolist()
            oovs_list = oovs_list + oov

In [None]:
ind = 0
print(f"= {' '.join(decode(targets_list[ind], vocab, oovs_list[ind]))}")
print(f"> {' '.join(decode(predicts_list[ind], vocab, oovs_list[ind]))}")

In [None]:
decoded_targets = []
decoded_predicts = []

for i in tqdm(range(len(predicts_list))):
    decoded_targets.append(' '.join(decode(targets_list[i], vocab, oovs_list[i])).replace('<SOS>', '').replace('<PAD>', '').replace('<EOS>', ''))
    decoded_predicts.append(' '.join(decode(predicts_list[i], vocab, oovs_list[i])).replace('<SOS>', '').replace('<PAD>', '').replace('<EOS>', ''))
    

In [None]:
rouge = Rouge()
rouge.get_scores(decoded_predicts, decoded_targets, avg=True)

In [29]:
### top3 baseline
from nltk.tokenize import sent_tokenize
def top3_baseline(text):
    return ' '.join(sent_tokenize(text)[:3])

In [30]:
decoded_targets = []
decoded_predicts = []

for data in tqdm(val_dataset):
    art, abst = data
    decoded_targets.append(' '.join(abst))
    decoded_predicts.append(top3_baseline(' '.join(art)))

  0%|          | 0/6369 [00:00<?, ?it/s]

In [31]:
rouge.get_scores(decoded_predicts, decoded_targets, avg=True)

{'rouge-1': {'r': 0.3397863654959966,
  'p': 0.2918713845352074,
  'f': 0.30560469773278953},
 'rouge-2': {'r': 0.15051781779285675,
  'p': 0.12250571423980426,
  'f': 0.13050782714767825},
 'rouge-l': {'r': 0.30222040078967655,
  'p': 0.2599901595337212,
  'f': 0.27202348650584385}}

In [27]:
src = torch.tensor(np.array(point.art_idxs)).unsqueeze(0).to(device)
src_mask = torch.ones(len(point.art_idxs)).unsqueeze(0).to(device)
src_ext = torch.tensor(np.array(point.art_ext_idxs)).unsqueeze(0).to(device)
extra_zeros = torch.zeros(len(point.art_oovs)).unsqueeze(0).to(device)

model.to(device)


with torch.no_grad():
    decoded = beam_decode(model=model,
                max_len=100,
                beam_width=5,
                src=src,
                src_mask=src_mask,
                src_ext=src_ext,
                extra_zeros=extra_zeros,
                device=device,
                SOS=SOS,
                EOS=EOS,
                vocab=vocab,
               )

  0%|          | 0/1 [00:00<?, ?it/s]

NameError: name 's_t_1' is not defined

In [None]:
print(' ' .join(decode([i.item() for i in decoded], vocab, point.art_oovs)))