In [1]:
from torch.utils.data import Dataset, DataLoader
from data.utils.data_utils import WordTokenizer
from collections import defaultdict
import torch
import math
from models import CBR_LM, Transformer_LM, LSTM_LM

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class NounPPDataset(Dataset):
    def __init__(self, nounpp_file, tokenizer):
        self.sentences = []
        self.conditions = []
        self.correct = []
        self.wrong = []
        self.encoded_sentences = []
        self.encoded_correct = []
        self.encoded_wrong = []

        with open(nounpp_file, "r") as f:
            for line in f:
                line = line.split()
                sentence = line[1:7]
                condition = " ".join(line[7:9])
                wrong = line[9]
                correct = line[6]
                
                # Encode each word individually
                encoded_sentence = [
                    tokenizer.stoi.get(word.lower(), 0)  # Get token ID directly
                    for word in sentence
                ]
                encoded_correct = tokenizer.stoi.get(correct.lower(), 0)
                encoded_wrong = tokenizer.stoi.get(wrong.lower(), 0)
                
                self.sentences.append(sentence)
                self.conditions.append(condition)
                self.correct.append(correct)
                self.wrong.append(wrong)
                self.encoded_sentences.append(encoded_sentence)  # Now a list of ints
                self.encoded_correct.append(encoded_correct)      # Now an int
                self.encoded_wrong.append(encoded_wrong)          # Now an int

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, idx):
        return {
            "sentence": self.sentences[idx],
            "encoded_sentence": torch.tensor(
                self.encoded_sentences[idx], dtype=torch.long
            ),  # [6] 
            "correct": self.correct[idx],
            "encoded_correct": self.encoded_correct[idx],  # scalar int
            "wrong": self.wrong[idx],
            "encoded_wrong": self.encoded_wrong[idx],      # scalar int
            "condition": self.conditions[idx],
        }
        


def collate_fn_nounpp(batch):
    sentences = [item["sentence"] for item in batch]
    encoded_sentences = torch.stack([item["encoded_sentence"] for item in batch])
    
    # Convert lists of scalars to tensors
    encoded_correct = torch.tensor([item["encoded_correct"] for item in batch], dtype=torch.long)
    encoded_wrong = torch.tensor([item["encoded_wrong"] for item in batch], dtype=torch.long)
    
    correct = [item["correct"] for item in batch]
    wrong = [item["wrong"] for item in batch]
    conditions = [item["condition"] for item in batch]

    return {
        "sentence": sentences,
        "encoded_sentence": encoded_sentences,  # [batch_size, 6]
        "correct": correct,
        "encoded_correct": encoded_correct,     # [batch_size]
        "wrong": wrong,
        "encoded_wrong": encoded_wrong,         # [batch_size]
        "condition": conditions,
    }

In [4]:
tokenizer = WordTokenizer.load("data/tokenizer/tokenizer.json")
test_dataset = NounPPDataset('tests/test_datasets/nounpp.txt', tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=1000, collate_fn=collate_fn_nounpp)


In [None]:
checkpoint_path = '/scratch2/mrenaudin/Hard-CBR-RNN/checkpoints/job_lstm_003/lightning_logs/version_1851441/checkpoints/epoch=49-step=309500.ckpt'
model=SimpleLSTM_LM.load_from_checkpoint(checkpoint_path)
model.eval()

# Eval function 1


