# Only Parallel Data Model

This is the Python Notebook concerning the final version containing the coding for the Only-Parallel (Supervised) version of the program. It is to be completed with the correct format of the Contextual and Style Classifiers in their corresponding parts.

Date of upload: Friday 31th January

Actual Version: 3.0, coherence implemented (Monday 10th February, night)

Precedent Versions : 2.1, 2.0, 1.1, 1.0



## import(s)

In [1]:
from scripts.data_builders.prepare_dataset import prepare_dataset_parallel,string2code,code2string

import math
import torch
import torchvision.datasets as datasets
import torch.nn.functional as F
from torch import nn
from torch import optim
from torch.utils.tensorboard import SummaryWriter
from torchnlp.metrics import get_moses_multi_bleu,get_token_accuracy
from torch.optim import Adam

from pathlib import Path

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device = ",device)
print(torch.__version__)

device =  cuda
1.4.0


## Data Pre-processing

In [2]:
train_data, dict_words = prepare_dataset_parallel("data/shakespeare.csv",device,ratio=0.5) #check with shift+tab to look at the data structure
dict_token = {b:a for a,b in dict_words.items()}

Loading ...
- Shakespeare dataset length :  21079
- Corrupted samples (ignored) :  0


## Parameters and Embedding

In [3]:
savepath = Path("data/models/embedding_v1")
embedding = torch.load(savepath,map_location = device)
embedding.weight.requires_grad = False

In [11]:
nb_heads = 4
d_feedforward = 1024
batch_size = 32
dict_size = len(dict_token)

d_embedding = embedding.embedding_dim

## Positional Encoding

