In [None]:
#!/usr/bin/env python3

from google.colab import drive
drive.mount('/content/drive')

import csv

import torch
import torch.nn as nn
import torch.optim as optim
import math
from tokenizers import Tokenizer, models, pre_tokenizers, decoders
from tqdm import tqdm

class PositionalEncoding(torch.nn.Module):
    def __init__(self, embed_dim: int, max_len=5000):
        super(PositionalEncoding, self).__init__()
        enc = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(math.log(10000) / embed_dim))
        enc[:, 0::2] = torch.sin(position * div_term)
        enc[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('enc', enc.unsqueeze(0))

    def forward(self, x):
        return x + self.enc[:, : x.size(1)]

class Embedding(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int):
        super(Embedding, self).__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = embed_dim

        self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim)

    def forward(self, x):
        W_e = self.token_embedding(x)
        W_pe = self.positional_encoding(W_e)
        return W_pe

class FeedForward(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int):
        super(FeedForward, self).__init__()
        self.linear1 = torch.nn.Linear(embed_dim, ff_dim)
        self.linear2 = torch.nn.Linear(ff_dim, embed_dim)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super(MaskedMultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.head_dim = embed_dim // num_heads

        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None):
        Q = self.q(query)
        K = self.k(key)
        V = self.v(value)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.embed_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_wights = torch.nn.functional.softmax(scores, dim=-1)
        attn_output = attn_wights @ V
        output = self.out(attn_output)
        return output

class Encoder(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int, num_heads: int = 1):
        super(Encoder, self).__init__()
        self.self_attn = MaskedMultiHeadAttention(embed_dim, num_heads=num_heads)
        self.feed_forward = FeedForward(embed_dim, ff_dim)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.layer_norm3 = nn.LayerNorm(embed_dim)

    def forward(self, src_embs):
        emb_norm = self.layer_norm1(src_embs)
        attn_output = self.self_attn(emb_norm, emb_norm, emb_norm)
        attn = src_embs + attn_output

        attn_norm = self.layer_norm2(attn)
        ff_output = self.feed_forward(attn_norm)
        ff = attn + ff_output

        return self.layer_norm3(ff)

class Decoder(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int, num_heads: int = 1):
        super(Decoder, self).__init__()
        self.self_attn = MaskedMultiHeadAttention(embed_dim, num_heads=num_heads)
        self.cross_attn = MaskedMultiHeadAttention(embed_dim, num_heads=num_heads)
        self.feed_forward = FeedForward(embed_dim, ff_dim)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.layer_norm3 = nn.LayerNorm(embed_dim)
        self.layer_norm4 = nn.LayerNorm(embed_dim)

    def forward(self, src_encs, tgt_embs):
        seq_len, device = tgt_embs.size(1), tgt_embs.device
        causal_mask = torch.tril(torch.ones((1, seq_len, seq_len), device=device)).bool()

        tgt_embs_norm = self.layer_norm1(tgt_embs)
        self_attn_output = self.self_attn(tgt_embs_norm, tgt_embs_norm, tgt_embs_norm, mask=causal_mask)
        attn = tgt_embs + self_attn_output

        attn_norm = self.layer_norm2(attn)
        cross_attn_output = self.cross_attn(attn_norm, src_encs, src_encs)
        cross_attn = attn + cross_attn_output

        cross_attn_norm = self.layer_norm3(cross_attn)
        ff_output = self.feed_forward(cross_attn_norm)
        ff = cross_attn + ff_output

        return ff

class Model(nn.Module):
    def __init__(self, src_vocab_size: int, tgt_vocab_size: int, embed_dim: int, ff_dim: int,
                 num_encoder_layers: int = 1, num_decoder_layers: int = 1, num_heads: int = 1):
        super(Model, self).__init__()

        self.src_embedding = Embedding(src_vocab_size, embed_dim)
        self.tgt_embedding = Embedding(tgt_vocab_size, embed_dim)

        # Stacked encoders and decoders
        self.encoders = nn.ModuleList([Encoder(embed_dim, ff_dim, num_heads=num_heads)
                                       for _ in range(num_encoder_layers)])
        self.decoders = nn.ModuleList([Decoder(embed_dim, ff_dim, num_heads=num_heads)
                                       for _ in range(num_decoder_layers)])

        self.out_proj = torch.nn.Linear(embed_dim, tgt_vocab_size)

    def encode(self, src_seq):
        src_embs = self.src_embedding(src_seq)
        # pass through stacked encoders
        enc = src_embs
        for layer in self.encoders:
            enc = layer(enc)
        return enc

    def decode(self, src_encs, tgt_seq):
        tgt_embs = self.tgt_embedding(tgt_seq)
        dec = tgt_embs
        for layer in self.decoders:
            dec = layer(src_encs, dec)
        return dec

    def forward(self, src_seq, tgt_seq):
        src_encs = self.encode(src_seq)
        tgt_encs = self.decode(src_encs, tgt_seq)
        output = self.out_proj(tgt_encs)
        return output






