<a href="https://colab.research.google.com/github/ninalzr/nlg/blob/master/BasicLSTMEncoderDecoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [33]:
!pip install transformers



In [0]:
import os, sys, json
from datetime import datetime

In [0]:
class Lookup():
    def __init__(self, model_class, file_prefix = None):

        self.model_class = model_class

        self.bos_token = None
        self.eos_token = None
        self.unk_token = None
        self.sep_token = None
        self.pad_token = None
        self.cls_token = None
        self.mask_token = None

        if model_class == 'gpt2':
            from transformers import GPT2Tokenizer
            self._tokenizer = GPT2Tokenizer.from_pretrained(model_class)
            self._tokenizer.add_special_tokens({'pad_token': '<PAD>'})
            if self._tokenizer._bos_token:
                self.bos_token = self._tokenizer.bos_token
            if self._tokenizer._eos_token:
                self.eos_token = self._tokenizer.eos_token
            if self._tokenizer._unk_token:                
                self.unk_token = self._tokenizer.unk_token
            if self._tokenizer._sep_token:
                self.sep_token = self._tokenizer.sep_token
            if self._tokenizer._pad_token:
                self.pad_token = self._tokenizer.pad_token
            if self._tokenizer._cls_token:
                self.cls_token = self._tokenizer.cls_token
            if self._tokenizer._mask_token:
                self.mask_token = self._tokenizer.mask_token 
        else:
            print("You need to load a tokenizer from https://huggingface.co/transformers/main_classes/tokenizer.html#")
        
        if file_prefix:
            self.load(file_prefix)

        def save_special_tokens(self, file_prefix):
            if self.model_class == "gpt2":
                special_tokens = {}
            if self.bos_token:
                special_tokens['bos_token'] = self.bos_token
            if self.eos_token:
                special_tokens['eos_token'] = self.eos_token
            if self.unk_token:
                special_tokens['unk_token'] = self.unk_token
            if self.sep_token:
                special_tokens['sep_token'] = self.sep_token
            if self.pad_token:
                special_tokens['pad_token'] = self.pad_token
            if self.cls_token:
                special_tokens['cls_token'] = self.cls_token
            if self.mask_token:
                special_tokens['mask_token'] = self.mask_token            
            json.dump(special_tokens, open(file_prefix+".special_tokens","w",encoding="utf8"), indent=4, sort_keys=True)            
            self._tokenizer.add_special_tokens(special_tokens)  
        
        def load(self, file_prefix):
            if os.path.exists(file_prefix+".special_tokens"):
                special_tokens = json.load(open(file_prefix+".special_tokens","r",encoding="utf8"))            
            if 'bos_token' in special_tokens:
                self.bos_token = special_tokens['bos_token']
            if 'eos_token' in special_tokens:
                self.eos_token = special_tokens['eos_token']
            if 'unk_token' in special_tokens:
                self.unk_token = special_tokens['unk_token']
            if 'sep_token' in special_tokens:
                self.sep_token = special_tokens['sep_token']
            if 'pad_token' in special_tokens:
                self.pad_token = special_tokens['pad_token']
            if 'cls_token' in special_tokens:
                self.cls_token = special_tokens['cls_token']
            if 'mask_token' in special_tokens:
                self.mask_token = special_tokens['mask_token']
            self._tokenizer.add_special_tokens(special_tokens)      

    def tokenize(self, text):
        return self._tokenizer.tokenize(text)

    def convert_tokens_to_ids(self, tokens):
        return self._tokenizer.convert_tokens_to_ids(tokens)

    def convert_ids_to_tokens(self, token_ids):
        return self._tokenizer.convert_ids_to_tokens(token_ids)

    def convert_tokens_to_string(self, tokens):
        return self._tokenizer.convert_tokens_to_string(tokens)

    def encode(self, text, add_bos_eos_tokens = False):
        tokens = self.tokenize(text)
        if add_bos_eos_tokens:
            return [self.convert_tokens_to_ids(self.bos_token)] + self.convert_tokens_to_ids(tokens) + [self.convert_tokens_to_ids(self.eos_token)]
        else:
            return self.convert_tokens_to_ids(tokens)

    def decode(self, token_ids, skip_bos_eos_tokens = False):
        if skip_bos_eos_tokens:
            if len(token_ids)>0:
                if token_ids[0] == self.convert_tokens_to_ids(self.bos_token):
                    token_ids = token_ids[1:]
            if len(token_ids)>0:
                if token_ids[-1] == self.convert_tokens_to_ids(self.eos_token):
                    token_ids = token_ids[:-1]   
        if len(token_ids) > 0:
            tokens = self.convert_ids_to_tokens(token_ids)
            return self.convert_tokens_to_string(tokens)
        return ""

    def __len__(self):          
        return len(self._tokenizer)