In [12]:
class PositionalEncoding(torch.nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

## Context Classifier (To be filled)

In [13]:
savepath_style = Path("data/models/coherence_classifier_v1_epoch_100")
coherence_classifier = torch.load(savepath_style,map_location = device)
coherence_classifier.requires_grad = False

## Style classifier (To be filled)

In [14]:
savepath_style = Path("data/models/style_classifier_v1")
style_classifier = torch.load(savepath_style,map_location = device)
style_classifier.requires_grad = False

## Model (To be adapted once we have the Context and Style Class For the Embedding)

In [15]:
class ParallelModel(torch.nn.Module):
    def __init__(self,dict_size, d_embedding, nb_heads, d_feedforward):
        super().__init__()
       
        self.embed_layer = torch.nn.Embedding(dict_size+1, d_embedding, padding_idx=dict_size)
        self.positional_layer = PositionalEncoding(d_embedding)
        self.sentence_encoder = torch.nn.TransformerEncoderLayer(d_model = d_embedding, nhead = nb_heads,
                                                    dim_feedforward = d_feedforward)
        self.context_encoder = torch.nn.TransformerEncoderLayer(d_model = d_embedding, nhead = nb_heads,
                                                    dim_feedforward = d_feedforward)
        self.sentence_decoder = torch.nn.TransformerDecoderLayer(d_model = d_embedding, nhead = nb_heads,
                                                    dim_feedforward = d_feedforward)
        self.label_embedding = torch.nn.Embedding(2,768)
        self.padd = dict_size
   
    def _generate_padding_mask(self,x):
        mask = (x == dict_size)
        return mask
   
    def _generate_square_subsequent_mask(self,sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
   
    def forward(self,x,ctx_x,y,label_x):
        device = x.device
        padd_x_mask = self._generate_padding_mask(x).to(device)
        padd_ctx_mask = self._generate_padding_mask(ctx_x).to(device)
        padd_y_mask = self._generate_padding_mask(y).to(device)
       
        mask_x = self._generate_square_subsequent_mask(x.shape[1]).to(device)
        mask_ctx = self._generate_square_subsequent_mask(ctx_x.shape[1]).to(device)
        mask_y = self._generate_square_subsequent_mask(y.shape[1]).to(device)
       
        # Embedding
        x = self.embed_layer(x).transpose(0,1) # token x batch x embedding
        ctx_x = self.embed_layer(ctx_x).transpose(0,1)
       
        # Positional Encoding
        x = self.positional_layer(x)
        ctx_x = self.positional_layer(ctx_x)        
       
        # Encoders
        x_enc = self.sentence_encoder(x,src_mask=mask_x,src_key_padding_mask=padd_x_mask)
        ctx_enc = self.context_encoder(ctx_x,src_mask=mask_ctx,src_key_padding_mask=padd_ctx_mask)
 
        # Linear and Style Mixing
        x_and_ctx = torch.cat((x_enc,ctx_enc),dim = 0)
        label = (1-label_x).reshape((1,x_and_ctx.shape[1])).expand((x_and_ctx.shape[0],x_and_ctx.shape[1])).to(device)
        x_lab = x_and_ctx + self.label_embedding(label)
       
        # Decoder
        padd_mem_mask = torch.cat((padd_x_mask,padd_ctx_mask),1)
        y = self.embed_layer(y)
        y_pos = self.positional_layer(y.transpose(0,1))
        y_pred = self.sentence_decoder(y_pos,x_lab,tgt_mask=mask_y,tgt_key_padding_mask=padd_y_mask,
                                       memory_key_padding_mask=padd_mem_mask)
       
        return(y_pred.transpose(0,1),y)
   
    def translator(self,x,ctx_x,y,label_x):
        device = x.device
        mask_x = self._generate_square_subsequent_mask(x.shape[1]).to(device)
        mask_ctx = self._generate_square_subsequent_mask(ctx_x.shape[1]).to(device)
       
        # Embedding
        x = self.embed_layer(x).transpose(0,1) # token x batch x embedding
        ctx_x = self.embed_layer(ctx_x).transpose(0,1)
       
        # Positional Encoding
        x = self.positional_layer(x)
        ctx_x = self.positional_layer(ctx_x)        
       
        # Encoders
        x_enc = self.sentence_encoder(x,mask_x)
        ctx_enc = self.context_encoder(ctx_x,mask_ctx)
       
        # Linear and Style Mixing
        x_and_ctx = torch.cat((x_enc,ctx_enc),dim = 0)
        label = (1-label_x).reshape((1,x_and_ctx.shape[1])).expand((x_and_ctx.shape[0],x_and_ctx.shape[1])).to(device)
        x_lab = x_and_ctx + self.label_embedding(label)
        return(x_lab)

## Training  (To be adapted once we have the Context and Style Class For the Embedding)

In [16]:
# Definition of the model(s)

#model = ParallelModel(dict_size, d_embedding, nb_heads, d_feedforward).to(device)
model = torch.load("data/models/parallel_model_v0_epoch_200",map_location = device)
model.embed_layer = embedding
model.embed_layer.weight.requires_grad = False

In [17]:
# Information concerning the Training optimizer

decoder_linear = torch.load("data/models/parallel_decoder_linear_v0_epoch_200",map_location = device)
softmax_layer = torch.nn.LogSoftmax(dim = 2).to(device)

params = list(model.parameters()) + list(decoder_linear.parameters())#+ list(context_encoder.parameters()) + 
                                  #list(linear_context.parameters) + list(sentence_decoder.parameters())

l_r = 1e-5
optimizer=Adam(params,lr=l_r)

#Weights of the losses
l1=1 #
l2=1
l3=1

In [18]:
# Losses 
#loss_seq2seq = torch.nn.SmoothL1Loss(reduction='mean') #Contextual Seq2Seq Loss
#loss_seq2seq = torch.nn.KLDivLoss(reduction = 'mean')
loss_seq2seq = torch.nn.CrossEntropyLoss(reduction = 'mean',ignore_index=dict_size)

In [19]:
nb_epoch = 300
writer = SummaryWriter("data/runs/parallel_model_v0")
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size,
                                            shuffle=True,collate_fn=train_data.collate)
n = len(train_data.x) // batch_size

for epoch in range(201,nb_epoch+1):
    total_loss = total_seq = total_style = total_coh = total_accuracy = total_BLEU = 0
    i = 0
    
    for x,y, ctx_x, ctx_y_1,ctx_y_2, label_x,len_y in train_loader:
        i += 1
        if i==n:
            break
        optimizer.zero_grad()
        
        x,_ = model.forward(x,ctx_x,y,label_x) #Output still embedded
        y_pred_dist = decoder_linear(x)[:,:-1,:]
        
        
        # Seq2Seq Loss with Embedding
        #y_hot = torch.nn.functional.one_hot(y).float()
        #y_hot[y_hot==dict_size] = 0.
        loss_seq = loss_seq2seq(y_pred_dist.reshape(-1,dict_size),y[:,1:].reshape(-1))
        
        # Style Loss
        loss_sty,_ = style_classifier.forward(inputs_embeds=x,attention_mask=(y != dict_size).int(),labels=(1-label_x).to(device))
        
        
        # Coherence Loss
        ctx_y = torch.cat([embedding(ctx_y_1),x,embedding(ctx_y_2)],dim=1)
        mask_coh = torch.cat([(ctx_y_1 != dict_size).int(),
                              ((y != dict_size)*(y != 0)*(y != 1)).int(),
                              (ctx_y_2 != dict_size).int()],dim=1).to(device)
        loss_coh,_ = coherence_classifier(inputs_embeds=ctx_y,attention_mask=mask_coh,labels=torch.LongTensor([1]*batch_size).to(device))
        
        
        # Total Loss
        loss = l1 * loss_seq + l2 * loss_sty + l3 *  loss_coh

        # Step
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_seq += loss_seq.item()
        total_style += loss_sty.item()
        total_coh += loss_coh.item()
        
        pred = torch.argmax(softmax_layer(y_pred_dist),dim=2)
        hypotheses = [code2string(item[:len_],dict_token,sos=True) for item,len_ in zip(pred,len_y)]
        references = [code2string(item[:len_],dict_token) for item,len_ in zip(y,len_y)]
        accuracy,_,_ = get_token_accuracy(y[:,1:],pred)
        total_accuracy += accuracy
        BLEU = get_moses_multi_bleu(hypotheses,references,lowercase=True)
        total_BLEU += BLEU if BLEU else 0

    #Vizualization and saving model
        if ((i-1) % 94 == 0):
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'total loss {:5.2f} | seq2seq loss {:5.2f} | '
                  'style loss {:5.2f} | coherence loss {:5.2f}'
                  '| total accuracy {:5.2f} | total BLEU {:5.2f}'.format(
                    epoch, i, n,total_loss/i,total_seq/i,total_style/i,total_coh/i,total_accuracy/i,total_BLEU/i))
    print('-' * 110)
    writer.add_scalar('train_loss',total_loss/(n-1),epoch)
    writer.add_scalar('train_loss_seq',total_seq/(n-1),epoch)
    writer.add_scalar('train_loss_style',total_style/(n-1),epoch)
    writer.add_scalar('train_loss_coh',total_coh/(n-1),epoch)
    writer.add_scalar('train_word_accuracy',total_accuracy/(n-1),epoch)
    writer.add_scalar('train_BLEU',total_BLEU/(n-1),epoch)
    if (epoch%5==0):
        torch.save(model,"data/models/parallel_model_v0_epoch_"+str(epoch))
        torch.save(decoder_linear,"data/models/parallel_decoder_linear_v0_epoch_"+str(epoch))

