In [1]:
import torch, torchdata, torchtext
from torch import nn
import torch.nn.functional as F
from tqdm import tqdm
import random, math, time
from torch.autograd import Variable
import operator
import numpy as np

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 6969
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

cpu


## Positional Encoding
$
PE_{pos,2i}     =   sin(\frac{pos}{10000^\frac{2i}{d_{model}}})\\
PE_{pos,2i+1}   =   cos(\frac{pos}{10000^\frac{2i}{d_{model}}})
$

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim_model, dropout = 0.1, max_len = 5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # Encoding - From formula
        pos_encoding = torch.zeros(max_len, dim_model)
        positions_list = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        division_term = torch.exp(torch.arange(0, dim_model, 2).float() * (-np.log(10000.0)) / dim_model) # 1000^(2i/dim_model)
        
        # PE(pos, 2i) = sin(pos/1000^(2i/dim_model))
        pos_encoding[:, 0::2] = torch.sin(positions_list * division_term)
        
        # PE(pos, 2i + 1) = cos(pos/1000^(2i/dim_model))
        pos_encoding[:, 1::2] = torch.cos(positions_list * division_term)
        
        # Saving buffer (same as parameter without gradients needed)
        pos_encoding = pos_encoding.unsqueeze(0).transpose(0, 1)
        self.register_buffer("pos_encoding",pos_encoding)
        
    def forward(self, token_embedding):
        # Residual connection + pos encoding
        token_embedding = token_embedding + self.pos_encoding[:token_embedding.size(0), :]
        return self.dropout(token_embedding)

