In [6]:
import os
import torch
import torch.nn as nn


### Loading The Pre-trained BERT

In [10]:
model_path = '../../XLM/dumped/xlm_ar/u7t8spazn5/checkpoint.pth'
reloaded = torch.load(model_path)

### Helper function to convert sentenes into BPE format

In [11]:
# Below is one way to bpe-ize sentences
codes = os.path.join(params.data_path, 'codes')
fastbpe = os.path.join(os.getcwd(), 'tools/fastBPE/fast')

def to_bpe(sentences):
    # write sentences to tmp file
    with open('/tmp/sentences', 'w') as fwrite:
        for sent in sentences:
            fwrite.write(sent + '\n')
    
    # apply bpe to tmp file
    os.system('%s applybpe /tmp/sentences.bpe /tmp/sentences %s' % (fastbpe, codes))
    
    # load bpe-ized sentences
    sentences_bpe = []
    with open('/tmp/sentences.bpe') as f:
        for line in f:
            sentences_bpe.append(line.rstrip())
    
    return sentences_bpe



NameError: name 'params' is not defined

In [7]:
class RNNAttnDecoder(nn.Module):
    
    '''
    RNN decoder class for generating output from BERT
    
    '''
    
    def __init__(self, embeddings_dim, vocab_size, hid_dim, n_layers, bert_dim, dropout):
        
        super().__init__()
        self.hid_dim = hid_dim
        self.vocab_size = vocab_size
        self.n_layers = n_layers
        self.dropout = dropout

        self.embedding = nn.Embedding(embedding_dim=embeddings_dim, num_embeddings=vocab_size)# initialize decoder embeddings with the pretrained embeddings
        self.rnn = nn.GRU(embeddings_dim + bert_dim, hid_dim, n_layers, dropout = dropout)
        self.out = nn.Linear(hid_dim, vocab_size)
        
        self.attn = nn.Linear(bert_dim + hid_dim , hid_dim)
        self.v = nn.Parameter(torch.rand(hid_dim))
        self.tanh = torch.nn.Tanh()
        self.attn_softmax = torch.nn.Softmax(dim=1)
        self.out_softmax = torch.nn.LogSoftmax(dim=2)

    
    def forward(self, x, hidden, encoder_outputs):
        
        '''
        Decode only one timestep
        
        x is (batch_size)
        hidden is (n_layers, batch_size, hid_dim)
        
        encoder_outputs is (S, B, D)
        '''
        assert x.size(0) == hidden.size(1)
        
        
        # Attention
        bs = hidden.size(1)
        src_len= encoder_outputs.size(0)
        
        last_hidden = hidden[-1,:,:] # (B x D) use only the hidden state of the last layer.
        
        hidden_repeated = last_hidden.unsqueeze(0).repeat(src_len, 1, 1)
        
        # compute attention weights
        attn_prod = self.attn(torch.cat([hidden_repeated, encoder_outputs], dim=2)) # S x B x H
        attn_energy = self.tanh(attn_prod).permute([1,0,2]).contiguous() # B x S x H
        
        v= self.v.view(1, -1, 1).repeat(bs, 1, 1) # B x D x 1
        attn_weights = self.attn_softmax(torch.bmm(attn_energy, v)).permute([0, 2, 1]) # B x 1 x S
        
        encoder_outputs = encoder_outputs.permute([1, 0, 2]) # B x S x D
        context_vector = torch.bmm(attn_weights, encoder_outputs).squeeze(1).unsqueeze(0) #1 x B x D
        
        
        # rnn input
        x = x.unsqueeze(0) # since sequence length is one (only one timestep)
        x_emb = self.embedding(x) #1 x B x D
        
        rnn_input = torch.cat([x_emb, context_vector], dim=2)
        output, hidden = self.rnn(rnn_input, hidden) # 1 x B x H
        
        prediction = self.out_softmax(self.out(output)) # 1 x B x V
        
        return prediction, hidden
          

