In [1]:
from __future__ import unicode_literals, print_function, division
from gensim.models import KeyedVectors
from nltk.tokenize import WordPunctTokenizer
import pandas as pd
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import json
from tqdm.notebook import tqdm
import time
import math
from google.colab import drive

from io import open
import random
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

plt.switch_backend('agg')

MAX_LENGTH=300

In [2]:
drive.mount ('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
def fix_contractions(text):
    tokens = text.split()
    cleaned = []
    for token in tokens:
        cleaned.append(contractions.get(token, token))
    return ' '.join(cleaned)

def tokenize(text):
    text = fix_contractions(text)
    tokens = tokenizer.tokenize(text)
    text = ' '.join(tokens).lower()
    text = text.replace('# person1 #', '#person1#')
    text = text.replace('# person2 #', '#person2#')
    text = text.replace('# person3 #', '#person3#')
    text = text.replace('# person4 #', '#person4#')
    text = text.replace('# person5 #', '#person5#')
    text = text.replace('# person6 #', '#person6#')
    text = text.replace('# person7 #', '#person7#')
    text = text.replace(' ,', ',')
    text = text.replace(' .', '.')
    text = text.replace(' ?', '?')
    text = text.replace(' !', '!')
    text = text.replace(" ' ", "'")
    text = text.replace("< ", "<")
    text = text.replace(" >", ">")
    return text.split()

tokenizer = WordPunctTokenizer()

In [4]:
def showPlot(points):
    plt.figure()
    fig, ax = plt.subplots()
    loc = ticker.MultipleLocator(base=0.2)
    ax.yaxis.set_major_locator(loc)
    plt.plot(points)

In [5]:
with open('/content/drive/My Drive/COMP9444/data/contractions.json', 'r') as f:
    contractions = json.load(f)

df = pd.read_json('/content/drive/My Drive/COMP9444/data/raw/dialogsum/dialogsum.train.jsonl', lines = True)[['dialogue', 'summary']]


In [6]:
src_tokens = list(df['dialogue'].apply(lambda x: tokenize(x)))
trg_tokens = list(df['summary'].apply(lambda x: tokenize(x)))

src = [' '.join(sent) for sent in src_tokens]
trg = [' '.join(sent) for sent in trg_tokens]


In [7]:
word_vectors = KeyedVectors.load_word2vec_format('/content/drive/My Drive/COMP9444/models/GloVe-Word2Vec/glove.bin')

In [8]:
pad_token = "<pad>"
sos_token = "<sos>"
eos_token = "<eos>"
if pad_token not in word_vectors.key_to_index:
    pad_index = len(word_vectors)
    word_vectors.key_to_index[pad_token] = pad_index
    word_vectors.index_to_key.append(pad_token)
else:
    pad_index = word_vectors.key_to_index[pad_token]

PAD_IDX = pad_index
SOS_TOKEN = word_vectors.key_to_index['<sos>']
EOS_TOKEN = word_vectors.key_to_index['<eos>']

In [9]:
vocab = list(word_vectors.key_to_index.keys())
vocab_size = len(vocab)
embedding_dim = word_vectors.get_vector('<sos>').shape[0]
dictionary = word_vectors.key_to_index

In [12]:
embedding_matrix = torch.zeros(vocab_size, embedding_dim)

In [14]:
for i, word in enumerate(tqdm(vocab)):
    if word == pad_token:
        continue
    embedding_matrix[i] = torch.Tensor(np.array(word_vectors[word]))

embedding_matrix = embedding_matrix.to(device)

  0%|          | 0/5707 [00:00<?, ?it/s]

In [15]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, embedding_matrix = None, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        if embedding_matrix is not None:
            self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze = False, padding_idx = PAD_IDX)
        else:
            self.embedding = nn.Embedding(input_size, hidden_size, padding_idx = PAD_IDX)

        self.gru = nn.GRU(hidden_size, hidden_size, num_layers, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)

        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)

        return context, weights

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, num_layers, embedding_matrix= None, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        if embedding_matrix is not None:
            self.embedding = nn.Embedding.from_pretrained(embedding_matrix, freeze = False, padding_idx = PAD_IDX)
        else:
            self.embedding = nn.Embedding(output_size, hidden_size, padding_idx = PAD_IDX)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(2 * hidden_size, hidden_size, num_layers, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_TOKEN)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None and i < target_tensor.size(1):
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions


    def forward_step(self, input, hidden, encoder_outputs):
        embedded =  self.dropout(self.embedding(input))

        last_layer_hidden = hidden[-1].unsqueeze(0)
        query = last_layer_hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((embedded, context), dim=2)

        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)

        return output, hidden, attn_weights

