In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LRScheduler
from torch.nn import Transformer
import json
import random
import copy
import warnings
import math
import os

In [None]:
# Transformer model class

class CharTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4, d_ff=1024, max_len=32, 
                 dropout=0.3, device="cuda", pad_token_id=0, start_token_id=1, end_token_id=2):
        super(CharTransformer, self).__init__()
        self.device = device

        self.pad_token_id = pad_token_id
        self.start_token_id = start_token_id
        self.end_token_id = end_token_id
        
        self.embedding = nn.Embedding(vocab_size, d_model).to(device)
        self.dropout = nn.Dropout(dropout)  # Dropout after embedding
        
        self.learnable_positional_encoding = nn.Parameter(torch.zeros(max_len, d_model).to(device))
        
        self.transformer = Transformer(
            d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, 
            num_decoder_layers=num_layers, dim_feedforward=d_ff, activation="gelu", batch_first=True, 
            dropout=dropout 
        ).to(device)

        self.fc_out = nn.Linear(d_model, vocab_size).to(device)

    def _generate_square_subsequent_mask(self, sz):
        mask = torch.tril(torch.ones(sz, sz))
        return torch.log(mask).to(self.device)
        
    def forward(self, src, tgt, feature_mask, tgt_is_causal=False, src_mask=None, tgt_mask=None, 
                src_key_padding_mask=None, tgt_key_padding_mask=None):
        src, tgt, feature_mask = src.to(self.device), tgt.to(self.device), feature_mask.to(self.device)
        
        # Compute embeddings and apply dropout
        src_emb = self.embedding(src)
        tgt_emb = self.embedding(tgt)
        
        src_emb = self.dropout(src_emb)
        tgt_emb = self.dropout(tgt_emb)

        # Apply positional encoding only to non-feature tokens
        src_emb += (1 - feature_mask[:, :src.shape[1], None]) * self.learnable_positional_encoding[:src.shape[1], :]

        # Compute key padding masks if not provided
        if src_key_padding_mask is None:
            src_key_padding_mask = (src == self.pad_token_id)  
        if tgt_key_padding_mask is None:
            tgt_key_padding_mask = (tgt == self.pad_token_id)  

        if tgt_is_causal: 
            if tgt_mask is None: 
                tgt_mask = self._generate_square_subsequent_mask(tgt.shape[1])
            else:
                tgt_mask += self._generate_square_subsequent_mask(tgt.shape[1])

        # Transformer forward pass
        transformer_output = self.transformer(
            src_emb, tgt_emb, 
            src_mask=src_mask, 
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            tgt_is_causal=tgt_is_causal
        )

        output = self.fc_out(transformer_output)
        return output

    
    def generate(self, src, feature_mask, max_len=32, beam_size=5):
        self.eval()
        src, feature_mask = src.to(self.device), feature_mask.to(self.device)

        # Initialize beams: (sequence, log probability)
        beams = torch.full((1, 1), self.start_token_id, device=self.device)  
        beam_scores = torch.zeros(1, device=self.device)  # Log probabilities

        completed_sequences = []  # Store completed sequences

        for _ in range(max_len):
            # Expand `src` to match the number of beams
            src_expanded = src.expand(beams.shape[0], -1)
            feature_mask_expanded = feature_mask.expand(beams.shape[0], -1)

            # Forward pass on all beams at once
            out = self.forward(src_expanded, beams, feature_mask_expanded)  
            logits = out[:, -1, :]  # Get last-step logits (shape: [beams, vocab_size])
            log_probs = torch.log_softmax(logits, dim=-1)  # Convert logits to log-probabilities

            # Get top-k candidates for each beam (shape: [beams, beam_size])
            topk_log_probs, topk_ids = log_probs.topk(beam_size, dim=-1)

            # Compute new scores by adding log probabilities (broadcasted)
            expanded_scores = beam_scores.unsqueeze(1) + topk_log_probs  # Shape: [beams, beam_size]
            expanded_scores = expanded_scores.view(-1)  # Flatten to [beams * beam_size]

            # Get top-k overall candidates
            topk_scores, topk_indices = expanded_scores.topk(beam_size)

            # Convert flat indices to beam/token indices
            beam_indices = topk_indices // beam_size  # Which original beam did this come from?
            token_indices = topk_indices % beam_size  # Which token was selected?

            # Append new tokens to sequences
            new_beams = torch.cat([beams[beam_indices], topk_ids.view(-1, 1)[topk_indices]], dim=-1)

            # Check for completed sequences
            eos_mask = (new_beams[:, -1] == self.end_token_id)
            if eos_mask.any():
                for i in range(beam_size):
                    if eos_mask[i]:
                        completed_sequences.append((new_beams[i], topk_scores[i]))

            # Keep only unfinished sequences
            beams = new_beams[~eos_mask]
            beam_scores = topk_scores[~eos_mask]

            # If all sequences finished, stop early
            if len(beams) == 0 or len(completed_sequences) >= beam_size:
                break

        # Choose the best sequence from completed ones
        if completed_sequences:
            best_sequence = max(completed_sequences, key=lambda x: x[1])[0]
        else:
            best_sequence = beams[0]  # If no sequence completed, return best unfinished one

        return best_sequence