| epoch 201 |     1/  658 batches | total loss  1.16 | seq2seq loss  1.15 | style loss  0.00 | coherence loss  0.01| total accuracy  0.29 | total BLEU 42.46
| epoch 201 |    95/  658 batches | total loss  1.13 | seq2seq loss  1.12 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 37.94
| epoch 201 |   189/  658 batches | total loss  1.13 | seq2seq loss  1.12 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 38.07
| epoch 201 |   283/  658 batches | total loss  1.13 | seq2seq loss  1.13 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 38.03
| epoch 201 |   377/  658 batches | total loss  1.13 | seq2seq loss  1.13 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 38.15
| epoch 201 |   471/  658 batches | total loss  1.13 | seq2seq loss  1.13 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 38.14
| epoch 201 |   565/  658 batches | total loss  1.13 | seq

multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmp_0u8k8gz: not in gzip format\nERROR: could not find reference file /tmp/tmp_0u8k8gz at /tmp/tmpzoyxz7rq line 32.\n'


--------------------------------------------------------------------------------------------------------------
| epoch 204 |     1/  658 batches | total loss  1.24 | seq2seq loss  1.23 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 35.92
| epoch 204 |    95/  658 batches | total loss  1.12 | seq2seq loss  1.11 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 38.55
| epoch 204 |   189/  658 batches | total loss  1.12 | seq2seq loss  1.11 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 38.94
| epoch 204 |   283/  658 batches | total loss  1.12 | seq2seq loss  1.11 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 38.78
| epoch 204 |   377/  658 batches | total loss  1.11 | seq2seq loss  1.11 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 38.86
| epoch 204 |   471/  658 batches | total loss  1.11 | seq2seq loss  1.10 | style loss  0.00 | coherence

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


