In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LRScheduler
from torch.nn import Transformer
import json
import random
import copy
import warnings
import math
import os

In [2]:
class CharTransformer(nn.Module):
    def __init__(self, vocab_size, d_model=256, nhead=4, num_layers=4, d_ff=1024, max_len=32, 
                 dropout=0.3, device="cuda", pad_token_id=0, start_token_id=1, end_token_id=2):
        super(CharTransformer, self).__init__()
        self.device = device

        self.pad_token_id = pad_token_id
        self.start_token_id = start_token_id
        self.end_token_id = end_token_id
        
        self.embedding = nn.Embedding(vocab_size, d_model).to(device)
        self.dropout = nn.Dropout(dropout)  # Dropout after embedding
        
        # Create sinusoidal positional encodings
        self.register_buffer("positional_encoding", self._generate_sinusoidal_encoding(max_len, d_model))
        
        self.transformer = Transformer(
            d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, 
            num_decoder_layers=num_layers, dim_feedforward=d_ff, activation="gelu", batch_first=True, 
            dropout=dropout 
        ).to(device)

        self.fc_out = nn.Linear(d_model, vocab_size).to(device)
    
    def _generate_sinusoidal_encoding(self, max_len, d_model):
        position = torch.arange(max_len, device=self.device).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, device=self.device) * -(math.log(10000.0) / d_model))
        encoding = torch.zeros(max_len, d_model, device=self.device)
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        return encoding

    def _generate_square_subsequent_mask(self, sz):
        mask = torch.tril(torch.ones(sz, sz))
        return torch.log(mask).to(self.device)
        
    def forward(self, src, tgt, feature_mask, tgt_is_causal=False, src_mask=None, tgt_mask=None, 
                src_key_padding_mask=None, tgt_key_padding_mask=None):
        src, tgt = src.to(self.device), tgt.to(self.device)
        
        # Compute embeddings and apply dropout
        src_emb = self.embedding(src)
        tgt_emb = self.embedding(tgt)

        # Apply positional encoding
        src_emb += (1 - feature_mask[:, :src.shape[1], None]) * self.positional_encoding[:src.shape[1], :]
        tgt_emb += self.positional_encoding[:tgt.shape[1], :].unsqueeze(0)
        
        src_emb = self.dropout(src_emb)
        tgt_emb = self.dropout(tgt_emb)

        # Compute key padding masks if not provided
        if src_key_padding_mask is None:
            src_key_padding_mask = (src == self.pad_token_id) 
        if tgt_key_padding_mask is None:
            tgt_key_padding_mask = (tgt == self.pad_token_id)
        if tgt_is_causal: 
            if tgt_mask is None: 
                tgt_mask = self._generate_square_subsequent_mask(tgt.shape[1])
            else:
                tgt_mask += self._generate_square_subsequent_mask(tgt.shape[1])

        # Transformer forward pass
        transformer_output = self.transformer(
            src_emb, tgt_emb, 
            src_mask=src_mask, 
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_key_padding_mask,
            tgt_key_padding_mask=tgt_key_padding_mask,
            tgt_is_causal=tgt_is_causal
        )
        output = self.fc_out(transformer_output)
        return output
    
    def generate(self, src, feature_mask, max_len=32, beam_size=5):
        self.eval()
        src, feature_mask = src.to(self.device), feature_mask.to(self.device)

        # Initialize beams: (sequence, log probability)
        beams = torch.full((1, 1), self.start_token_id, device=self.device)  
        beam_scores = torch.zeros(1, device=self.device)  # Log probabilities

        completed_sequences = []  # Store completed sequences

        for _ in range(max_len):
            # Expand `src` to match the number of beams
            src_expanded = src.expand(beams.shape[0], -1)
            feature_mask_expanded = feature_mask.expand(beams.shape[0], -1)

            # Forward pass on all beams at once
            out = self.forward(src_expanded, beams, feature_mask_expanded)  
            logits = out[:, -1, :]  # Get last-step logits (shape: [beams, vocab_size])
            log_probs = torch.log_softmax(logits, dim=-1)  # Convert logits to log-probabilities

            # Get top-k candidates for each beam (shape: [beams, beam_size])
            topk_log_probs, topk_ids = log_probs.topk(beam_size, dim=-1)

            # Compute new scores by adding log probabilities (broadcasted)
            expanded_scores = beam_scores.unsqueeze(1) + topk_log_probs  # Shape: [beams, beam_size]
            expanded_scores = expanded_scores.view(-1)  # Flatten to [beams * beam_size]

            # Get top-k overall candidates
            topk_scores, topk_indices = expanded_scores.topk(beam_size)

            # Convert flat indices to beam/token indices
            beam_indices = topk_indices // beam_size  # Which original beam did this come from?
            token_indices = topk_indices % beam_size  # Which token was selected?

            # Append new tokens to sequences
            new_beams = torch.cat([beams[beam_indices], topk_ids.view(-1, 1)[topk_indices]], dim=-1)

            # Check for completed sequences
            eos_mask = (new_beams[:, -1] == self.end_token_id)
            if eos_mask.any():
                for i in range(beam_size):
                    if eos_mask[i]:
                        completed_sequences.append((new_beams[i], topk_scores[i]))

            # Keep only unfinished sequences
            beams = new_beams[~eos_mask]
            beam_scores = topk_scores[~eos_mask]

            # If all sequences finished, stop early
            if len(beams) == 0 or len(completed_sequences) >= beam_size:
                break

        # Choose the best sequence from completed ones
        if completed_sequences:
            best_sequence = max(completed_sequences, key=lambda x: x[1])[0]
        else:
            best_sequence = beams[0]  # If no sequence completed, return best unfinished one

        return best_sequence
    
    def greedy_batch_decode(self, src, feature_mask, max_len=32):
        self.eval()
        batch_size = src.shape[0]
        src, feature_mask = src.to(self.device), feature_mask.to(self.device)
        outputs = torch.full((batch_size, 1), self.start_token_id, device=self.device)
        ended = torch.zeros(batch_size, dtype=torch.bool, device=self.device)
        
        for _ in range(max_len):
            out = self.forward(src, outputs, feature_mask)
            next_tokens = out[:, -1, :].argmax(dim=-1, keepdim=True)
            outputs = torch.cat([outputs, next_tokens], dim=1)
            ended |= (next_tokens.squeeze() == self.end_token_id)
            if ended.all():
                break
        
        return outputs

