In [1]:
import math
import time

import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from torch.utils.data import dataset

from torchtext.datasets import PennTreebank
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

torch.use_deterministic_algorithms(True)



In [2]:
def gen_sqr_nxt_mask(size):
    # Create mask and move it to correct device
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    mask = mask.masked_fill(mask == 0, 0.0)  # Add this line
    return mask

class Transformer(nn.Module):
    def __init__(self, num_token, num_inputs, num_heads,
                 num_hidden, num_layers, dropout=0.3):
        super().__init__()
        self.model_name = 'transformer'
        self.position_enc = PosEnc(num_inputs, dropout)
        layers_enc = TransformerEncoderLayer(num_inputs, num_heads, 
                                           num_hidden, dropout,
                                           batch_first=True)  # Note this is True
        self.enc_transformer = TransformerEncoder(layers_enc, num_layers)
        self.enc = nn.Embedding(num_token, num_inputs)
        self.num_inputs = num_inputs
        self.dec = nn.Linear(num_inputs, num_token)
        self.init_params()

    def init_params(self):
        initial_rng = 0.12
        self.enc.weight.data.uniform_(-initial_rng, initial_rng)
        self.dec.bias.data.zero_()
        self.dec.weight.data.uniform_(-initial_rng, initial_rng)

    def forward(self, source, mask_source):
        source = self.enc(source) * math.sqrt(self.num_inputs)
        source = self.position_enc(source)
        op = self.enc_transformer(source, mask_source)
        op = self.dec(op)
        return op

In [3]:
class PosEnc(nn.Module):
    def __init__(self, d_m, dropout=0.2, size_limit=5000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        p_enc = torch.zeros(size_limit, 1, d_m)
        pos = torch.arange(size_limit, dtype=torch.float).unsqueeze(1)
        divider = torch.exp(torch.arange(0, d_m, 2).float() * (-math.log(10_000.0) / d_m))
        p_enc[:, 0, 0::2] = torch.sin(pos * divider)
        p_enc[:, 0, 1::2] = torch.cos(pos * divider)
        self.register_buffer('p_enc', p_enc)

    def forward(self, x):
        return self.dropout(x + self.p_enc[:x.size(0)])

In [4]:
# Use raw text instead of iterator
tr_iter = PennTreebank(split='train')
tr_raw = [' '.join(text) for text in tr_iter]  # Convert to raw text

# Tokenize and build vocabulary
tkzer = get_tokenizer('basic_english')
vocabulary = build_vocab_from_iterator([tkzer(text) for text in tr_raw], specials=['<unk>'])
vocabulary.set_default_index(vocabulary['<unk>'])

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Process data function
def process_data(raw_text):
    numericalised_text = [torch.tensor(vocabulary(tkzer(text)), dtype=torch.long) for text in raw_text]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, numericalised_text)))

# Get all splits and convert to raw text
tr_iter = PennTreebank(split='train')
val_iter = PennTreebank(split='valid')
te_iter = PennTreebank(split='test')

tr_raw = [' '.join(text) for text in tr_iter]
val_raw = [' '.join(text) for text in val_iter]
te_raw = [' '.join(text) for text in te_iter]

# Process all splits
training_text = process_data(tr_raw)
validation_text = process_data(val_raw)
testing_text = process_data(te_raw)

# Batch generation function
def gen_batches(text_dataset, batch_size):
    num_batches = text_dataset.size(0) // batch_size
    text_dataset = text_dataset[:num_batches * batch_size]
    text_dataset = text_dataset.view(batch_size, num_batches).t().contiguous()
    return text_dataset.to(device)

# Generate batches
training_batch_size = 32
evaluation_batch_size = 16
training_data = gen_batches(training_text, training_batch_size)
validation_data = gen_batches(validation_text, evaluation_batch_size)
testing_data = gen_batches(testing_text, evaluation_batch_size)

################################################################################
The 'datapipes', 'dataloader2' modules are deprecated and will be removed in a
future torchdata release! Please see https://github.com/pytorch/data/issues/1196
to learn more and leave feedback.
################################################################################



In [5]:
max_seq_len = 64
def return_batch(src, k): 
    sequence_length = min(max_seq_len, len(src) - 1 - k)
    sequence_data = src[k:k+sequence_length]
    sequence_label = src[k+1:k+1+sequence_length].reshape(-1)
    return sequence_data, sequence_label

In [6]:
num_tokens = len(vocabulary)
embedding_size = 256
num_hidden_params = 256
num_layers = 2
num_heads = 2
dropout = 0.25 
loss_func = nn.CrossEntropyLoss()
lrate = 4.0
transformer_model = Transformer(num_tokens, embedding_size, num_heads,
                                num_hidden_params, num_layers,
                                dropout).to(device)
