In [1]:
source = "I am a boy."
target = "Ich bin ein Junge."

# ----------
# Encoder["I am a boy."] -> h 
# Next Token Prediction 
# 1. step Decoder [h, <bos>] -> "ich"
# 2. step Decoder [h, (<bos>, "ich")] -> "bin" 
# 3. step Decoder [h, (<bos>, "ich", "bin")] -> "ein"
# 4. step Decoder [h, (<bos>, "ich", "bin", "ein")] -> "Junge"
# 5. step Decoder [h, (<bos>, "ich", "bin", "ein", "Junge")] -> "<eos>"

# X = (h, (<bos>, "ich", "bin", "ein", "Junge")) -> input for Decoder
# y = ( "ich", "bin", "ein", "Junge", <eos>) -> target/labels for Loss

In [2]:
import re, json, torch, torch.nn as nn
from torch.utils.data import DataLoader

path = "./deu.txt"

lines = open(path, encoding="utf-8").read().strip().split("\n")
lines = lines[:20000]

pairs = [ln.split("\t")[:2] for ln in lines] 
src_texts, tgt_texts = zip(*pairs)

In [3]:
PAD, UNK, BOS, EOS = 0, 1, 2, 3 # special tokens
# PAD = Padding, UNK = Unknown,
# BOS, EOS 

VOCAB_SIZE = 20004 

def tokenize(s): return re.findall(r"\b\w+\b", s.lower())
def build_vocab(texts, max_tokens=VOCAB_SIZE):
    from collections import Counter
    freq = Counter(tok for t in texts for tok in tokenize(t))
    itos = ["<pad>", "<unk>", "<bos>", "<eos>"] + [w for w,_ in freq.most_common(max_tokens-4)]
    return {w:i for i,w in enumerate(itos)}, itos
src_texts_vocab, src_itos = build_vocab(src_texts)
tgt_texts_vocab, tgt_itos = build_vocab(tgt_texts)



def vectorize(text, stoi, max_len, add_bos_eos=False):
    ids = [stoi.get(tok, UNK) for tok in tokenize(text)]
    if add_bos_eos: ids = [BOS] + ids + [EOS]
    ids = ids[:max_len]
    if len(ids) < max_len: ids += [PAD]*(max_len-len(ids))
    return ids

#vectorize(src_texts[60], src_texts_vocab, 30)
#src_texts[60]


max_src, max_tgt = 30, 30 

def collate_fn(batch):
    src_batch, tgt_batch = zip(*batch)
    src = torch.tensor([vectorize(t, src_texts_vocab, max_src) for t in src_batch])
    tgt = torch.tensor([vectorize(t, tgt_texts_vocab, max_tgt, add_bos_eos=True) for t in tgt_batch])
    tgt_in, tgt_out = tgt[:, :-1], tgt[:, 1:]
    return src, tgt_in, tgt_out

dataset = list(zip(src_texts, tgt_texts))
loader = DataLoader(dataset, batch_size= 64, shuffle=True, collate_fn=collate_fn)


In [None]:
len(src_texts_vocab) # number of tokens in source
len(tgt_texts_vocab) # number of tokens in target 

5594

In [4]:
sentence = "this is sample sentence for embedding"
sentence2 = "this is sentence embedding"


dc = {s:i for i,s in enumerate(sorted(sentence.replace(',', '').split()))}
dc

{'embedding': 0, 'for': 1, 'is': 2, 'sample': 3, 'sentence': 4, 'this': 5}

In [7]:
vocab_size_tmp = len(dc)
emb = torch.nn.Embedding(vocab_size_tmp, 3)
emb.weight.data

tensor([[ 0.0443, -2.0053, -1.5915],
        [ 1.2568,  1.4348, -1.9296],
        [-0.5228,  0.9867, -0.4589],
        [-0.9804,  1.3372, -0.0818],
        [ 0.6943,  1.0501, -1.7118],
        [ 0.5269, -0.5099,  0.5838]])

In [17]:
emb_dim = 128 # in practice starts from 768 
hid_dim = 256 

class Encoder(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD)
        self.rnn = nn.GRU(emb_dim, hid_dim, batch_first= True)
    
    def forward(self,x):
        x = self.embedding(x)
        _, hidden = self.rnn(x)
        return hidden 
    
