In [12]:
# utils
from utils import count_parameters
import torch

# data
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator

# model
import torch.nn as nn
import torch.nn.functional as F

# training
import torch.optim as optim
import tqdm

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
device

device(type='cuda')

## Data Preparation

In [7]:
# create data fields for source and target
source = Field(
    init_token="<sos>",
    eos_token="<eos>",
    lower=True,
    tokenize="spacy",
    tokenizer_language="de",
    batch_first=True
)
target = Field(
    init_token="<sos>",
    eos_token="<eos>",
    lower=True,
    tokenize="spacy",
    tokenizer_language="de",
    batch_first=True
)

In [9]:
# download the parallel corpus
train, val, test = Multi30k.splits(
    exts=(".de", ".en"),
    fields=(source, target)
)

In [10]:
# build the vocab
source.build_vocab(train)
target.build_vocab(train)

In [16]:
# create data loaders
BATCH_SIZE = 128
train_loader, val_laoder, test_loader = BucketIterator.splits(
    datasets=(train, val, test),
    batch_size=BATCH_SIZE,
    device=device,
    shuffle=True
)

In [20]:
batch =  next(iter(train_loader))
print(batch.src.shape, batch.trg.shape)

torch.Size([128, 29]) torch.Size([128, 30])


## Transformer Model

In [None]:
class Encoder(nn.Module):
    """
        transformer encoder module returns a [batch_size, seq_len, out_dim] tensor
    """
    
    def __init__(self, vocab_size, embedding_dim, num_layers, n_heads, pf_dim, dropout=100, max_len=100):
        
        self.vocab_size = vocab_size
        
        # tok and pos embedding dim is same because we have to add them
        self.tok_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)
        self.pos_embedding = nn.Embedding(num_embeddings=max_len, embedding_dim=embedding_dim)
        
        

In [21]:
nn.Embedding?

[0;31mInit signature:[0m
[0mnn[0m[0;34m.[0m[0mEmbedding[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mnum_embeddings[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0membedding_dim[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpadding_idx[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_norm[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnorm_type[0m[0;34m=[0m[0;36m2.0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mscale_grad_by_freq[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0msparse[0m[0;34m=[0m[0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0m_weight[0m[0;34m=[0m[0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m     
A simple lookup table that stores embeddings of a fixed dictionary and size.

This module is often used to store word embeddings and retrieve them using indices.
The input to the module is a list of indices, 