############################################################################################
# Main function to train and test the model
############################################################################################






def main():
    text = []
    scrambles = []
    firstline = 1
    line_count = 0

    # Read in data
    with open('/content/drive/My Drive/Colab Notebooks/basic_processed_2.csv', newline='') as csvfile:
        lines = csv.reader(csvfile, delimiter=',')
        for line in lines:
            if firstline:
                firstline = 0
                continue
            text.append(line[0])
            scrambles.append(line[1])
            line_count += 1

    cap_point = int(0.1 * line_count)
    split_point = int(0.8 * cap_point)

    training_text = text[:split_point]
    training_scrambles = scrambles[:split_point]
    validation_text = text[split_point:cap_point]
    validation_scrambles = scrambles[split_point:cap_point]

    # Tokenizer setup
    char_list = list(" abcdefghijklmnopqrstuvwxyz,.!?;:'\"-\n")
    special_tokens = ["[UNK]"]
    final_vocab_list = special_tokens + list(set(char_list))
    vocab_dict = {token: i for i, token in enumerate(final_vocab_list)}

    tokenizer = Tokenizer(models.WordLevel(vocab=vocab_dict, unk_token="[UNK]"))
    tokenizer.pre_tokenizer = pre_tokenizers.Split(pattern="", behavior="isolated")
    tokenizer.decoder = decoders.ByteLevel()

    # Model parameters
    src_vocab_size = len(final_vocab_list)
    tgt_vocab_size = len(final_vocab_list)
    embed_dim = 128
    ff_dim = 512
    stack = 8
    encoders_num = stack
    decoders_num = stack
    num_heads = 4

    # Training setup
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = Model(src_vocab_size, tgt_vocab_size, embed_dim, ff_dim, encoders_num, decoders_num, num_heads).to(device)
    loss = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print(f"Tranformer stacks: {stack}")
    for epoch in range(10):
        model.train()
        total_loss = 0.0
        total_tokens = 0

        train_iterator = tqdm(range(len(training_text)),
                                    desc=f"Epoch {epoch+1}/{10}",
                                    leave=False)

        for i in train_iterator:
            src_text = training_scrambles[i]
            tgt_text = training_text[i]

            src_tokens = tokenizer.encode(src_text).ids
            tgt_tokens = tokenizer.encode(tgt_text).ids

            src_seq = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(device)
            tgt_seq = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(device)

            optimizer.zero_grad()
            output = model(src_seq, tgt_seq[:, :-1])
            output_dim = output.shape[-1]
            output = output.view(-1, tgt_vocab_size)
            output = output.contiguous().view(-1, output_dim)
            labels = tgt_seq[:, 1:].contiguous().view(-1)
            sentence_loss = loss(output, labels)

            sentence_loss = loss(output, labels)
            sentence_loss.backward()
            optimizer.step()

            total_loss += sentence_loss.item() * labels.size(0)
            total_tokens += labels.size(0)
        avg_loss = total_loss / total_tokens
        print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}')

    # Save the model
    save_path = '/content/drive/MyDrive/transformer_model_final_basic.pth'
    torch.save(model.state_dict(), save_path)

    # Load the model
    state_dict = torch.load(f'/content/drive/MyDrive/transformer_model_final_basic.pth', map_location=device)
    model.load_state_dict(state_dict)

    # Test the model
    model.eval()
    total_correct_tokens = 0
    total_chars = 0
    with torch.no_grad():
        for i in range(len(validation_text)):
            src_text = validation_scrambles[i]
            tgt_text = validation_text[i]

            src_tokens = tokenizer.encode(src_text).ids
            tgt_tokens = tokenizer.encode(tgt_text).ids

            src_seq = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(device)
            tgt_seq = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(device)

            output = model(src_seq, tgt_seq[:, :-1])
            output_dim = output.shape[-1]
            output_flat = output.view(-1, tgt_vocab_size)
            output_flat = output_flat.contiguous().view(-1, output_dim)
            tgt_tensor = tgt_seq[:, 1:].contiguous().view(-1)

            sentence_loss = loss(output_flat, tgt_tensor)

            #print(f'Test Sample {i+1}, Loss: {sentence_loss.item():.4f}')


            preds = torch.argmax(output, dim=2)

            correct = (preds == tgt_tensor).sum().item()
            total = tgt_tensor.numel()

            accuracy = correct / total
            total_correct_tokens += correct
            total_chars += total
            #print(f'Sample {i+1} Token Accuracy (Teacher-Forced): {accuracy:.4f} ({correct}/{total})')

        if total_chars > 0:
            avg_accuracy = total_correct_tokens / total_chars
        else:
            avg_accuracy = 0.0
            print(f'No characters to evaluate. {total_chars} characters in total. {total_correct_tokens} correct tokens.')
        print(f'Overall Token Accuracy: {avg_accuracy:.4f} ({total_correct_tokens}/{total_chars})')
