<h2> Embedding

In [792]:
from torch import nn

In [878]:
class Embedder(nn.Module): #임베딩 레이어
    def __init__(self,vocab_size,d_model):
        super().__init__()
        self.embed = nn.Embedding(vocab_size,d_model)
    
    def forward(self,x):
        #x= x.to("cpu")
        return self.embed(x)

In [794]:
e = Embedder(32000,768)

In [795]:
idx = torch.tensor([1,2,3])
e(idx)

tensor([[ 1.9352,  0.3704,  1.2609,  ...,  1.0664, -1.5787, -0.4266],
        [-1.1399,  1.3575,  0.0628,  ...,  1.1455, -1.6383, -0.4114],
        [-1.1844, -0.0755, -0.8986,  ..., -0.3078, -1.6215,  0.0780]],
       grad_fn=<EmbeddingBackward0>)

<h2> Positional Encoding

In [796]:
import math
import torch

In [797]:
class PositionalEncoder(nn.Module): #위치 인코딩 레이어
    def __init__(self, d_model, max_seq_len = 128):
        super().__init__()
        self.d_model = d_model
        
        #위치 인코딩 초기화
        positional_encoding = torch.zeros(max_seq_len, d_model)
        
        for pos in range(max_seq_len):
            for i in range(0, d_model, 2):
                positional_encoding[pos, i] = math.sin(pos / (10000 ** ((2 * i)/d_model)))
                positional_encoding[pos, i + 1] = math.cos(pos / (10000 ** ((2 * (i + 1))/d_model)))
                
        positional_encoding = positional_encoding.unsqueeze(0)
        #잔차 연결
        self.register_buffer('positional_encoding', positional_encoding)
 
    def forward(self, x):
        x = x * math.sqrt(self.d_model)
        seq_len = x.size(1)
        x = x + self.positional_encoding[:,:seq_len]
        return x

In [798]:
p = PositionalEncoder(768)

In [799]:
#y_pred값 대충 생성
input_seq = torch.randint(0,40,(128,))
emb = e(input_seq)
emb.shape

torch.Size([128, 768])

In [800]:
y_pred = p(emb.to("cpu"))

In [801]:
#y값 대충 생성
input_seq2 = torch.randint(0,40,(128,))
emb2 = e(input_seq2)

In [802]:
y = p(emb2)

In [803]:
loss = torch.mean((y_pred - y)**2)
loss.backward()

<h2> Masking

In [804]:
def get_attn_mask(input_seq,input_pad):
    return (input_seq != input_pad).unsqueeze(1).to("cuda")

In [805]:
def get_target_attn_mask(target_seq,target_pad):
    target_mask = (target_seq != target_pad).unsqueeze(1)
    size = target_seq.size(1)
    nopeak_mask = np.triu(np.ones((1,size,size)),k=1).astype('uint8')
    nopeak_mask = torch.from_numpy(nopeak_mask) == 0
    nopeak_mask = nopeak_mask.to("cuda")
    target_mask = target_mask & nopeak_mask
    return target_mask

<h2> Self-Attention

In [806]:
from torch.nn import functional as F

In [807]:
def attention(query, key, value, d_k, attention_mask=False, dropout=None):
    
    scores = torch.matmul(query,key.transpose(-2,-1)) /  math.sqrt(d_k)
    
    if attention_mask is True:
        attention_mask = get_attn_mask(scores,0)
        attention_mask = attention_mask.unsqueeze(1)
        scores = scores.masked_fill(attention_mask == 0, -1e9)
    
    scores = F.softmax(scores, dim=-1)
    
    if dropout is not None:
        scores = dropout(scores)
        
    output = torch.matmul(scores, value)
    return output

<h2> Multi-Headed Attention

In [808]:
from scipy.special import softmax