In [None]:
def eval_with_attention_analysis(model, test_dataloader):
    """
    Efficient evaluation: Process first 5 words to predict verb at position 5.
    Works with both CBR-RNN and Transformer models.
    """
    condition_accuracies = defaultdict(int)
    condition_counts = defaultdict(int)
    sentence_details = []
    attention_by_condition = defaultdict(list) 
    
    model.model.eval()
    
    # Detect model type
    is_transformer = hasattr(model, 'token_embedding')
    
    with torch.no_grad():
        for batch in test_dataloader:
            written = batch["sentence"]
            sentence = batch["encoded_sentence"]  
            correct = batch["encoded_correct"]
            wrong = batch["encoded_wrong"]
            condition = batch["condition"]
            batch_size = sentence.size(0)
            limit = sentence.size(1)-1
            context = sentence[:, :limit]
            hidden = model.model.init_hidden(context.size(0), device = 'cpu')
            out,_=model.model(context,hidden)
            log_probs = torch.nn.functional.log_softmax(out[:, -1, :], dim=-1)
            correct_log_probs = log_probs[torch.arange(batch_size), correct]
            wrong_log_probs = log_probs[torch.arange(batch_size), wrong]
            correct_predictions = correct_log_probs >= wrong_log_probs

            for i in range(batch_size):
                cond = condition[i]
                pred = correct_predictions[i].item()
                
                condition_counts[cond] += 1
                condition_accuracies[cond] += pred

                sentence_details.append({
                    "sentence": written[i],
                    "condition": cond,
                    "correct_log_prob": correct_log_probs[i].item(),
                    "wrong_log_prob": wrong_log_probs[i].item(),
                    "model_prefers_correct": pred,
                })
                

    final_accuracies = {
        cond: condition_accuracies[cond] / condition_counts[cond]
        for cond in condition_accuracies
    }
    

    return final_accuracies, None, sentence_details



In [None]:
res = eval_with_attention_analysis(model, test_dataloader)

In [None]:
res[0]

# Eval function 2

In [None]:
init_sentence = " ".join(
    [
        "In service , the aircraft was operated by a crew of five and could accommodate either 30 paratroopers , 32 <unk> and 28 sitting casualties , or 50 fully equipped troops .",
        'He even speculated that technical classes might some day be held " for the better training of workmen in their several crafts and industries .',
        "After the War of the Holy League in 1537 against the Ottoman Empire , a truce between Venice and the Ottomans was created in 1539 .",
        'Moore says : " Tony and I had a good <unk> and off-screen relationship , we are two very different people , but we did share a sense of humour " .',
        "<unk> is also the basis for online games sold through licensed lotteries .",
    ]
)


def feed_input(model, hidden, w):
    if w not in tokenizer.stoi:
        print(f"Warning: '{w}' not in vocabulary, using <unk> token")
        w = '<unk>'
    inp = torch.autograd.Variable(
        torch.LongTensor([[tokenizer.stoi[w]]]).to('cpu')
    )
    out, hidden= model.model(inp, hidden)
    return out, hidden


def feed_sentence(model, h, sentence):
    outs = []
    for w in sentence:
        out, h = feed_input(model, h, w)
        outs.append(torch.nn.functional.log_softmax(out[0]).unsqueeze(0))
    return outs, h 


# evaluation function
def eval(model, test_dataloader, init_sentence):
    condition_accuracies = defaultdict(int)
    condition_counts = defaultdict(int)
    correct_pred = 0
    sentence_details = []

    model.model.eval()

    hidden = model.model.init_hidden(1, device='cpu')
    init_out, init_h= feed_sentence(model, hidden, init_sentence.split(" "))
    with torch.no_grad():
        for batch in test_dataloader:
            out = None
            written = batch["sentence"]
            sentence = batch["encoded_sentence"]
            correct = batch["encoded_correct"]
            wrong = batch["encoded_wrong"]
            condition = batch["condition"]
            batch_size = sentence.size(0)
            hidden = (
                init_h[0].expand(-1, batch_size, -1).contiguous(),
                init_h[1].expand(-1, batch_size, -1).contiguous(),
            )

            for w in range(sentence.shape[1] - 1):

                word = torch.autograd.Variable(sentence[:, w:w+1])#.unsqueeze(0))
                out, hidden = model.model(word, hidden)
            print(out.shape)
            log_probs = torch.nn.functional.log_softmax(out.squeeze(1), dim=-1)
            correct_log_probs = log_probs[torch.arange(batch_size), correct]
            wrong_log_probs = log_probs[torch.arange(batch_size), wrong]
            correct_predictions = correct_log_probs >= wrong_log_probs
            for i in range(batch_size):
                cond = condition[i]
                pred = correct_predictions[i].item()
                condition_counts[cond] += 1
                condition_accuracies[cond] += pred

                sentence_details.append(
                    {
                        "sentence": written[i],
                        "condition": condition[i],
                        "correct_log_prob": correct_log_probs[i],
                        "wrong_log_prob": wrong_log_probs[i],
                        "model_prefers_correct": pred,
                    }
                )
    final_accuracies = {
        cond: condition_accuracies[cond] / condition_counts[cond]
        for cond in condition_accuracies
    }

    return final_accuracies

