In [1]:
from transformers import GPT2LMHeadModel
from torchtyping import TensorType
import torch
from transformers import GPT2Tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments
from torch.utils.data import Dataset

class AdditionDataset(Dataset):
    def __init__(self, tokenizer, examples, max_length=50):
        self.tokenizer = tokenizer
        self.examples = examples
        self.max_length = max_length

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        question, answer = self.examples[idx]
        inputs = self.tokenizer.encode_plus(
            question,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        labels = self.tokenizer.encode_plus(
            answer,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': inputs['input_ids'].flatten(),
            'attention_mask': inputs['attention_mask'].flatten(),
            'labels': labels['input_ids'].flatten()
        }

In [3]:
from transformers import TrainingArguments
from sklearn.model_selection import train_test_split

def generate_addition_validation(num_examples=100, max_number=100000000):
    import random
    examples = []
    for _ in range(num_examples):
        a = random.randint(0, max_number)
        b = random.randint(0, max_number)
        question = f"{a}+{b}="
        answer = str(a + b)
        examples.append((question, answer))
    return examples

def generate_random_addition_training(num_examples=100, max_number=100000000):
    import random
    examples = []
    for _ in range(num_examples):
        a = random.randint(0, max_number)
        b = random.randint(0, max_number)
        question = f"{a}+{b}="
        answer = str(a + b)
        examples.append((question, answer))
    return examples

def generate_ordered_addition_training(num_examples=100, max_number=100000000):
    import random
    examples = []
    for _ in range(num_examples):
        a = random.randint(0, max_number)
        b = random.randint(0, max_number)
        question = f"{a}+{b}="
        answer = str(a + b)
        examples.append((question, answer))
    # Sort the examples based on the sum of the numbers in the question
    examples.sort(key=lambda x: sum(int(num) for num in x[0].split('+') if num.strip().isdigit()))
    return examples

def generate_reverse_ordered_addition_training(num_examples=100, max_number=100000000):
    import random
    examples = []
    for _ in range(num_examples):
        a = random.randint(0, max_number)
        b = random.randint(0, max_number)
        question = f"{a}+{b}="
        answer = str(a + b)
        examples.append((question, answer))
    # Sort the examples based on the sum of the numbers in the question in reverse order
    examples.sort(key=lambda x: sum(int(num) for num in x[0].split('+') if num.strip().isdigit()), reverse=True)
    return examples

In [4]:
import matplotlib.pyplot as plt


datasets = {
    "random": generate_random_addition_training(),
    "ordered": generate_ordered_addition_training(),
    "reverse_ordered": generate_reverse_ordered_addition_training()
}

training_losses = {}

# Initialize the GPT2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token  # Add this line

examples = generate_addition_validation()
val_dataset = AdditionDataset(tokenizer, examples)

for dataset_name, examples in datasets.items():

    # Create a AdditionDataset instance for the validation set
    train_dataset = AdditionDataset(tokenizer, examples)

    model = GPT2LMHeadModel.from_pretrained('gpt2')

    def compute_metrics(eval_pred):
        logits, labels = eval_pred
        predictions = logits.argmax(-1)
        accuracy = (predictions == labels).mean().item()
        return {"accuracy": accuracy}

    training_args = TrainingArguments(
        output_dir="./results",  # output directory
        num_train_epochs=30,  # total number of training epochs
        per_device_train_batch_size=2,  # batch size per device during training
        per_device_eval_batch_size=8,  # batch size for evaluation
        warmup_steps=100,  # number of warmup steps for learning rate scheduler
        weight_decay=0.01,  # strength of weight decay
        logging_dir="./logs",  # directory for storing logs
        evaluation_strategy="epoch"
    )

    # Define the trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics
    )

    # Train the model
    trainer_state = trainer.train()

    # Store the training losses
    training_losses[dataset_name] = [log.get('eval_loss', None) for log in trainer.state.log_history if 'eval_loss' in log]

# Plot the training losses
for dataset_name, losses in training_losses.items():
    epochs = range(1, len(losses) + 1)
    plt.plot(epochs, losses, label=f'{dataset_name} training loss')

for dataset_name, losses in training_losses.items():
    print(f"Final loss for {dataset_name}: {losses[-1]}")

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

                                                 
  3%|▎         | 51/1500 [00:06<05:44,  4.21it/s]

{'eval_loss': 0.36602112650871277, 'eval_accuracy': 0.9246, 'eval_runtime': 1.0008, 'eval_samples_per_second': 99.917, 'eval_steps_per_second': 12.989, 'epoch': 1.0}


                                                 
  7%|▋         | 101/1500 [00:10<04:33,  5.11it/s]

{'eval_loss': 0.344858318567276, 'eval_accuracy': 0.9248, 'eval_runtime': 0.7372, 'eval_samples_per_second': 135.649, 'eval_steps_per_second': 17.634, 'epoch': 2.0}


                                                  
 10%|█         | 151/1500 [00:15<04:14,  5.30it/s]