In [809]:
class MultiHeadAttention(nn.Module):
    def __init__(self,heads,d_model,dropout_rate = 0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        
        self.query_layer = nn.Linear(d_model,d_model)
        self.key_layer = nn.Linear(d_model,d_model)
        self.value_layer = nn.Linear(d_model,d_model)
        
        self.dropout = nn.Dropout(dropout_rate)
        self.out_layer = nn.Linear(d_model,d_model)
        
    def forward(self,query,key,value,attention_mask = True):
        
        batch_size = query.size(0)
        
        query = self.key_layer(query).view(batch_size,-1,self.h,self.d_k)
        key = self.key_layer(key).view(batch_size,-1,self.h,self.d_k)
        value = self.key_layer(value).view(batch_size,-1,self.h,self.d_k)
        
        query =  query.transpose(1,2)
        key = key.transpose(1,2)
        value = value.transpose(1,2)
        
        scores = attention(query,key,value,self.d_k,attention_mask,self.dropout)
        Z = scores.transpose(1,2).contiguous().view(batch_size,-1,self.d_model)
        output = self.dropout(Z)
        
        return output

In [810]:
mha = MultiHeadAttention(8,512)

In [811]:
query = torch.rand((1,3,512))
key = torch.rand((1,3,512))
value = torch.rand((1,3,512))

<h2> Feed-Forward

In [845]:
class FeedForward(nn.Module):
    def __init__(self,d_model,d_ff = 2048,dropout = 0.1):
        super().__init__()
        
        self.fc1 = nn.Linear(d_model,d_ff)
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(d_ff,d_model)
        
    def forward(self,x):
        x = self.dropout(F.relu(self.fc1(x)))
        x = self.fc2(x)
        return x

<h2> Normalisation

<ul>
    <li>prevents the range of value in the layers changing too much
<li>it makes model trains faster and has better ability</ul>

In [813]:
class Norm(nn.Module):
    def __init__(self,d_model,eps = 1e-6):
        super().__init__()
        
        self.size = d_model
        
        self.alpha = nn.Parameter(torch.ones(self.size))
        self.bias = nn.Parameter(torch.zeros(self.size))
        self.eps = eps
    
    def forward(self,x):
        norm = self.alpha *(x - x.mean(dim = -1,keepdim=True)) / (x.std(dim=-1, keepdim=True) + self.eps) + self.bias
        return norm

<h1> ★Encoder&Decoder Layer★

<ui>
<li> build an encoder layer with one multi-head attention layer and one feed-forward layer </li>
</ui>

In [814]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model,heads,dropout = 0.1):
        super().__init__()
        self.norm1 = Norm(d_model)
        self.norm2 = Norm(d_model)
        self.mha = MultiHeadAttention(heads,d_model)
        self.ff = FeedForward(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        
    def forward(self,x,mask):
        x2 = self.norm1(x)
        x = x + self.dropout1(self.mha(x2,x2,x2,mask))
        x2 = self.norm2(x)
        x = x + self.dropout2(self.ff(x2))
        return x

In [840]:
class DecoderLayer(nn.Module):
    def __init__(self,d_model,heads,dropout=0.1):
        super().__init__()
        self.norm1 = Norm(d_model)
        self.norm2 = Norm(d_model)
        self.norm3 = Norm(d_model)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        
        self.mha1 = MultiHeadAttention(heads, d_model)
        self.mha2 = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model)
        
    def forward(self,x,encoder_outputs,src_mask,trg_mask):
        x2 = self.norm1(x)
        x = x + self.dropout1(self.mha1(x2,x2,x2,trg_mask))
        
        x2 = self.norm2(x)
        x = x + self.dropout2(self.mha2(x2,encoder_outputs,encoder_outputs,src_mask))
        
        x2 = self.norm3(x)
        x = x + self.dropout3(self.ff(x2))
        return x
        

In [816]:
# build a conveinient cloning function that can generate multiple layers
import copy
def get_clones(module,N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

<h1> Encoder Decoder

In [817]:
class Encoder(nn.Module):
    def __init__(self,vocab_size,d_model,N,heads):
        super().__init__()
        self.N = N
        self.embed = Embedder(vocab_size,d_model)
        self.pe = PositionalEncoder(d_model)
        self.layers = get_clones(EncoderLayer(d_model,heads),N)
        self.norm = Norm(d_model)
    
    def forward(self,src,mask):
        x = self.embed(src)
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x,mask)
        return self.norm(x)