In [3]:
# Data preparation

def prepare_data(data, val_size=1000, train_size=2000, test_ratio=0.4):
    paradigms = list(data.items())
    random.shuffle(paradigms)
    
    val_set = paradigms[:val_size]
    train_test_set = paradigms[val_size:val_size + train_size]
    
    train_size = int((1 - test_ratio) * len(train_test_set))
    train_set = train_test_set[:train_size]
    test_set = train_test_set[train_size:]
    
    return train_set, test_set, val_set

def generate_examples(paradigm):
    lemma = list(paradigm[0])  # Convert lemma to list of characters
    forms = paradigm[1]
    examples = []
    
    for tag, form in forms.items():
        src = ['<s>', f'<{tag}>'] + lemma + ['</s>']
        tgt = ['<s>'] + list(form) + ['</s>']
        examples.append((src, tgt))
    
    return examples

# Load dataset
with open("/home/minhk/Assignments/CSCI 5801/project/data/processed/eng_v.json", "r") as f:
    data = json.load(f)

train_set, test_set, val_set = prepare_data(data, val_size=1000, train_size=1000, test_ratio=0.1)

train_examples = [ex for paradigm in train_set for ex in generate_examples(paradigm)]
test_examples = [ex for paradigm in test_set for ex in generate_examples(paradigm)]
val_examples = [ex for paradigm in val_set for ex in generate_examples(paradigm)]

print("Train examples:", train_examples[:5])  # Show first 5 examples
print("Test examples:", test_examples[:5])
print("Validation examples:", val_examples[:5])

