In [62]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

In [63]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_model,n_head,d_k,max_len,casual=False):
    super().__init__()
    self.q=nn.Linear(d_model,n_head*d_k)
    self.k=nn.Linear(d_model,n_head*d_k)
    self.v=nn.Linear(d_model,n_head*d_k)
    self.final_layer=nn.Linear(d_k*n_head,d_model)
    self.d_k=d_k
    self.n_head=n_head
    self.casual=casual
    if casual==True:
      casual_mask=torch.ones(max_len,max_len)
      casual_mask=torch.tril(casual_mask)
      self.register_buffer('casual_mask',casual_mask.view(1,1,max_len,max_len))
  def forward(self,q,k,v,token_mask=None):
    q=self.q(q)
    k=self.k(k)
    v=self.v(v)
    N=q.shape[0]
    T_output=q.shape[1]
    T_input=k.shape[1]
    q=q.view(N,T_output,self.n_head,self.d_k).transpose(1,2)
    k=k.view(N,T_input,self.n_head,self.d_k).transpose(1,2)
    v=v.view(N,T_input,self.n_head,self.d_k).transpose(1,2)
    attention_layer=q @ k.transpose(-1,-2) / math.sqrt(self.d_k)
    if token_mask!=None:
      attention_layer=attention_layer.masked_fill(token_mask[:,None,None,:]==0,float('-inf'))
    if self.casual==True:
      attention_layer=attention_layer.masked_fill(self.casual_mask[:,:,:T_output,:T_input]==0,float('-inf'))
    attention_layer=F.softmax(attention_layer,dim=-1)
    A=attention_layer @ v
    A=A.transpose(1,2)
    A=A.contiguous().view(N,T_output,self.d_k*self.n_head)
    return self.final_layer(A)

In [64]:
class PositionalEncoding(nn.Module):
  def __init__(self,max_len,d_model,dropout_prob=0.1):
    super().__init__()
    self.l1=nn.Dropout(dropout_prob)
    i=torch.arange(0,d_model,2)
    pos=torch.arange(max_len).unsqueeze(1)
    x=torch.exp(-i*(math.log(10000)/d_model))
    PE=torch.zeros(1,max_len,d_model)
    PE[0,:,0::2]=torch.sin(pos*x)
    PE[0,:,1::2]=torch.cos(pos*x)
    self.register_buffer('PE',PE)
  def forward(self,x):
    k=x.size(1)
    x=x+self.PE[:,:k,:]
    return self.l1(x)

In [65]:
class Encoder_Block(nn.Module):
  def __init__(self,d_model,n_head,d_k,max_len,dropout_prob=0.1):
    super().__init__()
    self.ln1=nn.LayerNorm(d_model)
    self.ln2=nn.LayerNorm(d_model)
    self.mha=MultiHeadAttention(d_model,n_head,d_k,max_len)
    self.ann=nn.Sequential(
        nn.Linear(d_model,d_model*4),
        nn.GELU(),
        nn.Linear(d_model*4,d_model),
        nn.Dropout(dropout_prob)
    )
    self.dropout=nn.Dropout(dropout_prob)
  def forward(self,x,mask_token):
    x=self.ln1(x+self.mha(x,x,x,mask_token))
    x=self.ln2(x+self.ann(x))
    x=self.dropout(x)
    return x