In [830]:
class Decoder(nn.Module):
    def __init__(self,vocab_size,d_model,N,heads):
        super().__init__()
        self.N = N
        self.embed = Embedder(vocab_size,d_model)
        self.pe = PositionalEncoder(d_model)
        self.layers = get_clones(DecoderLayer(d_model,heads),N)
        self.norm = Norm(d_model)
    def forward(self,target,encoder_outputs,src_mask,target_mask):
        x = self.embed(target)
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x,encoder_outputs,src_mask,target_mask)
        return self.norm(x)

In [831]:
class Transformer(nn.Module):
    def __init__(self, src_vocab, trg_vocab, d_model, N, heads):
        super().__init__()
        self.encoder = Encoder(src_vocab, d_model, N, heads)
        self.decoder = Decoder(trg_vocab, d_model, N, heads)
        self.out = nn.Linear(d_model, trg_vocab)
    def forward(self, src, trg, src_mask, trg_mask):
        e_outputs = self.encoder(src, src_mask)
        d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
        output = self.out(d_output)
        return output

<h2> Dataset

In [820]:
class NMTDataset(Dataset):
    def __init__(self, text_df, vectorizer):
        
        self.text_df = text_df
        self._vectorizer = vectorizer

        self.train_df = self.text_df[self.text_df.split=='train']
        self.train_size = len(self.train_df)

        self.val_df = self.text_df[self.text_df.split=='val']
        self.validation_size = len(self.val_df)

        self.test_df = self.text_df[self.text_df.split=='test']
        self.test_size = len(self.test_df)

        self._lookup_dict = {'train': (self.train_df, self.train_size),
                             'val': (self.val_df, self.validation_size),
                             'test': (self.test_df, self.test_size)}

        self.set_split('train')

    @classmethod
    def load_dataset_and_make_vectorizer(cls, dataset_csv):
        text_df = pd.read_csv(dataset_csv)
        train_subset = text_df[text_df.split=='train']
        return cls(text_df, NMTVectorizer.from_dataframe(train_subset))

    @classmethod
    def load_dataset_and_load_vectorizer(cls, dataset_csv, vectorizer_filepath):
        text_df = pd.read_csv(dataset_csv)
        vectorizer = cls.load_vectorizer_only(vectorizer_filepath)
        return cls(text_df, vectorizer)

    @staticmethod
    def load_vectorizer_only(vectorizer_filepath):
        with open(vectorizer_filepath) as fp:
            return NMTVectorizer.from_serializable(json.load(fp))

    def save_vectorizer(self, vectorizer_filepath):
        with open(vectorizer_filepath, "w") as fp:
            json.dump(self._vectorizer.to_serializable(), fp)

    def get_vectorizer(self):
        return self._vectorizer

    def set_split(self, split="train"):
        self._target_split = split
        self._target_df, self._target_size = self._lookup_dict[split]

    def __len__(self):
        return self._target_size

    def __getitem__(self, index):

        row = self._target_df.iloc[index]

        vector_dict = self._vectorizer.vectorize(row.source_language, row.target_language)

        return {"source": vector_dict["source_vector"], 
                "target": vector_dict["target_vector"],
                "source_length": vector_dict["source_length"]}
        
    def get_num_batches(self, batch_size):

        return len(self) // batch_size

<h2> DataLoader

In [821]:
from torch.utils.data import DataLoader
def generate_nmt_batches(dataset, batch_size, shuffle=True, 
                            drop_last=True, device="cuda"):

    dataloader = DataLoader(dataset=dataset, batch_size=batch_size,
                            shuffle=shuffle, drop_last=drop_last)

    for data_dict in dataloader:
        lengths = data_dict['source_length'].numpy()
        sorted_length_indices = lengths.argsort().tolist()
        
        out_data_dict = {}
        for name, tensor in data_dict.items():
            out_data_dict[name] = data_dict[name][sorted_length_indices].to(device)
        yield out_data_dict

<h2> Vocabulary
   

