In [1]:
# from google.colab import drive
# drive.mount('/content/gdrive')
# # %matplotlib inline
# %cd /content/gdrive/My Drive/

In [17]:
import os
import sys
import time
from tqdm import tqdm

import numpy as np
import pandas as pd
import torch
from torch.nn.utils import clip_grad_norm_
from torch.optim import Adam, Adagrad 
from torch.optim.lr_scheduler import StepLR

from time import gmtime, strftime

In [2]:
# % ls 'Colab Notebooks'

In [3]:
# sys.path.insert(0, "Colab Notebooks/NLP_Projects/Abstractive_Summarization/")
import config
from data import Vocab, abstract2sents
from batcher import Example, Batch
from model import Model

from utils import get_input_from_batch, get_output_from_batch

#### Data

In [4]:
def preprocessing_file(file_path):

    keys = ['article_id', 'article_text', 'abstract_text', 'labels', 'section_names', 'sections']
    content_list = []

    with open(file_path, "r") as f:
        for l in tqdm(f.readlines()):

            content = [item for item in l.split("\"")]
            item_dict = {}

            for item in content[1:]:
                if item in keys:
                    item_dict[item] = []
                    key_ = item
                    if item == 'sections':
                        count_sections = -1

                else:
                    if key_ != 'sections':
                        if item not in ['], ', ', ', ': [[', ']]}\n', '{', ': [', ': '] :
                            item_dict[key_].append(item)
                    else:
                        if item in [': [[',  '], [']:
                            item_dict[key_].append([])
                            count_sections += 1
                        elif item not in [']]}\n', ', ' ]:
                            item_dict[key_][count_sections].append(item)
            
            if len(item_dict['article_text']) > 1:
                item_dict['abstract_text'] = " ".join(item_dict['abstract_text'] )
                item_dict['abstract_text'] = [sent.strip() for sent in abstract2sents(item_dict['abstract_text'])]
                content_list.append(item_dict)

        
    return content_list



In [5]:
data_ = {}
data_["train"] = preprocessing_file("../pubmed-release/train.txt")
data_["val"] = preprocessing_file("../pubmed-release/val.txt")

100%|██████████| 119924/119924 [00:21<00:00, 5680.73it/s]
100%|██████████| 6633/6633 [00:01<00:00, 3401.42it/s]


#### Vocab

In [6]:
vocab = Vocab("../pubmed-release/vocab", 50100)

Finished constructing vocabulary of 50004 total words. Last word added: hpse


In [7]:
vocab.word2id("medical")

382

#### Form Batch

In [8]:
# examples_batch = [Example(article="".join(item['article_text']), 
#                           abstract_sentences=item['abstract_text'], 
#                           vocab=vocab) 
#                   for item in data["train"][:16] ]

# batch = Batch(example_list=examples_batch, vocab=vocab, batch_size=16) 

In [9]:
# config.max_enc_steps

In [10]:
class Batcher(object):
    def __init__(self, data_dicts, vocab, batch_size, shuffle=True):
        self.data_dicts = data_dicts
        self.vocab = vocab
        self.batch_size = batch_size
        self.shuffle = shuffle
        
        if self.shuffle:
            inds = np.arange(len(self.data_dicts))
            np.random.shuffle(inds)
            self.data_dicts = [self.data_dicts[ind] for ind in list(inds)]
            
        self.start_batch = 0
    def next_batch(self):
        
        if self.start_batch + self.batch_size >= len(self.data_dicts):
            return None

        example_list = [Example(article=" ".join(self.data_dicts[self.start_batch + i]['article_text']), 
                          abstract_sentences=self.data_dicts[self.start_batch + i]['abstract_text'], 
                          vocab=self.vocab) 
                        for i in range(self.batch_size)]
        
        batch = Batch(example_list=example_list, 
                      vocab=self.vocab, 
                      batch_size=self.batch_size) 
        
        self.start_batch += self.batch_size
        return batch

In [11]:
def calc_running_avg_loss(loss, running_avg_loss, step, decay=0.99):
    if running_avg_loss == 0:  # on the first iteration just take the loss
        running_avg_loss = loss
    else:
        running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
    running_avg_loss = min(running_avg_loss, 12)  # clip

    return running_avg_loss