In [66]:
class Decoder_Block(nn.Module):
  def __init__(self,d_model,d_k,n_head,max_len,dropout_prob=0.1):
    super().__init__()
    self.ln1=nn.LayerNorm(d_model)
    self.ln2=nn.LayerNorm(d_model)
    self.ln3=nn.LayerNorm(d_model)
    self.mha1=MultiHeadAttention(d_model,n_head,d_k,max_len,casual=True)
    self.mha2=MultiHeadAttention(d_model,n_head,d_k,max_len)
    self.ann=nn.Sequential(
        nn.Linear(d_model,d_model*4),
        nn.GELU(),
        nn.Linear(d_model*4,d_model),
        nn.Dropout(dropout_prob)
    )
    self.dropout=nn.Dropout(dropout_prob)
  def forward(self,encoder_output,decoder_input,decoder_mask=False,encoder_mask=False):
    decoder_input=self.ln1(decoder_input+self.mha1(decoder_input,decoder_input,decoder_input,decoder_mask))
    decoder_input=self.ln2(decoder_input+self.mha2(decoder_input,encoder_output,encoder_output,encoder_mask))
    decoder_input=self.ln3(decoder_input+self.ann(decoder_input))
    return self.dropout(decoder_input)

In [67]:
class Encoder(nn.Module):
  def __init__(self,vocab_size,d_model,max_len,d_k,n_head,n_layers,dropout_prob=0.1):
    super().__init__()
    self.embedding=nn.Embedding(vocab_size,d_model)
    self.positional_embedding=PositionalEncoding(max_len,d_model,dropout_prob)
    transformer_block=[Encoder_Block(d_model,n_head,d_k,max_len,dropout_prob) for _ in range(n_layers)]
    self.transformer_block=nn.Sequential(*transformer_block)
    self.ln1=nn.LayerNorm(d_model)
  def forward(self,x,mask_token):
    x=self.embedding(x)
    x=x+self.positional_embedding(x)
    for z in self.transformer_block:
      x=z(x,mask_token)
    return self.ln1(x)

In [68]:
class Decoder(nn.Module):
  def __init__(self,vocab_size,d_model,max_len,d_k,n_head,n_layers,dropout_prob=0.1):
    super().__init__()
    self.embedding=nn.Embedding(vocab_size,d_model)
    self.positional_embedding=PositionalEncoding(max_len,d_model)
    transformer_block=[Decoder_Block(d_model,d_k,n_head,max_len,dropout_prob) for _ in range(n_layers)]
    self.transformer_block=nn.Sequential(*transformer_block)
    self.ln1=nn.LayerNorm(d_model)
    self.final_layer=nn.Linear(d_model,vocab_size)
  def forward(self,decoder_input,encoder_output,decoder_mask,encoder_mask):
    x=self.embedding(decoder_input)
    x=self.positional_embedding(x)
    for z in self.transformer_block:
      x=z(encoder_output,x,decoder_mask=decoder_mask,encoder_mask=encoder_mask)
    x=self.ln1(x)
    x=self.final_layer(x)
    return x

In [69]:
class Transformer(nn.Module):
  def __init__(self,encoder,decoder):
    super().__init__()
    self.encoder=encoder
    self.decoder=decoder
  def forward(self,encoder_input,decoder_input,encoder_mask,decoder_mask):
    encoder_output=self.encoder(encoder_input,encoder_mask)
    decoder_output=self.decoder(decoder_input,encoder_output,decoder_mask,encoder_mask)
    return decoder_output

In [70]:
encoder=Encoder(vocab_size=20000,d_model=64,max_len=512,d_k=16,n_head=4,n_layers=2)

In [71]:
decoder=Decoder(vocab_size=10000,d_model=64,max_len=512,d_k=16,n_head=4,n_layers=2)

In [72]:
transformer=Transformer(encoder,decoder)

In [73]:
device='cuda:0'

In [74]:
encoder.to(device)