In [822]:
class Vocabulary(object):
    def __init__(self, token_to_idx=None):

        if token_to_idx is None:
            token_to_idx = {}
            
        self._token_to_idx = token_to_idx

        self._idx_to_token = {idx: token 
                              for token, idx in self._token_to_idx.items()}
        
    def to_serializable(self):
        return {'token_to_idx': self._token_to_idx}

    @classmethod
    def from_serializable(cls, contents):
        return cls(**contents)

    def add_token(self, token):
        if token in self._token_to_idx:
            index = self._token_to_idx[token]
        else:
            index = len(self._token_to_idx)
            self._token_to_idx[token] = index
            self._idx_to_token[index] = token
        return index
            
    def add_many(self, tokens):
        return [self.add_token(token) for token in tokens]

    def lookup_token(self, token):
        return self._token_to_idx[token]

    def lookup_index(self, index):
        if index not in self._idx_to_token:
            raise KeyError("the index (%d) is not in the Vocabulary" % index)
        return self._idx_to_token[index]

    def __str__(self):
        return "<Vocabulary(size=%d)>" % len(self)

    def __len__(self):
        return len(self._token_to_idx)

In [823]:
class SequenceVocabulary(Vocabulary):
    def __init__(self, token_to_idx=None, unk_token="<UNK>",
                 mask_token="<MASK>", begin_seq_token="<BEGIN>",
                 end_seq_token="<END>"):

        super(SequenceVocabulary, self).__init__(token_to_idx)

        self._mask_token = mask_token
        self._unk_token = unk_token
        self._begin_seq_token = begin_seq_token
        self._end_seq_token = end_seq_token

        self.mask_index = self.add_token(self._mask_token)
        self.unk_index = self.add_token(self._unk_token)
        self.begin_seq_index = self.add_token(self._begin_seq_token)
        self.end_seq_index = self.add_token(self._end_seq_token)

    def to_serializable(self):
        contents = super(SequenceVocabulary, self).to_serializable()
        contents.update({'unk_token': self._unk_token,
                         'mask_token': self._mask_token,
                         'begin_seq_token': self._begin_seq_token,
                         'end_seq_token': self._end_seq_token})
        return contents

    def lookup_token(self, token):
        if self.unk_index >= 0:
            return self._token_to_idx.get(token, self.unk_index)
        else:
            return self._token_to_idx[token]

<h2> Vectorizer

In [863]:
class NMTVectorizer(object):
    def __init__(self, source_vocab, target_vocab, max_source_length, max_target_length):
        self.source_vocab = source_vocab
        self.target_vocab = target_vocab
        
        self.max_source_length = max_source_length
        self.max_target_length = max_target_length
        

    def _vectorize(self, indices, vector_length=-1, mask_index=0):

        if vector_length < 0:
            vector_length = len(indices)
        
        vector = np.zeros(vector_length, dtype=np.int64)
        vector[:len(indices)] = indices
        vector[len(indices):] = mask_index

        return vector
    
    def _get_source_indices(self, text):
        indices = [self.source_vocab.begin_seq_index]
        indices.extend(self.source_vocab.lookup_token(token) for token in text.split(" "))
        indices.append(self.source_vocab.end_seq_index)
        return indices
    
    def _get_target_indices(self, text):
        indices = [self.target_vocab.lookup_token(token) for token in text.split(" ")]
        indices = [self.target_vocab.begin_seq_index] + indices + [self.target_vocab.end_seq_index]
        return indices
        
    def vectorize(self, source_text, target_text, use_dataset_max_lengths=True):
        source_vector_length = -1
        target_vector_length = -1
        
        if use_dataset_max_lengths:
            source_vector_length = self.max_source_length + 2
            target_vector_length = self.max_target_length + 2
            
        source_indices = self._get_source_indices(source_text)
        source_vector = self._vectorize(source_indices, 
                                        vector_length=source_vector_length, 
                                        mask_index=self.source_vocab.mask_index)
        
        target_indices = self._get_target_indices(target_text)
        target_vector = self._vectorize(target_indices,
                                        vector_length=target_vector_length,
                                        mask_index=self.target_vocab.mask_index)

        return {"source_vector": source_vector, 
                "target_vector": target_vector, 
                "source_length": len(source_indices)}
        
    @classmethod
    def from_dataframe(cls, bitext_df):

        source_vocab = SequenceVocabulary()
        target_vocab = SequenceVocabulary()
        
        max_source_length = 0
        max_target_length = 0

        for _, row in bitext_df.iterrows():
            source_tokens = row["source_language"].split(" ")
            if len(source_tokens) > max_source_length:
                max_source_length = len(source_tokens)
            for token in source_tokens:
                source_vocab.add_token(token)
            
            target_tokens = row["target_language"].split(" ")
            if len(target_tokens) > max_target_length:
                max_target_length = len(target_tokens)
            for token in target_tokens:
                target_vocab.add_token(token)
            
        return cls(source_vocab, target_vocab, max_source_length, max_target_length)

    @classmethod
    def from_serializable(cls, contents):
        source_vocab = SequenceVocabulary.from_serializable(contents["source_vocab"])
        target_vocab = SequenceVocabulary.from_serializable(contents["target_vocab"])
        
        return cls(source_vocab=source_vocab, 
                   target_vocab=target_vocab, 
                   max_source_length=contents["max_source_length"], 
                   max_target_length=contents["max_target_length"])

    def to_serializable(self):
        return {"source_vocab": self.source_vocab.to_serializable(), 
                "target_vocab": self.target_vocab.to_serializable(), 
                "max_source_length": self.max_source_length,

                "max_target_length": self.max_target_length}

