In [1]:
import torch #type: ignore
import torch.nn as nn #type: ignore
import numpy #type: ignore
import math


In [2]:
# NOTE: THIS CLASS IS DEFINITELY CORRECTLY IMPLEMENTED (checked with StatQuest)

class PositionalEncoding(nn.Module):
    def __init__(self):
        super(PositionalEncoding, self).__init__()

    def forward(self, batch_X):
        _, max_sentence_length, d_model = batch_X.shape
        
        positional_encodings = torch.arange(start=0, end=max_sentence_length, dtype=torch.float32).unsqueeze(-1).expand(-1, d_model).clone() # .expand() doesn't create new memory for the duplicated dimension, it uses shared memory --> clone it to not used shared memory
        embedding_dimensions = torch.arange(start=0, end=d_model, step=2, dtype=torch.float32)
        div_factor = torch.tensor(10000) ** (embedding_dimensions / d_model)
        div_factor = div_factor.float()

        positional_encodings[:, 0::2] = torch.sin(positional_encodings[:, 0::2] / div_factor)
        positional_encodings[:, 1::2] = torch.cos(positional_encodings[:, 1::2] / div_factor)
        return batch_X.float() + positional_encodings.float() 



In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, use_mask=False):
        super(MultiHeadAttention, self).__init__()
        self.use_mask = use_mask
        self.num_heads = num_heads
        self.d_k = d_model // num_heads 
        
        self.W_Q = nn.Linear(in_features=d_model, out_features=d_model, dtype=torch.float32, bias=False)
        nn.init.xavier_uniform_(self.W_Q.weight)
        
        self.W_K = nn.Linear(in_features=d_model, out_features=d_model, dtype=torch.float32, bias=False)
        nn.init.xavier_uniform_(self.W_K.weight)
        
        self.W_V = nn.Linear(in_features=d_model, out_features=d_model, dtype=torch.float32, bias=False)
        nn.init.xavier_uniform_(self.W_V.weight)
        
        self.W_O = nn.Linear(in_features=d_model, out_features=d_model, dtype=torch.float32, bias=False)
        nn.init.xavier_uniform_(self.W_O.weight)

    def create_mask(self, batch_size, sentence_length, padding_mask, use_attention_mask):
        padding_mask = padding_mask.unsqueeze(-1).expand(-1, sentence_length, sentence_length).float()
        if use_attention_mask:
            causal_mask = torch.tril(torch.ones(sentence_length, sentence_length, dtype=torch.float32)).unsqueeze(0).expand(batch_size, sentence_length, sentence_length).float()
            combined_mask = torch.min(padding_mask, causal_mask).float()
        else:
            combined_mask = padding_mask.float()
        return combined_mask == 0

    def forward(self, batch_X, padding_mask, encoder_output=None):
        batch_X = batch_X.float()
        batch_size, sentence_length, d_model = batch_X.shape
        Q = self.W_Q(batch_X).reshape(batch_size, sentence_length, self.num_heads, self.d_k).permute(0, 2, 1, 3).float()
        if encoder_output is not None:
            K = self.W_K(encoder_output).reshape(batch_size, sentence_length, self.num_heads, self.d_k).permute(0, 2, 1, 3).float()
            V = self.W_V(encoder_output).reshape(batch_size, sentence_length, self.num_heads, self.d_k).permute(0, 2, 1, 3).float()
        else:
            K = self.W_K(batch_X).reshape(batch_size, sentence_length, self.num_heads, self.d_k).permute(0, 2, 1, 3).float()
            V = self.W_V(batch_X).reshape(batch_size, sentence_length, self.num_heads, self.d_k).permute(0, 2, 1, 3).float()

        # torch.matmul() performs the matrix multiplication over the last 2 dimensions, broadcasting all the others
        mask = self.create_mask(batch_size, sentence_length, padding_mask.float(), self.use_mask).unsqueeze(1).expand(batch_size, self.num_heads, sentence_length, sentence_length)
        attention_scores = (torch.matmul(Q, K.permute(0, 1, 3, 2)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float)))

        attention_scores = attention_scores.masked_fill(mask, float('-inf')).float()
        scaled_attention_scores = nn.functional.softmax(attention_scores, dim=-1).float()
        
        all_nan_rows_mask = torch.all(mask, dim=-1, keepdim=True) # (batch_size, num_heads, sentence_length, 1)
        scaled_attention_scores = scaled_attention_scores.masked_fill(all_nan_rows_mask, 0.0)
        
        scaled_dot_product_attention = torch.matmul(scaled_attention_scores, V).float() # shape = (batch_size, num_heads, sentence_length, d_v)
        
        # Concatenate all the heads
        scaled_dot_product_attention = scaled_dot_product_attention.permute(0, 2, 1, 3).reshape(batch_size, sentence_length, d_model).float()
        
        return self.W_O(scaled_dot_product_attention) # shape = (batch_size, sentence_length, d_model)
    

