<a href="https://colab.research.google.com/github/ninalzr/nlg/blob/master/Bert2Bert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install transformers

In [0]:
from tqdm import tqdm
for i in tqdm(range(10000), ):
    pass

In [0]:
import os, sys, json
from datetime import datetime

In [0]:
#TODO: Adjust the class for other tokenizers
class Lookup():
    def __init__(self, model_class, file_prefix = None):

        self.model_class = model_class

        self.bos_token = None
        self.eos_token = None
        self.unk_token = None
        self.sep_token = None
        self.pad_token = None
        self.cls_token = None
        self.mask_token = None

        if model_class == 'gpt2':
            from transformers import GPT2Tokenizer
            self._tokenizer = GPT2Tokenizer.from_pretrained(model_class)
            
        if model_class == 'bert':
            from transformers import BertTokenizer
        self._tokenizer = BertTokenizer.from_pretrained('bert-base-cased')

        self._tokenizer.add_special_tokens({'pad_token': '<PAD>'})

        if self._tokenizer._bos_token:
            self.bos_token = self._tokenizer.bos_token
        if self._tokenizer._eos_token:
            self.eos_token = self._tokenizer.eos_token
        if self._tokenizer._unk_token:                
            self.unk_token = self._tokenizer.unk_token
        if self._tokenizer._sep_token:
            self.sep_token = self._tokenizer.sep_token
        if self._tokenizer._pad_token:
            self.pad_token = self._tokenizer.pad_token
        if self._tokenizer._cls_token:
            self.cls_token = self._tokenizer.cls_token
        if self._tokenizer._mask_token:
            self.mask_token = self._tokenizer.mask_token 

        
        if file_prefix:
            self.load(file_prefix)

        def save_special_tokens(self, file_prefix):
            if self.model_class == "gpt2" or self.model_class == 'bert':
                special_tokens = {}
            if self.bos_token:
                special_tokens['bos_token'] = self.bos_token
            if self.eos_token:
                special_tokens['eos_token'] = self.eos_token
            if self.unk_token:
                special_tokens['unk_token'] = self.unk_token
            if self.sep_token:
                special_tokens['sep_token'] = self.sep_token
            if self.pad_token:
                special_tokens['pad_token'] = self.pad_token
            if self.cls_token:
                special_tokens['cls_token'] = self.cls_token
            if self.mask_token:
                special_tokens['mask_token'] = self.mask_token            
            json.dump(special_tokens, open(file_prefix+".special_tokens","w",encoding="utf8"), indent=4, sort_keys=True)            
            self._tokenizer.add_special_tokens(special_tokens)  
        
        def load(self, file_prefix):
            if os.path.exists(file_prefix+".special_tokens"):
                special_tokens = json.load(open(file_prefix+".special_tokens","r",encoding="utf8"))            
            if 'bos_token' in special_tokens:
                self.bos_token = special_tokens['bos_token']
            if 'eos_token' in special_tokens:
                self.eos_token = special_tokens['eos_token']
            if 'unk_token' in special_tokens:
                self.unk_token = special_tokens['unk_token']
            if 'sep_token' in special_tokens:
                self.sep_token = special_tokens['sep_token']
            if 'pad_token' in special_tokens:
                self.pad_token = special_tokens['pad_token']
            if 'cls_token' in special_tokens:
                self.cls_token = special_tokens['cls_token']
            if 'mask_token' in special_tokens:
                self.mask_token = special_tokens['mask_token']
            self._tokenizer.add_special_tokens(special_tokens)      

    def tokenize(self, text):
        return self._tokenizer.tokenize(text)

    def convert_tokens_to_ids(self, tokens):
        return self._tokenizer.convert_tokens_to_ids(tokens)

    def convert_ids_to_tokens(self, token_ids):
        return self._tokenizer.convert_ids_to_tokens(token_ids)

    def convert_tokens_to_string(self, tokens):
        return self._tokenizer.convert_tokens_to_string(tokens)

    def encode(self, text, add_bos_eos_tokens = False):
        tokens = self.tokenize(text)

        if add_bos_eos_tokens:
            if self.model_class == 'bert':
                if not self.cls_token or not self.sep_token:
                    raise Exception("Lookup encode error: {} model does not have CLS or SEP tokens set!")
                return [self.convert_tokens_to_ids(self.cls_token)] + self.convert_tokens_to_ids(tokens) + [self.convert_tokens_to_ids(self.sep_token)]
            else:
                if not self.bos_token or not self.eos_token:
                    raise Exception("Lookup encode error: {} model does not have BOS or EOS tokens set!")
                return [self.convert_tokens_to_ids(self.bos_token)] + self.convert_tokens_to_ids(tokens) + [self.convert_tokens_to_ids(self.eos_token)]
        else:
            return self.convert_tokens_to_ids(tokens)

    def decode(self, token_ids, skip_bos_eos_tokens = False):
        if skip_bos_eos_tokens:  
            if self.model_class == "bert":
                if len(token_ids)>0:
                    if token_ids[0] == self.convert_tokens_to_ids(self.cls_token):
                        token_ids = token_ids[1:]
                if len(token_ids)>0:
                    if token_ids[-1] == self.convert_tokens_to_ids(self.sep_token):
                        token_ids = token_ids[:-1]       
            else:
                if not self.bos_token or not self.eos_token:                
                    raise Exception("Lookup decode error: {} model does not have BOS or EOS tokens set!")                                  
                if len(token_ids)>0:
                    if token_ids[0] == self.convert_tokens_to_ids(self.bos_token):
                        token_ids = token_ids[1:]
                if len(token_ids)>0:
                    if token_ids[-1] == self.convert_tokens_to_ids(self.eos_token):
                        token_ids = token_ids[:-1]        
        if len(token_ids)>0: 
            tokens = self.convert_ids_to_tokens(token_ids)                
            return self.convert_tokens_to_string(tokens)
        return ""

    def __len__(self):          
        return len(self._tokenizer)