In [16]:
def prepareData(src, trg):
    dial=np.array(src)
    summary=np.array(trg)
    pairs = [[dial[i], summary[i]] for i in range(len(dial))]
    return pairs

def indexesFromSentence(sentence):
    return [dictionary.get(word) for word in tokenize(sentence) if word in dictionary.keys()]

def tensorFromSentence(sentence):
    indexes = indexesFromSentence(sentence)
    indexes.append(EOS_TOKEN)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def tensorsFromPair(pair):
    input_tensor = tensorFromSentence(pair[0])
    target_tensor = tensorFromSentence(pair[1])
    return (input_tensor, target_tensor)

def get_dataloader(pairs, batch_size):

    num_pairs = len(pairs)

    input_ids = np.full((num_pairs, MAX_LENGTH), PAD_IDX, dtype=np.int32)
    target_ids = np.full((num_pairs, MAX_LENGTH), PAD_IDX, dtype=np.int32)

    src_data_indexed = [[dictionary[word] for word in sentence if word in dictionary.keys()] for sentence in src_tokens]
    trg_data_indexed = [[dictionary[word] for word in sentence if word in dictionary.keys()] for sentence in trg_tokens]

    for idx, (inp, tgt) in enumerate(tqdm(pairs)):
        inp_ids = indexesFromSentence(inp)
        tgt_ids = indexesFromSentence(tgt)
        inp_ids.append(EOS_TOKEN)
        tgt_ids.append(EOS_TOKEN)
        input_ids[idx, :len(inp_ids)] = inp_ids[:MAX_LENGTH]
        target_ids[idx, :len(tgt_ids)] = tgt_ids[:MAX_LENGTH]

    train_data = TensorDataset(torch.LongTensor(input_ids).to(device),
                               torch.LongTensor(target_ids).to(device))

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
    return train_dataloader

In [17]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer,
          decoder_optimizer, criterion):

    total_loss = 0
    for data in tqdm(dataloader):
        input_tensor, target_tensor = data

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

In [18]:
def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [19]:
def train(train_dataloader, encoder, decoder, n_epochs, learning_rate=0.001,
          print_every=100, plot_every=100, momentum=0.9):
    start = time.time()
    plot_losses = []
    print_loss_total = 0
    plot_loss_total = 0

    # Using Adam with momentum
    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)#, momentum=momentum)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)#, momentum=momentum)
    scheduler_encoder = optim.lr_scheduler.StepLR(encoder_optimizer, step_size=10, gamma=0.5)
    scheduler_decoder = optim.lr_scheduler.StepLR(decoder_optimizer, step_size=10, gamma=0.5)
    criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

    best_loss = float('inf')

    for epoch in tqdm(range(1, n_epochs + 1)):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        scheduler_encoder.step()
        scheduler_decoder.step()

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            print(f'{timeSince(start, epoch / n_epochs)} ({epoch} {epoch / n_epochs * 100:.2f}%) {print_loss_avg:.4f}')

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

            # Checkpointing
            if plot_loss_avg < best_loss:
                best_loss = plot_loss_avg
                torch.save(encoder.state_dict(), 'encoder_best.pth')
                torch.save(decoder.state_dict(), 'decoder_best.pth')

    showPlot(plot_losses)

In [20]:
def evaluate(encoder, decoder, sentence):
    with torch.no_grad():
        input_tensor = tensorFromSentence(sentence)

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == EOS_TOKEN:
                decoded_words.append('<eos>')
                break
            decoded_words.append(vocab[idx.item()])
    return decoded_words, decoder_attn

