In [1]:
%load_ext autoreload
%autoreload 2
import torch
import transformers
import tiktoken
import json

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
# load roneneldan/TinyStories-1M using auto model
model = transformers.AutoModelForCausalLM.from_pretrained('roneneldan/TinyStories-1M').to(device)
tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")

In [None]:
with open('../datasets/autocircuit_ioi_prompts.json') as f:
    data = json.load(f)

In [11]:
print(len(data['prompts']))

1000


## Testing IOI like a benchmark

The initial purpose of this notebook was to use the IOI data like a benchmark to test if ablation training severely harms the performance of our model during training. The results were not as expected as even the regular tiny stories model does not fare well with how the evaluation is set up.

Ultimately, while this notebook does not work, I still think we need to demonstrate that the ablation training does not harm the general capabilities of the model.

In [41]:
def tokenize_prompts(prompts, tokenizer, batch_size=32):
    """
    Tokenizes prompts in batches.

    Parameters:
    - prompts (list): List of prompt strings.
    - tokenizer: The tokenizer instance.
    - batch_size (int): Number of prompts per batch.

    Returns:
    - list of dict: Tokenized inputs for each batch.
    """
    tokenized_batches = []
    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i + batch_size]
        tokenized = tokenizer(batch, return_tensors='pt', padding=True, truncation=True)
        tokenized_batches.append(tokenized)
    return tokenized_batches

def evaluate_model_accuracy(model, tokenizer, data, batch_size=32):
    """
    Evaluates the model's accuracy on the IOI task using batch processing.

    Parameters:
    - model: The HuggingFace ModelForCausalLM instance.
    - tokenizer: The corresponding tokenizer instance.
    - data (dict): The dataset containing prompts and answers.
    - batch_size (int): Number of prompts to process in each batch.

    Returns:
    - float: Accuracy of the model on the dataset.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    correct = 0
    total = len(data['prompts'])

    # Extract all corrupt prompts and corresponding answers
    prompts = [prompt['clean'] for prompt in data['prompts']]
    all_answers = [[ans.strip().lower() for ans in prompt['answers']] for prompt in data['prompts']]

    # Tokenize prompts in batches
    tokenized_batches = tokenize_prompts(prompts, tokenizer, batch_size=batch_size)

    with torch.no_grad():
        for idx, batch in enumerate(tokenized_batches):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            # Generate one token per prompt in the batch
            output_ids = model.generate(input_ids, attention_mask=attention_mask, max_new_tokens=1, do_sample=True)

            # Extract the generated token IDs
            generated_ids = output_ids[:, -1]  # Shape: (batch_size,)

            # Decode the generated tokens
            generated_tokens = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            generated_tokens = [token.strip().lower() for token in generated_tokens]

            # Get the corresponding answers for the current batch
            batch_start = idx * batch_size
            batch_end = batch_start + len(generated_tokens)
            batch_answers = all_answers[batch_start:batch_end]

            # Compare generated tokens with answers
            for gen_token, ans_list in zip(generated_tokens, batch_answers):
                if gen_token in ans_list:
                    correct += 1

    accuracy = correct / total
    print(f"Model Accuracy: {accuracy * 100:.2f}%")
    return accuracy

In [43]:
# Evaluate model accuracy with batching
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = 'left'
model.generation_config.pad_token_id = tokenizer.pad_token_id
accuracy = evaluate_model_accuracy(model, tokenizer, data, batch_size=64)

Model Accuracy: 0.80%


This does not work well since the LM uses a lot of pronouns which the answers do not account for.