In [4]:
class FFN(nn.Module):
    def __init__(self, d_model, dropout_rate, activation=nn.ReLU()):
        super(FFN, self).__init__()
        self.activation_function = activation
        d_ff = d_model * 4
    
        self.linear1 = nn.Linear(in_features=d_model, out_features=d_ff, dtype=torch.float32)
        nn.init.kaiming_uniform_(self.linear1.weight)
        nn.init.zeros_(self.linear1.bias)

        self.linear2 = nn.Linear(in_features=d_ff, out_features=d_model, dtype=torch.float32)
        nn.init.kaiming_uniform_(self.linear2.weight)
        nn.init.zeros_(self.linear2.bias)

        self.dropout = nn.Dropout(dropout_rate)
        self.activation = activation

    def forward(self, batch_X):
        batch_X = self.dropout(self.activation(self.linear1(batch_X.float())))
        return self.linear2(batch_X.float())

In [27]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super(EncoderLayer, self).__init__()
        self.mha = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.ffn = FFN(d_model=d_model, dropout_rate=dropout_rate)
        self.layer_norm1 = nn.LayerNorm(normalized_shape=d_model, dtype=torch.float32, eps=1e-6)
        self.layer_norm2 = nn.LayerNorm(normalized_shape=d_model, dtype=torch.float32, eps=1e-6)
        self.dropout_rate = dropout_rate
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, batch_X, padding_mask):
        batch_X = batch_X.float() + self.dropout(self.layer_norm1(self.mha(batch_X, padding_mask)))
        batch_X = batch_X.float() + self.dropout(self.layer_norm2(self.ffn(batch_X)))
        return batch_X.float()

In [6]:
class Encoder(nn.Module):
    def __init__(self, d_model, num_layers, num_heads, dropout_rate):
        super(Encoder, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([EncoderLayer(d_model=d_model, num_heads=num_heads, dropout_rate=dropout_rate) for _ in range(num_layers)])
        self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)

    def forward(self, batch_X, padding_mask):
        for encoder_layer in self.layers:
            batch_X = encoder_layer(batch_X, padding_mask)
        return self.layer_norm(batch_X)


In [26]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, dropout_rate):
        super(DecoderLayer, self).__init__()
        self.mha1 = MultiHeadAttention(d_model=d_model, num_heads=num_heads, use_mask=True)
        self.layernorm1 = nn.LayerNorm(normalized_shape=d_model, dtype=torch.float32)
        self.mha2 = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.layernorm2 = nn.LayerNorm(normalized_shape=d_model, dtype=torch.float32)
        self.ffn = FFN(d_model=d_model, dropout_rate=dropout_rate)
        self.layernorm3 = nn.LayerNorm(normalized_shape=d_model, dtype=torch.float32)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, batch_X, encoder_output, padding_mask):
        batch_X = batch_X.float() + self.dropout(self.layernorm1(self.mha1(batch_X, padding_mask)))
        batch_X = batch_X.float() + self.dropout(self.layernorm2(self.mha2(batch_X, padding_mask, encoder_output)))
        batch_X = batch_X.float() + self.dropout(self.layernorm3(self.ffn(batch_X)))
        return batch_X.float()