In [None]:
res = eval(model, test_dataloader, init_sentence)

In [None]:
res

In [None]:
res

In [None]:
# No init_sentence needed anymore!

def eval_no_priming(model, test_dataloader):
    """
    Evaluation without priming - fresh hidden state for each batch
    """
    condition_accuracies = defaultdict(int)
    condition_counts = defaultdict(int)
    sentence_details = []
    model.model.eval()
    
    with torch.no_grad():
        for batch in test_dataloader:
            written = batch["sentence"]
            sentence = batch["encoded_sentence"]
            correct = batch["encoded_correct"]
            wrong = batch["encoded_wrong"]
            condition = batch["condition"]
            batch_size = sentence.size(0)
            
            # Initialize fresh hidden state for this batch (no priming!)
            hidden = model.model.init_hidden(batch_size, device='cpu')
            
            # Process sentence word-by-word
            for w in range(sentence.shape[1] - 1):
                word = sentence[:, w].unsqueeze(1)  # [batch_size, 1]
                out, hidden = model.model(word, hidden)
            
            # Get predictions from final output
            log_probs = torch.nn.functional.log_softmax(out.squeeze(1), dim=-1)  # [batch_size, vocab_size]
            correct_log_probs = log_probs[torch.arange(batch_size), correct]
            wrong_log_probs = log_probs[torch.arange(batch_size), wrong]
            correct_predictions = correct_log_probs >= wrong_log_probs
            
            for i in range(batch_size):
                cond = condition[i]
                pred = correct_predictions[i].item()
                condition_counts[cond] += 1
                condition_accuracies[cond] += pred
                sentence_details.append({
                    "sentence": written[i],
                    "condition": cond,
                    "correct_log_prob": correct_log_probs[i].item(),
                    "wrong_log_prob": wrong_log_probs[i].item(),
                    "model_prefers_correct": pred,
                })
    
    final_accuracies = {
        cond: condition_accuracies[cond] / condition_counts[cond]
        for cond in condition_accuracies
    }
    
    return final_accuracies, sentence_details

In [None]:
res = eval_no_priming(model, test_dataloader)

In [None]:
res

-> so the difference in performance is only explained by priming 

# Testing Transformers with priming 

In [None]:
checkpoint_path = 'checkpoints/job_transformer_2_003/lightning_logs/version_1851787/checkpoints/epoch=49-step=309500.ckpt'
model=SimpleTransformerLM.load_from_checkpoint(checkpoint_path)
model.eval()