| epoch 206 |     1/  658 batches | total loss  0.90 | seq2seq loss  0.90 | style loss  0.00 | coherence loss  0.01| total accuracy  0.32 | total BLEU 43.43
| epoch 206 |    95/  658 batches | total loss  1.08 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.77
| epoch 206 |   189/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.33
| epoch 206 |   283/  658 batches | total loss  1.09 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.47
| epoch 206 |   377/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.35
| epoch 206 |   471/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.23
| epoch 206 |   565/  658 batches | total loss  1.10 | seq

multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmpgnr6e5gz: not in gzip format\nERROR: could not find reference file /tmp/tmpgnr6e5gz at /tmp/tmpl3k93bpn line 32.\n'


| epoch 208 |   189/  658 batches | total loss  1.10 | seq2seq loss  1.10 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.09
| epoch 208 |   283/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.14
| epoch 208 |   377/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.23
| epoch 208 |   471/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.24
| epoch 208 |   565/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.27
--------------------------------------------------------------------------------------------------------------
| epoch 209 |     1/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence

multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmppiij4ygz: not in gzip format\nERROR: could not find reference file /tmp/tmppiij4ygz at /tmp/tmpxj6t_1pl line 32.\n'


| epoch 210 |   283/  658 batches | total loss  1.09 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.31
| epoch 210 |   377/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.22
| epoch 210 |   471/  658 batches | total loss  1.10 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.26
| epoch 210 |   565/  658 batches | total loss  1.09 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.32
--------------------------------------------------------------------------------------------------------------
| epoch 211 |     1/  658 batches | total loss  1.29 | seq2seq loss  1.29 | style loss  0.00 | coherence loss  0.01| total accuracy  0.27 | total BLEU 38.12
| epoch 211 |    95/  658 batches | total loss  1.10 | seq2seq loss  1.10 | style loss  0.00 | coherence

multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmp6ztcf1gz: not in gzip format\nERROR: could not find reference file /tmp/tmp6ztcf1gz at /tmp/tmphyd28t3r line 32.\n'


| epoch 216 |   189/  658 batches | total loss  1.09 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.82
| epoch 216 |   283/  658 batches | total loss  1.09 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.63
| epoch 216 |   377/  658 batches | total loss  1.09 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.65
| epoch 216 |   471/  658 batches | total loss  1.09 | seq2seq loss  1.09 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.67


multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmpszpr8ygz: not in gzip format\nERROR: could not find reference file /tmp/tmpszpr8ygz at /tmp/tmpa5zxehjd line 32.\n'


| epoch 216 |   565/  658 batches | total loss  1.09 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.72
--------------------------------------------------------------------------------------------------------------
| epoch 217 |     1/  658 batches | total loss  1.07 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.19 | total BLEU 39.27


multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmp38h8dmgz: not in gzip format\nERROR: could not find reference file /tmp/tmp38h8dmgz at /tmp/tmpjtset_hm line 32.\n'
multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmpjym0f3gz: not in gzip format\nERROR: could not find reference file /tmp/tmpjym0f3gz at /tmp/tmpw13g3xhp line 32.\n'


| epoch 217 |    95/  658 batches | total loss  1.09 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 38.72
| epoch 217 |   189/  658 batches | total loss  1.09 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.14
| epoch 217 |   283/  658 batches | total loss  1.08 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.31
| epoch 217 |   377/  658 batches | total loss  1.08 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.33
| epoch 217 |   471/  658 batches | total loss  1.09 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.29
| epoch 217 |   565/  658 batches | total loss  1.09 | seq2seq loss  1.08 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.30


multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmprgi29rgz: not in gzip format\nERROR: could not find reference file /tmp/tmprgi29rgz at /tmp/tmpa09ajtw_ line 32.\n'


--------------------------------------------------------------------------------------------------------------
| epoch 218 |     1/  658 batches | total loss  0.97 | seq2seq loss  0.97 | style loss  0.00 | coherence loss  0.01| total accuracy  0.25 | total BLEU 45.31
| epoch 218 |    95/  658 batches | total loss  1.07 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.13
| epoch 218 |   189/  658 batches | total loss  1.07 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.06
| epoch 218 |   283/  658 batches | total loss  1.08 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.85
| epoch 218 |   377/  658 batches | total loss  1.08 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.87
| epoch 218 |   471/  658 batches | total loss  1.08 | seq2seq loss  1.07 | style loss  0.00 | coherence

multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmpu3romugz: not in gzip format\nERROR: could not find reference file /tmp/tmpu3romugz at /tmp/tmp053n5mfj line 32.\n'