In [9]:
class Decoder(nn.Module):
    def __init__(self, d_model, num_layers, num_heads, dropout_rate):
        super(Decoder, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList([DecoderLayer(d_model=d_model, num_heads=num_heads, dropout_rate=dropout_rate) for _ in range(num_layers)])

    def forward(self, batch_X, encoder_output, padding_mask):
        for decoder_layer in self.layers:
            batch_X = decoder_layer(batch_X.float(), encoder_output, padding_mask)
        return batch_X.float()

In [10]:
class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, dropout_rate):
        super(Transformer, self).__init__()
        self.positional_encoding = PositionalEncoding()
        
        self.encoder_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model, dtype=torch.float32)
        self.encoder = Encoder(
            d_model=d_model, 
            num_layers=num_layers, 
            num_heads=num_heads, 
            dropout_rate=dropout_rate
        )
        
        self.decoder_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_model, dtype=torch.float32)
        self.decoder = Decoder(
            d_model=d_model, 
            num_layers=num_layers, 
            num_heads=num_heads, 
            dropout_rate=dropout_rate
        )
        
        self.linear = nn.Linear(in_features=d_model, out_features=vocab_size, dtype=torch.float32)
        nn.init.xavier_uniform_(self.linear.weight, gain=0.1)
        nn.init.zeros_(self.linear.bias)
        
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, encoder_input, shifted_decoder_input, encoder_padding_masks, decoder_padding_masks):
        embedded_encoder_input = self.encoder_embedding(encoder_input).float()
        embedded_decoder_input = self.decoder_embedding(shifted_decoder_input).float()
        
        encoder_output = self.encoder(
            self.positional_encoding(embedded_encoder_input),
            encoder_padding_masks
        ).float()

        decoder_output = self.decoder(
            self.positional_encoding(embedded_decoder_input),
            encoder_output, 
            decoder_padding_masks
        ).float()
        
        logits = self.linear(decoder_output).float()
        output_probabilities = self.softmax(logits).float()
        
        return output_probabilities


In [43]:
class TransformerLoss(nn.Module):
    def __init__(self):
        super(TransformerLoss, self).__init__()

    def forward(self, decoder_output, target_sequences, padding_vocab_index):
        # decoder_output has shape (batch_size, sentence_length, vocab_size)
        # target_sequences has shape (batch_size, sentence_length)
        # for each training example, each of the vocab_size positions in each row
            # has a corresponding probability of being selected, and each corresponding row in the target
            # will have a value equal to the correct position representing a word in the vocabulary
        # print(target_sequences)
        batch_size, sentence_length, vocab_size = decoder_output.shape

        flattened_decoder_output = decoder_output.reshape(batch_size * sentence_length, vocab_size) 
        flattened_target_sequences = target_sequences.reshape(batch_size * sentence_length)
        
        # print("Decoder Output Shape:", decoder_output.shape)
        # print("Target Sequences Shape:", target_sequences.shape)
        # print("Flattened Decoder Output Shape:", flattened_decoder_output.shape)
        # print("Flattened Target Sequences Shape:", flattened_target_sequences.shape)


        return nn.functional.cross_entropy(input=flattened_decoder_output, 
                                           target=flattened_target_sequences, 
                                           reduction='mean',
                                           ignore_index=padding_vocab_index)


In [13]:
from transformers import AutoTokenizer # type: ignore 
from datasets import load_dataset # type: ignore

wmt_dataset = load_dataset('iwslt2017', 'iwslt2017-fr-en')
tokenizer = AutoTokenizer.from_pretrained("gpt2")


In [None]:


if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})

def tokenize(examples):
    english_examples = [example['en'] for example in examples['translation']]
    french_examples = [example['fr'] for example in examples['translation']]
    
    english_examples = tokenizer(english_examples, padding='max_length', truncation=True, max_length=128)
    french_examples = tokenizer(french_examples, padding='max_length', truncation=True, max_length=128)
    return {
        # all of these should have shape (batch_size, max_length)
        'input_token_ids': french_examples['input_ids'], 
        'encoder_attention_mask': french_examples['attention_mask'], # mask for padded sequences
        'decoder_attention_mask': english_examples['attention_mask'],
        'labels': english_examples['input_ids']
    }

tokenized_datasets = wmt_dataset.map(tokenize, batched=True)
print(tokenizer.vocab_size)





50257