def train_model(model, batcher, optimizer, scheduler, N_epoch, use_cuda=False, start_iter=0, running_avg_loss=0):
    train_history = {}
    df_history = pd.DataFrame(columns=["epoch", "itr", "loss", "running_loss"])
    
    train_dir = "weights/train_{}".format(strftime("%Y-%m-%d_%H:%M:%S", gmtime()))
    os.makedirs(train_dir)

    for epoch in range(N_epoch):
        train_history[epoch] = {"loss":[], "running_loss":[]}
        itr = 0
        while True:
            if itr % 1000 == 999:
                try:
                    torch.save({
                        'epoch': epoch,
                        'itr': itr,
                        'encoder': model.encoder.state_dict(),
                        'decoder': model.decoder.state_dict(),
                        'reduce_state': model.reduce_state.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'loss': loss}, train_dir + "/weights.loss_{:.3f}.pt".format(loss))
                except:
                    print("Failed to save")

                scheduler.step()

            batch = batcher.next_batch()
            if batch:
                loss = train_one_batch(model, optimizer, batch, use_cuda)
                train_history[epoch]["loss"].append(loss)
                
                running_avg_loss = calc_running_avg_loss(loss, running_avg_loss, iter)
                train_history[epoch]["running_loss"].append(running_avg_loss)

                df_history.loc[df_history.shape[0]] = [epoch+1, itr+1, loss, running_avg_loss]
                df_history.to_csv(train_dir + "/train_history.csv", index=False)

                if itr % 10 == 9:
                    print("{} epoch, {} itr: loss = {}".format(epoch +1, itr+1, running_avg_loss))

                itr += 1
            else:
                batcher.start_batch = 0
                break
    return model, train_history

In [12]:
def train_one_batch(model, optimizer, batch, use_cuda=False):
    
    enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
        get_input_from_batch(batch, use_cuda)
    dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
        get_output_from_batch(batch, use_cuda)
    optimizer.zero_grad()

    encoder_outputs, encoder_feature, encoder_hidden = model.encoder(enc_batch, enc_lens)
    s_t_1 = model.reduce_state(encoder_hidden)

    step_losses = []
    for di in range(min(max_dec_len, config.max_dec_steps)):
        
        y_t_1 = dec_batch[:, di]  # Teacher forcing
        final_dist, s_t_1,  c_t_1, attn_dist, p_gen, next_coverage = model.decoder(y_t_1, s_t_1,
                                                    encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                    extra_zeros, enc_batch_extend_vocab,
                                                                       coverage, di)
        target = target_batch[:, di]
        gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()
        step_loss = -torch.log(gold_probs + config.eps)
        if config.is_coverage:
            step_coverage_loss = torch.sum(torch.min(attn_dist, coverage), 1)
            step_loss = step_loss + config.cov_loss_wt * step_coverage_loss
            coverage = next_coverage

        step_mask = dec_padding_mask[:, di]
        step_loss = step_loss * step_mask
        step_losses.append(step_loss)

    sum_losses = torch.sum(torch.stack(step_losses, 1), 1)
    batch_avg_loss = sum_losses/dec_lens_var
    loss = torch.mean(batch_avg_loss)

    loss.backward()

    norm = clip_grad_norm_(model.encoder.parameters(), config.max_grad_norm)
    clip_grad_norm_(model.decoder.parameters(), config.max_grad_norm)
    clip_grad_norm_(model.reduce_state.parameters(), config.max_grad_norm)

    optimizer.step()

    return loss.item()


In [13]:
def setup_train(model_file_path=None):
    model = Model(model_file_path)

    params = list(model.encoder.parameters()) + list(model.decoder.parameters()) + \
             list(model.reduce_state.parameters())
    initial_lr = config.lr_coverage if config.is_coverage else config.lr
    optimizer = Adagrad(params, lr=initial_lr, initial_accumulator_value=config.adagrad_init_acc)
    scheduler = StepLR(optimizer, step_size=7, gamma=0.5)

    start_iter, start_loss = 0, 0

    if model_file_path is not None:
        state = torch.load(model_file_path, map_location= lambda storage, location: storage)
        start_iter = state['iter']
        start_loss = state['current_loss']

        if not config.is_coverage:
            optimizer.load_state_dict(state['optimizer'])
            if use_cuda:
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if torch.is_tensor(v):
                            state[k] = v.cuda()

    return model, optimizer, scheduler, start_iter, start_loss

In [21]:
config.pointer_gen = True
config.lr = 0.1
model, optimizer, scheduler, start_iter, start_loss = setup_train()
batcher_train = Batcher(data_["train"], vocab, batch_size=32)

In [15]:
checkpoint = torch.load("weights/train_2019-12-19_00:15:56/weights.loss_3.535.pt")
model.encoder.load_state_dict(checkpoint['encoder'])
model.decoder.load_state_dict(checkpoint['decoder'])
model.reduce_state.load_state_dict(checkpoint['reduce_state'])

<All keys matched successfully>

In [None]:
model_point, train_history_point = train_model(model, batcher_train, optimizer, scheduler, N_epoch=10, use_cuda=True)