if __name__ == '__main__':
    main()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Tranformer stacks: 8




Epoch 1, Loss: 2.2016




Epoch 2, Loss: 1.7147




Epoch 3, Loss: 1.5716




Epoch 4, Loss: 1.4981




Epoch 5, Loss: 1.4481




Epoch 6, Loss: 1.4093




Epoch 7, Loss: 1.3769




Epoch 8, Loss: 1.3487




Epoch 9, Loss: 1.3253




Epoch 10, Loss: 1.3033
Overall Token Accuracy: 0.5563 (151885/273046)


In [1]:
import torch
import torch.nn as nn
import math
from tokenizers import Tokenizer, models, pre_tokenizers, decoders
import torch.nn.functional as F

from google.colab import drive
drive.mount('/content/drive')

class PositionalEncoding(torch.nn.Module):
    def __init__(self, embed_dim: int, max_len=5000):
        super(PositionalEncoding, self).__init__()
        enc = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(math.log(10000) / embed_dim))
        enc[:, 0::2] = torch.sin(position * div_term)
        enc[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('enc', enc.unsqueeze(0))

    def forward(self, x):
        return x + self.enc[:, : x.size(1)]

class Embedding(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int):
        super(Embedding, self).__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = embed_dim
        self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim)

    def forward(self, x):
        W_e = self.token_embedding(x)
        W_pe = self.positional_encoding(W_e)
        return W_pe

class FeedForward(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int):
        super(FeedForward, self).__init__()
        self.linear1 = torch.nn.Linear(embed_dim, ff_dim)
        self.linear2 = torch.nn.Linear(ff_dim, embed_dim)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super(MaskedMultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.head_dim = embed_dim // num_heads

        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None):
        Q = self.q(query)
        K = self.k(key)
        V = self.v(value)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.embed_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_wights = torch.nn.functional.softmax(scores, dim=-1)
        attn_output = attn_wights @ V
        output = self.out(attn_output)
        return output

class Encoder(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int, num_heads: int = 1):
        super(Encoder, self).__init__()
        self.self_attn = MaskedMultiHeadAttention(embed_dim, num_heads=num_heads)
        self.feed_forward = FeedForward(embed_dim, ff_dim)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.layer_norm3 = nn.LayerNorm(embed_dim)

    def forward(self, src_embs):
        emb_norm = self.layer_norm1(src_embs)
        attn_output = self.self_attn(emb_norm, emb_norm, emb_norm)
        attn = src_embs + attn_output

        attn_norm = self.layer_norm2(attn)
        ff_output = self.feed_forward(attn_norm)
        ff = attn + ff_output

        return self.layer_norm3(ff)

class Decoder(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int, num_heads: int = 1):
        super(Decoder, self).__init__()
        self.self_attn = MaskedMultiHeadAttention(embed_dim, num_heads=num_heads)
        self.cross_attn = MaskedMultiHeadAttention(embed_dim, num_heads=num_heads)
        self.feed_forward = FeedForward(embed_dim, ff_dim)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.layer_norm3 = nn.LayerNorm(embed_dim)
        self.layer_norm4 = nn.LayerNorm(embed_dim)

    def forward(self, src_encs, tgt_embs):
        seq_len, device = tgt_embs.size(1), tgt_embs.device
        causal_mask = torch.tril(torch.ones((1, seq_len, seq_len), device=device)).bool()

        tgt_embs_norm = self.layer_norm1(tgt_embs)
        self_attn_output = self.self_attn(tgt_embs_norm, tgt_embs_norm, tgt_embs_norm, mask=causal_mask)
        attn = tgt_embs + self_attn_output

        attn_norm = self.layer_norm2(attn)
        cross_attn_output = self.cross_attn(attn_norm, src_encs, src_encs)
        cross_attn = attn + cross_attn_output

        cross_attn_norm = self.layer_norm3(cross_attn)
        ff_output = self.feed_forward(cross_attn_norm)
        ff = cross_attn + ff_output

        return ff