optim_module = torch.optim.SGD(transformer_model.parameters(), lr=lrate)
sched_module = torch.optim.lr_scheduler.StepLR(optim_module, 1.0, gamma=0.88)

In [7]:
def gen_sqr_nxt_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    mask = mask.masked_fill(mask == 0, 0.0)
    return mask

def train_model():
    transformer_model.train()
    loss_total = 0.
    time_start = time.time()
    
    for b, i in enumerate(range(0, training_data.size(0) - 1, max_seq_len)):
        train_data_batch, train_label_batch = return_batch(training_data, i)
        
        # Use size(0) if 1D tensor, otherwise use size(1) for 2D
        if train_data_batch.dim() == 1:
            sequence_length = train_data_batch.size(0)
        else:
            sequence_length = train_data_batch.size(1)  # For batch_first=True
            
        mask_source = gen_sqr_nxt_mask(sequence_length).to(device)
        
        op = transformer_model(train_data_batch, mask_source)
        loss_curr = loss_func(op.view(-1, num_tokens), train_label_batch)
        optim_module.zero_grad()
        loss_curr.backward()
        torch.nn.utils.clip_grad_norm_(transformer_model.parameters(), 0.6)
        optim_module.step()
        
        loss_total += loss_curr.item()
        interval = 100
        if b % interval == 0 and b > 0:
            loss_interval = loss_total / interval
            time_delta = time.time() - time_start
            print(f"epoch {ep}, {b}/{len(training_data)//max_seq_len} batches, "
                  f"training loss {loss_interval:.2f}, "
                  f"training perplexity {math.exp(loss_interval):.2f}")
            loss_total = 0
            time_start = time.time()

def eval_model(eval_model_obj, eval_data_source):
    eval_model_obj.eval() 
    loss_total = 0.
    mask_source = gen_sqr_nxt_mask(max_seq_len).to(device)
    with torch.no_grad():
        for j in range(0, eval_data_source.size(0) - 1, max_seq_len):
            eval_data, eval_label = return_batch(eval_data_source, j)
            sequence_length = eval_data.size(0)
            if sequence_length != max_seq_len:
                mask_source = mask_source[:sequence_length, :sequence_length]
            op = eval_model_obj(eval_data, mask_source)
            op_flat = op.view(-1, num_tokens)
            loss_total += sequence_length * loss_func(op_flat, eval_label).item()
    return loss_total / (len(eval_data_source) - 1)

In [8]:
min_validation_loss = float("inf")
eps = 5
best_model_so_far = None

for ep in range(1, eps+1):
    ep_time_start = time.time()
    train_model()
    validation_loss = eval_model(transformer_model, validation_data)
    print()
    print(f"epoch {ep}, validation loss {validation_loss:.2f}, validation perplexity {math.exp(validation_loss):.2f}")
    print()

    if validation_loss < min_validation_loss:
        min_validation_loss = validation_loss
        best_model_so_far = transformer_model
    sched_module.step()

epoch 1, 100/2016 batches, training loss 3.77, training perplexity 43.44
epoch 1, 200/2016 batches, training loss 2.61, training perplexity 13.58
epoch 1, 300/2016 batches, training loss 2.58, training perplexity 13.20
epoch 1, 400/2016 batches, training loss 2.58, training perplexity 13.14
epoch 1, 500/2016 batches, training loss 2.56, training perplexity 12.91
epoch 1, 600/2016 batches, training loss 2.56, training perplexity 12.89
epoch 1, 700/2016 batches, training loss 2.55, training perplexity 12.80
epoch 1, 800/2016 batches, training loss 2.55, training perplexity 12.85
epoch 1, 900/2016 batches, training loss 2.55, training perplexity 12.82
epoch 1, 1000/2016 batches, training loss 2.54, training perplexity 12.74
epoch 1, 1100/2016 batches, training loss 2.55, training perplexity 12.77
epoch 1, 1200/2016 batches, training loss 2.54, training perplexity 12.72
epoch 1, 1300/2016 batches, training loss 2.54, training perplexity 12.70
epoch 1, 1400/2016 batches, training loss 2.54,

RuntimeError: shape '[1, 1, 16, 16]' is invalid for input of size 4096

In [None]:
testing_loss = eval_model(best_model_so_far, testing_data)
print(f"testing loss {testing_loss:.2f}, testing perplexity {math.exp(testing_loss):.2f}")