In [None]:
def eval_transformer(model, test_dataloader, init_sentence):
    """
    Evaluation with priming for Transformer model
    """
    condition_accuracies = defaultdict(int)
    condition_counts = defaultdict(int)
    sentence_details = []
    model.model.eval()
    max_seq_len = model.model.pos_embedding.num_embeddings  # 64

    # Tokenize and encode init_sentence once
    init_tokens = init_sentence.split(" ")
    init_encoded = torch.LongTensor([[
        tokenizer.stoi.get(w, tokenizer.stoi.get('<unk>', 0)) for w in init_tokens
    ]])  # [1, init_len]
    
    with torch.no_grad():
        for batch in test_dataloader:
            written = batch["sentence"]
            sentence = batch["encoded_sentence"]  # [batch_size, seq_len]
            correct = batch["encoded_correct"]
            wrong = batch["encoded_wrong"]
            condition = batch["condition"]
            batch_size = sentence.size(0)
            device = sentence.device
            seq_len = sentence.size(1)
            
            # Move init_encoded to correct device
            init_encoded = init_encoded.to(device)
            max_init_len = max_seq_len - (seq_len - 1)
            if init_encoded.size(1) > max_init_len:
                # Take the last max_init_len tokens (most recent context)
                init_encoded = init_encoded[:, -max_init_len:]
            
            # Expand init_encoded to batch size
            init_batch = init_encoded.expand(batch_size, -1)  # [batch_size, init_len]
            
            # Concatenate init_sentence with test sentence
            # This gives the model context before processing the test sentence
            full_sequence = torch.cat([init_batch, sentence[:, :-1]], dim=1)  # [batch_size, init_len + seq_len - 1]
            
            # Forward pass through transformer
            # Transpose for model: [batch_size, seq_len] -> [seq_len, batch_size]
            out = model.model(full_sequence.transpose(0, 1))  # [seq_len, batch_size, vocab_size]
            
            # Get output at the position corresponding to the last word of the test sentence
            # We want the prediction after seeing init_sentence + test_sentence[:-1]
            pred_position = -1  # Last position
            log_probs = torch.nn.functional.log_softmax(out[pred_position], dim=-1)  # [batch_size, vocab_size]
            
            correct_log_probs = log_probs[torch.arange(batch_size), correct]
            wrong_log_probs = log_probs[torch.arange(batch_size), wrong]
            correct_predictions = correct_log_probs >= wrong_log_probs
            
            for i in range(batch_size):
                cond = condition[i]
                pred = correct_predictions[i].item()
                condition_counts[cond] += 1
                condition_accuracies[cond] += pred
                sentence_details.append({
                    "sentence": written[i],
                    "condition": cond,
                    "correct_log_prob": correct_log_probs[i].item(),
                    "wrong_log_prob": wrong_log_probs[i].item(),
                    "model_prefers_correct": pred,
                })
    
    final_accuracies = {
        cond: condition_accuracies[cond] / condition_counts[cond]
        for cond in condition_accuracies
    }
    
    return final_accuracies, sentence_details

In [7]:
# tokenizer = WordTokenizer.load("tokenizer.json")
test_dataset = NounPPDataset('tests/test_datasets/nounpp.txt', tokenizer)
test_dataloader = DataLoader(test_dataset, batch_size=1000, collate_fn=collate_fn_nounpp)
init_sentence = " ".join(
    [
        "In service , the aircraft was operated by a crew of five and could accommodate either 30 paratroopers , 32 <unk> and 28 sitting casualties , or 50 fully equipped troops .",
        'He even speculated that technical classes might some day be held " for the better training of workmen in their several crafts and industries .',
        "After the War of the Holy League in 1537 against the Ottoman Empire , a truce between Venice and the Ottomans was created in 1539 .",
        'Moore says : " Tony and I had a good <unk> and off-screen relationship , we are two very different people , but we did share a sense of humour " .',
        "<unk> is also the basis for online games sold through licensed lotteries .",
    ]
)

In [None]:


res = eval_transformer(model, test_dataloader, init_sentence)

In [None]:
res[0]

In [None]:
def eval_transformer_simple(model, test_dataloader):
    """
    Simple evaluation for Transformer - no priming needed
    Transformers don't benefit from priming the same way LSTMs do
    """
    condition_accuracies = defaultdict(int)
    condition_counts = defaultdict(int)
    sentence_details = []
    model.model.eval()
    
    with torch.no_grad():
        for batch in test_dataloader:
            written = batch["sentence"]
            sentence = batch["encoded_sentence"]  # [batch_size, seq_len]
            correct = batch["encoded_correct"]
            wrong = batch["encoded_wrong"]
            condition = batch["condition"]
            batch_size = sentence.size(0)
            
            # Process test sentence directly
            context = sentence[:, :-1]  # [batch_size, seq_len-1]
            
            # Forward pass through transformer
            out = model.model(context.transpose(0, 1))  # [seq_len-1, batch_size, vocab_size]
            
            # Get predictions from last position
            log_probs = torch.nn.functional.log_softmax(out[-1], dim=-1)  # [batch_size, vocab_size]
            
            correct_log_probs = log_probs[torch.arange(batch_size), correct]
            wrong_log_probs = log_probs[torch.arange(batch_size), wrong]
            correct_predictions = correct_log_probs >= wrong_log_probs
            
            for i in range(batch_size):
                cond = condition[i]
                pred = correct_predictions[i].item()
                condition_counts[cond] += 1
                condition_accuracies[cond] += pred
                sentence_details.append({
                    "sentence": written[i],
                    "condition": cond,
                    "correct_log_prob": correct_log_probs[i].item(),
                    "wrong_log_prob": wrong_log_probs[i].item(),
                    "model_prefers_correct": pred,
                })
    
    final_accuracies = {
        cond: condition_accuracies[cond] / condition_counts[cond]
        for cond in condition_accuracies
    }
    
    return final_accuracies, sentence_details