In [0]:
model = 'bert'
lookup = Lookup(model)
text = "Daisy, Daisy, Give me your answer, do!"
print("\n1. String to tokens (tokenize):")
tokens = lookup.tokenize(text)
print(tokens)

print("\n2. Tokens to ints (convert_tokens_to_ids):")
ids = lookup.convert_tokens_to_ids(tokens)
print(ids)
        
print("\n2.5 Token to int (convert_tokens_to_ids with a single str):")
id = lookup.convert_tokens_to_ids(tokens[0])
print(id)

print("\n3. Ints to tokens (convert_ids_to_tokens):")
tokens = lookup.convert_ids_to_tokens(ids)
print(tokens)

print("\n3.5 Int to token (convert_ids_to_tokens with a single int):")
token = lookup.convert_ids_to_tokens(id)
print(token)

print("\n4. Tokens to string (convert_tokens_to_string):")
recreated_text = lookup.convert_tokens_to_string(tokens)
print(recreated_text)

print("\n5. String to ints (encode):")
ids = lookup.encode(text)
print(ids)

print("\n6. Ints to string (decode):")
recreated_text = lookup.decode(ids)
print(recreated_text)

print("\n7. Encode adding special tokens:")
ids = lookup.encode(text, add_bos_eos_tokens=True)
print(ids)
print("How it looks like with tokens: {}".format(lookup.convert_ids_to_tokens(ids)))
    
print("\n8. Decode skipping special tokens:")
recreated_text = lookup.decode(ids, skip_bos_eos_tokens=True)
print(recreated_text)

print("\n9. Vocabulary size:")
vocab_size = lookup.__len__()
print(vocab_size)

In [0]:
import os, sys, json, random
import numpy as np
import torch
import torch as nn
import torch.utils.data

from functools import partial