--------------------------------------------------------------------------------------------------------------
| epoch 225 |     1/  658 batches | total loss  1.15 | seq2seq loss  1.15 | style loss  0.00 | coherence loss  0.01| total accuracy  0.19 | total BLEU 39.15
| epoch 225 |    95/  658 batches | total loss  1.08 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 40.14
| epoch 225 |   189/  658 batches | total loss  1.08 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.94
| epoch 225 |   283/  658 batches | total loss  1.08 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.87
| epoch 225 |   377/  658 batches | total loss  1.08 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.84
| epoch 225 |   471/  658 batches | total loss  1.08 | seq2seq loss  1.07 | style loss  0.00 | coherence

multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmpmz298zgz: not in gzip format\nERROR: could not find reference file /tmp/tmpmz298zgz at /tmp/tmpyl5_l0yy line 32.\n'


| epoch 229 |   565/  658 batches | total loss  1.07 | seq2seq loss  1.06 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.08
--------------------------------------------------------------------------------------------------------------
| epoch 230 |     1/  658 batches | total loss  1.08 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.23 | total BLEU 36.81
| epoch 230 |    95/  658 batches | total loss  1.07 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.25
| epoch 230 |   189/  658 batches | total loss  1.06 | seq2seq loss  1.05 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.55
| epoch 230 |   283/  658 batches | total loss  1.06 | seq2seq loss  1.06 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.46
| epoch 230 |   377/  658 batches | total loss  1.07 | seq2seq loss  1.07 | style loss  0.00 | coherence

multi-bleu.perl script returned non-zero exit code
b'\ngzip: /tmp/tmpw806xugz: not in gzip format\nERROR: could not find reference file /tmp/tmpw806xugz at /tmp/tmp8ctmijiv line 32.\n'


--------------------------------------------------------------------------------------------------------------
| epoch 235 |     1/  658 batches | total loss  1.01 | seq2seq loss  1.00 | style loss  0.00 | coherence loss  0.01| total accuracy  0.32 | total BLEU 43.96
| epoch 235 |    95/  658 batches | total loss  1.07 | seq2seq loss  1.06 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 39.91
| epoch 235 |   189/  658 batches | total loss  1.07 | seq2seq loss  1.06 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.19
| epoch 235 |   283/  658 batches | total loss  1.07 | seq2seq loss  1.07 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.20
| epoch 235 |   377/  658 batches | total loss  1.07 | seq2seq loss  1.06 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.10
| epoch 235 |   471/  658 batches | total loss  1.07 | seq2seq loss  1.06 | style loss  0.00 | coherence

Unable to fetch multi-bleu.perl script


| epoch 242 |   283/  658 batches | total loss  1.06 | seq2seq loss  1.05 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.49
| epoch 242 |   377/  658 batches | total loss  1.06 | seq2seq loss  1.05 | style loss  0.00 | coherence loss  0.01| total accuracy  0.24 | total BLEU 40.56


Unable to fetch multi-bleu.perl script


KeyboardInterrupt: 

In [None]:
# reduction = 'sum' l_r = 5e-4
from matplotlib import pyplot as plt
import numpy as np
rango = np.arange(len(loss_graph))
plt.figure(1,figsize=(16,16))
plt.subplot(231)
plt.plot(rango,loss_graph)
plt.title("Total Loss")
plt.grid()
plt.subplot(234)
plt.plot(rango,loss_graph)
plt.xscale('log')
plt.yscale('log')
plt.title("Total Loss Log")
plt.grid()
plt.subplot(232)
plt.plot(rango,loss_seq_list)
plt.title("Seq Loss")
plt.grid()
plt.subplot(235)
plt.plot(rango,loss_seq_list)
plt.xscale('log')
plt.yscale('log')
plt.title("Seq Loss Log")
plt.grid()
plt.subplot(233)
plt.plot(rango,loss_style_list)
plt.title("Style Loss")
plt.grid()
plt.subplot(236)
plt.plot(rango,loss_style_list)
plt.title("Style Loss Log")
plt.xscale('log')
plt.yscale('log')
plt.grid()
plt.show()

In [207]:
x,y,ctx_x,_,_,label_x,_ = train_data[0]

In [204]:
x = string2code("my dear juliet . do you want to hang out with me ?".split(" "),dict_words).to(device)
ctx = string2code("hello .".split(" "),dict_words).to(device)
label_x = 0

