# Google drive & Wandb

In [1]:
#from google.colab import drive
#drive.mount('/content/drive')
#!pip install wandb
#import wandb

# Helper functions


In [2]:
import math
import random
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset
import torch
import matplotlib.pyplot as plt

# Encoding the GCD paper used. For example, 10 and 150 as [+,1,0,+,1,5,0]
def encode_feature(a, b):
    encoded = [plus_sign_label] + [int(d) for d in str(a)] + [plus_sign_label] + [int(d) for d in str(b)]
    return encoded

# Data generation [[input],[target]]
def generate_gcd_encoded_data(num_samples, max_value,log_uniform=True):
    data = []
    for _ in range(num_samples):
        # Sampling from log-uniform distribution suggested by the GCD paper
        if log_uniform :
            a = int(math.exp(random.uniform(0, math.log(max_value))))
            b = int(math.exp(random.uniform(0, math.log(max_value))))
        else :
            a = random.randint(0, max_value)
            b = random.randint(0, max_value)

        gcd_value = math.gcd(a, b)
        encoded_gcd = [plus_sign_label] + [int(d) for d in str(gcd_value)]
        data.append((encode_feature(a, b), encoded_gcd))
    return data

# Naive padding to fixed length
def pad_sequence(seq, max_len):
    return seq + ([padding_label] * (max_len - len(seq)))