In [8]:
class RNNDecoder(nn.Module):
    
    '''
    RNN decoder class for generating output from BERT
    
    '''
    
    def __init__(self, embeddings_dim, vocab_size, hid_dim, n_layers, bert_dim, dropout):
        
        super().__init__()
        self.hid_dim = hid_dim
        self.output_dim = vocab_size
        self.n_layers = n_layers
        self.dropout = dropout

        self.embedding = nn.Embedding(embedding_dim=embeddings_dim, num_embeddings=vocab_size)# initialize decoder embeddings with the pretrained embeddings
        self.rnn = nn.GRU(embeddings_dim, hid_dim, n_layers, dropout = dropout)
        self.out = nn.Linear(hid_dim, vocab_size)
        self.out_softmax = torch.nn.LogSoftmax(dim=2)

    
    def forward(self, x, hidden, encoder_outputs):
        
        '''
        Decode only one timestep
        
        x is (batch_size)
        hidden is (n_layers, batch_size, hid_dim)
        
        encoder_outputs is (S, B, D)
        '''
        assert x.size(0) == hidden.size(1)
        
        # Attention
        bs = hidden.size(1)
        src_len= encoder_outputs.size(0)

        # rnn input
        x = x.unsqueeze(0) # since sequence length is one (only one timestep)
        x_emb = self.embedding(x) #1 x B x D
        
        output, hidden = self.rnn(x_emb, hidden) # 1 x B x H
        prediction = self.out(output) # 1 x B x V
        
        
        return prediction, hidden
          

## Transpoemer Class

In [12]:
import re
class Transpoemer(nn.Module):
    
    def __init__(self, bert, params):
        super().__init__()
        
        self.bert = bert
        self.params = params
        self.decoder = RNNAttnDecoder(300, vocab_size=params.n_words,
                                  hid_dim=params.decoder_hidden, n_layers=params.decoder_n_layers,
                                  bert_dim= bert.dim,
                                 dropout=params.dropout)
        
        self.criterion = nn.CrossEntropyLoss(ignore_index = self.params.pad_index)
        self.bert_to_decoder= nn.Linear(bert.dim, params.decoder_hidden*params.decoder_n_layers)
        
        self.tanh =nn.Tanh()
    def forward(self, x, x_lengths, y=None, max_generation_length=20):

        '''
        This should autoregreesively generate the next token 
        
        x is the input sequence [SrcSeqLen x B]
        y is [TrgSeqLen x B] and is used only during finetuning
        
        if labels are not None, we return logits and cross entropy loss,
        otherwise we justreturns logits tensor of size []
        
        returned logits is [TrgSeqxLen x B x V]
        
        '''
    
        bs = x.size(1)
    
        # get bert embeddings
        bert_embeddings = self.bert('fwd', x=x, lengths=x_lengths, causal=True).contiguous()
        # use embeddings max pooling as init hidden of the decoder
        #bert_embeddings_maxpool, _ = bert_embeddings.max(dim=0) # B x D
        bert_embeddings_meanpool = bert_embeddings.mean(dim=0)
        
        decoder_hidden =self.tanh(self.bert_to_decoder(bert_embeddings_meanpool)) # B x LH
        decoder_hidden = decoder_hidden.view(self.params.decoder_n_layers, 
                                                       bert_embeddings_meanpool.size(0), -1).contiguous() #(L x B x Decoder_H)
        
        # autoregressive decoding 
        
        all_outputs= []
        
        if y is not None: # training case

            decoder_input = y[0,:] # <s> symbol
            y = y[1:, :] # skip start symbol to avoid repeating it
            
            for t in range(y.shape[0]):
                output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, bert_embeddings)
                all_outputs.append(output)
                decoder_input= y[t,:]
        
            #print(len(all_outputs))
        
        else : # greedy decoding for generation 
            
            decoder_input = torch.LongTensor([dico.word2id['<s>']] * bs).to("cuda") ## TODO: use the right device
            for i in range(max_generation_length):
                # (B)
                output, decoder_hidden = self.decoder(decoder_input, decoder_hidden, bert_embeddings) # 1 x B x V
                
                _, topi =  output.max(dim=2) # use predictions as input again # 1 x B 
                all_outputs.append(topi) 
                decoder_input = topi.squeeze(0) # for next step
                        
        all_outputs= torch.cat(all_outputs, dim=0) # trgSeqLen x B x V
    
        # computing loss if any
        loss = None
        
        if y is not None:
            loss = self.criterion(all_outputs.view(-1, all_outputs.size(-1)), y.view(-1))     
        
        return all_outputs, loss
    
    
    def generate_next_verse(self, verse):
    
        verse_bpe = to_bpe([verse])
        ids, lengths= shape_sentences(verse_bpe)

        generated_word_ids, _ = self.forward(ids.to("cuda"), lengths.to("cuda"))
        generated_word_ids = generated_word_ids.squeeze(1).cpu().numpy()
        words = [dico.id2word[id] for id in generated_word_ids]
        
        words = [w for w in words if w not in ['<s>', '</s>', '<unk>']]
        sent = ' '.join(words)
        