class Decoder(nn.Module):
    def __init__(self,vocab_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=PAD)
        self.rnn = nn.GRU(emb_dim, hid_dim, batch_first=True)
        self.fc = nn.Linear(hid_dim, vocab_size) # Classifier head, MLP head, FFNN head 
    
    def forward(self, x , h): # hidden state from the encoder
        x = self.embedding(x)
        out, _ = self.rnn(x,h)
        return self.fc(out)

class Seq2Seq(nn.Module):
    def __init__(self, enc, dec):
        super().__init__()
        self.enc = enc
        self.dec = dec
    
    def forward(self, src, tgt_in_dec ):
        # src ... source (english sentences)
        # tgt_in_dec ... actual german sentences that are also input to the decoder
        hidden_enc = self.enc(src)
        logits = self.dec(tgt_in_dec, hidden_enc)
        return logits

device = "mps" # cpu or cuda 

model = Seq2Seq(
    Encoder(len(src_texts_vocab)),
    Decoder(len(tgt_texts_vocab))
).to(device)


In [18]:
crit = nn.CrossEntropyLoss(ignore_index=PAD) 
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
epochs = 20

@torch.no_grad()
def translate(prompt, max_len=max_tgt):
    model.eval()
    src = torch.tensor([vectorize(prompt, src_texts_vocab, max_src)], device=device)
    h = model.enc(src)
    ys = torch.tensor([[BOS]], device=device)
    out_tokens = []
    for _ in range(max_len):
        logits = model.dec(ys, h)
        next_id = logits[0, -1].argmax().item()
        if next_id in (EOS, PAD): break
        out_tokens.append(next_id)
        ys = torch.cat([ys, torch.tensor([[next_id]], device=device)], dim=1)
    return " ".join(tgt_itos[t] for t in out_tokens)

for epoch in range(epochs):
    model.train()
    running_loss = 0.0 
    for src, tgt_in, tgt_out in loader:
        src, tgt_in, tgt_out = src.to(device), tgt_in.to(device), tgt_out.to(device)
        logits = model(src,tgt_in)
        loss = crit(logits.reshape(-1, logits.size(-1)), tgt_out.reshape(-1)) 
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # gradient clipping 
        optimizer.step()
        running_loss += loss.item()
    print(f"epoch {epoch+1}: loss {running_loss/len(loader):.4f}") # 
    print(translate("I will do my best."))



epoch 1: loss 4.3803
ich bin ein
epoch 2: loss 3.2316
ich habe einen plan
epoch 3: loss 2.6660
ich werde mich
epoch 4: loss 2.2611
ich werde das machen
epoch 5: loss 1.9438
ich kann das erklären
epoch 6: loss 1.6822
ich werde das tun
epoch 7: loss 1.4607
ich kann das erklären
epoch 8: loss 1.2646
ich kann das erklären
epoch 9: loss 1.0920
ich werde mein bestes tun
epoch 10: loss 0.9394
ich werde mein bestes tun
epoch 11: loss 0.8110
ich kann das tun


KeyboardInterrupt: 

In [19]:
model.state_dict()

OrderedDict([('enc.embedding.weight',
              tensor([[ 0.0069, -0.0716, -0.0309,  ..., -0.0119, -0.1059,  0.0858],
                      [ 0.1495,  0.1305,  0.4564,  ..., -0.2707,  0.7617, -0.9606],
                      [ 1.9574, -0.9580,  0.8112,  ...,  0.9416, -0.3763, -0.1840],
                      ...,
                      [-0.3965, -0.5867,  0.0605,  ..., -0.6160,  0.1138,  1.4554],
                      [-1.5110,  0.6691,  1.4929,  ...,  0.4278,  0.8104, -0.4752],
                      [ 0.5352,  0.2932,  0.2892,  ...,  0.6478, -0.6552,  0.4084]],
                     device='mps:0')),
             ('enc.rnn.weight_ih_l0',
              tensor([[ 1.7273e-01, -1.9324e-01, -1.1670e-01,  ..., -9.5050e-02,
                       -1.0352e-01, -1.6131e-02],
                      [ 6.0002e-02,  1.1714e-01,  3.2160e-02,  ...,  1.5135e-02,
                       -3.1400e-02, -1.0648e-01],
                      [ 1.2933e-02, -1.8350e-01,  1.1375e-01,  ...,  1.7294e-01,
          