Train examples: [(['<s>', '<PRS>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'e', '</s>'], ['<s>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'e', '</s>']), (['<s>', '<3SG>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'e', '</s>'], ['<s>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'e', 's', '</s>']), (['<s>', '<PST>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'e', '</s>'], ['<s>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'e', 'd', '</s>']), (['<s>', '<PRS.PTCP>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'e', '</s>'], ['<s>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'i', 'n', 'g', '</s>']), (['<s>', '<PST.PTCP>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'e', '</s>'], ['<s>', 'c', 'a', 'r', 'i', 'c', 'a', 't', 'u', 'r', 'e', 'd', '</s>'])]
Test examples: [(['<s>', '<PRS>', 'z', 'o', 'o', 'm', 'o', 'r', 'p', 'h', 'i', 'z', 'e', '</s>'], ['<s>', 'z', 'o', 'o', 'm', 'o', 'r', 'p', 'h', 'i', 'z', 'e', '</s>']), (['<s>', '<3SG>', 'z', 'o', 'o', 'm', 'o', 'r', 'p',

In [4]:
import string
from bisect import bisect_left
def align_sequences(src, tgt):
    # Scoring scheme
    match_score = 1
    mismatch_penalty = -3
    gap_penalty = -2
    
    m, n = len(src), len(tgt)
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    # Initialize the DP table
    for i in range(1, m + 1):
        dp[i][0] = i * gap_penalty
    for j in range(1, n + 1):
        dp[0][j] = j * gap_penalty
    
    # Fill the DP table
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            if src[i - 1] == tgt[j - 1]:
                score = match_score
            else:
                score = mismatch_penalty
            dp[i][j] = max(
                dp[i - 1][j - 1] + score,  # Diagonal (match/mismatch)
                dp[i - 1][j] + gap_penalty,  # Gap in target
                dp[i][j - 1] + gap_penalty   # Gap in source
            )
    
    # Backtrack to find the aligned sequences
    aligned_src, aligned_tgt = [], []
    i, j = m, n
    while i > 0 and j > 0:
        if src[i - 1] == tgt[j - 1] or dp[i][j] == dp[i - 1][j - 1] + mismatch_penalty:
            aligned_src.append(src[i - 1])
            aligned_tgt.append(tgt[j - 1])
            i -= 1
            j -= 1
        elif dp[i][j] == dp[i - 1][j] + gap_penalty:
            aligned_src.append(src[i - 1])
            aligned_tgt.append('#')
            i -= 1
        else:
            aligned_src.append('#')
            aligned_tgt.append(tgt[j - 1])
            j -= 1
    
    # Add remaining characters
    while i > 0:
        aligned_src.append(src[i - 1])
        aligned_tgt.append('#')
        i -= 1
    while j > 0:
        aligned_src.append('#')
        aligned_tgt.append(tgt[j - 1])
        j -= 1
    
    # Reverse to get the correct order
    aligned_src.reverse()
    aligned_tgt.reverse()
    
    return aligned_src, aligned_tgt

def hallucinate_trigram(src, tgt):
    """
    Hallucinates an example from an existing one by replacing aligned trigrams with random characters.
    """
    # Align the sequences first
    tag = src[1]  # The morphological tag
    src = src[2:-1]  # Remove <s>, morphological tag, and </s>
    tgt = tgt[1:-1]  # Remove <s> and </s>
    aligned_src, aligned_tgt = align_sequences(src, tgt)
    n = len(aligned_src)

    first_non_gap_src = max(
        min(idx for idx, char in enumerate(aligned_src) if char != '#'),
        min(idx for idx, char in enumerate(aligned_tgt) if char != '#')
    )
    last_non_gap_src = min(
        max(idx for idx, char in enumerate(aligned_src) if char != '#'),
        max(idx for idx, char in enumerate(aligned_tgt) if char != '#')
    )
    
    # Find all valid trigrams (3 consecutive characters) that do not touch the ends
    trigram_indices = []
    for i in range(first_non_gap_src + 1, last_non_gap_src - 2):
        if aligned_src[i] == aligned_tgt[i] and aligned_src[i + 1] == aligned_tgt[i + 1] and aligned_src[i + 2] == aligned_tgt[i + 2]:
            trigram_indices.append(i)
    # Shuffle the indices to randomize the hallucination
    random.shuffle(trigram_indices)

    # If alignment does not work, try to replace a single aligned character
    if len(trigram_indices) == 0:
        for i in range(first_non_gap_src, last_non_gap_src - 1):
            if aligned_src[i] == aligned_tgt[i]:
                vowels = set('aeiouy')
                char = aligned_src[i]
                if char in vowels:
                    new_char = random.choice('aeiouy')
                else:
                    new_char = random.choice([c for c in string.ascii_lowercase if c not in 'aeiouy'])
                aligned_src[i] = new_char
                aligned_tgt[i] = new_char
                break
        else:
            return None, None
    
    new_src = aligned_src
    new_tgt = aligned_tgt
    replaced = [False] * n
    
    # Randomly hallucinate some of these trigrams
    for i in trigram_indices:
        if not replaced[i] and not replaced[i + 1] and not replaced[i + 2]:
            # Replace the trigram with random characters
            for j in range(i, i + 3):
                new_char = random.choice(string.ascii_lowercase)
                new_src[j] = new_char
                new_tgt[j] = new_char
            replaced[i] = True
            replaced[i + 1] = True
            replaced[i + 2] = True
    
    new_src = ['<s>', tag] + [char for char in new_src if char != '#'] + ['</s>']
    new_tgt = ['<s>'] + [char for char in new_tgt if char != '#'] + ['</s>']

    return new_src, new_tgt

orig_train_examples = copy.deepcopy(train_examples)
def hallucinate_data(examples, hallucination_ratio=0.2, probs=None):
    """
    Hallucinates data by replacing aligned trigrams with random characters.
    Considers the probability of choosing examples for hallucination based on `probs`.
    """
    new_examples = []
    if probs is None:
        probs = [1] * len(examples)  # Uniform probability if none provided
    
    # Normalize probabilities to sum to 1
    total_prob = sum(probs) + 1
    normalized_probs = [(p + 1 / len(probs)) / total_prob for p in probs]
    
    choices = random.choices(examples, weights=normalized_probs, k=int(len(examples) * hallucination_ratio * 1.1))
    for example in choices:
        if len(new_examples) >= len(examples) * hallucination_ratio:
            break
        # Choose an example based on the normalized probabilities
        src, tgt = hallucinate_trigram(example[0], example[1])
        if src is not None and tgt is not None:
            new_examples.append((src, tgt))
    return new_examples

In [5]:
# Test the hallucination function
for _ in range(5):
    # Randomly select an example from the training set
    example = random.choice(orig_train_examples)
    src, tgt = hallucinate_trigram(example[0], example[1])
    print("Original:", "".join(example[0]), "".join(example[1]))
    if src is not None and tgt is not None:
        print("Hallucinated:", "".join(src), "".join(tgt))
    print()

Original: <s><3SG>reskim</s> <s>reskims</s>
Hallucinated: <s><3SG>rgluim</s> <s>rgluims</s>

Original: <s><PRS.PTCP>beshape</s> <s>beshaping</s>
Hallucinated: <s><PRS.PTCP>bokmape</s> <s>bokmaping</s>

Original: <s><PST.PTCP>cassate</s> <s>cassated</s>
Hallucinated: <s><PST.PTCP>casooee</s> <s>casooeed</s>

Original: <s><PST>aviate</s> <s>aviated</s>
Hallucinated: <s><PST>avmswe</s> <s>avmswed</s>

Original: <s><PST.PTCP>uncause</s> <s>uncaused</s>
Hallucinated: <s><PST.PTCP>ujsquse</s> <s>ujsqused</s>



In [6]:
class InverseSquareLRWithWarmup(LRScheduler):
    """
    Implements an inverse square learning rate scheduler with warmup steps.
    
    During warmup, the learning rate increases linearly from init_lr to max_lr.
    After warmup, the learning rate decreases according to the inverse square of the step number:
    lr = max_lr * (warmup_steps / step)^2 for step > warmup_steps
    
    Args:
        optimizer (Optimizer): Wrapped optimizer.
        init_lr (float): Initial learning rate during warmup phase. Default: 0.0
        max_lr (float): Maximum learning rate after warmup phase. Default: 0.1
        warmup_steps (int): Number of warmup steps. Default: 1000
        last_epoch (int): The index of the last epoch. Default: -1
    """
    
    def __init__(self, optimizer, init_lr=0.0, max_lr=0.001, warmup_steps=1000, last_epoch=-1):
        self.init_lr = init_lr
        self.max_lr = max_lr
        self.warmup_steps = warmup_steps
        super(InverseSquareLRWithWarmup, self).__init__(optimizer, last_epoch)
    
    def is_warmed_up(self):
        return self.last_epoch >= self.warmup_steps
        
    def get_lr(self):
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, "
                          "please use `get_last_lr()`.")
        
        if self.last_epoch < self.warmup_steps:
            # Linear warmup phase
            alpha = self.last_epoch / self.warmup_steps
            return [self.init_lr + alpha * (self.max_lr - self.init_lr) for _ in self.base_lrs]
        else:
            # Inverse square decay phase
            decay_factor = math.sqrt(self.warmup_steps / self.last_epoch)
            return [self.max_lr * decay_factor for _ in self.base_lrs]
            
    def _get_closed_form_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Linear warmup phase
            alpha = self.last_epoch / self.warmup_steps
            return [self.init_lr + alpha * (self.max_lr - self.init_lr) for _ in self.base_lrs]
        else:
            # Inverse square decay phase
            decay_factor = (self.warmup_steps / self.last_epoch) ** 2
            return [self.max_lr * decay_factor for _ in self.base_lrs]

In [7]:
# Sampler for scheduled learning, gradually replaces ground truth (teacher forcing) with model input

class ScheduledSampler():
    def __init__(self, base_rate=0.5, warmup_steps=1000):
        self.base_rate = base_rate
        self.warmup_steps = warmup_steps
        self.step_count = 0        
        self.sampling_rate = 1
    
    def step(self):
        self.step_count += 1
        if self.step_count > self.warmup_steps:
            self.sampling_rate = self.base_rate + (1 - self.base_rate) * math.sqrt(self.warmup_steps / self.step_count)

    def sample(self, logits, truth_ids):
        """
        Selects truth_ids with probability `sampling_rate`, 
        otherwise samples using Gumbel noise.
        """
        batch_size, seq_len, vocab_size = logits.shape
        
        # Decide per-token whether to take ground truth (1) or Gumbel sample (0)
        mask = torch.bernoulli(torch.full((batch_size, seq_len), self.sampling_rate, device=logits.device, dtype=float)).bool()
        
        # Gumbel-sampled predictions
        gumbel_preds = self._gumbel_sample(logits)
        
        # Use ground truth where mask == True, else use gumbel_preds
        return torch.where(mask, truth_ids, gumbel_preds)
    
    def _gumbel_sample(self, logits):
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))  # Generate Gumbel noise
        return (logits + gumbel_noise).argmax(dim=-1)  # Apply Gumbel noise and take the argmax

In [14]:
# Tokenizer
def tokenize(sequence, char_to_idx):
    return [char_to_idx[char] for char in sequence]

# Build character vocabulary
all_chars = set()
for word, inflect in data.items():
    all_chars.update(word)
    for tag, forms in inflect.items():
        all_chars.add(f"<{tag}>")
        all_chars.update(forms)
print(all_chars)
all_alphabet_chars = {char for char in all_chars if not (char.startswith('<') and char.endswith('>'))}
char_to_idx = {char: i for i, char in enumerate(sorted(all_chars), start=3)}  # Reserve 0, 1, 2 for special tokens
char_to_idx['<pad>'] = 0
char_to_idx['<s>'] = 1
char_to_idx['</s>'] = 2
idx_to_char = {
    i: char for char, i in char_to_idx.items()
}
vocab_size = len(char_to_idx)
max_len = 32

device = "cuda" if torch.cuda.is_available() else "cpu"
model = CharTransformer(vocab_size, device=device, max_len=max_len)

def pad_sequence(sequence, max_len, pad_token='<pad>'):
    return sequence + [pad_token] * (max_len - len(sequence))

def create_feature_mask(sequence):
    """Create a feature mask where tags (enclosed in < >) are marked as 1, else 0."""
    return torch.tensor([1 if char.startswith('<') and char.endswith('>') else 0 for char in sequence], device=device)

def create_padding_mask(sequence, pad_token='<pad>'):
    """Create a padding mask where padding tokens are marked as True (to be ignored)."""
    return (sequence == pad_token)

def train_model(model, train_examples, test_examples, epochs=1000, batch_size=256, 
                patience=20, 
                hallucination_ratio=0.2, hallucination_refresh_rate=5, stop_hallucinating_after=0.9):
    optimizer = optim.AdamW(model.parameters(), betas=(0.99, 0.98))
    scheduler = InverseSquareLRWithWarmup(optimizer, init_lr=1e-5, max_lr=1e-3, warmup_steps=4000)
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1, reduction='none')
    # sampler = ScheduledSampler(base_rate=0.5, warmup_steps=4000)

    orig_train_examples = copy.deepcopy(train_examples)
    u_train_examples = orig_train_examples
    order = [i for i in range(len(train_examples))]
    # losses = [0] * len(train_examples)

    pad_token = char_to_idx['<pad>']
    best_test_loss = float('inf')  # Initialize the best test loss to a very large value
    best_model_state = copy.deepcopy(model.state_dict())  # Store best model parameters
    
    for epoch in range(epochs):
        if epoch == int(stop_hallucinating_after * epochs):
            # Stop hallucinating after a certain epoch
            u_train_examples = orig_train_examples
            order = [i for i in range(len(u_train_examples))]
            #  = [0] * len(u_train_examples)
        elif epoch % hallucination_refresh_rate == 0 and scheduler.is_warmed_up() and epoch < int(stop_hallucinating_after * epochs):
            u_train_examples = orig_train_examples + hallucinate_data(orig_train_examples, hallucination_ratio) # losses[:len(orig_train_examples)])
            order = [i for i in range(len(u_train_examples))]
            # losses = [0] * len(u_train_examples)
            
        model.train()
        total_loss = 0
        random.shuffle(order)
        train_examples = [u_train_examples[order[i]] for i in range(len(u_train_examples))]
        print(len(u_train_examples))
        
        for i in range(0, len(train_examples), batch_size):
            batch = train_examples[i:i+batch_size]
            src_batch, tgt_batch = zip(*batch)

            max_batch_len = max(max(len(s) for s in src_batch), max(len(t) for t in tgt_batch))
            
            # Pad sequences to the maximum length (max_len) in the batch
            src_padded = [pad_sequence(seq, max_batch_len) for seq in src_batch]
            tgt_padded = [pad_sequence(seq, max_batch_len) for seq in tgt_batch]
            
            # Convert padded sequences to tensors
            src_tensor = torch.tensor([tokenize(seq, char_to_idx) for seq in src_padded], device=device)
            tgt_tensor = torch.tensor([tokenize(seq, char_to_idx) for seq in tgt_padded], device=device)
            
            # Create the feature mask
            feature_mask_src = torch.stack([create_feature_mask(seq) for seq in src_padded], dim=0)

            # Shift target tensor for teacher forcing (the model will predict next token)
            tgt_input = tgt_tensor[:, :-1]  # Remove the last token (it's not used as input)
            tgt_expected = tgt_tensor[:, 1:]  # The target sequence for the loss is shifted by 1

            # First round of predictions (using teacher forcing)
            optimizer.zero_grad()
            output = model(src_tensor, tgt_input, feature_mask_src, tgt_is_causal=True)

            # Compute the loss
            loss = criterion(output.reshape(-1, vocab_size), tgt_expected.reshape(-1))

            # Apply padding mask to loss (ignores padded tokens)
            tgt_mask = (tgt_input != pad_token).float().view(-1)
            loss = loss * tgt_mask  # Element-wise multiply with the mask to ignore padding tokens

            # Reshape loss and mask back to batch and sequence dimensions
            loss = loss.view(-1, max_batch_len-1)
            tgt_mask = tgt_mask.view(-1, max_batch_len-1)

            # Calculate the average loss per word in the batch
            word_loss = loss.sum(dim=1) / tgt_mask.sum(dim=1)  # Average loss per word
            # if not torch.isnan(word_loss).any():
            #     for j in range(len(batch)):
            #        losses[order[i+j]] += word_loss[j].item()


            # Normalize the total loss (average over non-padding tokens)
            loss = loss.sum() / tgt_mask.sum()

            # Apply loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Clip gradients
            optimizer.step()
            scheduler.step()
            # sampler.step()  # Update the sampler (teacher forcing rate)

            total_loss += loss.item() * len(batch)

        print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_examples)}")
        
        # Evaluate on the test set
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for i in range(0, len(test_examples), batch_size):
                batch = test_examples[i:i+batch_size]
                src_batch, tgt_batch = zip(*batch)

                max_batch_len = max(max(len(s) for s in src_batch), max(len(t) for t in tgt_batch))

                # Pad sequences to the maximum length (max_len) in the batch
                src_padded = [pad_sequence(seq, max_batch_len) for seq in src_batch]
                tgt_padded = [pad_sequence(seq, max_batch_len) for seq in tgt_batch]

                # Convert padded sequences to tensors
                src_tensor = torch.tensor([tokenize(seq, char_to_idx) for seq in src_padded], device=device)
                tgt_tensor = torch.tensor([tokenize(seq, char_to_idx) for seq in tgt_padded], device=device)

                # Create the feature mask
                feature_mask_src = torch.stack([create_feature_mask(seq) for seq in src_padded], dim=0)

                # Shift target tensor
                tgt_input = tgt_tensor[:, :-1]
                tgt_expected = tgt_tensor[:, 1:]

                # Forward pass
                output = model(src_tensor, tgt_input, feature_mask_src, tgt_is_causal=True)

                # Compute the loss
                loss = criterion(output.reshape(-1, vocab_size), tgt_expected.reshape(-1))

                # Apply padding mask to loss
                tgt_mask = (tgt_expected != pad_token).float().view(-1)
                loss = loss * tgt_mask  # Element-wise multiply with the mask to ignore padding tokens
                loss = loss.sum() / tgt_mask.sum()  # Normalize the loss (average over non-padding tokens)
                
                test_loss += loss.item() * len(batch)

        test_loss = test_loss / len(test_examples)
        print(f"Test Loss after Epoch {epoch+1}: {test_loss}")

        # Early stopping based on test set loss
        if test_loss < best_test_loss:
            best_test_loss = test_loss
            best_epoch = epoch
            patience_count = 0
            best_model_state = copy.deepcopy(model.state_dict())  # Save best model state
        else:
            patience_count += 1
            if patience_count == patience:
                patience_count = 0
                patience = int(patience * math.sqrt(2))
                # Rollback to best model state (undo last epoch update)
                model.load_state_dict(best_model_state)
                print(f"Rolling back to best model from epoch {best_epoch + 1}")
                print(f"Best test loss: {best_test_loss}")
        
    model.load_state_dict(best_model_state)