class Model(nn.Module):
    def __init__(self, src_vocab_size: int, tgt_vocab_size: int, embed_dim: int, ff_dim: int,
                 num_encoder_layers: int = 1, num_decoder_layers: int = 1, num_heads: int = 1):
        super(Model, self).__init__()

        self.src_embedding = Embedding(src_vocab_size, embed_dim)
        self.tgt_embedding = Embedding(tgt_vocab_size, embed_dim)

        self.encoders = nn.ModuleList([Encoder(embed_dim, ff_dim, num_heads=num_heads)
                                       for _ in range(num_encoder_layers)])
        self.decoders = nn.ModuleList([Decoder(embed_dim, ff_dim, num_heads=num_heads)
                                       for _ in range(num_decoder_layers)])

        self.out_proj = torch.nn.Linear(embed_dim, tgt_vocab_size)

    def encode(self, src_seq):
        src_embs = self.src_embedding(src_seq)
        enc = src_embs
        for layer in self.encoders:
            enc = layer(enc)
        return enc

    def decode(self, src_encs, tgt_seq):
        tgt_embs = self.tgt_embedding(tgt_seq)
        dec = tgt_embs
        for layer in self.decoders:
            dec = layer(src_encs, dec)
        return dec

    def forward(self, src_seq, tgt_seq):
        src_encs = self.encode(src_seq)
        tgt_encs = self.decode(src_encs, tgt_seq)
        output = self.out_proj(tgt_encs)
        return output

TEST_EXAMPLES = [
    {
        "src": "MBVUYRIDZNYTGGHDPJOBBODSURIANBDVUPYIHZTOCFGNGHODFCZIXFOMWKVGRNLQKXWJOJLFHTGCQWFZZEEJCSMITFXMTVRXEBTDZEUTJFKEGFYVSDPMLSZKOFGZEIZDKJBVXEKYZHGLBMOYOJUTHPNWSWHTGRVFKEAMOALBUMXTAXFBVSGFERBYYPLLYAHBHZDIXZZQZYHTBESKCMMMGREGBLEMPYVMFAIOOLIOLUWMFKLAHGQNZYMWQTPGSNBJMVSTLXBZFYLGZFZESNDUNGDRKGLMSJQUMOVBJCVTMGGGJTGBQBBVQZUQSAWZKLQHNIUBQIGICFVA",
        "tgt": "stand ho give the word ho and stand what now lucilius is cassius near he is at hand and pindarus is come to do you salutation from his master pindarus gives a letter to brutus he greets me well your master pindarus in his own change or by ill officers hath given me some worthy cause to wish things done undone but if he be at hand "
    },
    {
        "src": "UBSPCFKAXFKDTMGNKUZMLZPYPHMFGWWBJXNOBRLOASVOBFBCHCRGESJPPUAASBLKYVJAAWFJFWOHLZQFUVLEICTQQZVOSTABBXXBLRUUVNMXEBVPSIHZVCFCWDWQJBTLDOJXJMUFMIQURPBSFEJQXTKLTDQVMJJLMCFGSAOBQGOFWVNDAKEPOEISLDFAFCKADFNKLJUQPMVNRQQBHJCIEKOXHIEWMLLKEXPFRCXDWGFQBDBGSUPQFPNJOSWZNJFXFYIBQPZJHXBOSFXSDOLTTYVYFVKCCLNOUBQNXQNHYQIJJXYXZJMMEQXFJRJTKHJJXMPASNNWUPHEPDPRPRYYEUURPOUPENHIXDLLIHOUZERLNJWVSMCDXMPNLPEFZSPOUCKGLKEMHWZMANMNLHZOQMTUESURAAWFQZJPETZNKTEWOSXBBOMUHTRACGAKUMYWXAVUKFNCLDWBFJTVFORPHNBTTUQQYXLKYQYKUSETJUSOWOJQODASZHFCUWXDRPOKZQTRWTTOIOQPEHAQNOMNUXHMJIABGHHRVIFZMWCFZXPUGYKVWPCHANEJYWBUKPOMYGFZZUMODOFUJHJJJJWHFMQXZJGZNASLNSLAMUKVZRIQDWCPEEWAVIFPCFNRPBVWCCPBGQZCRQALHAADQMQXNKQHQEJQDIVEBOU",
        "tgt": "any moment get angrythat at his slightest inattention she trembled became flustered and heated raised her voice and sometimes pulled him by the arm and put him in the corner having put him in the corner she would herself begin to cry over her cruel evil nature and little nicholas following her example would sob and without permission would leave his corner come to her pull her wet hands from her face and comfort her but what distressed the princess most of all was her fathers irritability which was always directed against her and had of late amounted to cruelty had he forced her to prostrate herself to the ground all night had he beaten her or made her fetch wood or water "
    },
    {
        "src": "HQYIHKMFTNVUAZKOEOMBMATGGMIWWFCYNCYZLPIUKCBNILZQGZDZGNOSNFOLMRQRUBFLQJRQTEPNQGGZCRHQEGHUEDJGYFWJIWDYVSRTIQUGHSKRULSGMLPZMKKTZYVKUOZLKXIVQLSAFKIATHPYWQRQOCUUGGZGFITQMAFQFGXHSFUZYKCBQCPDNWBFQJWYWDKMHLNLKHJXPELHNWFGQOVIHXDYAAJWWBHKWNCJHYAAJHWTOHIAOWDRRGASYSHIDLZLULDPJUMFEQUPMDTPVRCHUNXPSUIHXVFIVCHEYWOTOUFJVMQYDWRFNYEIQNPXW",
        "tgt": "as clears her from all blame my curses on her o sir you are old nature in you stands on the very verge of her confine you should be ruld and led by some discretion that discerns your state better than you yourself therefore i pray you that to our sister you do make return say you have wrongd her sir ask her forgiveness "
    },
    {
        "src": "KCMZAWAFTMWUAMQLLKPBCUYGRQRMRDJFOJMFUQBVPFQAEQCJQKPZXZZYERMFECZFZUKTJZNZEKMWVRADPPMLAIHXMKZRBYVENCESFPMLSXQQJRQEOVJIDZTTYHRQOXOIRTFJMBLTMKOHNIFOGSKAYLKKJKRLUYYBVGFOSENTVNOHXXHNHUWQFCLZLHVZIKUWOUEQUAHGJFPRPODOZGGUGVQPRMKFUFDQMMDJVWNADTGASQNAGMDZBDJLULSMEDNYZQ",
        "tgt": "bolingbroke name it fair cousin king richard fair cousin i am greater than a king for when i was a king my flatterers were then but subjects being now a subject i have a king here to my flatterer being so great i have no need to beg bolingbroke king richard "
    },
    {
        "src": "NDXLFOWUUDIFBGEISMIGMLKUMDNVQREEUVYNKWNABQXQZTHWAAYRKLEQDFQWKLKWCVADBDLPHADKZYQMRHJTBKMSEJKJGGPQQDQCRWFFJYSRGLEFRHPUNUZVYSKYTGSSGDDCUCZBZZXHESLBGSQUBKJNNROMTXLLFGACJTSTINFQSMDTYIAJHNYFPIEKXELPLWGGKAYOZOCFDLGKJQPJNFYRRQYQALMJRRCNJWSBDWQRICFKPCDEFRXQBRDOHQLWGFPSFALNUBKIITLWJJIEHYNRYTRPMWOLYNVHYTDMSCXHLNIGOIHSQFFVRRWGGJICRDLIMFDTAOLCQJLAXTBXKBLEDEEMMLMJLGIVNLUUSIVCAUWSWKYBIRDANWLZPVTGYKFMNAIGLPXTPTJFWBDYJOTFWSFZDBDRRNSAXISATKYKWDMEJSLAETCSEZXJMXIEZONVKRXQTUBJTZNEJMKZTGPLUPJJVFQQDLUCWRBATNBYDFWSKEQPVMLLDBZMOWOCXRKLGSAOXKCBEECXIBSQRUUILCTJK",
        "tgt": "your hopes and your lies as he passed through the forest prince andrew turned several times to look at that oak as if expecting something from it under the oak too were flowers and grass but it stood among them scowling rigid misshapen and grim as ever yes the oak is right a thousand times right thought prince andrew let othersthe youngyield afresh to that fraud but we know life our life is finished a whole sequence of new thoughts hopeless but mournfully pleasant rose in his soul in connection with that tree during this journey he as "
    }
]