In [36]:
model = 'gpt2'
lookup = Lookup(model)
text = "Daisy, Daisy, Give me your answer, do!"
print("\n1. String to tokens (tokenize):")
tokens = lookup.tokenize(text)
print(tokens)

print("\n2. Tokens to ints (convert_tokens_to_ids):")
ids = lookup.convert_tokens_to_ids(tokens)
print(ids)

print("\n2.5 Token to int (convert_tokens_to_ids with a single str):")
id = lookup.convert_tokens_to_ids(tokens[0])
print(id)

print("\n3. Ints to tokens (convert_ids_to_tokens):")
tokens = lookup.convert_ids_to_tokens(ids)
print(tokens)

print("\n3.5 Int to token (convert_ids_to_tokens with a single int):")
token = lookup.convert_ids_to_tokens(id)
print(token)

print("\n4. Tokens to string (convert_tokens_to_string):")
recreated_text = lookup.convert_tokens_to_string(tokens)
print(recreated_text)

print("\n5. String to ints (encode):")
ids = lookup.encode(text)
print(ids)

print("\n6. Ints to string (decode):")
recreated_text = lookup.decode(ids)
print(recreated_text)

print("\n7. Encode adding special tokens:")
ids = lookup.encode(text, add_bos_eos_tokens=True)
print(ids)
print("How it looks like with tokens: {}".format(lookup.convert_ids_to_tokens(ids)))
    
print("\n8. Decode skipping special tokens:")
recreated_text = lookup.decode(ids, skip_bos_eos_tokens=True)
print(recreated_text)

print("\n9. Vocabulary size:")
vocab_size = lookup.__len__()
print(vocab_size)


1. String to tokens (tokenize):
['Da', 'isy', ',', 'ĠDaisy', ',', 'ĠGive', 'Ġme', 'Ġyour', 'Ġanswer', ',', 'Ġdo', '!']

2. Tokens to ints (convert_tokens_to_ids):
[26531, 13560, 11, 40355, 11, 13786, 502, 534, 3280, 11, 466, 0]

2.5 Token to int (convert_tokens_to_ids with a single str):
26531

3. Ints to tokens (convert_ids_to_tokens):
['Da', 'isy', ',', 'ĠDaisy', ',', 'ĠGive', 'Ġme', 'Ġyour', 'Ġanswer', ',', 'Ġdo', '!']

3.5 Int to token (convert_ids_to_tokens with a single int):
Da

4. Tokens to string (convert_tokens_to_string):
Daisy, Daisy, Give me your answer, do!

5. String to ints (encode):
[26531, 13560, 11, 40355, 11, 13786, 502, 534, 3280, 11, 466, 0]

6. Ints to string (decode):
Daisy, Daisy, Give me your answer, do!