{'œ', '6', 'j', 'b', '*', 'k', 'g', '/', 'T', 'u', '<PST>', 'i', 'r', 't', '8', 'è', 's', '9', '<PRS.PTCP>', 'q', 'l', '4', ' ', 'w', 'é', 'v', 'æ', 'ä', '’', '<PRS>', 'o', '<3SG>', 'ö', 'û', 'n', '<PST.PTCP>', 'a', 'f', 'h', 'p', '-', 'c', 'R', "'", 'ê', 'ë', 'x', '0', '1', 'y', 'E', 'ï', 'd', 'e', 'z', 'm', 'U'}


In [9]:
train_set, test_set, val_set = prepare_data(data, val_size=100, train_size=5000, test_ratio=0.1)

In [10]:
def list_to_word(arr):
    return ''.join(arr[1:-1])

In [16]:
for train_size in range(1000, 1001, 100):
    for rat in [0.2, 0.5, 1]: #0.5, 1
        cur_train_set = train_set[:int(train_size*0.9)]
        cur_test_set = test_set[:int(train_size*0.1)]
        train_examples = [ex for paradigm in cur_train_set for ex in generate_examples(paradigm)]
        test_examples = [ex for paradigm in cur_test_set for ex in generate_examples(paradigm)]
        print(len(cur_train_set), len(train_examples), len(test_examples))
        model = CharTransformer(vocab_size, device=device, max_len=max_len)
        train_model(model, train_examples, test_examples, hallucination_ratio=rat, hallucination_refresh_rate=5, stop_hallucinating_after=0.9, batch_size=398, epochs=int(20000*398/(train_size*(1+rat*0.7)*5)))

        save_dir = "/home/minhk/Assignments/CSCI 5801/project/model/"
        filename=f"base_transformers_hallucinate_b_{rat}_{train_size}.pth"
        os.makedirs(save_dir, exist_ok=True)
        torch.save(model, save_dir+filename)