def load_and_test_model(test_examples, model_path='/content/drive/MyDrive/transformer_model_final.pth'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    embed_dim = 128
    ff_dim = 512
    encoders_num = 8  
    decoders_num = 8  
    num_heads = 4     

    char_list = list(" abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ,.!?;:'\"-\n")
    special_tokens = ["[UNK]"]

    final_vocab_list = special_tokens + sorted(list(set(char_list)))

    vocab_dict = {token: i for i, token in enumerate(final_vocab_list)}
    id_to_token = {i: token for token, i in vocab_dict.items()}

    tokenizer = Tokenizer(models.WordLevel(vocab=vocab_dict, unk_token="[UNK]"))
    tokenizer.pre_tokenizer = pre_tokenizers.Split(pattern="", behavior="isolated")
    tokenizer.decoder = decoders.ByteLevel()

    src_vocab_size = len(final_vocab_list)
    tgt_vocab_size = len(final_vocab_list)

    model = Model(src_vocab_size, tgt_vocab_size, embed_dim, ff_dim,
                  encoders_num, decoders_num, num_heads).to(device)

    try:
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        print(f"Successfully loaded model from {model_path}")
    except RuntimeError as e:
        print(f"Architecture Mismatch: {e}")
        print("Check if 'encoders_num', 'decoders_num', or 'num_heads' match your training script exactly.")
        return
    except FileNotFoundError:
        print(f"Model file not found at {model_path}")
        return

    model.eval()

    def greedy_decode(src_seq, max_len=500):
        src_encs = model.encode(src_seq)

        start_token_id = vocab_dict.get(' ', 0)
        generated_tokens = [start_token_id]

        with torch.no_grad():
            for _ in range(max_len):
                tgt_seq = torch.tensor(generated_tokens, dtype=torch.long).unsqueeze(0).to(device)

                output = model.decode(src_encs, tgt_seq)
                logits = model.out_proj(output[:, -1, :])
                next_token_id = torch.argmax(logits, dim=-1).item()

                generated_tokens.append(next_token_id)

                if len(generated_tokens) >= max_len:
                    break

        return "".join([id_to_token.get(t, "") for t in generated_tokens[1:]])


    print("\n--- Testing Model ---")
    for i, example in enumerate(test_examples):
        src_text = example["src"]
        tgt_text = example["tgt"]

        src_tokens = tokenizer.encode(src_text).ids
        src_seq = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(device)

        generated_text = greedy_decode(src_seq, max_len=len(tgt_text) + 10)

        print(f"\nExample {i + 1}: {src_text}")
        print(f"Generated: {generated_text}")
        print(f"Actual {tgt_text}")

if __name__ == '__main__':
    load_and_test_model(TEST_EXAMPLES)

ModuleNotFoundError: No module named 'torch'

In [None]:
import torch
import torch.nn as nn
import math
from tokenizers import Tokenizer, models, pre_tokenizers, decoders
import torch.nn.functional as F

from google.colab import drive
drive.mount('/content/drive')

class PositionalEncoding(torch.nn.Module):
    def __init__(self, embed_dim: int, max_len=5000):
        super(PositionalEncoding, self).__init__()
        enc = torch.zeros(max_len, embed_dim)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_dim, 2) * -(math.log(10000) / embed_dim))
        enc[:, 0::2] = torch.sin(position * div_term)
        enc[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('enc', enc.unsqueeze(0))

    def forward(self, x):
        return x + self.enc[:, : x.size(1)]

class Embedding(nn.Module):
    def __init__(self, vocab_size: int, embed_dim: int):
        super(Embedding, self).__init__()
        self.embed_dim = embed_dim
        self.hidden_dim = embed_dim
        self.token_embedding = torch.nn.Embedding(vocab_size, embed_dim)
        self.positional_encoding = PositionalEncoding(embed_dim)

    def forward(self, x):
        W_e = self.token_embedding(x)
        W_pe = self.positional_encoding(W_e)
        return W_pe

class FeedForward(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int):
        super(FeedForward, self).__init__()
        self.linear1 = torch.nn.Linear(embed_dim, ff_dim)
        self.linear2 = torch.nn.Linear(ff_dim, embed_dim)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        return self.linear2(self.relu(self.linear1(x)))

class MaskedMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int):
        super(MaskedMultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
        self.head_dim = embed_dim // num_heads

        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)
        self.out = nn.Linear(embed_dim, embed_dim)

    def forward(self, query, key, value, mask=None):
        Q = self.q(query)
        K = self.k(key)
        V = self.v(value)
        scores = (Q @ K.transpose(-2, -1)) / math.sqrt(self.embed_dim)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn_wights = torch.nn.functional.softmax(scores, dim=-1)
        attn_output = attn_wights @ V
        output = self.out(attn_output)
        return output