In [13]:
def generate_next_verse(verse):

    verse = '<s> %s </s>' %(verse)
    verse_bpe = to_bpe([verse])
    
    ids, lengths= shape_sentences(verse_bpe)
    generated_word_ids, _ = transpoemer(ids.to("cuda"), lengths.to("cuda"))
    print(generated_word_ids.size())
    generated_word_ids = generated_word_ids.squeeze(1).cpu().numpy()
    words = [dico.id2word[id] for id in generated_word_ids]

    words = [w for w in words if w not in ['<s>', '</s>', '<unk>']]
    sent = ' '.join(words)
    return sent


        

In [15]:
def shape_sentences(sentences):
    bs = len(sentences)
    slen = max([len(sent) for sent in sentences])

    word_ids = torch.LongTensor(slen, bs).fill_(params.pad_index).contiguous()
    for i in range(len(sentences)):
        sent = torch.LongTensor([dico.index(w) for w in sentences[i]]).contiguous()
        word_ids[:len(sent), i] = sent

    lengths = torch.LongTensor([len(sent) for sent in sentences])
    
    return word_ids, lengths


## Poetry Dataset

In [16]:
import numpy as np

class PoetryDataset(torch.utils.data.Dataset):
    
    def __init__(self, poetry_file, max_len=50):
        
        self.poetry_file=poetry_file 
        verses=[]
        cnt=0
        with open(poetry_file, 'r', encoding='utf-8') as f:
            for verse in f:
                verses.append(verse.strip())
                cnt+=1
        
        print("BPEing verses...")
        
        verses= to_bpe(verses)
        print("Done.")
        tokenized_verses=[]
        for verse in verses:
            if not verse.strip(): # skip empty verses
                continue
            if verse.strip() == '#': # end of poem
                tokenized_verses.append('<s> <special0> </s>'.split()) # add poem separator
            else:
                tokens = verse.strip().split()[:max_len - 2]
                tokens = ['<s>'] + tokens + ['</s>']
                tokenized_verses.append(tokens)
        
        tokenized_verses = np.array(tokenized_verses)
        # creating X, Y
        X = []
        Y = []
        
        even_length = (len(tokenized_verses) //2) * 2
        tokenized_verses= tokenized_verses[:even_length]
        
        for i in range(even_length - 1):
            X.append(tokenized_verses[i])
            Y.append(tokenized_verses[i+1])
            
        assert len(X) == len(Y)
        
        self.X, self.X_len = shape_sentences(X) # Len x N
        self.Y, self.Y_len = shape_sentences(Y)
        
        print(self.X.shape)
    def __len__(self):
        return self.X.shape[1]
    
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        
        return self.X[:, idx], self.X_len[idx], self.Y[:, idx], self.Y_len[idx]
        #ids, lens = shape_sentences(tokens)
        #ids = ids.T.contiguous()
        #print(ids.size())
        #return ids.T, lens
    

In [16]:
dataset=PoetryDataset('poems_separated_verses.txt')

BPEing verses...
Done.
torch.Size([50, 590001])


# Training Loop

In [17]:
import os
from torch.utils.data import DataLoader
from ignite.engine import Engine, Events
from ignite.metrics import RunningAverage, Accuracy
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import CosineAnnealingScheduler, PiecewiseLinear, create_lr_scheduler_with_warmup, ProgressBar

In [18]:
from collections import namedtuple


Config = namedtuple('Config',
  field_names="batch_size, n_epochs, lr, gradient_accumulation_steps, n_warmup, max_norm, dropout, log_dir, device")

args = Config(32, 50, 5e-4, 1, 1, 10.0, 0.1, "./poetry_models", "cuda" if torch.cuda.is_available() else "cpu")

In [24]:
# Define training function


def update(engine, batch):
    #transpoemer.train()
    
    src_ids, src_lengths, trg_ids, trg_length = batch
      
    src_ids = src_ids.T.contiguous().to(args.device)
    src_lengths = src_lengths.to(args.device)
    
    trg_ids = trg_ids.T.contiguous().to(args.device)
    trg_length = trg_length.to(args.device)
        
    logits, loss = transpoemer(x=src_ids, x_lengths=src_lengths,y=trg_ids)
    loss.backward()
    
    optimizer.step()
    optimizer.zero_grad()
    
    if engine.state.iteration % 1000 == 0:
        print(generate_next_verse('عيناك غابتا نخيل ساعة السحر'))
    
    return loss.item()

trainer = Engine(update)

In [29]:
# dataset loader
train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
# optimizer
# model

params['decoder_hidden']= 1024
params['decoder_n_layers'] = 1

transpoemer = Transpoemer(model, params).to("cuda")
optimizer = torch.optim.Adam(transpoemer.parameters(), lr=args.lr)


  "num_layers={}".format(dropout, num_layers))