In [0]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [0]:
#Ignore slots for now.
#TODO: Figure out what to do with slots
#Remember to change load from file X, y
def loader(data_folder, batch_size, src_lookup, tgt_lookup, min_seq_len_X = 5, max_seq_len_X = 1000, min_seq_len_y = 5,
           max_seq_len_y = 1000, MEI = ""):
    MEI = MEI.replace(" ","_")
    pad_id = tgt_lookup.convert_tokens_to_ids(tgt_lookup.pad_token)
    
    train_loader = torch.utils.data.DataLoader(
        MyDataset(data_folder, "train", min_seq_len_X, max_seq_len_X, min_seq_len_y, max_seq_len_y, MEI),
        num_workers=0,
        batch_size=batch_size,
        collate_fn=partial(paired_collate_fn, padding_idx = pad_id),
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(
        MyDataset(data_folder, "dev", min_seq_len_X, max_seq_len_X, min_seq_len_y, max_seq_len_y, MEI),
        num_workers=0,
        batch_size=batch_size,
        collate_fn=partial(paired_collate_fn, padding_idx = pad_id))
    
    return train_loader, valid_loader

def paired_collate_fn(insts, padding_idx):
    # insts contains a batch_size number of (x, y) elements    
    src_insts, tgt_insts = list(zip(*insts))   
    
    src_max_len = max(len(inst) for inst in src_insts) # determines max size for all examples
    
    src_seq_lengths = torch.tensor(list(map(len, src_insts)), dtype=torch.long)    
    src_seq_tensor = torch.tensor(np.array( [ inst + [padding_idx] * (src_max_len - len(inst)) for inst in src_insts ] ), dtype=torch.long)
    src_seq_mask = torch.tensor(np.array( [ [1] * len(inst) + [0] * (src_max_len - len(inst)) for inst in src_insts ] ), dtype=torch.long)
    
    src_seq_lengths, perm_idx = src_seq_lengths.sort(0, descending=True)
    src_seq_tensor = src_seq_tensor[perm_idx]   
    src_seq_mask = src_seq_mask[perm_idx]
    tgt_max_len = max(len(inst) for inst in tgt_insts)
    
    tgt_seq_lengths = torch.tensor(list(map(len, tgt_insts)), dtype=torch.long)    
    tgt_seq_tensor = torch.tensor(np.array( [ inst + [padding_idx] * (tgt_max_len - len(inst)) for inst in tgt_insts ] ), dtype=torch.long)
    tgt_seq_mask = torch.tensor(np.array( [ [1] * len(inst) + [0] * (tgt_max_len - len(inst)) for inst in tgt_insts ] ), dtype=torch.long)
    
    tgt_seq_lengths = tgt_seq_lengths[perm_idx]
    tgt_seq_tensor = tgt_seq_tensor[perm_idx]      
    tgt_seq_mask = tgt_seq_mask[perm_idx]   
      
    return ((src_seq_tensor, src_seq_lengths, src_seq_mask), (tgt_seq_tensor, tgt_seq_lengths, tgt_seq_mask))   
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, type, min_seq_len_X, max_seq_len_X, min_seq_len_y, max_seq_len_y, MEI):  
        self.root_dir = root_dir

        self.X = [] # this will store joined sentences
        self.y = [] # this will store the output

    
        with open(os.path.join(root_dir, type, MEI + '_output.txt'), 'r') as f:
            y = [lookup.encode(y.strip(), add_bos_eos_tokens=True)  for y in f]
        with open(os.path.join(root_dir, type, MEI + '_sentences.txt'), 'r') as g:
            X = [lookup.encode(x.strip(), add_bos_eos_tokens=True)  for x in g]   
                    
        cut_over_X = 0
        cut_under_X = 0
        cut_over_y = 0
        cut_under_y = 0
        
        # max len
        for (sx, sy) in zip(X, y):
            if len(sx) > max_seq_len_X:
                cut_over_X += 1
            elif len(sx) < min_seq_len_X+2:                
                cut_under_X += 1
            elif len(sy) > max_seq_len_y:
                cut_over_y += 1
            elif len(sy) < min_seq_len_y+2:                
                cut_under_y += 1
            else:
                self.X.append(sx)
                self.y.append(sy)         

        c = list(zip(self.X, self.y))
        random.shuffle(c)
        self.X, self.y = zip(*c)
        self.X = list(self.X)
        self.y = list(self.y)
                    
        print("Dataset [{}] loaded with {} out of {} ({}%) instances.".format(type, len(self.X), len(X), float(100.*len(self.X)/len(X)) ) )
        print("\t\t For X, {} are over max_len {} and {} are under min_len {}.".format(cut_over_X, max_seq_len_X, cut_under_X, min_seq_len_X))
        print("\t\t For y, {} are over max_len {} and {} are under min_len {}.".format(cut_over_y, max_seq_len_y, cut_under_y, min_seq_len_y))
        
        assert(len(self.X)==len(self.y))
        
    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):        
        return self.X[idx], self.y[idx]

In [0]:
from google.colab import drive
drive.mount('/content/drive')
data_path = 'drive/My Drive/nlg/tiny'
src_lookup = Lookup(model)
tgt_lookup = Lookup(model)
batch_size = 2
min_seq_len_X = 10
max_seq_len_X = 1000
min_seq_len_y = min_seq_len_X
max_seq_len_y = max_seq_len_X 
MEI = "Management Overview"
model = 'bert'
lookup = Lookup(model)

In [0]:
import os, sys

import torch
import torch.nn as nn
from transformers import BertModel, BertConfig

