In [None]:
# Implementation script for Arithmetic Transformer
# This script will be converted to a notebook for Google Colab

# Cell 1: Setup and Installation
"""
# Investigating the Limitations of Transformers with Simple Arithmetic Tasks

This notebook implements the experiments from the paper:
[Nogueira, Jiang, Lin "Investigating the Limitations of Transformers with Simple Arithmetic Tasks", 2021](https://arxiv.org/abs/2102.13019)

It demonstrates how different number representations affect the ability of transformer models to learn arithmetic tasks.

## Setup
First, let's install the required packages and set up the environment.
"""

# Install required packages
!pip install -q torch pytorch-lightning transformers num2words numpy pandas matplotlib tqdm

# Clone the repository if running on Colab
!if [ ! -d "5782_Final_Project" ]; then git clone https://github.com/joshiarnav/5782_Final_Project.git; fi

# Navigate to the code directory
%cd 5782_Final_Project/code

In [None]:
# Cell 2: Import Libraries and Set Parameters
"""
## Configuration
Let's set up the configuration for training the model.
"""

import os
import json
import glob
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import pytorch_lightning as pl

# Set random seeds for reproducibility
SEED = 1
random.seed(SEED)
pl.seed_everything(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# Define training parameters
OUTPUT_DIR = './output'
MODEL_NAME = 't5-base'
OPERATION = 'addition'  # 'addition' or 'subtraction'
ORTHOGRAPHY = '10ebased'  # 'decimal', 'character', 'character_fixed', 'underscore', 'words', '10based', '10ebased'
MAX_DIGITS_TRAIN = 5  # Reduced from 15 for faster training
MAX_DIGITS_TEST = 5   # Reduced from 15 for faster training
TRAIN_SIZE = 1000      # Reduced from 100000 for faster training
VAL_SIZE = 200         # Reduced from 10000 for faster training
TEST_SIZE = 200        # Reduced from 10000 for faster training
BATCH_SIZE = 4
MAX_EPOCHS = 5         # Reduced from 20 for faster training

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# Cell 3: Run Training
"""
## Training
Now let's train the model using the specified parameters.
"""

# Construct the training command
train_command = f"""python train.py \
    --output_dir={OUTPUT_DIR} \
    --model_name_or_path={MODEL_NAME} \
    --operation={OPERATION} \
    --orthography={ORTHOGRAPHY} \
    --balance_train \
    --balance_val \
    --train_size={TRAIN_SIZE} \
    --val_size={VAL_SIZE} \
    --test_size={TEST_SIZE} \
    --min_digits_train=2 \
    --max_digits_train={MAX_DIGITS_TRAIN} \
    --min_digits_test=2 \
    --max_digits_test={MAX_DIGITS_TEST} \
    --base_number=10 \
    --seed={SEED} \
    --train_batch_size={BATCH_SIZE} \
    --accumulate_grad_batches=4 \
    --val_batch_size={BATCH_SIZE*4} \
    --max_seq_length=512 \
    --num_workers=2 \
    --gpus=1 \
    --optimizer=AdamW \
    --lr=3e-4 \
    --weight_decay=5e-5 \
    --scheduler=StepLR \
    --gamma=1.0 \
    --step_size=1000 \
    --max_epochs={MAX_EPOCHS} \
    --check_val_every_n_epoch=1 \
    --precision=32 \
    --gradient_clip_val=1.0"""

# Execute the training command
!{train_command}

In [None]:
# Cell 4: Evaluate Results
"""
## Results
Let's evaluate the results of our training.
"""

# Load results
results_file = os.path.join(OUTPUT_DIR, 'results.json')
if os.path.exists(results_file):
    with open(results_file, 'r') as f:
        results = json.load(f)
    print(f"Test Exact Match: {results['test_exact_match']:.4f}")
    
    # Display other metadata
    print(f"Operation: {results['operation']}")
    print(f"Orthography: {results['orthography']}")
    print(f"Max Digits (Train): {results['max_digits_train']}")
    print(f"Max Digits (Test): {results['max_digits_test']}")
else:
    print("No results file found. Training may have failed or is still in progress.")

In [None]:
# Cell 5: Visualize Sample Predictions
"""
## Sample Predictions
Let's look at some sample predictions from the model.
"""

# Find the latest log file
log_files = glob.glob(os.path.join(OUTPUT_DIR, 'logs', '*.txt'))
if log_files:
    latest_log = max(log_files, key=os.path.getmtime)
    print(f"Latest log file: {latest_log}")
    
    # Extract and display sample predictions
    samples = []
    current_sample = {}
    with open(latest_log, 'r') as f:
        lines = f.readlines()
        for i, line in enumerate(lines):
            if 'Sample question:' in line:
                if current_sample and 'question' in current_sample:
                    samples.append(current_sample)
                current_sample = {}
                current_sample['question'] = line.split('Sample question:')[1].strip()
            elif 'Sample correct answer:' in line and 'question' in current_sample:
                current_sample['correct'] = line.split('Sample correct answer:')[1].strip()
            elif 'Sample predicted answer:' in line and 'question' in current_sample:
                current_sample['predicted'] = line.split('Sample predicted answer:')[1].strip()
            elif 'Exact match:' in line and 'question' in current_sample:
                current_sample['exact_match'] = line.split('Exact match:')[1].strip()
    
    # Add the last sample if it exists
    if current_sample and 'question' in current_sample:
        samples.append(current_sample)
    
    # Display samples
    for i, sample in enumerate(samples[:5]):  # Show up to 5 samples
        print(f"\nSample {i+1}:")
        print(f"Question: {sample.get('question', 'N/A')}")
        print(f"Correct: {sample.get('correct', 'N/A')}")
        print(f"Predicted: {sample.get('predicted', 'N/A')}")
        print(f"Exact Match: {sample.get('exact_match', 'N/A')}")
else:
    print("No log files found.")

In [None]:
# Cell 6: Visualize Performance Across Different Orthographies
"""
## Orthography Comparison
Let's visualize how different number representations (orthographies) affect model performance.
"""

# This is a placeholder for actual data - in a real scenario, you would run multiple experiments
# with different orthographies and collect the results
orthographies = ['decimal', 'character', 'character_fixed', 'underscore', 'words', '10based', '10ebased']
accuracies = [0.05, 0.35, 0.45, 0.40, 0.60, 0.95, 0.98]  # Placeholder data based on paper findings

plt.figure(figsize=(12, 6))
plt.bar(orthographies, accuracies, color='skyblue')
plt.xlabel('Number Representation', fontsize=12)
plt.ylabel('Accuracy (Exact Match)', fontsize=12)
plt.title('Model Performance by Number Representation', fontsize=14)
plt.ylim(0, 1.0)
plt.xticks(rotation=45, ha='right')
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add a horizontal line at 0.5 for reference
plt.axhline(y=0.5, color='red', linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

In [None]:
# Cell 7: Evaluate Custom Examples
"""
## Custom Evaluation
Let's evaluate the model on some custom examples.
"""

# Find the latest checkpoint
checkpoint_files = glob.glob(os.path.join(OUTPUT_DIR, '*.ckpt'))
if checkpoint_files:
    latest_checkpoint = max(checkpoint_files, key=os.path.getmtime)
    print(f"Latest checkpoint: {latest_checkpoint}")
    
    # Construct the evaluation command
    examples = ["123,456", "7890,1234", "9999,9999"]
    eval_command = f"python evaluate.py \
        --checkpoint_dir={OUTPUT_DIR} \
        --operation={OPERATION} \
        --orthography={ORTHOGRAPHY} \
        --max_digits={MAX_DIGITS_TEST} \
        --examples {' '.join(examples)}"
    
    # Execute the evaluation command
    !{eval_command}
else:
    print("No checkpoint files found.")