900 4500 500
4500




Epoch 1, Loss: 4.067972715059916
Test Loss after Epoch 1: 3.736744210243225
4500
Epoch 2, Loss: 3.7102280310524836
Test Loss after Epoch 2: 3.5517073907852175
4500
Epoch 3, Loss: 3.5272643058564928
Test Loss after Epoch 3: 3.445409442901611
4500
Epoch 4, Loss: 3.419533242755466
Test Loss after Epoch 4: 3.3505896120071412
4500
Epoch 5, Loss: 3.348769915368822
Test Loss after Epoch 5: 3.2837056665420534
4500
Epoch 6, Loss: 3.2970220947265627
Test Loss after Epoch 6: 3.231209547042847
4500
Epoch 7, Loss: 3.246455425898234
Test Loss after Epoch 7: 3.1726335897445677
4500
Epoch 8, Loss: 3.1757055989371406
Test Loss after Epoch 8: 3.0838602142333986
4500
Epoch 9, Loss: 3.065431721581353
Test Loss after Epoch 9: 2.9411994962692263
4500
Epoch 10, Loss: 2.9672621830834283
Test Loss after Epoch 10: 2.8451403369903563
4500
Epoch 11, Loss: 2.8526019938786824
Test Loss after Epoch 11: 2.8165315923690795
4500
Epoch 12, Loss: 2.769704253938463
Test Loss after Epoch 12: 2.7295224609375
4500
Epoch 13, 