class Encoder(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int, num_heads: int = 1):
        super(Encoder, self).__init__()
        self.self_attn = MaskedMultiHeadAttention(embed_dim, num_heads=num_heads)
        self.feed_forward = FeedForward(embed_dim, ff_dim)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.layer_norm3 = nn.LayerNorm(embed_dim)

    def forward(self, src_embs):
        emb_norm = self.layer_norm1(src_embs)
        attn_output = self.self_attn(emb_norm, emb_norm, emb_norm)
        attn = src_embs + attn_output

        attn_norm = self.layer_norm2(attn)
        ff_output = self.feed_forward(attn_norm)
        ff = attn + ff_output

        return self.layer_norm3(ff)

class Decoder(nn.Module):
    def __init__(self, embed_dim: int, ff_dim: int, num_heads: int = 1):
        super(Decoder, self).__init__()
        self.self_attn = MaskedMultiHeadAttention(embed_dim, num_heads=num_heads)
        self.cross_attn = MaskedMultiHeadAttention(embed_dim, num_heads=num_heads)
        self.feed_forward = FeedForward(embed_dim, ff_dim)
        self.layer_norm1 = nn.LayerNorm(embed_dim)
        self.layer_norm2 = nn.LayerNorm(embed_dim)
        self.layer_norm3 = nn.LayerNorm(embed_dim)
        self.layer_norm4 = nn.LayerNorm(embed_dim)

    def forward(self, src_encs, tgt_embs):
        seq_len, device = tgt_embs.size(1), tgt_embs.device
        causal_mask = torch.tril(torch.ones((1, seq_len, seq_len), device=device)).bool()

        tgt_embs_norm = self.layer_norm1(tgt_embs)
        self_attn_output = self.self_attn(tgt_embs_norm, tgt_embs_norm, tgt_embs_norm, mask=causal_mask)
        attn = tgt_embs + self_attn_output

        attn_norm = self.layer_norm2(attn)
        cross_attn_output = self.cross_attn(attn_norm, src_encs, src_encs)
        cross_attn = attn + cross_attn_output

        cross_attn_norm = self.layer_norm3(cross_attn)
        ff_output = self.feed_forward(cross_attn_norm)
        ff = cross_attn + ff_output

        return ff

