<a href="https://colab.research.google.com/github/gunadhineha/molecularGNN_smiles/blob/master/Machine_Translation_Student_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torchtext
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
import random
import math

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
from google.colab import files
uploaded = files.upload()

Saving jpn_cleaned.npy to jpn_cleaned.npy


In [None]:
uploaded = files.upload()

Saving eng_cleaned.npy to eng_cleaned.npy


In [None]:
jpn_dataset = np.load("jpn_cleaned.npy", allow_pickle=True)
eng_dataset = np.load("eng_cleaned.npy", allow_pickle=True)

In [None]:
jpn_vocab = build_vocab_from_iterator(jpn_dataset, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
eng_vocab = build_vocab_from_iterator(eng_dataset, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
jpn_vocab.set_default_index(jpn_vocab["<unk>"])
eng_vocab.set_default_index(eng_vocab["<unk>"])

In [None]:
print(jpn_dataset[3001])

['私', 'は', 'コーヒー', 'が', '大嫌い', 'です', '。']


In [None]:
jpn_vocab(jpn_dataset[3001])

[20, 5, 208, 11, 1043, 19, 4]

In [None]:
dataset = [(
    torch.tensor(jpn_vocab(jpn_text), dtype=torch.long),
    torch.tensor(eng_vocab(eng_text), dtype=torch.long)) for (jpn_text, eng_text) in zip(jpn_dataset, eng_dataset)]

In [None]:
J_PAD_IDX = jpn_vocab['<pad>']
J_BOS_IDX = jpn_vocab['<bos>']
J_EOS_IDX = jpn_vocab['<eos>']
E_PAD_IDX = eng_vocab['<pad>']
E_BOS_IDX = eng_vocab['<bos>']
E_EOS_IDX = eng_vocab['<eos>']

def generate_batch(batch):
  jpn_list, eng_list = [], []
  for (jpn_batch, eng_batch) in batch:
    jpn_list.append(torch.cat([torch.tensor([J_BOS_IDX]), jpn_batch, torch.tensor([J_EOS_IDX])], dim=0))
    eng_list.append(torch.cat([torch.tensor([E_BOS_IDX]), eng_batch, torch.tensor([E_EOS_IDX])], dim=0))
  jpn_list = pad_sequence(jpn_list, padding_value=J_PAD_IDX).transpose(0,1)
  eng_list = pad_sequence(eng_list, padding_value=E_PAD_IDX).transpose(0,1)
  return jpn_list.to(device), eng_list.to(device)

In [None]:
trainloader = DataLoader(dataset, batch_size=64, shuffle=True, collate_fn=generate_batch)

In [None]:
class SimplifiedAttention(nn.Module):
    # ...

class TransformerLayer(nn.Module):
    # ...

class Transformer(nn.Module):
    # ...

In [None]:
input_vocab_size = len(jpn_vocab)
num_class = len(eng_vocab)

model = Transformer(...).to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters())

In [None]:
def train(model, loader, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0

    for _, (src, trg) in enumerate(loader):
        src, trg = src.to(device), trg.to(device)
        output = model(src, trg)
        output = output[:,1:].reshape(-1, output.shape[-1])
        trg = trg[:,1:].reshape(-1)
        loss = criterion(output, trg)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()

    return epoch_loss / len(loader)

In [None]:
criterion = nn.CrossEntropyLoss(ignore_index=PAD_IDX)

In [None]:
for epoch in range(10):
    train_loss = train(model, trainloader, optimizer, criterion, 1.0)

    print(f'Epoch: {epoch+1:02}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')

In [None]:
sentence = jpn_dataset[3001]
tensor_sentence = torch.tensor(jpn_vocab(sentence), dtype=torch.long)
tensor_sentence = torch.cat([torch.tensor([BOS_IDX]), tensor_sentence, torch.tensor([EOS_IDX])], dim=0)
tensor_sentence = tensor_sentence.unsqueeze(-1)
dummy_tgt = torch.zeros((30, 1), dtype=torch.long) + 2
tensor_sentence = tensor_sentence.to(device)
dummy_tgt = dummy_tgt.to(device)
model.eval()
with torch.no_grad():
    output = model(tensor_sentence, dummy_tgt, 0.0)
    print(output.shape)
    predicted_word_idxs = output.argmax(dim=-1).squeeze().cpu().numpy()
    print(predicted_word_idxs.shape)
    trans = eng_vocab.lookup_tokens(predicted_word_idxs)
    print(trans)