{'eval_loss': 0.3599613308906555, 'eval_accuracy': 0.9246, 'eval_runtime': 0.7199, 'eval_samples_per_second': 138.907, 'eval_steps_per_second': 18.058, 'epoch': 3.0}


                                                  
 13%|█▎        | 201/1500 [00:19<04:04,  5.32it/s]

{'eval_loss': 0.37265831232070923, 'eval_accuracy': 0.9252, 'eval_runtime': 0.6879, 'eval_samples_per_second': 145.368, 'eval_steps_per_second': 18.898, 'epoch': 4.0}


                                                  
 17%|█▋        | 251/1500 [00:23<03:52,  5.37it/s]

{'eval_loss': 0.3792695701122284, 'eval_accuracy': 0.9248, 'eval_runtime': 0.684, 'eval_samples_per_second': 146.197, 'eval_steps_per_second': 19.006, 'epoch': 5.0}


                                                  
 20%|██        | 301/1500 [00:27<03:43,  5.37it/s]

{'eval_loss': 0.3961324393749237, 'eval_accuracy': 0.9248, 'eval_runtime': 0.6804, 'eval_samples_per_second': 146.969, 'eval_steps_per_second': 19.106, 'epoch': 6.0}


                                                  
 23%|██▎       | 351/1500 [00:32<03:53,  4.92it/s]

{'eval_loss': 0.40267840027809143, 'eval_accuracy': 0.925, 'eval_runtime': 0.7869, 'eval_samples_per_second': 127.089, 'eval_steps_per_second': 16.522, 'epoch': 7.0}


                                                  
 27%|██▋       | 401/1500 [00:36<03:54,  4.69it/s]

{'eval_loss': 0.41852977871894836, 'eval_accuracy': 0.9246, 'eval_runtime': 0.7941, 'eval_samples_per_second': 125.936, 'eval_steps_per_second': 16.372, 'epoch': 8.0}


                                                  
 30%|███       | 451/1500 [00:41<03:39,  4.77it/s]

{'eval_loss': 0.45478156208992004, 'eval_accuracy': 0.9248, 'eval_runtime': 0.8083, 'eval_samples_per_second': 123.716, 'eval_steps_per_second': 16.083, 'epoch': 9.0}


 33%|███▎      | 500/1500 [00:45<01:14, 13.42it/s]

{'loss': 0.5543, 'learning_rate': 3.571428571428572e-05, 'epoch': 10.0}


                                                  
 33%|███▎      | 501/1500 [00:47<07:31,  2.21it/s]

{'eval_loss': 0.5226693153381348, 'eval_accuracy': 0.9248, 'eval_runtime': 0.7058, 'eval_samples_per_second': 141.688, 'eval_steps_per_second': 18.419, 'epoch': 10.0}


                                                  
 37%|███▋      | 551/1500 [00:52<03:04,  5.14it/s]

{'eval_loss': 0.5360320806503296, 'eval_accuracy': 0.925, 'eval_runtime': 0.7111, 'eval_samples_per_second': 140.636, 'eval_steps_per_second': 18.283, 'epoch': 11.0}


                                                  
 40%|████      | 601/1500 [00:56<02:42,  5.52it/s]

{'eval_loss': 0.5643739700317383, 'eval_accuracy': 0.9246, 'eval_runtime': 0.6722, 'eval_samples_per_second': 148.764, 'eval_steps_per_second': 19.339, 'epoch': 12.0}


                                                  
 43%|████▎     | 651/1500 [01:00<02:35,  5.46it/s]

{'eval_loss': 0.5904717445373535, 'eval_accuracy': 0.9246, 'eval_runtime': 0.6719, 'eval_samples_per_second': 148.841, 'eval_steps_per_second': 19.349, 'epoch': 13.0}


                                                  
 47%|████▋     | 701/1500 [01:05<02:24,  5.54it/s]

{'eval_loss': 0.6216340065002441, 'eval_accuracy': 0.9252, 'eval_runtime': 0.6687, 'eval_samples_per_second': 149.553, 'eval_steps_per_second': 19.442, 'epoch': 14.0}


                                                  
 50%|█████     | 751/1500 [01:09<02:19,  5.38it/s]

{'eval_loss': 0.6309906244277954, 'eval_accuracy': 0.9244, 'eval_runtime': 0.6977, 'eval_samples_per_second': 143.334, 'eval_steps_per_second': 18.633, 'epoch': 15.0}


                                                  
 53%|█████▎    | 801/1500 [01:13<02:07,  5.50it/s]

{'eval_loss': 0.6367547512054443, 'eval_accuracy': 0.9256, 'eval_runtime': 0.6803, 'eval_samples_per_second': 146.987, 'eval_steps_per_second': 19.108, 'epoch': 16.0}


 57%|█████▋    | 849/1500 [01:17<00:45, 14.33it/s]