In [None]:
res = eval_transformer_simple(model, test_dataloader)

In [None]:
res[0]

In [None]:
from tensorboard.backend.event_processing import event_accumulator
import pandas as pd
import matplotlib.pyplot as plt

# Load the tensorboard log
log_path = '/scratch2/mrenaudin/Hard-CBR-RNN/checkpoints/job_transformer_2_003/lightning_logs/version_1851787/events.out.tfevents.1758631281.jzxh004.4137051.0'

ea = event_accumulator.EventAccumulator(log_path)
ea.Reload()

print("=== Available Tags ===")
print("Scalars:", ea.Tags()['scalars'])
print()

# Extract key metrics
def get_metric_data(ea, tag):
    try:
        events = ea.Scalars(tag)
        steps = [e.step for e in events]
        values = [e.value for e in events]
        return steps, values
    except:
        return None, None

# Get training and validation metrics
metrics_to_check = [
    'train_loss', 'train_ppl', 'train_accuracy',
    'val_loss', 'val_ppl', 'val_accuracy',
    'learning_rate', 'grad_norm',
    'train_entropy_norm', 'val_entropy_norm',
    'train_confidence', 'val_confidence',
    'train_repetition_ratio', 'val_repetition_ratio'
]

print("=== Training Summary ===\n")
for metric in metrics_to_check:
    steps, values = get_metric_data(ea, metric)
    if steps:
        print(f"{metric:25s}: Start={values[0]:.4f}, Final={values[-1]:.4f}, Best={min(values) if 'loss' in metric else max(values):.4f}, Steps={len(steps)}")

# Plot key metrics
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle('Transformer Training Analysis', fontsize=16)

# 1. Loss curves
ax = axes[0, 0]
steps, train_loss = get_metric_data(ea, 'train_loss')
if steps:
    ax.plot(steps, train_loss, label='Train Loss', alpha=0.7)