In [3]:
class seq2seqTransformer(nn.Module):
    def __init__(self, input_dim, emb_dim, device, word_dropout = 0.4, dropout = 0.1):
        super(seq2seqTransformer, self).__init__()
        self.input_dim = input_dim 
        self.emb_dim = emb_dim 
        self.word_dropout = word_dropout
        self.dropout = dropout
        self.device = device
        
        self.scale = np.sqrt(self.emb_dim)
        # self.scale = torch.sqrt(torch.IntTensor([self.hid_dim])).to(device)

        # vcocabulary embedding
        self.embedding_encoder = nn.Embedding(input_dim, emb_dim)
        self.embedding_decoder = nn.Embedding(input_dim, emb_dim)
        # positional encoding
        self.pos_encoder = PositionalEncoding(emb_dim, dropout = 0.0)

        self.transformer = nn.Transformer(d_model = emb_dim, nhead = 8, dropout = dropout)
        # linear Transformation
        self.linear = nn.Linear(emb_dim, input_dim)
        # self.init_weights()

    def load_embedding(self, embedding):  #synPG applied with GloVe glove.840B.300d.txt
        self.embedding_encoder.weight.data.copy_(torch.from_numpy(embedding)) 
        self.embedding_decoder.weight.data.copy_(torch.from_numpy(embedding))  

    def init_weights(self):
        initrange = 0.1
        # initialize cocabulary matrix weight
        self.embedding_encoder.weight.data.uniform_(-initrange, initrange)
        self.embedding_decoder.weight.data.uniform_(-initrange, initrange)
        # initialize linear weight
        self.linear.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.fill_(0)

    def generate_square_mask(self, max_sent_len, max_synt_len):
        size = max_sent_len + max_synt_len + 2 #<sos> and <eos>
        mask = torch.zeros((size, size))
        mask[:max_sent_len, max_sent_len:] = float("-inf")
        mask[max_sent_len:, :max_sent_len] = float("-inf")
        return mask
        
    def forward(self, sents, synts, trg):
        #sents  : batch_size, seq_len
        #synts  : batch_size, seq_len
        #trgs   : batch_size, seq_len 
        batch_size   = sents.size(0)
        max_sent_len = sents.size(1)
        max_synt_len = synts.size(1) - 2    # count without <sos> and <eos>
        max_targ_len = trg.size(1) - 2      # count without <sos> and <eos>

        # apply word dropout
        drop_mask = torch.bernoulli(self.word_dropout * torch.ones(max_sent_len)).bool().to(self.device)
        sents = sents.masked_fill(drop_mask, -1e10)

        # sentence, syntax => embedding
        sent_embeddings = self.embedding_encoder(sents).transpose(0, 1) * self.scale
        #sent_emb = [seq_len, batch size, emb_size]
        synt_embeddings = self.embedding_encoder(synts).transpose(0, 1) * self.scale
        synt_embeddings = self.pos_encoder(synt_embeddings) 
        #synt_emb = [seq_len, batch size, emb_size]
        en_embeddings = torch.cat((sent_embeddings, synt_embeddings), dim=0)
        #en_emb = [seq_len, batch size, emb_size*2]

        # do not allow cross attetion
        src_mask = self.generate_square_mask(max_sent_len, max_synt_len).to(self.device)
        
        # target => embedding
        de_embeddings = self.embedding_decoder(trg[:, :-1]).transpose(0, 1) * self.scale
        de_embeddings = self.pos_encoder(de_embeddings)
        
        # sequential mask
        trg_mask = self.transformer.generate_square_subsequent_mask(max_targ_len+1).to(self.device)
        
        # forward
        outputs = self.transformer(en_embeddings, de_embeddings, src_mask=src_mask, tgt_mask=trg_mask)
        
        # apply linear layer to vocabulary size
        outputs = outputs.transpose(0, 1)
        outputs = self.linear(outputs.contiguous().view(-1, self.emb_dim))
        outputs = outputs.view(batch_size, max_targ_len + 1, self.input_dim)
        #output = [batch size, trg_len, vocab_size]
        return outputs
    
    def generate(self, sents, synts, max_len = 30, sample=True, temp=0.5):
        #sents  : batch_size, seq_len
        #synts  : batch_size, seq_len
        batch_size   = sents.size(0)
        max_sent_len = sents.size(1)
        max_synt_len = synts.size(1) - 2  # count without <sos> and <eos>
        max_targ_len = max_len
        
        # output index starts with <sos>
        idxs = torch.zeros((batch_size, max_targ_len+2), dtype=torch.long).to(self.device)
        idxs[:, 0] = 1
        
        # sentence, syntax => embedding
        sent_embeddings = self.embedding_encoder(sents).transpose(0, 1) * self.scale
        synt_embeddings = self.embedding_encoder(synts).transpose(0, 1) * self.scale
        synt_embeddings = self.pos_encoder(synt_embeddings)
        en_embeddings = torch.cat((sent_embeddings, synt_embeddings), dim=0)
        
        # do not allow cross attetion
        src_mask = self.generate_square_mask(max_sent_len, max_synt_len).to(self.device)
        
        # starting index => embedding
        de_embeddings = self.embedding_decoder(idxs[:, :1]).transpose(0, 1) * self.scale
        de_embeddings = self.pos_encoder(de_embeddings)
        
        # sequential mask
        trg_mask = self.transformer.generate_square_subsequent_mask(de_embeddings.size(0)).to(self.device)
        
        # encode
        memory = self.transformer.encoder(en_embeddings, mask=src_mask)
        
        # auto-regressively generate output
        for i in range(1, max_targ_len+2):
            if i % 5 == 0:
                print(f'epoch : {i}')
            # decode
            outputs = self.transformer.decoder(de_embeddings, memory, tgt_mask=trg_mask)
            outputs = self.linear(outputs[-1].contiguous().view(-1, self.emb_dim))
            
            # get argmax index or sample index
            if not sample:
                values, idx = torch.max(outputs, 1)
            else:
                probs = F.softmax(outputs/temp, dim=1)
                idx = torch.multinomial(probs, 1).squeeze(1)
            
            # save to output index
            idxs[:, i] = idx
            
            # concatenate index to decoding
            de_embeddings = self.embedding_decoder(idxs[:, :i+1]).transpose(0, 1) * self.scale
            de_embeddings = self.pos_encoder(de_embeddings)
            
            # new sequential mask
            trg_mask = self.transformer.generate_square_subsequent_mask(de_embeddings.size(0)).to(self.device)
        
        return idxs[:, 1:]

## FastText Embedding

In [None]:
from torchtext.vocab import FastText
fast_vectors = FastText(language='simple') ##Load fasttext with language=simple
fast_embedding = fast_vectors.get_vecs_by_tokens(vocab_transform.get_itos()).to(device)

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

input_dim = 1000 #len(vocab)
emb_dim = 300  #fasttext
word_dropout = 0.4 #following SynPG
dropout     = 0.1