In [None]:
# Select a random paradigm from validation set
random_paradigm = random.choice(val_set)

# Generate examples from the paradigm
generated_examples = generate_examples(random_paradigm)

for src, tgt in generated_examples: 
    print(list_to_word(src), list_to_word(tgt))
    feature_mask = create_feature_mask(src).unsqueeze(0)
    src_tokenized = torch.tensor(tokenize(src, char_to_idx), device=device).unsqueeze(0)
    gen = model.generate(src_tokenized, feature_mask, beam_size=5).squeeze(0)
    gen = list_to_word([idx_to_char[id.item()] for id in gen])
    print(gen)


<PRS>monumentalize monumentalize
monumentalize
<3SG>monumentalize monumentalizes
monumentalizes
<PST>monumentalize monumentalized
monumentalized
<PRS.PTCP>monumentalize monumentalizing
monumentalizing
<PST.PTCP>monumentalize monumentalized
monumentalized


In [23]:
word = 'pass'
tags = ["PRS", "3SG", "PST", "PRS.PTCP", "PST.PTCP"]

for tag in tags: 
    tokens = ["<s>", f"<{tag}>"] + list(word) + ["</s>"]
    feature_mask = create_feature_mask(tokens).unsqueeze(0)
    src_tokenized = torch.tensor(tokenize(tokens, char_to_idx), device=device).unsqueeze(0)
    gen = model.generate(src_tokenized, feature_mask).squeeze(0)
    gen = list_to_word([idx_to_char[id.item()] for id in gen])
    gen2 = model.greedy_batch_decode(src_tokenized, feature_mask).squeeze(0)
    gen2 = list_to_word([idx_to_char[id.item()] for id in gen2])
    print(f"{word} + {tag} = {gen}, {gen2}")
    