class Encoder(nn.Module):
    def __init__(self, vocab_size, device):       
        super().__init__()
        
        self.hidden_size = 768

        configuration = BertConfig()
        configuration.output_attentions = True

        self.bertmodel = BertModel(configuration)
        self.bertmodel.resize_token_embeddings(vocab_size)      
        for param in self.bertmodel.parameters():
            param.requires_grad = False
        
        self.device = device
        self.to(device)

    def forward(self, input_tuple):
        """
        Args:
            input_tuple (tensor): The input of the encoder. On the first position it must be a 2-D tensor of integers, padded. The second is the lenghts of the first.
                Shape: ([batch_size, seq_len_enc], [batch_size], [att_mask]])

        Returns:
            Output shape: [batch_size, seq_len_enc, 768]
            (tuple) Past shape: ((2, batch_size, num_heads, sequence_length, embed_size_per_head),(2, batch_size, num_heads, sequence_length, embed_size_per_head))
            (tuple) Att shape: ((batch_size, num_heads, sequence_length, sequence_length), (batch_size, num_heads, sequence_length, sequence_length))

        
        """
        self.bertmodel.eval()
        X, X_lengths, X_att_mask = input_tuple[0], input_tuple[1], input_tuple[2]
        batch_size = X.size(0)
        seq_len = X.size(1)
        print(seq_len)
        
        output = torch.zeros(batch_size, seq_len, self.hidden_size).to(self.device)
        output.requires_grad = False

        
        with torch.no_grad(): 
            hidden_states, past, att   = self.bertmodel(X, attention_mask = X_att_mask)  
            for i in range(batch_size):
                output[i:i+1, 0:X_lengths[i], :] = hidden_states[i:i+1, 0:X_lengths[i], :]
            
        return {'output':output, 'past':past, 'att': att}

In [0]:
train_loader, valid_loader = loader(data_path, batch_size, src_lookup, tgt_lookup, min_seq_len_X, max_seq_len_X, min_seq_len_y, max_seq_len_y, MEI = MEI)

In [0]:
from transformers import BertModel, BertConfig
class Decoder(nn.Module):
    def __init__(self, hidden_size, vocab_size, device = device):
        super(Decoder, self).__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size


        configuration = BertConfig()
        configuration.is_decoder = True
        configuration.output_attentions = True

        self.bertmodel = BertModel(configuration)
        self.bertmodel.resize_token_embeddings(vocab_size) #resize the size of vocab to include new tokens 
        for param in self.bertmodel.parameters():
            param.requires_grad = False
        
        self.lin_out = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.LogSoftmax(dim = 1)

        self.device = device
        self.to(device)

    def forward(self, y_tuple, X_att_mask, encoder_hidden_states):
        y = y_tuple[0]
        y_lenghts = y_tuple[1]
        y_att_mask = y_tuple[2]
        batch_size = y.size(0)
        y_seq_len = y.size(1)

        output = torch.zeros(batch_size, y_seq_len, self.hidden_size).to(self.device)

        output.requires_grad = False
        with torch.no_grad():
            hidden, past, decoder_attention = self.bertmodel(y, attention_mask = y_att_mask, 
                                                             encoder_hidden_states = encoder_hidden_states)  
            for i in range(batch_size):
                output[i:i+1, 0:y_lenghts[i], :] = hidden[i:i+1, 0:y_lenghts[i], :]

        out_lin = self.lin_out(output)
        output = self.softmax(out_lin)

        return {'output':output, 'past':past, 'att': decoder_attention}     