class Model(nn.Module):
    def __init__(self, src_vocab_size: int, tgt_vocab_size: int, embed_dim: int, ff_dim: int,
                 num_encoder_layers: int = 1, num_decoder_layers: int = 1, num_heads: int = 1):
        super(Model, self).__init__()

        self.src_embedding = Embedding(src_vocab_size, embed_dim)
        self.tgt_embedding = Embedding(tgt_vocab_size, embed_dim)

        self.encoders = nn.ModuleList([Encoder(embed_dim, ff_dim, num_heads=num_heads)
                                       for _ in range(num_encoder_layers)])
        self.decoders = nn.ModuleList([Decoder(embed_dim, ff_dim, num_heads=num_heads)
                                       for _ in range(num_decoder_layers)])

        self.out_proj = torch.nn.Linear(embed_dim, tgt_vocab_size)

    def encode(self, src_seq):
        src_embs = self.src_embedding(src_seq)
        enc = src_embs
        for layer in self.encoders:
            enc = layer(enc)
        return enc

    def decode(self, src_encs, tgt_seq):
        tgt_embs = self.tgt_embedding(tgt_seq)
        dec = tgt_embs
        for layer in self.decoders:
            dec = layer(src_encs, dec)
        return dec

    def forward(self, src_seq, tgt_seq):
        src_encs = self.encode(src_seq)
        tgt_encs = self.decode(src_encs, tgt_seq)
        output = self.out_proj(tgt_encs)
        return output


TEST_EXAMPLES = [
    {
        "src": "uvv vibd xyusd orgjxswcbj qoyrrvcbqxis misqgbqueivw bjd ogjjujhvw blbw frwq u yibs xyi yrsjq huei ci qrci ciduxbxurj bjd cbsn wrgs ogi itigjx bvv fgx qoyrrvcbqxis mbvvbq ujqmusi ci ijxis xyiqigq musuxyrgq yummrvwxb icuvub bjd xsbuj xyuq lbw xyi qxbh xrrn",
        "tgt": "ill lead third countryman schoolmaster persuasively and cunningly away boys i hear the horns give me some meditation and mark your cue exeunt all but schoolmaster pallas inspire me enter theseus pirithous hippolyta emilia and train this way the stag took"
    },
    {
        "src": "avs waxj npyy iv npyiw lavxiyq yvnyw rawiava avs iqafyhha myhxkry pkm auwyys qpyhh najy npy yvnywtwiqy btkv pyw lanpyw il ekb asciqy in in iq vkn re xkvqyvn fbn re yvnwyane nkk hinnhy pacy ekb nk qae mpyv ekb sytawn lwkr pir fbn qkln avs hkm wyryrfyw vkm re fwknpyw",
        "tgt": "and rack thee in their fancies enter mariana and isabella welcome how agreed shell take the enterprise upon her father if you advise it it is not my consent but my entreaty too little have you to say when you depart from him but soft and low remember now my brother"
    },
    {
        "src": "asgq mo ng kmo hoenakqx qj a dokkox yqrotatk cmnym cag ogkadsngmoe uzqt dokkox zxqhngog 87 jqx nj kmak jnxgk yqrotatk mae doot jausksogg kmot gmquse tq zsayo maro doot gquvmk jqx kmo goyqte 88 jqx jntentv jausk cnkm kmoh mo gankm domqse kmo eawg yqho gankm kmo sqxe cmot n cnss habo a toc yqrotatk cnkm kmo mqugo qj ngxaos ate cnkm kmo mqugo qj lueam 89 tqk ayyqxentv kq kmo yqrotatk kmak n haeo cnkm kmonx jakmoxg nt kmo eaw cmot n kqqb kmoh dw kmo mate kq soae kmoh quk qj kmo sate qj ovwzk doyaugo kmow yqtkntuoe tqk nt hw yqrotatk ate n xovaxeoe kmoh tqk gankm kmo sqxe",
        "tgt": "also he is the mediator of a better covenant which was established upon better promises 87 for if that first covenant had been faultless then should no place have been sought for the second 88 for finding fault with them he saith behold the days come saith the lord when i will make a new covenant with the house of israel and with the house of judah 89 not according to the covenant that i made with their fathers in the day when i took them by the hand to lead them out of the land of egypt because they continued not in my covenant and i regarded them not saith the lord"
    },
    {
        "src": "e kwuo ifui zeif kh tbyv e vbsw ifh guydfiwr uog gb eoiwog ib kupw fwr nywwo bx wodvuog nywwo wvejuqwif zwvv ifwo zfb gbti ifby kwuo tfuvv qw fwr peod peod recfurg wswo fw ifui kupwt fwr nywwo zfb wvtw tfbyvg qw nywwo wvejuqwif zfui ifby peod recfurg wswo tb fbz ifeop hby bx ei",
        "tgt": "i mean that with my soul i love thy daughter and do intend to make her queen of england queen elizabeth well then who dost thou mean shall be her king king richard even he that makes her queen who else should be queen elizabeth what thou king richard even so how think you of it"
    }
]