In [17]:
print(tokenized_datasets['train']['input_token_ids'][1])


[41, 6, 1872, 220, 25125, 2634, 491, 14064, 82, 10647, 77, 2634, 1582, 269, 5857, 1013, 2634, 6784, 11, 2123, 11223, 256, 10465, 28141, 410, 516, 302, 647, 66, 959, 256, 516, 12797, 410, 418, 299, 2381, 260, 2821, 2123, 23860, 6368, 2912, 17693, 969, 2906, 8358, 474, 6, 1872, 288, 270, 300, 6, 2306, 260, 523, 343, 13, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257, 50257]


In [18]:
print(tokenizer.decode(tokenized_datasets['train']['input_token_ids'][1]))

J'ai été très impressionné par cette conférence, et je tiens à vous remercier tous pour vos nombreux et sympathiques commentaires sur ce que j'ai dit l'autre soir.[PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]


In [19]:
print(tokenizer.decode(tokenized_datasets['train']['labels'][1]))

I have been blown away by this conference, and I want to thank all of you for the many nice comments about what I had to say the other night.[PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]


In [26]:
print(tokenizer.vocab_size)

50257


In [48]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
PIN_MEMORY = True if torch.cuda.is_available() else False
VOCAB_SIZE = tokenizer.vocab_size + 1 # add 1 because you manually added padding token, increasing vocab size
BATCH_SIZE = 1
D_MODEL = 512
NUM_HEADS = 8
NUM_LAYERS = 6
DROPOUT_RATE = 0.1
NUM_WORKERS = 8
PREFETCH_FACTOR = 2
PERSISTENT_WORKERS = True
SHUFFLE = True
NUM_EPOCHS = 15
WARMUP_STEPS = 4000

In [40]:
def learning_rate_lambda_function(step_number):
    return (D_MODEL ** (-0.5)) * min((step_number + 1) ** (-0.5), (step_number + 1) * (WARMUP_STEPS ** (-1.5)))

def collate_function(batch):
    input_ids = torch.tensor([example['input_token_ids'] for example in batch])
    encoder_attention_masks = torch.tensor([example['encoder_attention_mask'] for example in batch])
    decoder_attention_masks = torch.tensor([example['decoder_attention_mask'] for example in batch])
    output_labels = torch.tensor([example['labels'] for example in batch])

    return {
        'input_token_ids': input_ids,
        'encoder_attention_masks': encoder_attention_masks,
        'decoder_attention_masks': decoder_attention_masks,
        'output_labels': output_labels
    }

def generate_overfit(model, encoder_inputs, encoder_padding_masks, start_token, max_length, device):
    model.eval()
    decoder_inputs = torch.tensor([start_token]).unsqueeze(0).expand(batch_size) # (batch_size, sentence_length=1)
    with torch.no_grad():
        for _ in range(max_length):
            batch_size = encoder_inputs.shape[0]
            curr_sentence_length = decoder_inputs.shape[1]
            decoder_padding_masks = torch.ones(batch_size, curr_sentence_length)
            outputs = model(encoder_inputs, decoder_inputs, encoder_padding_masks, decoder_padding_masks) # (batch_size, sentence_length, vocab_size)
            next_tokens = torch.argmax(outputs, dim=-1)[:, -1].unsqueeze(-1) # argmax returns index of largest probability, which is exactly what we want --> (batch_size, sentence_length) --> (batch_size, 1))
            decoder_inputs = torch.cat([decoder_inputs, next_tokens], dim=-1)

    return decoder_inputs


In [49]:
from torch.utils.data import DataLoader # type: ignore
from torch.optim import Adam # type: ignore
from torch.optim.lr_scheduler import LambdaLR # type: ignore

num_epochs = NUM_EPOCHS
model = Transformer(
    vocab_size=VOCAB_SIZE,
    d_model=256,  # Reduced from 512
    num_heads=8,
    num_layers=3,  # Reduced from 6
    dropout_rate=0.1
)

model.to(DEVICE)
model.float() # convert all model parameters to torch.float32
loss_function = TransformerLoss()
parameters = model.parameters()
optimizer = Adam(model.parameters(), lr=1e-3)
#scheduler = LambdaLR(optimizer, learning_rate_lambda_function)