7. Encode adding special tokens:
[50256, 26531, 13560, 11, 40355, 11, 13786, 502, 534, 3280, 11, 466, 0, 50256]
How it looks like with tokens: ['<|endoftext|>', 'Da', 'isy', ',', 'ĠDaisy', ',', 'ĠGive', 'Ġme', 'Ġyour', 'Ġanswer', ',', 'Ġdo', '!', '<|endoftex

TODO: Adjust the loader for distributed training (maybe?)

In [0]:
import os, sys, json, random
import numpy as np
import torch
import torch as nn
import torch.utils.data

from functools import partial


In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [0]:
def loader(data_folder, batch_size, src_lookup, tgt_lookup, min_seq_len_X = 5, max_seq_len_X = 1000, min_seq_len_y = 5,
           max_seq_len_y = 1000, MEI = ""):
    MEI = MEI.replace(" ","_")
    pad_id = tgt_lookup.convert_tokens_to_ids(tgt_lookup.pad_token)
    
    train_loader = torch.utils.data.DataLoader(
        MyDataset(data_folder, "train", min_seq_len_X, max_seq_len_X, min_seq_len_y, max_seq_len_y, MEI),
        num_workers=0,
        batch_size=batch_size,
        collate_fn=partial(paired_collate_fn, padding_idx = pad_id),
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(
        MyDataset(data_folder, "dev", min_seq_len_X, max_seq_len_X, min_seq_len_y, max_seq_len_y, MEI),
        num_workers=0,
        batch_size=batch_size,
        collate_fn=partial(paired_collate_fn, padding_idx = pad_id))
    
    return train_loader, valid_loader

def paired_collate_fn(insts, padding_idx):
    # insts contains a batch_size number of (x, y) elements    
    src_insts, tgt_insts = list(zip(*insts))
   
    # now src is a batch_size(=64) array of x0 .. x63, and tgt is y0 .. x63 ; xi is variable length
    # ex: if a = [(1,2), (3,4), (5,6)]
    # then b, c = list(zip(*a)) => b = (1,3,5) and b = (2,4,6)
    
    # src_insts is now a tuple of batch_size Xes (x0, x63) where xi is an instance
    #src_insts, src_lenghts, tgt_insts, tgt_lenghts = length_collate_fn(src_insts, tgt_insts)       
    
    src_max_len = max(len(inst) for inst in src_insts) # determines max size for all examples
    
    src_seq_lengths = torch.tensor(list(map(len, src_insts)), dtype=torch.long)    
    src_seq_tensor = torch.tensor(np.array( [ inst + [padding_idx] * (src_max_len - len(inst)) for inst in src_insts ] ), dtype=torch.long)
    src_seq_mask = torch.tensor(np.array( [ [1] * len(inst) + [0] * (src_max_len - len(inst)) for inst in src_insts ] ), dtype=torch.long)
    
    src_seq_lengths, perm_idx = src_seq_lengths.sort(0, descending=True)
    src_seq_tensor = src_seq_tensor[perm_idx]   
    src_seq_mask = src_seq_mask[perm_idx]
    tgt_max_len = max(len(inst) for inst in tgt_insts)
    
    tgt_seq_lengths = torch.tensor(list(map(len, tgt_insts)), dtype=torch.long)    
    tgt_seq_tensor = torch.tensor(np.array( [ inst + [padding_idx] * (tgt_max_len - len(inst)) for inst in tgt_insts ] ), dtype=torch.long)
    tgt_seq_mask = torch.tensor(np.array( [ [1] * len(inst) + [0] * (tgt_max_len - len(inst)) for inst in tgt_insts ] ), dtype=torch.long)
    
    tgt_seq_lengths = tgt_seq_lengths[perm_idx]
    tgt_seq_tensor = tgt_seq_tensor[perm_idx]      
    tgt_seq_mask = tgt_seq_mask[perm_idx]   
      
    return ((src_seq_tensor, src_seq_lengths, src_seq_mask), (tgt_seq_tensor, tgt_seq_lengths, tgt_seq_mask))   
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, type, min_seq_len_X, max_seq_len_X, min_seq_len_y, max_seq_len_y, MEI):  
        self.root_dir = root_dir

        self.X = [] # this will store joined sentences
        self.y = [] # this will store the output

    
        with open(os.path.join(root_dir, type, MEI + '_output.txt'), 'r') as f:
            y = [lookup.encode(y.strip(), add_bos_eos_tokens=True)  for y in f]
        with open(os.path.join(root_dir, type, MEI + '_sentences.txt'), 'r') as g:
            X = [lookup.encode(x.strip(), add_bos_eos_tokens=True)  for x in g]   

        
        cut_over_X = 0
        cut_under_X = 0
        cut_over_y = 0
        cut_under_y = 0
        
        # max len
        for (sx, sy) in zip(X, y):
            if len(sx) > max_seq_len_X:
                cut_over_X += 1
            elif len(sx) < min_seq_len_X+2:                
                cut_under_X += 1
            elif len(sy) > max_seq_len_y:
                cut_over_y += 1
            elif len(sy) < min_seq_len_y+2:                
                cut_under_y += 1
            else:
                self.X.append(sx)
                self.y.append(sy)         

        c = list(zip(self.X, self.y))
        random.shuffle(c)
        self.X, self.y = zip(*c)
        self.X = list(self.X)
        self.y = list(self.y)
        print(X)
        print(y)
                    
        print("Dataset [{}] loaded with {} out of {} ({}%) instances.".format(type, len(self.X), len(X), float(100.*len(self.X)/len(X)) ) )
        print("\t\t For X, {} are over max_len {} and {} are under min_len {}.".format(cut_over_X, max_seq_len_X, cut_under_X, min_seq_len_X))
        print("\t\t For y, {} are over max_len {} and {} are under min_len {}.".format(cut_over_y, max_seq_len_y, cut_under_y, min_seq_len_y))
        
        assert(len(self.X)==len(self.y))
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):        
        return self.X[idx], self.y[idx]