In [3]:
# Data preparation

def prepare_data(data, val_size=1000, train_size=2000, test_ratio=0.4):
    paradigms = list(data.items())
    random.shuffle(paradigms)
    
    val_set = paradigms[:val_size]
    train_test_set = paradigms[val_size:val_size + train_size]
    
    train_size = int((1 - test_ratio) * len(train_test_set))
    train_set = train_test_set[:train_size]
    test_set = train_test_set[train_size:]
    
    return train_set, test_set, val_set

def generate_examples(paradigm):
    lemma = list(paradigm[0])  # Convert lemma to list of characters
    forms = paradigm[1]
    examples = []
    
    for tag, form in forms.items():
        src = ['<s>', f'<{tag}>'] + lemma + ['</s>']
        tgt = ['<s>'] + list(form) + ['</s>']
        examples.append((src, tgt))
    
    return examples

# Load dataset
with open("/home/minhk/Assignments/CSCI 5801/project/data/processed/eng_v.json", "r") as f:
    data = json.load(f)

train_set, test_set, val_set = prepare_data(data, val_size=1000, train_size=1000, test_ratio=0.1)

train_examples = [ex for paradigm in train_set for ex in generate_examples(paradigm)]
test_examples = [ex for paradigm in test_set for ex in generate_examples(paradigm)]
val_examples = [ex for paradigm in val_set for ex in generate_examples(paradigm)]

print("Train examples:", train_examples[:5])  # Show first 5 examples
print("Test examples:", test_examples[:5])
print("Validation examples:", val_examples[:5])