pass + PRS = psss, psss
pass + 3SG = pssses, pssses
pass + PST = pssssasass, psssspasasspas
pass + PRS.PTCP = psssing, psssing
pass + PST.PTCP = psssspasasspas, psssspasasspas


In [17]:
# Perform greedy batched decoding
val_set = data.items()

for train_size in range(1000, 1001, 100):
    for rat in [0.2, 0.5, 1]: # 0.5, 1
        save_dir = f"/home/minhk/Assignments/CSCI 5801/project/model/base_transformers_hallucinate_b_{rat}_{train_size}.pth" # oops the file names are all wrong

        # Load model
        model = torch.load(save_dir, map_location=device, weights_only=False)
        model.to(device)
        model.eval()  # Set model to evaluation mode

        print(f"Model {train_size} loaded")

        correct_predictions = 0
        total_predictions = 0
        batch_size = 32

        all_src = []
        all_tgt = []

        for paradigm in val_set:
            pairs = generate_examples(paradigm)
            for src, tgt in pairs:
                all_src.append(src)
                all_tgt.append(tgt)

        # Determine max sequence length in the batch for padding
        def pad_sequences(sequences, pad_token, max_len=None):
            """Pads sequences to the max length with the given pad token."""
            max_len = max(len(seq) for seq in sequences)
            return [seq + [pad_token] * (max_len - len(seq)) for seq in sequences]

        # Process in batches
        for i in range(0, len(all_src), batch_size):
            src_batch = all_src[i:i+batch_size]
            tgt_batch = all_tgt[i:i+batch_size]

            try:
                # Pad inputs
                src_batch_padded = pad_sequences(src_batch, "<pad>")
                tgt_batch_padded = pad_sequences(tgt_batch, "<pad>")  # Only for consistent tensor creation
                
                feature_mask = torch.stack([create_feature_mask(src) for src in src_batch_padded]).to(device)
                src_tokenized = torch.tensor([tokenize(seq, char_to_idx) for seq in src_batch_padded], device=device)
                
                gen_batch = model.greedy_batch_decode(src_tokenized, feature_mask)
                
                for gen, tgt in zip(gen_batch, tgt_batch):
                    gen_str = "".join([idx_to_char[idx.item()] for idx in gen])
                    
                    # Trim at first stop token
                    gen_str = gen_str.split("</s>")[0][3:]
                    
                    correct_predictions += (gen_str == list_to_word(tgt))
                    total_predictions += 1
            except:
                print('Error')
                pass

        # Print accuracy
        print(f"{train_size} {rat} {correct_predictions / total_predictions:.4f}")
        print(f"Model trained on {train_size} paradigms; hallucination ratio {rat}; {correct_predictions} predictions correct out of {total_predictions} total. Accuracy: {correct_predictions / total_predictions:.4f}")