In [40]:
from google.colab import drive
drive.mount('/content/drive')
data_path = 'drive/My Drive/nlg/tiny'
src_lookup = Lookup(model)
tgt_lookup = Lookup(model)
batch_size = 4    
min_seq_len_X = 10
max_seq_len_X = 1000
min_seq_len_y = min_seq_len_X
max_seq_len_y = max_seq_len_X 
MEI = "Management Overview"
model = 'gpt2'
lookup = Lookup(model)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [41]:
train_loader, valid_loader = loader(data_path, batch_size, src_lookup, tgt_lookup, min_seq_len_X, max_seq_len_X, min_seq_len_y, max_seq_len_y, MEI = MEI)

[[50256, 1, 818, 2274, 812, 11, 262, 1664, 750, 407, 7715, 5981, 13380, 38, 3136, 13, 32, 3096, 5583, 379, 262, 1664, 318, 4497, 329, 29852, 18848, 2428, 691, 13, 464, 1664, 16523, 281, 6142, 2450, 13, 10493, 2370, 5644, 262, 1664, 857, 407, 423, 5423, 8998, 379, 1919, 5127, 6333, 2428, 13, 464, 1919, 5127, 6333, 3210, 16523, 4571, 319, 11149, 4137, 290, 1200, 10515, 13, 464, 1664, 338, 34875, 11383, 468, 12872, 5260, 526, 50256], [50256, 1546, 38, 6447, 379, 262, 1664, 318, 4939, 13, 464, 1664, 338, 5583, 4497, 329, 15030, 13380, 38, 2428, 318, 2174, 262, 3096, 1241, 13, 464, 1664, 338, 6142, 2450, 318, 1913, 13, 464, 1664, 468, 407, 4920, 5423, 284, 5698, 262, 4542, 286, 1919, 5127, 6333, 2428, 13, 464, 1919, 5127, 6333, 3210, 16523, 4571, 319, 11149, 4137, 290, 1200, 10515, 13, 464, 1664, 857, 407, 7271, 15771, 257, 34875, 11383, 13, 50256], [50256, 464, 1664, 16523, 287, 13380, 38, 13019, 355, 340, 468, 407, 3199, 5981, 3136, 287, 2274, 812, 13, 464, 1664, 468, 407, 9899, 3096, 124

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, device):
        super(Encoder, self).__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size 

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first = True)
        self.device = device

        self.to(device)

    def forward(self, input, input_seq_len):
        embedded = self.embedding(input)
        lstm_input = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_seq_len, batch_first = True, enforce_sorted = False)
        lstm_output, states = self.lstm(lstm_input)
        output, _ = torch.nn.utils.rnn.pad_packed_sequence(lstm_output, batch_first=True)

        return output, states


In [0]:
class Decoder(nn.Module):
    def __init__(self, hidden_size, output_size, device = device):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first = True)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)
        self.to(device)

    def forward(self, y, states):
        input = y.long()
        lstm_states = states
        batch_size = y.size(0)

        emb = self.embedding(input)
        output = F.relu(emb)
        lstm_input = output

        output_lstm, hidden = self.lstm(lstm_input, lstm_states)
        lin = self.out(output_lstm)
        output = self.softmax(lin)    
        return output, hidden

In [159]:
for i, t in enumerate(train_loader):
    sentences = t[0][0]
    seq_len = t[0][1]
    y = t[1][0]
    y_seq_len = t[1][1]
    print(sentences.size())
    print(y.size())
    break

torch.Size([4, 89])
torch.Size([4, 90])


In [167]:
input_size = vocab_size
hidden_size = 256

encoder = EncoderRNN(vocab_size, hidden_size, device = device)
enc_output, states = encoder.forward(sentences, seq_len)
print(enc_output.shape)

torch.Size([4, 89, 256])


In [169]:
decoder = DecoderRNN(hidden_size, vocab_size, device = device)
decoder_out,decoder_hidden  = decoder.forward(y, states)
print(decoder_out[0].shape)


Y size torch.Size([4, 90])
Input shape torch.Size([4, 90, 256])
relu shape torch.Size([4, 90, 256])
lstm format torch.Size([90, 256])
lin format torch.Size([4, 90, 50258])
Output torch.Size([4, 90, 50258])
torch.Size([90, 50258])


In [171]:
criterion = nn.NLLLoss()
loss = criterion(decoder_out.view(-1, vocab_size), y.contiguous().flatten())
print(loss)

tensor(4.5020, grad_fn=<NllLossBackward>)