# Custom Dataset
class GCDDataset(Dataset):
    def __init__(self, data, max_len):
        self.data = [(pad_sequence(seq, max_len), pad_sequence(target, max_len)) for seq, target in data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq, target = self.data[idx]
        return torch.tensor(seq, dtype=torch.long), torch.tensor(target, dtype=torch.long)

"""
# Resume or create a new graph
def init_wandb(resume, name, resume_run_id=None):
  if not resume:
    new_run = wandb.init(project="gcd", name=name, resume="allow")
    run_id = new_run.id
    print(f'\n##### RUN ID: {run_id} #####')
    return

  wandb.init(project="gcd", name=name, resume="allow", id=resume_run_id)
  return

# Others

def wandb_log(train_loss,stratified_test_loss):
   wandb.log({
            "train_loss": train_loss,
            "stratified_test_loss": stratified_test_loss,
            })

def save_checkpoint(epoch_num,model,optimizer):
  torch.save({
            'epoch': epoch_num,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, checkpoint_path)
"""

def print_first_element(loader):
    for sequences, targets in loader:
        print("First Sequence in Batch:", sequences[0].numpy())
        print("First Target in Batch:", targets[0].numpy())
        break

In [3]:
""" Evaluates the model's predictions by calculating the accuracy of predicted sequences
    against true sequences, excluding padding and start tokens. The function aggregates
    sequences of tokens into numerical values for direct comparison """

def report(model, loader, acc=True, freq=False):
    model.eval()
    model.to(device)

    total_correct_predictions_count = 0
    total_predictions_count = 0
    correct_predictions_count = {}
    all_aggregated_predictions = []

    with torch.no_grad():
        for data in loader:
            src, tgt = data[0].to(device), data[1].to(device)
            tgt_input = tgt[:, :-1]

            src_key_padding_mask = (src == padding_label).to(device)
            tgt_key_padding_mask = (tgt_input == padding_label).to(device)

            outputs = model(
                src,
                tgt_input,
                src_key_padding_mask=src_key_padding_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                tgt_mask=tgt_mask
            )

            _, predicted = torch.max(outputs, dim=2)
            tgt_expected = tgt[:, 1:].reshape(-1)
            predicted = predicted.reshape(-1)

            # Filter out padding tokens from the calculation
            non_padding_mask = tgt_expected != padding_label
            valid_predictions = predicted[non_padding_mask]
            valid_labels = tgt_expected[non_padding_mask]

            # Lists to store the aggregated numbers from predictions and labels
            pred_numbers = []
            true_numbers = []
            start_idx = 0

            # Calculate the lengths of the sequences without padding
            seq_lens = (tgt.size(1) - 1) - tgt[:, 1:].eq(padding_label).sum(dim=1)

            # Loop over the lengths of the sequences to extract and compare numbers
            for seq_len in seq_lens:
                # Calculate the end index for the current sequence based on its length
                end_idx = start_idx + seq_len.item()

                # Extract the sequence of predictions and labels as lists
                seq_pred = valid_predictions[start_idx:end_idx].tolist()
                seq_true = valid_labels[start_idx:end_idx].tolist()

                # Convert the lists of digits into integers representing the aggregated number
                pred_number = int(''.join(map(str, seq_pred)))
                true_number = int(''.join(map(str, seq_true)))

                # Append the aggregated numbers to the respective lists
                pred_numbers.append(pred_number)
                true_numbers.append(true_number)

                # Keep track of all predicted numbers
                all_aggregated_predictions.append(pred_number)

                if pred_number == true_number:
                    total_correct_predictions_count += 1
                    correct_predictions_count[pred_number] = correct_predictions_count.get(pred_number, 0) + 1

                # Update the start index for the next sequence
                start_idx = end_idx

            total_predictions_count += len(pred_numbers)

    if acc:
        overall_accuracy = 100 * total_correct_predictions_count / total_predictions_count if total_predictions_count > 0 else 0
        print(f'Overall Accuracy: {overall_accuracy}%\n')

    if freq:
        print("Accuracy for each unique predicted number:")
        for num in set(all_aggregated_predictions):
            correct_count = correct_predictions_count.get(num, 0)
            total_count = all_aggregated_predictions.count(num)
            accuracy = 100 * correct_count / total_count if total_count > 0 else 0
            print(f'Accuracy for number {num}: {accuracy:.2f}% (Correct: {correct_count}, Total Prediction: {total_count})')

# Config

In [4]:
plus_sign_label = 10
padding_label = 11

sample_size = int(4e5)
max_value = int(1e7) # upperbound of data generation

# limited training set around 40% suggest by Neel for grokking
test_size = 0.5
batch_size = 256

# padding to max_len
max_num_length = len(str(max_value))
max_len = (max_num_length + 1) * 2

#checkpoint_path = '/content/drive/MyDrive/CSCI567/2layers_256d_8heads_lr1e-4.pth'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Data preprocessing

In [None]:
random.seed(52)
random_state = 52

raw_data = generate_gcd_encoded_data(sample_size, max_value)
train_data, test_data = train_test_split(raw_data, test_size=test_size, random_state=random_state)
train_dataset = GCDDataset(train_data, max_len)
test_dataset = GCDDataset(test_data, max_len)
train_loader = DataLoader(train_dataset, batch_size=batch_size , shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size , shuffle=False)

print("Train Loader:")
print_first_element(train_loader)
print("\n")
print("Test Loader:")
print_first_element(test_loader)

In [None]:
def generate_coprime_pair(max_value):
    while True:
        a = random.randint(1, max_value)
        b = random.randint(1, max_value)
        if math.gcd(a, b) == 1:
            return a, b

# Stratified test data suggested by GCD paper, accuraccy on normal test set for GCD
# is misleading
def generate_stratified_gcd_test_set(num_samples_per_k, max_value, k_max=100):
    stratified_test_data = []
    for k in range(1, k_max + 1):
        for _ in range(num_samples_per_k):
            a, b = generate_coprime_pair(max_value // k)
            a *= k
            b *= k
            gcd_value = k
            encoded_pair = encode_feature(a, b)
            encoded_gcd = [plus_sign_label] + [int(d) for d in str(gcd_value)]
            stratified_test_data.append((encoded_pair, encoded_gcd))
    return stratified_test_data

stratified_test_set = generate_stratified_gcd_test_set(num_samples_per_k=1000, max_value=max_value, k_max=100)
stratified_dataset = GCDDataset(stratified_test_set , max_len)
stratified_test_loader = DataLoader(stratified_dataset, batch_size=batch_size , shuffle=False)
print("Stratified Test Loader:")
print_first_element(stratified_test_loader)

# Model constructor

In [7]:
import torch
import torch.nn as nn
import math

# conventional seq2seq auto-regressive Transformer with positional encoding
class Seq2SeqTransformer(nn.Module):
    def __init__(self, input_dim, output_dim, d_model=512, nhead=8, num_encoder_layers=4, num_decoder_layers=4, dim_feedforward=2048,dropout=0.0):
        super(Seq2SeqTransformer, self).__init__()
        self.d_model = d_model

        self.embedding = nn.Embedding(input_dim, d_model, padding_idx=padding_label)
        self.positional_encoding = self.create_positional_encoding(max_len, d_model)

        # Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers,enable_nested_tensor=False)

        # Decoder
        decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

        self.output_layer = nn.Linear(d_model, output_dim)

    # From Pytorch doc
    def create_positional_encoding(self, max_len, d_model):
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        return nn.Parameter(pe, requires_grad=False)

    def forward(self, src, tgt, tgt_mask=None, src_key_padding_mask=None, tgt_key_padding_mask=None,memory_key_padding_mask=None):
      # Attn all you need paper used this
      src = self.embedding(src) * math.sqrt(self.d_model)
      tgt = self.embedding(tgt) * math.sqrt(self.d_model)

      # Add positional encoding to src and tgt embeddings
      src = src + self.positional_encoding[:, :src.size(1), :]
      tgt = tgt + self.positional_encoding[:, :tgt.size(1), :]

      memory = self.transformer_encoder(src.transpose(0, 1), src_key_padding_mask=src_key_padding_mask)
      output = self.transformer_decoder(tgt.transpose(0, 1), memory, tgt_mask=tgt_mask, memory_key_padding_mask=src_key_padding_mask,tgt_is_causal=True)
      output = self.output_layer(output.transpose(0, 1))
      return output

# Training loop

In [8]:
import os

# resume or start a new graph
#init_wandb(False, "2layers_256d_8heads")
#init_wandb(True, "2layers_256d_8heads", resume_run_id="pvle5tmk")

# causal mask for auto-regressive
tgt_mask = torch.nn.Transformer.generate_square_subsequent_mask(max_len)

def train_nn_model(model, epochs, optimizer, loss_func):
    start_epoch = 0
    """
    # Load checkpoint if available
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch'] + 1
        print(f"Resuming from epoch {start_epoch}")
    """

    for epoch_num in range(start_epoch, epochs):
        print(f"Epoch: {epoch_num}")
        # train set
        model.train()
        total_loss = 0
        for i, data in enumerate(train_loader):
            src, tgt = data[0].to(device), data[1].to(device)
            # Exclude last token to ensure prediction of next token (teacher forcing)
            tgt_input = tgt[:, :-1]
            # Create expected output by shifting target sequences to match predicted output
            tgt_expected = tgt[:, 1:].reshape(-1)
            optimizer.zero_grad()

            src_key_padding_mask = (src == padding_label).to(device)
            tgt_key_padding_mask = (tgt_input == padding_label).to(device)

            outputs = model(
                src,
                tgt_input,
                src_key_padding_mask=src_key_padding_mask,
                tgt_key_padding_mask=tgt_key_padding_mask,
                tgt_mask=tgt_mask
            )

            outputs = outputs.reshape(-1, outputs.shape[-1])
            loss = loss_func(outputs, tgt_expected)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        train_loss = total_loss / len(train_loader)
        print(f"Train loss: {train_loss}")

        # Evaluate on stratified test set
        model.eval()
        total_stratified_test_loss = 0
        with torch.no_grad():
            for i, data in enumerate(stratified_test_loader):
                src, tgt = data[0].to(device), data[1].to(device)
                tgt_input = tgt[:, :-1]
                tgt_expected = tgt[:, 1:].reshape(-1)

                src_key_padding_mask = (src == padding_label).to(device)
                tgt_key_padding_mask = (tgt_input == padding_label).to(device)

                outputs = model(
                    src,
                    tgt_input,
                    src_key_padding_mask=src_key_padding_mask,
                    tgt_key_padding_mask=tgt_key_padding_mask,
                    tgt_mask=tgt_mask
                )

                outputs = outputs.reshape(-1, outputs.shape[-1])
                loss = loss_func(outputs, tgt_expected)
                total_stratified_test_loss += loss.item()

        stratified_test_loss = total_stratified_test_loss / len(stratified_test_loader)
        print(f"Stratified Test loss: {stratified_test_loss}\n")

        #wandb_log(train_loss,stratified_test_loss)
        #save_checkpoint(epoch_num,model,optimizer)

    #wandb.finish()

# Training config

In [None]:
# 0~9 and two special tags
toy = Seq2SeqTransformer(12, 12).to(device)
loss_func = nn.CrossEntropyLoss(ignore_index=padding_label).to(device)
# Neel claims L2 is necessary for grokking to occur
optimizer = torch.optim.AdamW(toy.parameters(),lr=1e-4)
# Training loop
train_nn_model(toy, 3000, optimizer, loss_func)

# A report for GCD

In [None]:
#report(toy, train_loader,freq=True)
report(toy, stratified_test_loader,freq=True)