In [208]:
if(label_x == 1):
    print("Original Style : Shakespearian")
else:
    print("Original Style : Modern")
print("Original Phrase :",code2string(x,dict_token))
print("Target phrase :",code2string(y,dict_token))
print("Context_x :",code2string(ctx_x,dict_token))

Original Style : Modern
Original Phrase : <SOS> i’ve lurked in the shadows here to watch the downfall of my enemies . <EOS>
Target phrase : <SOS> here in these confines slyly have i lurked to watch the waning of mine enemies . <EOS>
Context_x : <SOS> let him thank me , who helped him get there . i’ll head to france soon . <EOS>


In [209]:
    #Generation de phrase 
with torch.no_grad():
        h_t = model.translator(x.unsqueeze(0),ctx_x.unsqueeze(0),y.unsqueeze(0),torch.tensor([label_x]).unsqueeze(0)).to(device)
        phrase = torch.tensor([[0]]).to(device)
        print("Starting with: ",end='')
        for p in phrase:
            print(dict_token[p.item()],end=' ')
        print("")
        i = 0
        limit = 20
        flag = False
        while(not(flag) and i != limit):
            #mask = model._generate_square_subsequent_mask(phrase.shape[1])
            y_aux = model.embed_layer(phrase)
            y_pos = model.positional_layer(y_aux)
            y_pred = model.sentence_decoder(y_pos,h_t).transpose(0,1)
            y_pred = decoder_linear(y_pred)
            y_pred = torch.argmax(softmax_layer(y_pred),dim = 2)
            phrase = torch.cat((phrase,y_pred[:,-1].reshape((1,1))),0)
            i += 1
            flag = (y_pred[:,-1].item() == 1)
        print("Produced phrase: ",end='')
        for p in phrase:
            print(dict_token[p.item()],end=' ')
        print("")
            

Starting with: <SOS> 
Produced phrase: <SOS> i have heard any unseasonable the best to the rest . <EOS> 


In [38]:
x = string2code("i love killing people .".split(" "),dict_words).to(device)
ctx = string2code("i have to tell you something . this is my passion .".split(" "),dict_words).to(device)
label_x = 0

In [20]:
hypotheses

['<SOS> this fake astonishment of yours is just like your other pranks . <EOS> ,',
 '<SOS> go to the church . <EOS> !',
 '<SOS> and yet i may yet converted but see you , but methinks i see to your eyes . other men’s . i <EOS> to',
 '<SOS> how , voltemand , what the the quick norway ? <EOS> .',
 '<SOS> i want to uncles here to welcome to . <EOS> .',
 '<SOS> macduff is missing , and your noble son . <EOS> .',
 '<SOS> let you favor to nope . <EOS> !',
 '<SOS> how can i i return from to earth , unless my husband send from from from heaven and leaving earth ? <EOS> ,',
 '<SOS> are you call me me ? too worship ? <EOS> ?',
 '<SOS> i gives , pray madam . <EOS> .',
 '<SOS> o , ominous ! <EOS> are',
 '<SOS> less in the knowledge and the grace is are scarce the the earth’s , thus and than to divine . <EOS> ,',
 '<SOS> no no else . . <EOS> lord',
 '<SOS> i highness , would to force me to tell you <EOS> .',
 '<SOS> this would breed from me occasions , and i can try that i know work more <EOS> my',


In [21]:
references

['<SOS> this fake astonishment of yours is just like your other pranks . <EOS>',
 '<SOS> go to the church . <EOS>',
 '<SOS> and how you may be converted i know not , but methinks you look with your eyes as other women do . <EOS>',
 '<SOS> say , voltemand , what from our brother norway ? <EOS>',
 '<SOS> i want more uncles here to welcome me . <EOS>',
 '<SOS> macduff is missing , and your noble son . <EOS>',
 '<SOS> as a favor to nope . <EOS>',
 '<SOS> how shall that faith return again to earth , unless that husband send it me from heaven by leaving earth ? <EOS>',
 '<SOS> did you call for me , your worship ? <EOS>',
 '<SOS> that will i , pompey . <EOS>',
 '<SOS> oh , ominous ! <EOS>',
 '<SOS> less in your knowledge and your grace you show not than our earth’s wonder , more than earth divine . <EOS>',
 '<SOS> there’s none else by . <EOS>',
 '<SOS> your highness will have to force me to tell . <EOS>',
 '<SOS> i would breed from hence occasions , and i shall , that i may speak . <EOS>',
 '