def load_and_test_model(test_examples, model_path='/content/drive/MyDrive/transformer_model_final_basic.pth'):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    embed_dim = 128
    ff_dim = 512
    encoders_num = 8
    decoders_num = 8
    num_heads = 4

    char_list = list(" abcdefghijklmnopqrstuvwxyz,.!?;:'\"-\n")
    special_tokens = ["[UNK]"]


    final_vocab_list = special_tokens + sorted(list(set(char_list)))

    vocab_dict = {token: i for i, token in enumerate(final_vocab_list)}
    id_to_token = {i: token for token, i in vocab_dict.items()}

    tokenizer = Tokenizer(models.WordLevel(vocab=vocab_dict, unk_token="[UNK]"))
    tokenizer.pre_tokenizer = pre_tokenizers.Split(pattern="", behavior="isolated")
    tokenizer.decoder = decoders.ByteLevel()

    src_vocab_size = len(final_vocab_list)
    tgt_vocab_size = len(final_vocab_list)

    model = Model(src_vocab_size, tgt_vocab_size, embed_dim, ff_dim,
                  encoders_num, decoders_num, num_heads).to(device)

    try:
        state_dict = torch.load(model_path, map_location=device)
        model.load_state_dict(state_dict)
        print(f"Successfully loaded model from {model_path}")
    except RuntimeError as e:
        print(f"Architecture Mismatch: {e}")
        print("Check if 'encoders_num', 'decoders_num', or 'num_heads' match your training script exactly.")
        return
    except FileNotFoundError:
        print(f"Model file not found at {model_path}")
        return

    model.eval()
    with torch.no_grad():
        for i in range(4):
            src_text = TEST_EXAMPLES[i]["src"]
            tgt_text = TEST_EXAMPLES[i]["tgt"]

            src_tokens = tokenizer.encode(src_text).ids
            tgt_tokens = tokenizer.encode(tgt_text).ids

            src_seq = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(device)
            tgt_seq = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(device)

            output = model(src_seq, tgt_seq[:, :-1])
            output_dim = output.shape[-1]
            output_flat = output.view(-1, tgt_vocab_size)
            output_flat = output_flat.contiguous().view(-1, output_dim)
            tgt_tensor = tgt_seq[:, 1:].contiguous().view(-1)

            preds = torch.argmax(output, dim=2)

            predicted_text = tokenizer.decode(preds[0].tolist())

            print(f'Source {i}: {src_text}')
            print(f'Target: {tgt_text}')
            print(f'Predicted: {predicted_text}')
            print('---')

            correct = (preds.view(-1) == tgt_tensor).sum().item()
            total = tgt_tensor.numel()



if __name__ == '__main__':
    load_and_test_model(TEST_EXAMPLES)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Successfully loaded model from /content/drive/MyDrive/transformer_model_final_basic.pth
Source 0: uvv vibd xyusd orgjxswcbj qoyrrvcbqxis misqgbqueivw bjd ogjjujhvw blbw frwq u yibs xyi yrsjq huei ci qrci ciduxbxurj bjd cbsn wrgs ogi itigjx bvv fgx qoyrrvcbqxis mbvvbq ujqmusi ci ijxis xyiqigq musuxyrgq yummrvwxb icuvub bjd xsbuj xyuq lbw xyi qxbh xrrn
Target: ill lead third countryman schoolmaster persuasively and cunningly away boys i hear the horns give me some meditation and mark your cue exeunt all but schoolmaster pallas inspire me enter theseus pirithous hippolyta emilia and train this way the stag took
Predicted: p""eeap-'prpj.;p pp;de r-;;"-ppp rpr.os;yspprp.pope.epp'kppp-p-pp.rr-pdppppa.sepp.sprpkepdapt[UNK].pp. o;pp a; a-.--r.pe;eprk r.u;ppp.skpodp"apa"eprre?p-ar--ppp "pa.ssoepe"pr.ep-pjpp aro[UNK];ddsprappppa..jppepppa[UNK]p.;"r;aeso ppp-;ppptpde-p;