In [2]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='0'
os.environ['CUDA_LAUNCH_BLOCKING']='1'
from io import open
import unicodedata
import re
import random

import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.optim import AdamW
import math 

import numpy as np
from torch.utils.data import TensorDataset, DataLoader, RandomSampler

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import multi30k, Multi30k
from typing import Iterable, List


# We need to modify the URLs for the dataset since the links to the original dataset are broken
# Refer to https://github.com/pytorch/text/issues/1756#issuecomment-1163664163 for more info
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"

SRC_LANGUAGE = 'de'
TGT_LANGUAGE = 'en'

# Place-holders
token_transform = {}
vocab_transform = {}

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
import torch
X = torch.ones(5,5)*10
X / X.mean()

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [3]:
# token_transform

In [103]:
# next(iter(Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))))

In [5]:
token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='de_core_news_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')


# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for data_sample in data_iter:
        yield token_transform[language](data_sample[language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Training data Iterator
    train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(train_iter, ln),
                                                    min_freq=1,
                                                    specials=special_symbols,
                                                    special_first=True)

# Set ``UNK_IDX`` as the default index. This index is returned when the token is not found.
# If not set, it throws ``RuntimeError`` when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

In [6]:
from torch.nn.utils.rnn import pad_sequence
from torchtext.transforms import PadTransform, Truncate, AddToken, VocabTransform, Sequential, ToTensor

class Tokenize(nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        
    def forward(self, x):
        # Do some transformations
        return [self.tokenizer(sentence) for sentence in x]

class Pad(nn.Module):
    def __init__(self, length, pad_value):
        super().__init__()
        self.length = length
        self.pad_value = pad_value 

    def forward(self, x):
        # m = max(len(x_) for x_ in x)
        def pad(x):        
            if len(x) < self.length:
                x += [self.pad_value] * (self.length - len(x))
            return x
            
        return [pad(x_) for x_ in x]

def get_preproc(token_transform, vocab_transform, max_seq_len):
    # vocab = 
    return nn.Sequential(
        Tokenize(token_transform),
        VocabTransform(vocab_transform),
        Truncate(max_seq_len=max_seq_len-2),
        AddToken(BOS_IDX, begin=True),
        AddToken(EOS_IDX, begin=False),
        Pad(pad_value=PAD_IDX,length=max_seq_len),
        ToTensor(),
    )

BLOCK_SIZE=64
preproc_src = get_preproc(token_transform['de'],vocab_transform['de'], max_seq_len=BLOCK_SIZE)
# can remove this +1 since last token is basically EOS
preproc_tgt = get_preproc(token_transform['en'],vocab_transform['en'], max_seq_len=BLOCK_SIZE+1)
preproc_tgt(['hello', 'my fellow'])

tensor([[   2, 5465,    3,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1],
        [   2, 2227, 2572,    3,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
            1,    1,    1,    1,    1]])

In [104]:
BATCH_SIZE=128
train_iter = Multi30k(split='train', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE)
test_iter = Multi30k(split='valid', language_pair=(SRC_LANGUAGE, TGT_LANGUAGE))
test_dataloader = DataLoader(test_iter, batch_size=BATCH_SIZE)

In [46]:
from ai.models.nlp.seq2seq.simple_attention import Encoder, Decoder
device = torch.device('cuda')

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
n_embd=64
encoder = Encoder(input_size=SRC_VOCAB_SIZE, n_embd=n_embd, max_block_size=BLOCK_SIZE).to(device)
decoder = Decoder(output_vocab_size=TGT_VOCAB_SIZE, n_embd=n_embd, block_size=BLOCK_SIZE).to(device)

In [47]:
lr=.003
optim1 = AdamW(list(encoder.parameters()) + list(decoder.parameters()), lr=lr)

In [48]:
step = 0
for epoch in range(10):
    for src,tgt in train_dataloader:
        step +=1
        src = preproc_src(src).to(device)
        tgt = preproc_tgt(tgt).to(device)
        
        enc_outputs = encoder(src)
        decoder_outputs = decoder(tgt[:,:-1], enc_outputs)
        loss = F.cross_entropy(decoder_outputs.view(-1, decoder_outputs.shape[-1]), tgt[:,1:].contiguous().view(-1))
        if step % 100 == 0:
            print(f'e{epoch}:s{step} | loss:{loss:.3f}')
        optim1.zero_grad()
        # optim2.zero_grad()
        loss.backward()
        optim1.step()
        # optim2.step()
        # break


e0:s100 | loss:1.037
e0:s200 | loss:0.934
e0:s300 | loss:0.872
e0:s400 | loss:0.871
e0:s500 | loss:0.970
e0:s600 | loss:0.783
e0:s700 | loss:0.805
e0:s800 | loss:0.953
e0:s900 | loss:0.799




e1:s1000 | loss:0.671
e1:s1100 | loss:0.676
e1:s1200 | loss:0.683
e1:s1300 | loss:0.553
e1:s1400 | loss:0.727
e1:s1500 | loss:0.649
e1:s1600 | loss:0.659
e1:s1700 | loss:0.737
e1:s1800 | loss:0.641
e2:s1900 | loss:0.652
e2:s2000 | loss:0.580
e2:s2100 | loss:0.651
e2:s2200 | loss:0.492
e2:s2300 | loss:0.587
e2:s2400 | loss:0.592
e2:s2500 | loss:0.623
e2:s2600 | loss:0.659


KeyboardInterrupt: 

In [108]:
itos=vocab_transform['en'].get_itos()
stoi=vocab_transform['en'].get_stoi()

for X,y in test_dataloader:
    break
X = preproc_src(X).to(device)
y = preproc_tgt(y).to(device)

y_hat = torch.tensor([[BOS_IDX] + [PAD_IDX]*(BLOCK_SIZE) for _ in range(y.shape[0])]).to(device)
for i in range(0,BLOCK_SIZE):
    enc_outputs = encoder(X)
    decoder_outputs = decoder(y_hat[:,:-1], enc_outputs)
    y_hat[:, i+1] = decoder_outputs[:, i].topk(1).indices.flatten()

for i in range(X.shape[0]):
    print('src :'+' '.join(vocab_transform['de'].lookup_tokens(X[i].tolist())).replace('<pad>',''))
    print('yhat:'+' '.join(vocab_transform['en'].lookup_tokens(y_hat[i].tolist())).replace('<pad>',''))    
    print('y   :'+' '.join(vocab_transform['en'].lookup_tokens(y[i].tolist())).replace('<pad>',''))    
    print('*'*72)
    if i>10:
        break
        

src :<bos> Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen <eos>                                                     
yhat:<bos> A group of men are working on a truck . <eos>                                                     
y   :<bos> A group of men are loading cotton onto a truck <eos>                                                     
************************************************************************
src :<bos> Ein Mann schläft in einem grünen Raum auf einem Sofa . <eos>                                                   
yhat:<bos> A man is sleeping in a green room with a green hair . <eos>                                                  
y   :<bos> A man sleeping in a green room on a couch . <eos>                                                    
************************************************************************
src :<bos> Ein Junge mit Kopfhörern sitzt auf den Schultern einer Frau . <eos>                                                   
yhat:<bos> 

In [1]:
!ls -alh .

total 84K
drwxrwxr-x  3 spock spock 4.0K Mar 15 14:10 .
drwxrwxr-x 10 spock spock 4.0K Mar 15 12:25 ..
drwxrwxr-x  2 spock spock 4.0K Mar 15 14:10 .ipynb_checkpoints
-rw-rw-r--  1 spock spock  25K Feb 13 16:48 multi30k_enc_dec.ipynb
-rw-rw-r--  1 spock spock  28K Feb 11 16:41 seq2seq_torch-Copy1.ipynb
-rw-rw-r--  1 spock spock  16K Feb 10 15:48 seq2seq_v2.ipynb