Train examples: [(['<s>', '<PRS>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'e', '</s>'], ['<s>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'e', '</s>']), (['<s>', '<3SG>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'e', '</s>'], ['<s>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'e', 's', '</s>']), (['<s>', '<PST>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'e', '</s>'], ['<s>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'e', 'd', '</s>']), (['<s>', '<PRS.PTCP>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'e', '</s>'], ['<s>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'i', 'n', 'g', '</s>']), (['<s>', '<PST.PTCP>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'e', '</s>'], ['<s>', 'u', 'n', 's', 't', 'a', 'b', 'i', 'l', 'i', 'z', 'e', 'd', '</s>'])]
Test examples: [(['<s>', '<PRS>', 's', 'i', 'd', 'e', 's', 't', 'e', 'p', '</s>'], ['<s>', 's', 'i', 'd', 'e', 's', 't', 'e', 'p', '</s>']), (['<s>', '<3SG>', 's', 'i', 'd',

In [4]:
class InverseSquareLRWithWarmup(LRScheduler):
    """
    Implements an inverse square learning rate scheduler with warmup steps.
    
    During warmup, the learning rate increases linearly from init_lr to max_lr.
    After warmup, the learning rate decreases according to the inverse square of the step number:
    lr = max_lr * (warmup_steps / step)^2 for step > warmup_steps
    
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        init_lr (float): Initial learning rate during warmup phase. Default: 0.0
        max_lr (float): Maximum learning rate after warmup phase. Default: 0.1
        warmup_steps (int): Number of warmup steps. Default: 1000
        last_epoch (int): The index of the last epoch. Default: -1
    """
    
    def __init__(self, optimizer, init_lr=0.0, max_lr=0.001, warmup_steps=1000, last_epoch=-1):
        self.init_lr = init_lr
        self.max_lr = max_lr
        self.warmup_steps = warmup_steps
        super(InverseSquareLRWithWarmup, self).__init__(optimizer, last_epoch)
        
    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.")
        
        if self.last_epoch < self.warmup_steps:
            # Linear warmup phase
            alpha = self.last_epoch / self.warmup_steps
            return [self.init_lr + alpha * (self.max_lr - self.init_lr) for _ in self.base_lrs]
        else:
            # Inverse square decay phase
            decay_factor = math.sqrt(self.warmup_steps / self.last_epoch)
            return [self.max_lr * decay_factor for _ in self.base_lrs]
            
    def _get_closed_form_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Linear warmup phase
            alpha = self.last_epoch / self.warmup_steps
            return [self.init_lr + alpha * (self.max_lr - self.init_lr) for _ in self.base_lrs]
        else:
            # Inverse square decay phase
            decay_factor = (self.warmup_steps / self.last_epoch) ** 2
            return [self.max_lr * decay_factor for _ in self.base_lrs]

In [5]:
# Sampler for scheduled learning, gradually replaces ground truth (teacher forcing) with model input

class ScheduledSampler():
    def __init__(self, base_rate=0.5, warmup_steps=1000):
        self.base_rate = base_rate
        self.warmup_steps = warmup_steps
        self.step_count = 0        
        self.sampling_rate = 1
    
    def step(self):
        self.step_count += 1
        if self.step_count > self.warmup_steps:
            self.sampling_rate = self.base_rate + (1 - self.base_rate) * math.sqrt(self.warmup_steps / self.step_count)

    def sample(self, logits, truth_ids):
        """
        Selects truth_ids with probability `sampling_rate`, 
        otherwise samples using Gumbel noise.
        """
        batch_size, seq_len, vocab_size = logits.shape
        
        # Decide per-token whether to take ground truth (1) or Gumbel sample (0)
        mask = torch.bernoulli(torch.full((batch_size, seq_len), self.sampling_rate, device=logits.device, dtype=float)).bool()
        
        # Gumbel-sampled predictions
        gumbel_preds = self._gumbel_sample(logits)
        
        # Use ground truth where mask == True, else use gumbel_preds
        return torch.where(mask, truth_ids, gumbel_preds)
    
    def _gumbel_sample(self, logits):
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))  # Generate Gumbel noise
        return (logits + gumbel_noise).argmax(dim=-1)  # Apply Gumbel noise and take the argmax

In [6]:
# Tokenizer
def tokenize(sequence, char_to_idx):
    return [char_to_idx[char] for char in sequence]

# Build character vocabulary
all_chars = set()
for ex in train_examples + test_examples + val_examples:
    all_chars.update(ex[0])
    all_chars.update(ex[1])
all_chars.remove('<s>')
all_chars.remove('</s>')
char_to_idx = {char: i for i, char in enumerate(sorted(all_chars), start=3)}  # Reserve 0, 1, 2 for special tokens
char_to_idx['<pad>'] = 0
char_to_idx['<s>'] = 1
char_to_idx['</s>'] = 2
idx_to_char = {
    i: char for char, i in char_to_idx.items()
}
vocab_size = len(char_to_idx)
max_len = 32

device = "cuda" if torch.cuda.is_available() else "cpu"
model = CharTransformer(vocab_size, device=device, max_len=max_len)

def pad_sequence(sequence, max_len, pad_token='<pad>'):
    return sequence + [pad_token] * (max_len - len(sequence))

def create_feature_mask(sequence):
    """Create a feature mask where tags (enclosed in < >) are marked as 1, else 0."""
    return torch.tensor([1 if char.startswith('<') and char.endswith('>') else 0 for char in sequence], device=device)

def create_padding_mask(sequence, pad_token='<pad>'):
    """Create a padding mask where padding tokens are marked as True (to be ignored)."""
    return (sequence == pad_token)

def train_model(model, train_examples, test_examples, epochs=1000, batch_size=256, patience=20):
    optimizer = optim.AdamW(model.parameters(), betas=(0.99, 0.98))
    scheduler = InverseSquareLRWithWarmup(optimizer, init_lr=1e-5, max_lr=1e-3, warmup_steps=4000)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    sampler = ScheduledSampler(base_rate=0.5, warmup_steps=4000)

    pad_token = char_to_idx['<pad>']
    best_test_loss = float('inf')  # Initialize the best test loss to a very large value
    best_model_state = copy.deepcopy(model.state_dict())  # Store best model parameters
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        random.shuffle(train_examples)
        
        for i in range(0, len(train_examples), batch_size):
            batch = train_examples[i:i+batch_size]
            src_batch, tgt_batch = zip(*batch)

            max_batch_len = max(max(len(s) for s in src_batch), max(len(t) for t in tgt_batch))
            
            # Pad sequences to the maximum length (max_len) in the batch
            src_padded = [pad_sequence(seq, max_batch_len) for seq in src_batch]
            tgt_padded = [pad_sequence(seq, max_batch_len) for seq in tgt_batch]
            
            # Convert padded sequences to tensors
            src_tensor = torch.tensor([tokenize(seq, char_to_idx) for seq in src_padded], device=device)
            tgt_tensor = torch.tensor([tokenize(seq, char_to_idx) for seq in tgt_padded], device=device)
            
            # Create the feature mask
            feature_mask_src = torch.stack([create_feature_mask(seq) for seq in src_padded], dim=0)

            # Shift target tensor for teacher forcing (the model will predict next token)
            tgt_input = tgt_tensor[:, :-1]  # Remove the last token (it's not used as input)
            tgt_expected = tgt_tensor[:, 1:]  # The target sequence for the loss is shifted by 1

            # First round of predictions (using teacher forcing)
            optimizer.zero_grad()
            output = model(src_tensor, tgt_input, feature_mask_src, tgt_is_causal=True)

            # Sample from model output and true target based on current sampling rate
            sampled_tgt_input = sampler.sample(output, tgt_expected)

            # Ensure proper alignment for next round of input:
            # Take the first element of each item in the target sequence (start token)
            # Concatenate with the sampled output (excluding the last token)
            sampled_tgt_input = torch.cat(
                [tgt_input[:, :1], sampled_tgt_input[:, :-1]], dim=1
            )

            # Second round of predictions using the sampled input
            output = model(src_tensor, sampled_tgt_input, feature_mask_src, tgt_is_causal=True)

            # Compute the loss
            loss = criterion(output.reshape(-1, vocab_size), tgt_expected.reshape(-1))

            # Apply padding mask to loss (ignores padded tokens)
            tgt_mask = (tgt_input != pad_token).float().view(-1)
            loss = loss * tgt_mask  # Element-wise multiply with the mask to ignore padding tokens
            loss = loss.sum() / tgt_mask.sum()  # Normalize the loss (average over non-padding tokens)
            
            # Apply loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients
            optimizer.step()
            scheduler.step()
            sampler.step()  # Update the sampler (teacher forcing rate)

            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss * batch_size / len(train_examples)}")
        
        # Evaluate on the test set
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for i in range(0, len(test_examples), batch_size):
                batch = test_examples[i:i+batch_size]
                src_batch, tgt_batch = zip(*batch)

                max_batch_len = max(max(len(s) for s in src_batch), max(len(t) for t in tgt_batch))

                # Pad sequences to the maximum length (max_len) in the batch
                src_padded = [pad_sequence(seq, max_batch_len) for seq in src_batch]
                tgt_padded = [pad_sequence(seq, max_batch_len) for seq in tgt_batch]

                # Convert padded sequences to tensors
                src_tensor = torch.tensor([tokenize(seq, char_to_idx) for seq in src_padded], device=device)
                tgt_tensor = torch.tensor([tokenize(seq, char_to_idx) for seq in tgt_padded], device=device)

                # Create the feature mask
                feature_mask_src = torch.stack([create_feature_mask(seq) for seq in src_padded], dim=0)

                # Shift target tensor
                tgt_input = tgt_tensor[:, :-1]
                tgt_expected = tgt_tensor[:, 1:]

                # Forward pass
                output = model(src_tensor, tgt_input, feature_mask_src, tgt_is_causal=True)

                # Compute the loss
                loss = criterion(output.reshape(-1, vocab_size), tgt_expected.reshape(-1))

                # Apply padding mask to loss
                tgt_mask = (tgt_expected != pad_token).float().view(-1)
                loss = loss * tgt_mask  # Element-wise multiply with the mask to ignore padding tokens
                loss = loss.sum() / tgt_mask.sum()  # Normalize the loss (average over non-padding tokens)
                
                test_loss += loss.item()

        test_loss = test_loss * batch_size / len(test_examples)
        print(f"Test Loss after Epoch {epoch+1}: {test_loss}")

        # Early stopping based on test set loss
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            best_epoch = epoch
            patience_count = 0
            best_model_state = copy.deepcopy(model.state_dict())  # Save best model state
        else:
            patience_count += 1
            if patience_count == patience:
                patience_count = 0
                patience = int(patience * math.sqrt(2))
                # Rollback to best model state (undo last epoch update)
                model.load_state_dict(best_model_state)
                print(f"Rolling back to best model from epoch {best_epoch + 1}")
                print(f"Best test loss: {best_test_loss}")
        
    model.load_state_dict(best_model_state)


In [7]:
train_model(model, train_examples, test_examples)



Epoch 1, Loss: 3.297080620659722
Test Loss after Epoch 1: 2.911204833984375


  output = torch._nested_tensor_from_mask(


Epoch 2, Loss: 2.6395516493055555
Test Loss after Epoch 2: 2.750455322265625
Epoch 3, Loss: 2.504649386935764
Test Loss after Epoch 3: 2.3752802734375
Epoch 4, Loss: 2.325970458984375
Test Loss after Epoch 4: 2.2539462890625
Epoch 5, Loss: 2.163264655219184
Test Loss after Epoch 5: 2.1372369384765624
Epoch 6, Loss: 2.049650614420573
Test Loss after Epoch 6: 2.0540411987304688
Epoch 7, Loss: 1.9592076348198784
Test Loss after Epoch 7: 2.0184488525390627
Epoch 8, Loss: 1.9218873833550347
Test Loss after Epoch 8: 1.9830564575195313
Epoch 9, Loss: 1.8567845052083334
Test Loss after Epoch 9: 1.8948652954101564
Epoch 10, Loss: 1.7878001641167536
Test Loss after Epoch 10: 1.86536474609375
Epoch 11, Loss: 1.7494991590711806
Test Loss after Epoch 11: 1.8282542724609374
Epoch 12, Loss: 1.7031078016493055
Test Loss after Epoch 12: 1.8113834838867187
Epoch 13, Loss: 1.6500536159939236
Test Loss after Epoch 13: 1.7835491943359374
Epoch 14, Loss: 1.621664781358507
Test Loss after Epoch 14: 1.7635612

In [16]:
total_params = 0
for name, param in model.named_parameters():
    param_count = param.numel()  # Number of elements in the tensor
    total_params += param_count
    print(f"{name}: {param.size()}, {param.dtype}, {param_count} params")

print(f"Total number of parameters: {total_params}")

learnable_positional_encoding: torch.Size([32, 256]), torch.float32, 8192 params
embedding.weight: torch.Size([39, 256]), torch.float32, 9984 params
transformer.encoder.layers.0.self_attn.in_proj_weight: torch.Size([768, 256]), torch.float32, 196608 params
transformer.encoder.layers.0.self_attn.in_proj_bias: torch.Size([768]), torch.float32, 768 params
transformer.encoder.layers.0.self_attn.out_proj.weight: torch.Size([256, 256]), torch.float32, 65536 params
transformer.encoder.layers.0.self_attn.out_proj.bias: torch.Size([256]), torch.float32, 256 params
transformer.encoder.layers.0.linear1.weight: torch.Size([1024, 256]), torch.float32, 262144 params
transformer.encoder.layers.0.linear1.bias: torch.Size([1024]), torch.float32, 1024 params
transformer.encoder.layers.0.linear2.weight: torch.Size([256, 1024]), torch.float32, 262144 params
transformer.encoder.layers.0.linear2.bias: torch.Size([256]), torch.float32, 256 params
transformer.encoder.layers.0.norm1.weight: torch.Size([256]), 

In [9]:
save_dir = "/home/minhk/Assignments/CSCI 5801/project/model/"
filename="base_transformers.pth"
os.makedirs(save_dir, exist_ok=True)
torch.save(model, save_dir+filename)

In [25]:
# Select a random paradigm from validation set
random_paradigm = random.choice(val_set)

# Generate examples from the paradigm
generated_examples = generate_examples(random_paradigm)

def list_to_word(arr):
    return ''.join(arr[1:-1])

for src, tgt in generated_examples: 
    print(list_to_word(src), list_to_word(tgt))
    feature_mask = create_feature_mask(src).unsqueeze(0)
    src_tokenized = torch.tensor(tokenize(src, char_to_idx), device=device).unsqueeze(0)
    print(src_tokenized.shape)
    print(feature_mask.shape)
    gen = model.generate(src_tokenized, feature_mask, beam_size=5).squeeze(0)
    gen = list_to_word([idx_to_char[id.item()] for id in gen])
    print(gen)


<PRS>pathotype pathotype
torch.Size([1, 12])
torch.Size([1, 12])
pathotype
<3SG>pathotype pathotypes
torch.Size([1, 12])
torch.Size([1, 12])
pathotypes
<PST>pathotype pathotyped
torch.Size([1, 12])
torch.Size([1, 12])
pathotyped
<PRS.PTCP>pathotype pathotyping
torch.Size([1, 12])
torch.Size([1, 12])


RuntimeError: The size of tensor a (12) must match the size of tensor b (13) at non-singleton dimension 1

In [None]:
word = 'run'
tags = ["PRS", "3SG", "PST", "PRS.PTCP", "PST.PTCP"]

for tag in tags: 
    tokens = ["<s>", f"<{tag}>"] + list(word) + ["</s>"]
    feature_mask = create_feature_mask(src).unsqueeze(0)
    src_tokenized = torch.tensor(tokenize(src, char_to_idx), device=device).unsqueeze(0)
    gen = model.generate(src_tokenized, feature_mask).squeeze(0)
    gen = [idx_to_char[id.item()] for id in gen]
    

SyntaxError: incomplete input (3334691788.py, line 4)

In [11]:
correct_predictions = 0
total_predictions = 0

for paradigm in val_set:
    pairs = generate_examples(paradigm)
    for src, tgt in pairs: 
        src = pad_sequence(src, max_len)
        feature_mask = create_feature_mask(src).unsqueeze(0)
        src_tokenized = torch.tensor(tokenize(src, char_to_idx), device=device).unsqueeze(0)
        gen = model.generate(src_tokenized, feature_mask).squeeze(0)
        gen = [idx_to_char[id.item()] for id in gen]
        correct_predictions += (gen == tgt)
        total_predictions += 1

print(f"{correct_predictions} predictions correct out of {total_predictions} total. Accuracy: {correct_predictions / total_predictions}")
    

3514 predictions correct out of 5000 total. Accuracy: 0.7028