In [22]:
hidden_size = 300
batch_size = 32
num_layers = 1

pairs = prepareData(src, trg)

train_dataloader = get_dataloader(pairs, batch_size)

  0%|          | 0/12460 [00:00<?, ?it/s]

In [23]:
encoder = EncoderRNN(vocab_size, hidden_size, num_layers, embedding_matrix).to(device)
decoder = AttnDecoderRNN(hidden_size, vocab_size, num_layers, embedding_matrix).to(device)

In [24]:
encoder.embedding.weight.data, decoder.embedding.weight.data

(tensor([[-0.5984,  0.0620, -0.0847,  ...,  0.7339,  0.1335,  0.0200],
         [-0.6761,  0.1510, -0.1087,  ..., -0.0675, -0.2059,  0.5739],
         [-0.7899, -0.9644, -0.0499,  ..., -0.1214,  0.8511,  0.3563],
         ...,
         [ 0.1933,  0.1437, -0.0567,  ...,  0.2963, -0.1398,  0.2437],
         [-0.2460,  0.2695, -0.3072,  ..., -0.0709, -0.2492,  0.0746],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
        device='cuda:0'),
 tensor([[-0.5984,  0.0620, -0.0847,  ...,  0.7339,  0.1335,  0.0200],
         [-0.6761,  0.1510, -0.1087,  ..., -0.0675, -0.2059,  0.5739],
         [-0.7899, -0.9644, -0.0499,  ..., -0.1214,  0.8511,  0.3563],
         ...,
         [ 0.1933,  0.1437, -0.0567,  ...,  0.2963, -0.1398,  0.2437],
         [-0.2460,  0.2695, -0.3072,  ..., -0.0709, -0.2492,  0.0746],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
        device='cuda:0'))

In [25]:
train(train_dataloader, encoder, decoder, 30, print_every=1, plot_every=1)

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/390 [00:00<?, ?it/s]

4m 11s (- 121m 45s) (1 3.33%) 4.9599


  0%|          | 0/390 [00:00<?, ?it/s]

8m 25s (- 118m 3s) (2 6.67%) 4.1881


  0%|          | 0/390 [00:00<?, ?it/s]

12m 39s (- 113m 55s) (3 10.00%) 3.7933


  0%|          | 0/390 [00:00<?, ?it/s]

16m 52s (- 109m 41s) (4 13.33%) 3.4816


  0%|          | 0/390 [00:00<?, ?it/s]

21m 2s (- 105m 13s) (5 16.67%) 3.2208


  0%|          | 0/390 [00:00<?, ?it/s]

25m 10s (- 100m 43s) (6 20.00%) 2.9848


  0%|          | 0/390 [00:00<?, ?it/s]

29m 23s (- 96m 34s) (7 23.33%) 2.7678


  0%|          | 0/390 [00:00<?, ?it/s]

33m 37s (- 92m 27s) (8 26.67%) 2.5793


  0%|          | 0/390 [00:00<?, ?it/s]

37m 47s (- 88m 11s) (9 30.00%) 2.4021


  0%|          | 0/390 [00:00<?, ?it/s]

42m 1s (- 84m 2s) (10 33.33%) 2.2425


  0%|          | 0/390 [00:00<?, ?it/s]

46m 14s (- 79m 52s) (11 36.67%) 1.9968


  0%|          | 0/390 [00:00<?, ?it/s]

50m 30s (- 75m 45s) (12 40.00%) 1.8941


  0%|          | 0/390 [00:00<?, ?it/s]

54m 45s (- 71m 35s) (13 43.33%) 1.8132


  0%|          | 0/390 [00:00<?, ?it/s]

58m 56s (- 67m 21s) (14 46.67%) 1.7421


  0%|          | 0/390 [00:00<?, ?it/s]

63m 6s (- 63m 6s) (15 50.00%) 1.6743


  0%|          | 0/390 [00:00<?, ?it/s]

67m 16s (- 58m 51s) (16 53.33%) 1.6111


  0%|          | 0/390 [00:00<?, ?it/s]

71m 26s (- 54m 37s) (17 56.67%) 1.5517


  0%|          | 0/390 [00:00<?, ?it/s]

75m 37s (- 50m 24s) (18 60.00%) 1.4942


  0%|          | 0/390 [00:00<?, ?it/s]

79m 52s (- 46m 14s) (19 63.33%) 1.4403


  0%|          | 0/390 [00:00<?, ?it/s]

84m 7s (- 42m 3s) (20 66.67%) 1.3913


  0%|          | 0/390 [00:00<?, ?it/s]

88m 19s (- 37m 51s) (21 70.00%) 1.2915


  0%|          | 0/390 [00:00<?, ?it/s]

92m 31s (- 33m 38s) (22 73.33%) 1.2525


  0%|          | 0/390 [00:00<?, ?it/s]

96m 42s (- 29m 25s) (23 76.67%) 1.2217


  0%|          | 0/390 [00:00<?, ?it/s]

100m 51s (- 25m 12s) (24 80.00%) 1.2000


  0%|          | 0/390 [00:00<?, ?it/s]

105m 2s (- 21m 0s) (25 83.33%) 1.1730


  0%|          | 0/390 [00:00<?, ?it/s]

109m 13s (- 16m 48s) (26 86.67%) 1.1504


  0%|          | 0/390 [00:00<?, ?it/s]

113m 24s (- 12m 36s) (27 90.00%) 1.1267


  0%|          | 0/390 [00:00<?, ?it/s]

117m 33s (- 8m 23s) (28 93.33%) 1.1049


  0%|          | 0/390 [00:00<?, ?it/s]

121m 44s (- 4m 11s) (29 96.67%) 1.0829


  0%|          | 0/390 [00:00<?, ?it/s]

125m 53s (- 0m 0s) (30 100.00%) 1.0625


In [26]:
def count_trainable_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

attention = BahdanauAttention(hidden_size)
total_params_encoder = count_trainable_parameters(encoder)
total_params_decoder = count_trainable_parameters(decoder)
total_params_attention = count_trainable_parameters(decoder.attention)

print("Total trainable parameters in encoder:", total_params_encoder)
print("Total trainable parameters in decoder:", total_params_decoder)
print("Total trainable parameters in attention:", total_params_attention)

Total trainable parameters in encoder: 2253900
Total trainable parameters in decoder: 4422608
Total trainable parameters in attention: 180901


In [27]:
total_params_encoder+total_params_decoder+total_params_attention

6857409

In [28]:
encoder.eval()
decoder.eval()

AttnDecoderRNN(
  (embedding): Embedding(5707, 300, padding_idx=5706)
  (attention): BahdanauAttention(
    (Wa): Linear(in_features=300, out_features=300, bias=True)
    (Ua): Linear(in_features=300, out_features=300, bias=True)
    (Va): Linear(in_features=300, out_features=1, bias=True)
  )
  (gru): GRU(600, 300, batch_first=True)
  (out): Linear(in_features=300, out_features=5707, bias=True)
  (dropout): Dropout(p=0.1, inplace=False)
)

In [29]:
evaluateRandomly(pairs, encoder, decoder)

> #person1#: when in rome, do as the romans do, they say. #person2#: what do the romans do? #person1#: they live in rome, of course, and go to work by car or bus. but sometimes it takes too long that way because of the traffic jams, so they walk. #person2#: in other words, the romans do what everyone else does. #person1#: yes, but they do it differently. everything is different. #person2#: what do you mean? #person1#: well, the climate is different for a start. it does not rain so much as it does in england. the sun shines more often. #person2#: i envy them for the sun. #person1#: i know. you hate the rain, do not you? #person2#: i certainly do. #person1#: and a roman really loves life. they always eat spaghetti and drink wine. #person2#: not always, but they like a good meal. lots of tourists go to rome just for food. #person1#: sure.
= #person1# and #person2# talk about transportation in rome and its climate. #person2# envies romans for the sun and thinks that romans like a good meal