Model 1000 loaded
1000 0.2 0.9135
Model trained on 1000 paradigms; hallucination ratio 0.2; 109147 predictions correct out of 119480 total. Accuracy: 0.9135
Model 1000 loaded
1000 0.5 0.9319
Model trained on 1000 paradigms; hallucination ratio 0.5; 111338 predictions correct out of 119480 total. Accuracy: 0.9319
Model 1000 loaded
1000 1 0.9293
Model trained on 1000 paradigms; hallucination ratio 1; 111036 predictions correct out of 119480 total. Accuracy: 0.9293


In [11]:
"""
correct_predictions = 0
total_predictions = 0

for paradigm in val_set:
    pairs = generate_examples(paradigm)
    for src, tgt in pairs: 
        feature_mask = create_feature_mask(src).unsqueeze(0)
        src_tokenized = torch.tensor(tokenize(src, char_to_idx), device=device).unsqueeze(0)
        gen = model.generate(src_tokenized, feature_mask).squeeze(0)
        gen = [idx_to_char[id.item()] for id in gen]
        correct_predictions += (gen == tgt)
        total_predictions += 1

print(f"{correct_predictions} predictions correct out of {total_predictions} total. Accuracy: {correct_predictions / total_predictions}")
""" 

'\ncorrect_predictions = 0\ntotal_predictions = 0\n\nfor paradigm in val_set:\n    pairs = generate_examples(paradigm)\n    for src, tgt in pairs: \n        feature_mask = create_feature_mask(src).unsqueeze(0)\n        src_tokenized = torch.tensor(tokenize(src, char_to_idx), device=device).unsqueeze(0)\n        gen = model.generate(src_tokenized, feature_mask).squeeze(0)\n        gen = [idx_to_char[id.item()] for id in gen]\n        correct_predictions += (gen == tgt)\n        total_predictions += 1\n\nprint(f"{correct_predictions} predictions correct out of {total_predictions} total. Accuracy: {correct_predictions / total_predictions}")\n'