Encoder(
  (embedding): Embedding(20000, 64)
  (positional_embedding): PositionalEncoding(
    (l1): Dropout(p=0.1, inplace=False)
  )
  (transformer_block): Sequential(
    (0): Encoder_Block(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (q): Linear(in_features=64, out_features=64, bias=True)
        (k): Linear(in_features=64, out_features=64, bias=True)
        (v): Linear(in_features=64, out_features=64, bias=True)
        (final_layer): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): Encoder_Block(
      (ln1): LayerNorm((64,), eps=1e-05, eleme

In [75]:
decoder.to(device)

Decoder(
  (embedding): Embedding(10000, 64)
  (positional_embedding): PositionalEncoding(
    (l1): Dropout(p=0.1, inplace=False)
  )
  (transformer_block): Sequential(
    (0): Decoder_Block(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha1): MultiHeadAttention(
        (q): Linear(in_features=64, out_features=64, bias=True)
        (k): Linear(in_features=64, out_features=64, bias=True)
        (v): Linear(in_features=64, out_features=64, bias=True)
        (final_layer): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha2): MultiHeadAttention(
        (q): Linear(in_features=64, out_features=64, bias=True)
        (k): Linear(in_features=64, out_features=64, bias=True)
        (v): Linear(in_features=64, out_features=64, bias=True)
        (final_layer): Linear(in_features=64, out_features=64, bias=True)


In [76]:
xe=np.random.randint(0,20000,size=(8,512))

In [77]:
xe=torch.tensor(xe).to(device)

In [78]:
xd=np.random.randint(0,10000,size=(8,256))
xd=torch.tensor(xd).to(device)

In [79]:
maske=torch.ones(8,512)

In [80]:
maske[:,256:]=0

In [81]:
maske=torch.tensor(maske).to(device)

  maske=torch.tensor(maske).to(device)


In [82]:
maskd=torch.ones(8,256)
maskd[:,128:]=0
maskd=torch.tensor(maskd).to(device)

  maskd=torch.tensor(maskd).to(device)


In [83]:
out=transformer(xe,xd,maske,maskd)

In [84]:
out.shape

torch.Size([8, 256, 10000])

In [85]:
import pandas as pd

In [86]:
with open("/content/spa.txt") as f:
  lines=f.read().split('\n')[:-1]
text_pairs=[]
for line in lines:
  english,spanish,x=line.split("\t")
  text_pairs.append((english,spanish))

In [87]:
len(text_pairs)

139013

In [88]:
db=pd.DataFrame(text_pairs)

In [89]:
db

Unnamed: 0,0,1
0,Go.,Ve.
1,Go.,Vete.
2,Go.,Vaya.
3,Go.,Váyase.
4,Hi.,Hola.
...,...,...
139008,A carbon footprint is the amount of carbon dio...,Una huella de carbono es la cantidad de contam...
139009,Since there are usually multiple websites on a...,Como suele haber varias páginas web sobre cual...
139010,"If you want to sound like a native speaker, yo...","Si quieres sonar como un hablante nativo, debe..."
139011,It may be impossible to get a completely error...,Puede que sea imposible obtener un corpus comp...


In [90]:
db.head()

Unnamed: 0,0,1
0,Go.,Ve.
1,Go.,Vete.
2,Go.,Vaya.
3,Go.,Váyase.
4,Hi.,Hola.


In [91]:
db.columns=['en','es']

In [92]:
db.to_csv('spa.csv',index=None)

In [93]:
!head spa.csv

en,es
Go.,Ve.
Go.,Vete.
Go.,Vaya.
Go.,Váyase.
Hi.,Hola.
Hi.,Hola
Run!,¡Corre!
Run!,¡Corran!
Run!,¡Corra!


In [94]:
!pip install transformers datasets sentencepiece sacremoses

Collecting datasets
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sentencepiece
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m72.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting sacremoses
  Downloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m70.2 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━

In [95]:
from datasets import load_dataset
raw_dataset = load_dataset('csv',data_files='spa.csv')

Generating train split: 0 examples [00:00, ? examples/s]

In [96]:
split=raw_dataset['train'].train_test_split(test_size=0.3,seed=42)

In [97]:
split

DatasetDict({
    train: Dataset({
        features: ['en', 'es'],
        num_rows: 97309
    })
    test: Dataset({
        features: ['en', 'es'],
        num_rows: 41704
    })
})

In [98]:
from transformers import AutoTokenizer

In [99]:
model_checkpoint='Helsinki-NLP/opus-mt-en-es'

In [100]:
tokenizer=AutoTokenizer.from_pretrained(model_checkpoint)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/44.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.47k [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/802k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/826k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.59M [00:00<?, ?B/s]

In [101]:
max_input_length=150
max_output_length=150
def preprocessing(batch):
  model_input=tokenizer(batch['en'],truncation=True,max_length=max_input_length)
  labels=tokenizer(text_target=batch['es'],truncation=True,max_length=max_input_length)
  model_input['labels']=labels['input_ids']
  return model_input

In [102]:
tokenized_data=split.map(preprocessing,batched=True,remove_columns=split['train'].column_names)

Map:   0%|          | 0/97309 [00:00<?, ? examples/s]

Map:   0%|          | 0/41704 [00:00<?, ? examples/s]

In [103]:
tokenized_data

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 97309
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 41704
    })
})

In [104]:
from transformers import DataCollatorForSeq2Seq

In [105]:
data_collator=DataCollatorForSeq2Seq(tokenizer)

In [106]:
from torch.utils.data import DataLoader

In [107]:
train_loader=DataLoader(tokenized_data['train'],shuffle=True,batch_size=32,collate_fn=data_collator)

In [108]:
valid_loader=DataLoader(tokenized_data['test'],shuffle=True,batch_size=32,collate_fn=data_collator)

In [109]:
tokenizer.vocab_size

65001

In [110]:
tokenizer.add_special_tokens({'cls_token':'<s>'})

1

In [111]:
encoder=Encoder(vocab_size=tokenizer.vocab_size+1,d_model=64,max_len=512,d_k=16,n_head=4,n_layers=2)

In [112]:
decoder=Decoder(vocab_size=tokenizer.vocab_size+1,d_model=64,max_len=512,d_k=16,n_head=4,n_layers=2)

In [113]:
transformer=Transformer(encoder,decoder)

In [114]:
encoder.to(device)

Encoder(
  (embedding): Embedding(65002, 64)
  (positional_embedding): PositionalEncoding(
    (l1): Dropout(p=0.1, inplace=False)
  )
  (transformer_block): Sequential(
    (0): Encoder_Block(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (q): Linear(in_features=64, out_features=64, bias=True)
        (k): Linear(in_features=64, out_features=64, bias=True)
        (v): Linear(in_features=64, out_features=64, bias=True)
        (final_layer): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (1): Encoder_Block(
      (ln1): LayerNorm((64,), eps=1e-05, eleme

In [115]:
decoder.to(device)

Decoder(
  (embedding): Embedding(65002, 64)
  (positional_embedding): PositionalEncoding(
    (l1): Dropout(p=0.1, inplace=False)
  )
  (transformer_block): Sequential(
    (0): Decoder_Block(
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln3): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha1): MultiHeadAttention(
        (q): Linear(in_features=64, out_features=64, bias=True)
        (k): Linear(in_features=64, out_features=64, bias=True)
        (v): Linear(in_features=64, out_features=64, bias=True)
        (final_layer): Linear(in_features=64, out_features=64, bias=True)
      )
      (mha2): MultiHeadAttention(
        (q): Linear(in_features=64, out_features=64, bias=True)
        (k): Linear(in_features=64, out_features=64, bias=True)
        (v): Linear(in_features=64, out_features=64, bias=True)
        (final_layer): Linear(in_features=64, out_features=64, bias=True)


In [116]:
crit=nn.CrossEntropyLoss(ignore_index=-100)
optimizer=torch.optim.Adam(transformer.parameters())

In [117]:
for batch in train_loader:
  for k,v in batch.items():
    print(k)
  break

input_ids
attention_mask
labels


In [118]:
from datetime import datetime

In [119]:
def train(model,crit,optimizer,train_loader,valid_loader,epochs):
  train_losses=np.zeros(epochs)
  test_losses=np.zeros(epochs)
  for it in range(epochs):
    train_loss=[]
    model.train()
    t0=datetime.now()
    for batch in train_loader:
       x={k:v.to(device) for k,v in batch.items()}
       optimizer.zero_grad()
       enc_input=x['input_ids']
       enc_mask=x['attention_mask']
       dec_input=x['labels'].clone().detach()
       dec_input=torch.roll(dec_input,shifts=1,dims=1)
       dec_input[:,0]=65_001
       dec_input=dec_input.masked_fill(dec_input==-100,tokenizer.pad_token_id)
       dec_mask=torch.ones_like(dec_input)
       dec_mask=dec_mask.masked_fill(dec_input[:,:]==tokenizer.pad_token_id,0)
       pred=model(enc_input,dec_input,enc_mask,dec_mask)
       loss=crit(pred.transpose(1,2),x['labels'])
       loss.backward()
       optimizer.step()
       train_loss.append(loss.item())
    train_loss=np.mean(train_loss)
    train_losses[it]=train_loss
    model.eval()
    test_loss=[]
    for batch in valid_loader:
      x={k:v.to(device) for k,v in batch.items()}
      enc_input=x['input_ids']
      enc_mask=x['attention_mask']
      dec_input=x['labels'].clone().detach()
      dec_input=torch.roll(dec_input,shifts=1,dims=1)
      dec_input[:,0]=65_001
      dec_input=dec_input.masked_fill(dec_input==-100,tokenizer.pad_token_id)
      dec_mask=torch.ones_like(dec_input)
      dec_mask=dec_mask.masked_fill(dec_input[:,:]==tokenizer.pad_token_id,0)
      pred=model(enc_input,dec_input,enc_mask,dec_mask)
      loss=crit(pred.transpose(1,2),x['labels'])
      test_loss.append(loss.item())
    test_loss=np.mean(test_loss)
    test_losses[it]=test_loss
    t1=datetime.now()-t0
    print(f'Epochs: {it+1}/{epochs} , Train_Loss : {train_loss:.4f} , Test_Loss : {test_loss:.4f} , Duration : {t1}')
  return train_losses,test_losses

In [120]:
train_losses,test_losses=train(transformer,crit,optimizer,train_loader,valid_loader,15)

Epochs: 1/15 , Train_Loss : 4.2908 , Test_Loss : 3.2728 , Duration : 0:02:03.194082
Epochs: 2/15 , Train_Loss : 3.0978 , Test_Loss : 2.6859 , Duration : 0:02:03.002250
Epochs: 3/15 , Train_Loss : 2.6512 , Test_Loss : 2.4144 , Duration : 0:02:03.038529
Epochs: 4/15 , Train_Loss : 2.3894 , Test_Loss : 2.2624 , Duration : 0:02:03.100638
Epochs: 5/15 , Train_Loss : 2.2110 , Test_Loss : 2.1664 , Duration : 0:02:03.048555
Epochs: 6/15 , Train_Loss : 2.0795 , Test_Loss : 2.0630 , Duration : 0:02:03.345628
Epochs: 7/15 , Train_Loss : 1.9773 , Test_Loss : 2.0096 , Duration : 0:02:03.923755
Epochs: 8/15 , Train_Loss : 1.8937 , Test_Loss : 1.9473 , Duration : 0:02:03.114345
Epochs: 9/15 , Train_Loss : 1.8213 , Test_Loss : 1.9237 , Duration : 0:02:03.654847
Epochs: 10/15 , Train_Loss : 1.7617 , Test_Loss : 1.9021 , Duration : 0:02:03.231607
Epochs: 11/15 , Train_Loss : 1.7102 , Test_Loss : 1.8573 , Duration : 0:02:03.201124
Epochs: 12/15 , Train_Loss : 1.6639 , Test_Loss : 1.8256 , Duration : 0:02

In [121]:
input_sen=split['test'][5]['en']

In [122]:
output_sen=split['test'][5]['es']

In [123]:
input_sen

'Ancient astronomers noticed constellations and gave them names.'

In [124]:
output_sen

'Los antiguos astrónomos encontraron constelaciones y les dieron nombres.'

In [125]:
enc_input=tokenizer(input_sen,return_tensors='pt')

In [126]:
enc_input

{'input_ids': tensor([[34888, 45519,     9, 15362, 45184,     9,    10,  2576,   167,  5301,
             3,     0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}

In [127]:
dec_input_str='<s>'
dec_input=tokenizer(text_target=dec_input_str,return_tensors='pt')

In [128]:
dec_input

{'input_ids': tensor([[65001,     0]]), 'attention_mask': tensor([[1, 1]])}

In [129]:
enc_input.to(device)

{'input_ids': tensor([[34888, 45519,     9, 15362, 45184,     9,    10,  2576,   167,  5301,
             3,     0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}

In [130]:
dec_input.to(device)

{'input_ids': tensor([[65001,     0]], device='cuda:0'), 'attention_mask': tensor([[1, 1]], device='cuda:0')}

In [131]:
output=transformer(enc_input['input_ids'],dec_input['input_ids'][:,:-1],enc_input['attention_mask'],dec_input['attention_mask'][:,:-1])

In [132]:
output.shape

torch.Size([1, 1, 65002])

In [133]:
output.argmax(axis=2)[0][0]

tensor(131, device='cuda:0')

In [134]:
tokenizer.convert_ids_to_tokens(output.argmax(axis=2)[0])

['▁Los']

In [135]:
enc_output=encoder(enc_input['input_ids'],enc_input['attention_mask'])

In [136]:
k=decoder(dec_input['input_ids'][:,:-1],enc_output,dec_input['attention_mask'][:,:-1],enc_input['attention_mask'])

In [137]:
torch.allclose(output,k)

True

In [138]:
dec_inp=dec_input['input_ids'][:,:-1]
dec_mask=dec_input['attention_mask'][:,:-1]
next_input=dec_inp
for _ in range(32):
  pred=decoder(next_input.to(device),enc_output.to(device),dec_mask.to(device),enc_input['attention_mask'].to(device))
  pred=torch.argmax(pred,axis=2)
  pred_sen=pred
  next_input=torch.cat((dec_inp.to(device),pred_sen.to(device)),axis=1)
  dec_mask=torch.ones_like(next_input).to(device)
  if pred[0][pred.size(1)-1]==0:
    break;

In [139]:
pred

tensor([[  131,  1611,     4,     6, 52639,     9,    11,    17,  5910,   679,
          2918,  5910,     3,     0]], device='cuda:0')

In [140]:
tokenizer.decode(pred[0])

'Los precios de la constelacións y los nombres les dio nombres.</s>'

In [141]:
output_sen

'Los antiguos astrónomos encontraron constelaciones y les dieron nombres.'

In [142]:
def translation(sentence):
  tokenized_enc=tokenizer(sentence,return_tensors='pt')
  enc_input=tokenized_enc['input_ids']
  enc_mask=tokenized_enc['attention_mask']
  enc_output=encoder(enc_input.to(device),enc_mask.to(device))
  starting_token='<s>'
  tokenized_dec=tokenizer(starting_token,return_tensors='pt')
  dec_input=tokenized_dec['input_ids'][:,:-1]
  dec_mask=torch.ones_like(dec_input)
  next_input=dec_input
  for _ in range(32):
    pred=decoder(next_input.to(device),enc_output.to(device),dec_mask.to(device),enc_mask.to(device))
    pred=torch.argmax(pred,axis=2)
    pred_sen=pred
    next_input=torch.cat((dec_inp.to(device),pred_sen.to(device)),axis=1)
    dec_mask=torch.ones_like(next_input).to(device)
    if pred[0][pred.size(1)-1]==0:
      break
  print(tokenizer.decode(pred[0][:-1]))

In [149]:
translation('Hello Everyone')

Todos Holan.