1 epoch, 10 itr: loss = 3.9310953600291323
1 epoch, 20 itr: loss = 3.9662183771015367
1 epoch, 30 itr: loss = 3.9720924995878724
1 epoch, 40 itr: loss = 3.9567189146835293
1 epoch, 50 itr: loss = 3.9482634328228516
1 epoch, 60 itr: loss = 3.941840803266103
1 epoch, 70 itr: loss = 3.920153228020532
1 epoch, 80 itr: loss = 3.9081718429653134
1 epoch, 90 itr: loss = 3.896719419912846
1 epoch, 100 itr: loss = 3.8984959900711647
1 epoch, 110 itr: loss = 3.8942459460996903


In [25]:
! ls weights/train_2019-12-19_00:15:56

train_history.csv      weights.loss_3.941.pt  weights.loss_4.115.pt
weights.loss_3.535.pt  weights.loss_3.970.pt  weights.loss_4.149.pt
weights.loss_3.654.pt  weights.loss_4.111.pt  weights.loss_4.323.pt


### Evaluate

In [22]:
config.pointer_gen = True
model = Model(is_eval=True)
use_cuda = True

batcher = Batcher(data_["val"], vocab, batch_size=1)

In [23]:
checkpoint = torch.load("weights/train_2019-12-19_01:58:07/weights.loss_3.420.pt")
model.encoder.load_state_dict(checkpoint['encoder'])
model.decoder.load_state_dict(checkpoint['decoder'])
model.reduce_state.load_state_dict(checkpoint['reduce_state'])


<All keys matched successfully>

In [24]:
def evaluate_batch(model, batch):
    enc_batch, enc_padding_mask, enc_lens, enc_batch_extend_vocab, extra_zeros, c_t_1, coverage = \
        get_input_from_batch(batch, use_cuda)
    dec_batch, dec_padding_mask, max_dec_len, dec_lens_var, target_batch = \
        get_output_from_batch(batch, use_cuda)

    encoder_outputs, encoder_feature, encoder_hidden = model.encoder(enc_batch, enc_lens)
    s_t_1 = model.reduce_state(encoder_hidden)

    batch_prediction, batch_target = [], []

    for di in range(min(max_dec_len, config.max_dec_steps)):
        y_t_1 = dec_batch[:, di]  # Teacher forcing
        final_dist, s_t_1, c_t_1,attn_dist, p_gen, next_coverage = model.decoder(y_t_1, s_t_1,
                                                    encoder_outputs, encoder_feature, enc_padding_mask, c_t_1,
                                                    extra_zeros, enc_batch_extend_vocab, coverage, di)
        
        target = target_batch[:, di]
        batch_target.append(target[0].item())

        prediction = torch.argmax(final_dist, dim=1)
        batch_prediction.append(prediction[0].item())
        
#         print(final_dist.shape)

        gold_probs = torch.gather(final_dist, 1, target.unsqueeze(1)).squeeze()

    step = 20
    # Target 
    batch_target = [vocab.id2word(b) for b in batch_target if b < 50004]
    print(" --- TARGET ---")
    for i in range(0, len(batch_target), step):
        print(" ".join(batch_target[i:min(i+step, len(batch_target))] ))

    # Prediction
    batch_prediction = [vocab.id2word(b) for b in batch_prediction if b < 50004]
    print(" --- PREDICTION --- ")
    for i in range(0, len(batch_prediction), step):
        print(" ".join(batch_prediction[i:min(i+step, len(batch_prediction))] ))


def evaluate(model, batcher, N):

    for _ in range(N):
        print("-"*40)
        batch = batcher.next_batch()
        evaluate_batch(model, batch)

In [25]:
evaluate(model, batcher, N=5)

----------------------------------------
 --- TARGET ---
we describe the first reported cases of invasive type e haemophilus influenzae disease in italy . all five cases occurred
in adults . the isolates were susceptible to ampicillin and eight other antimicrobial agents . molecular analysis showed two distinct
type e strains circulating in italy , both containing a single copy of the locus . [STOP]
 --- PREDICTION --- 
[UNK] report a first case case of meningitis meningitis e - influenzae strains ( elderly . the the cases were
in young with the patient were found to the , the of cases agents . [STOP] analysis revealed that distinct
genes e strain . in the . which with the [UNK] - number the locus gene . [STOP]
----------------------------------------
 --- TARGET ---
introduction : : bone disease , melting wax syndrome , disease ) is a rare chronic bone disorder , first
described in by and . men and women are equally affected , and no hereditary features have been discovered .
onset is

In [0]:
! nvidia-smi

Wed Dec 18 10:54:46 2019       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 440.44       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|   0  Tesla P4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   75C    P0    27W /  75W |   3689MiB /  7611MiB |      0%      Default |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                       GPU Memory |
|  GPU       PID   Type   Process name                             Usage      |
+-------

In [50]:
vocab.id2word(50004)

ValueError: ignored

In [58]:
vocab.id2word(0), vocab.id2word(1), vocab.id2word(2), vocab.id2word(3)

('[UNK]', '[PAD]', '[START]', '[STOP]')