In [885]:
d_model = 512
heads = 8
N = 6
batch_size = 64
epochs = 100
dataset = NMTDataset.load_dataset_and_make_vectorizer("simplest_eng_fra.csv")
vectorizer = dataset.get_vectorizer()

src_vocab = len(vectorizer.source_vocab)
trg_vocab = len(vectorizer.target_vocab)

model = Transformer(src_vocab, trg_vocab, d_model, N, heads)
model = model.to("cuda")

for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

optim = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [883]:
from tqdm.notebook import tqdm

In [886]:
import time

start = time.time()
temp = start
total_loss=0
print_every=100
for epoch in tqdm(range(epochs),total=epochs):
    model.train()
    dataset.set_split('train')
    
    batch_generator = generate_nmt_batches(dataset, batch_size=64)
    
    for i, batch in tqdm(enumerate(batch_generator),total=len(dataset)/batch_size):
        optim.zero_grad()
        
        src = batch['source']
        trg = batch['target']

        trg_input = trg[:, :-1]
        targets = trg[:, 1:].contiguous().view(-1)

        src_mask = get_attn_mask(src,input_pad = vectorizer.source_vocab.mask_index)
        trg_mask = get_target_attn_mask(trg_input,target_pad = vectorizer.target_vocab.mask_index)
        preds = model(src, trg_input, src_mask, trg_mask)
        preds = preds.view(-1, preds.size(-1)).to("cuda")
        
        result = trg[:, 1:].contiguous().view(-1)
        loss = F.cross_entropy(preds,result, ignore_index=vectorizer.target_vocab.mask_index)
        
        loss.backward()
        optim.step()

        total_loss += loss.item()
        
        if (i + 1) % print_every == 0:
            loss_avg = total_loss / print_every
            print("time = %dm, epoch %d, iter = %d, loss = %.3f,%ds per %d iters" % \
                  ((time.time() - start) // 60,epoch + 1, i + 1, loss_avg, time.time() - temp,print_every))
            total_loss = 0
            temp = time.time()
            
    dataset.set_split('val')
    batch_generator = generate_nmt_batches(dataset, batch_size)
    
    running_loss = 0.
    model.eval()

    for i, batch in tqdm(enumerate(batch_generator),total=len(dataset)/batch_size):
        
        src = batch['source']
        trg = batch['target']
        
        trg_input = trg[:, :-1]
        targets = trg[:, 1:].contiguous().view(-1)
        
        src_mask = get_attn_mask(src,input_pad = vectorizer.source_vocab.mask_index)
        trg_mask = get_target_attn_mask(trg_input,target_pad = vectorizer.target_vocab.mask_index)
        
        preds = model(src, trg_input, src_mask, trg_mask)
        preds = preds.view(-1, preds.size(-1)).to("cuda")
        result = trg[:, 1:].contiguous().view(-1)
        loss = F.cross_entropy(preds,result, ignore_index=vectorizer.target_vocab.mask_index)
        running_loss += (loss.item() - running_loss) / (i + 1)
        
    print("validation: epoch %d, loss = %.3f," % (epoch + 1, running_loss) )

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 0m, epoch 1, iter = 100, loss = 4.916,11s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 1, loss = 3.450,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 0m, epoch 2, iter = 100, loss = 4.807,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 2, loss = 2.942,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 0m, epoch 3, iter = 100, loss = 4.048,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 3, loss = 2.575,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 1m, epoch 4, iter = 100, loss = 3.493,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 4, loss = 2.310,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 1m, epoch 5, iter = 100, loss = 3.054,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 5, loss = 2.112,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 1m, epoch 6, iter = 100, loss = 2.705,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 6, loss = 1.962,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 1m, epoch 7, iter = 100, loss = 2.447,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 7, loss = 1.846,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 2m, epoch 8, iter = 100, loss = 2.191,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 8, loss = 1.730,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 2m, epoch 9, iter = 100, loss = 2.015,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 9, loss = 1.628,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 2m, epoch 10, iter = 100, loss = 1.831,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 10, loss = 1.560,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 3m, epoch 11, iter = 100, loss = 1.660,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 11, loss = 1.512,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 3m, epoch 12, iter = 100, loss = 1.524,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 12, loss = 1.418,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 3m, epoch 13, iter = 100, loss = 1.382,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 13, loss = 1.354,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 3m, epoch 14, iter = 100, loss = 1.257,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 14, loss = 1.297,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 4m, epoch 15, iter = 100, loss = 1.157,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 15, loss = 1.255,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 4m, epoch 16, iter = 100, loss = 1.053,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 16, loss = 1.213,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 4m, epoch 17, iter = 100, loss = 0.950,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 17, loss = 1.185,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 5m, epoch 18, iter = 100, loss = 0.874,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 18, loss = 1.129,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 5m, epoch 19, iter = 100, loss = 0.795,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 19, loss = 1.120,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 5m, epoch 20, iter = 100, loss = 0.713,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 20, loss = 1.104,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 5m, epoch 21, iter = 100, loss = 0.645,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 21, loss = 1.069,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 6m, epoch 22, iter = 100, loss = 0.589,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 22, loss = 1.053,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 6m, epoch 23, iter = 100, loss = 0.534,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 23, loss = 1.042,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 6m, epoch 24, iter = 100, loss = 0.477,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 24, loss = 1.011,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 7m, epoch 25, iter = 100, loss = 0.433,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 25, loss = 0.989,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 7m, epoch 26, iter = 100, loss = 0.389,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 26, loss = 1.004,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 7m, epoch 27, iter = 100, loss = 0.349,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 27, loss = 0.976,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 7m, epoch 28, iter = 100, loss = 0.306,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 28, loss = 0.969,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 8m, epoch 29, iter = 100, loss = 0.275,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 29, loss = 0.965,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 8m, epoch 30, iter = 100, loss = 0.248,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 30, loss = 0.962,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 8m, epoch 31, iter = 100, loss = 0.215,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 31, loss = 0.951,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 9m, epoch 32, iter = 100, loss = 0.192,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 32, loss = 0.930,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 9m, epoch 33, iter = 100, loss = 0.166,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 33, loss = 0.943,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 9m, epoch 34, iter = 100, loss = 0.153,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 34, loss = 0.940,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 10m, epoch 35, iter = 100, loss = 0.130,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 35, loss = 0.944,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 10m, epoch 36, iter = 100, loss = 0.118,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 36, loss = 0.951,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 10m, epoch 37, iter = 100, loss = 0.103,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 37, loss = 0.927,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 10m, epoch 38, iter = 100, loss = 0.091,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 38, loss = 0.932,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 11m, epoch 39, iter = 100, loss = 0.082,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 39, loss = 0.933,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 11m, epoch 40, iter = 100, loss = 0.074,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 40, loss = 0.925,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 11m, epoch 41, iter = 100, loss = 0.069,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 41, loss = 0.936,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 12m, epoch 42, iter = 100, loss = 0.062,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 42, loss = 0.956,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 12m, epoch 43, iter = 100, loss = 0.052,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 43, loss = 0.945,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 12m, epoch 44, iter = 100, loss = 0.050,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 44, loss = 0.950,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 12m, epoch 45, iter = 100, loss = 0.046,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 45, loss = 0.974,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 13m, epoch 46, iter = 100, loss = 0.043,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 46, loss = 0.962,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 13m, epoch 47, iter = 100, loss = 0.040,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 47, loss = 0.961,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 13m, epoch 48, iter = 100, loss = 0.037,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 48, loss = 0.963,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 14m, epoch 49, iter = 100, loss = 0.034,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 49, loss = 0.948,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 14m, epoch 50, iter = 100, loss = 0.033,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 50, loss = 0.966,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 14m, epoch 51, iter = 100, loss = 0.031,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 51, loss = 0.969,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 14m, epoch 52, iter = 100, loss = 0.031,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 52, loss = 0.986,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 15m, epoch 53, iter = 100, loss = 0.027,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 53, loss = 0.958,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 15m, epoch 54, iter = 100, loss = 0.028,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 54, loss = 0.977,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 15m, epoch 55, iter = 100, loss = 0.025,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 55, loss = 0.947,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 16m, epoch 56, iter = 100, loss = 0.022,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 56, loss = 0.977,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 16m, epoch 57, iter = 100, loss = 0.023,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 57, loss = 0.974,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 16m, epoch 58, iter = 100, loss = 0.021,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 58, loss = 0.971,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 16m, epoch 59, iter = 100, loss = 0.020,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 59, loss = 1.006,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 17m, epoch 60, iter = 100, loss = 0.021,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 60, loss = 0.993,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 17m, epoch 61, iter = 100, loss = 0.019,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 61, loss = 0.986,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 17m, epoch 62, iter = 100, loss = 0.018,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 62, loss = 0.976,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 18m, epoch 63, iter = 100, loss = 0.018,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 63, loss = 0.976,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 18m, epoch 64, iter = 100, loss = 0.017,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 64, loss = 0.981,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 18m, epoch 65, iter = 100, loss = 0.017,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 65, loss = 0.985,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 18m, epoch 66, iter = 100, loss = 0.017,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 66, loss = 0.995,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 19m, epoch 67, iter = 100, loss = 0.016,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 67, loss = 0.996,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

time = 19m, epoch 68, iter = 100, loss = 0.014,17s per 100 iters


  0%|          | 0/30.546875 [00:00<?, ?it/s]

validation: epoch 68, loss = 0.983,


  0%|          | 0/142.78125 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [894]:
def get_source_sentence(vectorizer, batch_dict, index):
    indices = batch_dict['source'][index].cpu().data.numpy()
    vocab = vectorizer.source_vocab
    return sentence_from_indices(indices, vocab)

def get_true_sentence(vectorizer, batch_dict, index):
    return sentence_from_indices(batch_dict['target'].cpu().data.numpy()[index], vectorizer.target_vocab)
    
def get_sampled_sentence(vectorizer, batch_dict, index):
    y_pred = model(batch_dict['source'], batch_dict['target'],False,False )
    return sentence_from_indices(torch.max(y_pred, dim=2)[1].cpu().data.numpy()[index], vectorizer.target_vocab)

def get_all_sentences(vectorizer, batch_dict, index):
    return {"source": get_source_sentence(vectorizer, batch_dict, index), 
            "truth": get_true_sentence(vectorizer, batch_dict, index), 
            "sampled": get_sampled_sentence(vectorizer, batch_dict, index)}
    
def sentence_from_indices(indices, vocab, strict=True):
    ignore_indices = set([vocab.mask_index, vocab.begin_seq_index, vocab.end_seq_index])
    out = []
    for index in indices:
        if index == vocab.begin_seq_index and strict:
            continue
        elif index == vocab.end_seq_index and strict:
            return " ".join(out)
        else:
            out.append(vocab.lookup_index(index))
    return " ".join(out)

In [937]:
dataset.set_split('test')
batch_generator = generate_nmt_batches(dataset, 
                                       batch_size=batch_size)
batch_dict = next(batch_generator)

model = model.eval().to("cuda")

In [None]:
source = "I'll be back in a jiffy."
target = "Je reviens dans un instant."

In [941]:
batch_dict['source'].shape

torch.Size([64, 25])

In [938]:
results = get_all_sentences(vectorizer, batch_dict, 1)

In [939]:
results

{'source': "you 're finicky .",
 'truth': 'vous êtes <UNK> .',
 'sampled': 'vous êtes tatillonnes .'}