<a href="https://colab.research.google.com/github/gabeorosan/papers-explained/blob/main/transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [16]:
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from collections import namedtuple
import torch
import torch.nn as nn
import torch.optim as optim
import math

In [3]:
# Tokenizer
tokenizer = get_tokenizer('basic_english')

# Load data
train_iter, valid_iter, test_iter = WikiText2()

# Build vocabulary
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<pad>', '<bos>', '<eos>'])
vocab.set_default_index(vocab["<pad>"])

# Prepare data
def data_process(raw_text_iter):
    data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter]
    return torch.nn.utils.rnn.pad_sequence(data, padding_value=vocab["<pad>"])

train_data = data_process(train_iter)
valid_data = data_process(valid_iter)
test_data = data_process(test_iter)

BATCH_SIZE = 32
PAD_IDX = vocab["<pad>"]
BOS_IDX = vocab["<bos>"]
EOS_IDX = vocab["<eos>"]

def generate_batch(data_batch):
    src = torch.cat([torch.cat([torch.tensor([BOS_IDX]), item, torch.tensor([EOS_IDX])]) for item in data_batch]).view(len(data_batch), -1)
    tgt = torch.cat([item for item in data_batch]).view(len(data_batch), -1)
    return src, tgt

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
valid_loader = torch.utils.data.DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=True, collate_fn=generate_batch)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=generate_batch)



In [14]:
class TransformerModel(nn.Module):
    def __init__(self, ntoken, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(ntoken, d_model)
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward)
        self.fc_out = nn.Linear(d_model, ntoken)
        self.positional_encoding = self.generate_positional_encoding(d_model, max_seq_length)

    def generate_positional_encoding(self, d_model, max_len):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        # pe = pe.unsqueeze(0).transpose(0, 1)
        return pe
    def forward(self, src, tgt):
        print('hi mom')
        # ... [rest of the forward method]


In [15]:
# Given vocab from earlier
ntoken = len(vocab)
d_model = 32  # Reduced embedding dimension
nhead = 2  # Reduced number of heads in multihead attention
num_encoder_layers = 1  # Just one encoder layer
num_decoder_layers = 1  # Just one decoder layer
dim_feedforward = 64  # Reduced FFN size
max_seq_length = 100  # Limit sequence length

model = TransformerModel(ntoken, d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, max_seq_length)