model = seq2seqTransformer(input_dim=input_dim, emb_dim = emb_dim, device=device, word_dropout = word_dropout, dropout = dropout)
model.embedding_encoder.weight.data = fast_embedding #apply fasttext instead of Glove 840b 300d.txt (5.56 GB) TT
model.embedding_decoder.weight.data = fast_embedding

In [None]:
#check_model
# sents   = torch.LongTensor(16, 30).random_(0, 10) #sents: [batch size, seq_len]
# synts   = torch.LongTensor(16, 30).random_(0, 10) #synts: [batch size, seq_len]
# output  = model.generate(sents, synts)
# output.shape

epoch : 5
epoch : 10
epoch : 15
epoch : 20
epoch : 25
epoch : 30


In [1]:
def train(model, loader, optimizer, criterion, clip, loader_length):
    
    model.train()
    epoch_loss = 0
    for sents_, synts_, trgs_ in loader:

        batch_size   = sents_.size(0)
        max_sent_len = sents_.size(1)
        max_synt_len = synts_.size(1) - 2  # count without <sos> and <eos>
        
        optimizer.zero_grad()
        
        #forward 
        outputs = model(sents_, synts_, trgs_)

        # calculate loss
        targs_ = trgs_[:, 1:].contiguous().view(-1) #Without <SOS>
        outputs_ = outputs.contiguous().view(-1, outputs.size(-1))
        
        optimizer.zero_grad()

        loss = criterion(outputs_, targs_)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        
        epoch_loss += loss.item()
        
    return epoch_loss / loader_length

def evaluate(model, loader, criterion, loader_length):

    #turn off dropout (and batch norm if used)
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():

        for sents_, synts_, trgs_ in loader:

            batch_size   = sents_.size(0)
            max_sent_len = sents_.size(1)
            max_synt_len = synts_.size(1) - 2  # count without <sos> and <eos>
            
            #forward 
            outputs = model(sents_, synts_, trgs_)

            # calculate loss
            targs_ = trgs_[:, 1:].contiguous().view(-1) #Without <SOS>
            outputs_ = outputs.contiguous().view(-1, outputs.size(-1))
            
            loss = criterion(outputs_, targs_)
            
            epoch_loss += loss.item()
        
    return epoch_loss / loader_length

In [7]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [None]:
import torch.optim as optim

lr = 10e-4 #Following SynPG
wd = 10e-5 #Following SynPG
#training hyperparameters
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
criterion = nn.CrossEntropyLoss(ignore_index = PAD_IDX) #combine softmax with cross entropy

In [None]:
import time
import math
best_valid_loss = float('inf')
num_epochs = 5
clip       = 1

save_path = f'models/{model.__class__.__name__}.pt' #Change here

train_losses = []
valid_losses = []

for epoch in range(num_epochs):

    start_time = time.time()

    # training
    train_loss = train()
    valid_loss = evaluate()
    
    #for plotting
    train_losses.append(train_loss)
    valid_losses.append(valid_loss)
    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    # save model
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), save_path)
    
    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f}')
    
    #lower perplexity is better

## Plot the losses

In [None]:
#code here
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(1, 1, 1)
ax.plot(train_losses, label = 'train loss')
ax.plot(valid_losses, label = 'valid loss')
plt.legend()
ax.set_xlabel('updates')
ax.set_ylabel('loss')

In [None]:
from utils import synt2str, sent2str, load_embedding

def generate(model, loader, criterion, loader_length,vocab_transform):
    #turn off dropout (and batch norm if used)
    model.eval()
    epoch_loss = 0
    
    with torch.no_grad():

        for sents_, synts_, trgs_ in loader:

            batch_size   = sents_.size(0)
            max_sent_len = sents_.size(1)
            max_synt_len = synts_.size(1) - 2  # count without <sos> and <eos>
            
            # generate
            idxs = model.generate(sents_, synts_, sents_.size(1), temp=0.5)
            
            # write output
            for sent, idx, synt in zip(sents_.cpu().numpy(), idxs.cpu().numpy(), synts_.cpu().numpy()):
                print(synt2str(synt[1:], vocab_transform)+'\n')
                print(sent2str(sent, vocab_transform)+'\n')
                print(synt2str(idx, vocab_transform)+'\n')
                print("--\n")