In [30]:
# freeze bert weights
transpoemer.bert.eval()
for name, p in transpoemer.named_parameters():
    if 'bert.' in name and not p.requires_grad:
        print(name)
        p.requires_grad=True
        pass

In [31]:
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
ProgressBar(persist=True).attach(trainer, metric_names=['loss'])


# Learning rate schedule: linearly warm-up to lr and then decrease the learning rate to zero with cosine
#cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, len(train_dataloader) * args.n_epochs)
#scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr, args.n_warmup)
#trainer.add_event_handler(Events.ITERATION_STARTED, cos_scheduler)

# Save checkpoints and training config
checkpoint_handler = ModelCheckpoint(args.log_dir, 'checkpoint', save_interval=1, n_saved=1, require_empty=False)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'attn-1x1024': model})
torch.save(args, os.path.join(args.log_dir, 'training_args.bin'))

In [32]:
trainer.run(train_dataloader, max_epochs=args.n_epochs)

HBox(children=(IntProgress(value=0, max=18438), HTML(value='')))

torch.Size([20, 1])
ن رسم ن رسم ن رسم ن رسم
torch.Size([20, 1])
ن ن ن ن غ@@ لقت ن غ@@ لقت ن غ@@
torch.Size([20, 1])
و ن ن ن ن ن ن ن ن ن
torch.Size([20, 1])
ن ن ن ن ن ن ن ن ن ن
torch.Size([20, 1])
و ن ن ن ن ن ن ن ن ن
torch.Size([20, 1])
ن ن ن ن ن ن ن ن ن ن
torch.Size([20, 1])
ف@@ نا نت نت نت نت نت ي ي ي
torch.Size([20, 1])
نا ضد العلاقة نا ضد العلاقة و مش@@ ي خ@@ رى
torch.Size([20, 1])

torch.Size([20, 1])
و@@ نت يا نت يا نت يا نت يا نت يا نت يا
torch.Size([20, 1])

torch.Size([20, 1])
و@@ ن ن رى ن رى ن رى ن رى
torch.Size([20, 1])
نا لا رى سببا س@@ لتي خ@@ رى ذا ما رى ال@@
torch.Size([20, 1])
و@@ رى ن غ@@ دو على ص@@ بوة من جل ن كون غ@@
torch.Size([20, 1])
م ن ح@@ لـ@@ ت ن ر@@ يت ح@@ لام من ي
torch.Size([20, 1])
ن حرر في ي شيء ي@@ اسي@@ دة نا من جل ن قول
torch.Size([20, 1])
و@@ نا رى ن ح@@ بك ن ح@@ بك ن ح@@ بك
torch.Size([20, 1])
و ح@@ بك و ح@@ بك و ح@@ بك و ح@@ لامي و


HBox(children=(IntProgress(value=0, max=18438), HTML(value='')))

HBox(children=(IntProgress(value=0, max=18438), HTML(value='')))

KeyboardInterrupt: 