In [0]:
class EncoderDecoder(nn.Module):
    def __init__(self, src_lookup, tgt_lookup, encoder, decoder, device):
        super().__init__()
        
        if torch.cuda.is_available():            
            self.cuda = True
            self.device = torch.device('cuda')
        else:            
            self.cuda = False
            self.device = torch.device('cpu')

        self.src_lookup = src_lookup
        self.tgt_lookup = tgt_lookup
        self.src_bos_token_id = src_lookup.convert_tokens_to_ids(src_lookup.bos_token)
        self.src_eos_token_id = src_lookup.convert_tokens_to_ids(src_lookup.eos_token)
        self.tgt_bos_token_id = src_lookup.convert_tokens_to_ids(tgt_lookup.bos_token)
        self.tgt_eos_token_id = src_lookup.convert_tokens_to_ids(tgt_lookup.eos_token)
    
        self.encoder = encoder       
        self.decoder = decoder
        
        self.device = device
        self.to(self.device)

    def forward(self, X_tuple, y_tuple, teacher_forcing_ratio=0.):
        x, x_lenghts, x_mask= X_tuple[0], X_tuple[1], X_tuple[2]
        batch_size = x.shape[0]

        encoder_dict = self.encoder.forward((x, x_lenghts, x_mask))
        enc_output = encoder_dict["output"]
        enc_past = encoder_dict["past"]
        enc_att = encoder_dict["att"]

        decoder_dict = self.decoder.forward(y_tuple, X_att_mask = x_mask, encoder_hidden_states = encoder_out['output'])

        output_decoder = decoder_dict['output']
        attention_decoder = decoder_dict['att']


        return output_decoder, attention_decoder

    def run_batch(self, X_tuple, y_tuple = None, criterion = None):
        y = y_tuple[0]
        print("Run batch {}".format(y.shape))

        output_decoder, attention_decoder = self.forward(X_tuple, y_tuple, teacher_forcing_ratio=0.)
        print("Decoder out {}".format(output_decoder.shape))

        total_loss = 0

        if criterion is not None:

            loss = criterion(output_decoder.view(-1, vocab_size), y.contiguous().flatten())
            print("Loss {}".format(loss))
            total_loss += loss

        return output_decoder, total_loss, attention_decoder


    def load_checkpoint(self, folder, extension):
        filename = os.path.join(folder, "checkpoint." + extension)
        print("Loading model {} ...".format(filename))
        if not os.path.exists(filename):
            print("\tModel file not found, not loading anything!")
            #return {}
            raise Exception("Error, model file not found! {} -> model {}".format(folder, extension))

        checkpoint = torch.load(filename, map_location=self.device)
        self.load_state_dict(checkpoint["state_dict"])
        
        self.encoder.to(self.device)
        self.decoder.to(self.device)
        return checkpoint["extra"]

    def save_checkpoint(self, folder, extension, extra={}):
        filename = os.path.join(folder, "checkpoint." + extension)
        checkpoint = {}
        checkpoint["state_dict"] = self.state_dict()
        checkpoint["extra"] = extra
        torch.save(checkpoint, filename)


In [0]:
encoder = Encoder(vocab_size=vocab_size, device = device)
decoder = Decoder(hidden_size=hidden_size, vocab_size=vocab_size, device = device)

In [0]:
src_lookup = Lookup('bert')
tgt_lookup = Lookup('bert')
model = EncoderDecoder(src_lookup = src_lookup, tgt_lookup = tgt_lookup, encoder = encoder , decoder = decoder, device = device)

In [0]:
def clean_sequences(sequences, lookup):
    """
        Cleans BOS and EOS from sequences.
        sequences (list): is a list of lists containing ints corresponding to the lookup
    """
    bos_id = lookup.convert_tokens_to_ids(lookup.bos_token)
    eos_id = lookup.convert_tokens_to_ids(lookup.eos_token)
    cleaned_sequences = []        
    for seq in sequences:
        lst = []
        for i, value in enumerate(seq):                                
            if i == 0 and value == bos_id: # skip bos
                continue
            if i>0 and value == eos_id: # stop before first eos       
                break
            lst.append(value)
        cleaned_sequences.append(lst)
    
    return cleaned_sequences

In [0]:
#train
patience = 10
max_epochs=40
current_patience = patience
current_epoch = 0
print(str(current_epoch))
while current_patience > 0 and current_epoch < max_epochs:  
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, amsgrad=True)#, weight_decay=1e-3) 
    criterion = nn.NLLLoss()
    total_loss = 0
    t = tqdm(train_loader, mininterval=0.5, desc="Epoch " + str(current_epoch)+" [train]", unit="b")
    for i, t in enumerate(t):
        X_tuple = t[0]
        y_tuple = t[1] 
        output_decoder, loss, attention_decoder = model.run_batch(X_tuple, y_tuple, criterion = criterion)

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)    
        optimizer.step()

        total_loss += loss.item()

    #dev
    model.eval()
    seq_len = 2
    with torch.no_grad():
        y_gold = list()
        y_predicted = list()
        t = tqdm(valid_loader, mininterval=0.5, desc="Epoch " + str(current_epoch)+" [valid]", unit="b")
        for i, v in enumerate(t):
            X_tuple = v[0]
            y_tuple = v[1]

            output_decoder, loss, attention_decoder = model.run_batch(X_tuple, y_tuple, criterion = criterion)
            y_predicted_batch = output_decoder.argmax(dim=2)
            print(y_predicted_batch.shape)
            y_gold += y_tuple[0].tolist()

            y_predicted += y_predicted_batch.tolist()

        for i in range(seq_len):

            lst = []
            for j in range(len(y_predicted[i])):
                lst.append(y_predicted[i][j])

            print("Y Pred: ")
            tstr = tgt_lookup.decode(lst, skip_bos_eos_tokens = True)
            print(tstr)

            glst = []
            print("Y Gold: ")
            for g in range(len(y_gold[i])):
                glst.append(y_gold[i][g])
            gstr = tgt_lookup.decode(glst, skip_bos_eos_tokens = True)
            print(gstr)
    current_epoch += 1