steps, val_loss = get_metric_data(ea, 'val_loss')
if steps:
    ax.plot(steps, val_loss, label='Val Loss', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('Loss Curves')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Perplexity
ax = axes[0, 1]
steps, train_ppl = get_metric_data(ea, 'train_ppl')
if steps:
    ax.plot(steps, train_ppl, label='Train PPL', alpha=0.7)
steps, val_ppl = get_metric_data(ea, 'val_ppl')
if steps:
    ax.plot(steps, val_ppl, label='Val PPL', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Perplexity')
ax.set_title('Perplexity')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_ylim(0, min(200, max(train_ppl) if train_ppl else 200))

# 3. Learning rate
ax = axes[0, 2]
steps, lr = get_metric_data(ea, 'learning_rate')
if steps:
    ax.plot(steps, lr)
ax.set_xlabel('Step')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedule')
ax.grid(True, alpha=0.3)

# 4. Gradient norm
ax = axes[1, 0]
steps, grad_norm = get_metric_data(ea, 'grad_norm')
if steps:
    ax.plot(steps, grad_norm)
ax.set_xlabel('Step')
ax.set_ylabel('Gradient Norm')
ax.set_title('Gradient Norm (Should be stable)')
ax.grid(True, alpha=0.3)

# 5. Entropy (prediction diversity)
ax = axes[1, 1]
steps, train_ent = get_metric_data(ea, 'train_entropy_norm')
if steps:
    ax.plot(steps, train_ent, label='Train', alpha=0.7)
steps, val_ent = get_metric_data(ea, 'val_entropy_norm')
if steps:
    ax.plot(steps, val_ent, label='Val', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Normalized Entropy')
ax.set_title('Prediction Diversity')
ax.legend()
ax.grid(True, alpha=0.3)
ax.axhline(y=0.1, color='r', linestyle='--', alpha=0.5, label='Low diversity threshold')

# 6. Repetition ratio
ax = axes[1, 2]
steps, train_rep = get_metric_data(ea, 'train_repetition_ratio')
if steps:
    ax.plot(steps, train_rep, label='Train', alpha=0.7)
steps, val_rep = get_metric_data(ea, 'val_repetition_ratio')
if steps:
    ax.plot(steps, val_rep, label='Val', linewidth=2)
ax.set_xlabel('Step')
ax.set_ylabel('Repetition Ratio')
ax.set_title('Model Collapse Indicator (Lower is better)')
ax.legend()
ax.grid(True, alpha=0.3)
ax.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Collapse threshold')

plt.tight_layout()
plt.savefig('transformer_training_analysis.png', dpi=150, bbox_inches='tight')
plt.show()

print("\n" + "="*60)
print("DIAGNOSIS")
print("="*60)

# Get final values
_, final_train_ppl = get_metric_data(ea, 'train_ppl')
_, final_val_ppl = get_metric_data(ea, 'val_ppl')
_, final_entropy = get_metric_data(ea, 'val_entropy_norm')
_, final_rep = get_metric_data(ea, 'val_repetition_ratio')

if final_train_ppl:
    print(f"\nFinal Training PPL: {final_train_ppl[-1]:.2f}")
    print(f"Final Validation PPL: {final_val_ppl[-1]:.2f}")
    
    if final_val_ppl[-1] > 50:
        print("⚠️  HIGH PERPLEXITY - Model didn't train well!")
        print("   → Model is still very uncertain about predictions")
    elif final_val_ppl[-1] > 30:
        print("⚠️  MODERATE PERPLEXITY - Training could be better")
    else:
        print("✓ Good perplexity - Model trained reasonably")

if final_entropy:
    print(f"\nFinal Validation Entropy: {final_entropy[-1]:.3f}")
    if final_entropy[-1] < 0.2:
        print("⚠️  LOW ENTROPY - Model predictions are NOT diverse")
        print("   → Model might be collapsing to predict same tokens")
    else:
        print("✓ Reasonable entropy")

if final_rep:
    print(f"\nFinal Validation Repetition: {final_rep[-1]:.3f}")
    if final_rep[-1] > 0.3:
        print("⚠️  HIGH REPETITION - Model is repeating same prediction")
        print("   → This explains the attention collapse!")
    else:
        print("✓ Low repetition ratio")

print("\n" + "="*60)
print("RECOMMENDATIONS")
print("="*60)

# Test on CBR

In [15]:
checkpoint_path = '/scratch2/mrenaudin/Hard-CBR-RNN/checkpoints/job_cbr_2_001/lightning_logs/version_1850825/checkpoints/epoch=49-step=309500.ckpt'
model=CBR_LM.load_from_checkpoint(checkpoint_path)
model.eval()

CBR_LM(
  (model): CBR_RNN(
    (drop): Dropout(p=0.5, inplace=False)
    (encoder): Embedding(49999, 1024)
    (q): Linear(in_features=2048, out_features=1024, bias=True)
    (intermediate_h): Linear(in_features=4096, out_features=4096, bias=True)
    (final_h): Linear(in_features=4096, out_features=3072, bias=True)
    (decoder): Linear(in_features=1024, out_features=49999, bias=True)
    (q_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (int_norm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
    (f_norm): LayerNorm((3072,), eps=1e-05, elementwise_affine=True)
  )
  (criterion): CrossEntropyLoss()
)

In [19]:
init_sentence = " ".join(
    [
        "In service , the aircraft was operated by a crew of five and could accommodate either 30 paratroopers , 32 <unk> and 28 sitting casualties , or 50 fully equipped troops .",
        'He even speculated that technical classes might some day be held " for the better training of workmen in their several crafts and industries .',
        "After the War of the Holy League in 1537 against the Ottoman Empire , a truce between Venice and the Ottomans was created in 1539 .",
        'Moore says : " Tony and I had a good <unk> and off-screen relationship , we are two very different people , but we did share a sense of humour " .',
        "<unk> is also the basis for online games sold through licensed lotteries .",
    ]
)

In [31]:
def eval_cbr_batched(model, test_dataloader, init_sentence):
    """
    More efficient batched evaluation for CBR-RNN with priming
    """
    condition_accuracies = defaultdict(int)
    condition_counts = defaultdict(int)
    sentence_details = []
    model.model.eval()
    
    # Tokenize and encode init_sentence
    init_tokens = init_sentence.split(" ")
    init_encoded = torch.LongTensor([
        [tokenizer.stoi.get(w, tokenizer.stoi.get('<unk>', 0)) for w in init_tokens]
    ]).transpose(0, 1)  # [init_len, 1]
    
    with torch.no_grad():
        for batch in test_dataloader:
            written = batch["sentence"]
            sentence = batch["encoded_sentence"]  # [batch_size, seq_len]
            correct = batch["encoded_correct"]
            wrong = batch["encoded_wrong"]
            condition = batch["condition"]
            batch_size = sentence.size(0)
            device = sentence.device
            
            # Move init_encoded to device
            init_encoded_device = init_encoded.to(device)
            
            # Process priming sentence with batch_size=1
            cache = model.model.init_cache(sentence.transpose(0,1))
            # _, primed_states = model.model(init_encoded_device, init_cache)
            
            # # Extract final primed cache
            # primed_hidden = primed_states[-1:, :, :]  # [1, 1, nhid]
            
            # # Expand to batch size
            # batch_cache = tuple(
            #     primed_hidden.expand(-1, batch_size, -1).contiguous()
            #     for _ in range(3)  # (hidden, keys, values)
            # )
            
            # Process test sequence (excluding last token)
            test_context = sentence[:, :-1].transpose(0, 1)  # [seq_len-1, batch_size]
            out, _ = model.model(test_context, cache)
            
            # Get predictions from last position
            # out shape: [seq_len-1, batch_size, vocab_size]
            log_probs = torch.nn.functional.log_softmax(out[-1], dim=-1)  # [batch_size, vocab_size]
            
            correct_log_probs = log_probs[torch.arange(batch_size), correct]
            wrong_log_probs = log_probs[torch.arange(batch_size), wrong]
            correct_predictions = correct_log_probs >= wrong_log_probs
            
            for i in range(batch_size):
                cond = condition[i]
                pred = correct_predictions[i].item()
                condition_counts[cond] += 1
                condition_accuracies[cond] += pred
                sentence_details.append({
                    "sentence": written[i],
                    "condition": cond,
                    "correct_log_prob": correct_log_probs[i].item(),
                    "wrong_log_prob": wrong_log_probs[i].item(),
                    "model_prefers_correct": pred,
                })
    
    final_accuracies = {
        cond: condition_accuracies[cond] / condition_counts[cond]
        for cond in condition_accuracies
    }
    
    return final_accuracies, sentence_details

In [32]:
res, sent = eval_cbr_batched(model, test_dataloader, init_sentence)

In [33]:
res

{'singular singular': 0.295,
 'singular plural': 0.313,
 'plural singular': 0.696,
 'plural plural': 0.707}

In [3]:
import torch
valid = torch.load('valid_data.pt')

FileNotFoundError: [Errno 2] No such file or directory: 'valid_data.pt'

In [None]:
valid

In [4]:
from grid_search.data_utils import WordTokenizer

# Load the tokenizer
tokenizer = WordTokenizer.load("tokenizer.json")

ModuleNotFoundError: No module named 'grid_search.data_utils'

In [5]:
decoded = tokenizer.decode(valid)

NameError: name 'tokenizer' is not defined

In [None]:
decoded

In [2]:
import torch
from grid_search.data_utils import WordTokenizer

# Load tokenizer and data
tokenizer = WordTokenizer.load("tokenizer.json")
train_data = torch.load('train_data.pt')

# Count unknown tokens
total_tokens = len(train_data)
unk_tokens = (train_data == tokenizer.unk_id).sum().item()

# Calculate percentage
unk_percentage = (unk_tokens / total_tokens) * 100

print(f"Total tokens: {total_tokens:,}")
print(f"Unknown tokens: {unk_tokens:,}")
print(f"Unknown percentage: {unk_percentage:.2f}%")

# Also show vocabulary coverage
vocab_size = tokenizer.vocab_size
print(f"\nVocabulary size: {vocab_size:,}")
print(f"Tokens used: {len(torch.unique(train_data)):,}")

ModuleNotFoundError: No module named 'grid_search.data_utils'