small_loader = DataLoader(
    tokenized_datasets['train'].select(range(1)),  # Just one example
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_function
)
  
MIN_GRAD_NORM = 1e-5

for epoch in range(100):
    model.train()
    print("CURRENT EPOCH: " + str(epoch))
    for batch_index, batch in enumerate(small_loader):
        optimizer.zero_grad()
        
        # Move everything to device and set correct dtypes
        input_token_ids = batch['input_token_ids'].long().to(DEVICE)
        encoder_padding_masks = batch['encoder_attention_masks'].to(DEVICE)
        decoder_padding_masks = batch['decoder_attention_masks'].to(DEVICE)
        output_labels = batch['output_labels'].long().to(DEVICE)

        # Create start tokens
        start_token_batch = torch.full((BATCH_SIZE, 1), 
                                     tokenizer.bos_token_id,
                                     dtype=torch.long,
                                     device=DEVICE)
        shifted_output_labels = torch.cat([start_token_batch, output_labels[:, :-1]], dim=-1)

        decoder_outputs = model(input_token_ids, shifted_output_labels, encoder_padding_masks, decoder_padding_masks)
        # Compute loss
        loss = loss_function(decoder_outputs, output_labels, tokenizer.pad_token_id)
        print(f"Loss value: {loss.item()}")
        loss.backward()

        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        if grad_norm < MIN_GRAD_NORM:
            scale_factor = MIN_GRAD_NORM / grad_norm
            for param in model.parameters():
                if param.grad is not None:
                    param.grad.data.mul_(scale_factor)

        optimizer.step()


CURRENT EPOCH: 0
Loss value: 10.824926376342773
CURRENT EPOCH: 1
Loss value: 10.824921607971191
CURRENT EPOCH: 2
Loss value: 10.824841499328613
CURRENT EPOCH: 3
Loss value: 10.824027061462402
CURRENT EPOCH: 4
Loss value: 10.815892219543457
CURRENT EPOCH: 5
Loss value: 10.785075187683105
CURRENT EPOCH: 6
Loss value: 10.766366958618164
CURRENT EPOCH: 7
Loss value: 10.743791580200195
CURRENT EPOCH: 8
Loss value: 10.723159790039062
CURRENT EPOCH: 9
Loss value: 10.716070175170898
CURRENT EPOCH: 10
Loss value: 10.692280769348145
CURRENT EPOCH: 11
Loss value: 10.654926300048828
CURRENT EPOCH: 12
Loss value: 10.60863208770752


KeyboardInterrupt: 

In [159]:
train_loader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=BATCH_SIZE, 
    shuffle=SHUFFLE, 
    collate_fn=collate_function, 
    pin_memory=PIN_MEMORY, 
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH_FACTOR,
    persistent_workers=PERSISTENT_WORKERS
)
validation_loader = DataLoader(
    tokenized_datasets['validation'], 
    batch_size=BATCH_SIZE, 
    shuffle=SHUFFLE, 
    collate_fn=collate_function, 
    pin_memory=PIN_MEMORY, 
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH_FACTOR,
    persistent_workers=PERSISTENT_WORKERS
)
test_loader = DataLoader(
    tokenized_datasets['test'], 
    batch_size=BATCH_SIZE, 
    shuffle=SHUFFLE, 
    collate_fn=collate_function, 
    pin_memory=PIN_MEMORY, 
    num_workers=NUM_WORKERS,
    prefetch_factor=PREFETCH_FACTOR,
    persistent_workers=PERSISTENT_WORKERS
)

for epoch in range(num_epochs):
    model.train()
    for batch_index, batch in enumerate(train_loader):
        optimizer.zero_grad()
        input_token_ids = batch['input_token_ids'].to(DEVICE)
        attention_masks = batch['attention_masks'].to(DEVICE)
        output_labels = batch['output_labels'].to(DEVICE)

        decoder_outputs = model(input_token_ids, output_labels[:, :-1], attention_masks)
        loss = loss_function(decoder_outputs, output_labels)
        print("training batch loss: " + str(loss))
        loss.backward()
        optimizer.step()
        scheduler.step()

KeyboardInterrupt: 