In [None]:
from transformers import Trainer, TrainingArguments, EarlyStoppingCallback, TrainerCallback, TrainerState, TrainerControl
import torch
import random
import torch.nn as nn
import numpy as np
from torch.utils.data import Dataset
import csv
import gc
from sklearn.metrics import accuracy_score

# UNCOMMENT CODE BELOW TO USE MLP
# class MLP(nn.Module):
#     def __init__(self, vocab_size, seq_len, num_layers, hidden_dim):
#         super().__init__()
#         self.seq_len = seq_len
#         self.vocab_size = vocab_size
#         input_dim = vocab_size * seq_len
#         output_dim = input_dim

#         layers = []

#         prev_dim = input_dim
#         for layer in range(num_layers):
#             layers.append(nn.Linear(prev_dim, hidden_dim))
#             layers.append(nn.ReLU())
#             prev_dim = hidden_dim

#         layers.append(nn.Linear(prev_dim, output_dim))

#         self.model = nn.Sequential(*layers)

#     def forward(self, x):
#         logits_flat = self.model(x)
#         batch_size = x.size(0)
#         logits = logits_flat.view(batch_size, self.seq_len, self.vocab_size)

#         return logits


# # wrapper class to fit Hugging Face trainer
# class HFWrapper(nn.Module):
#     def __init__(self, model, vocab_size, seq_len):
#         super().__init__()
#         self.model = model
#         self.vocab_size = vocab_size
#         self.seq_len = seq_len
#         self.loss_fn = nn.CrossEntropyLoss()

#     def forward(self, inputs, labels=None):
#         logits = self.model(inputs)
#         logits = logits.view(-1, self.seq_len, self.vocab_size)

#         output = {"logits": logits}

#         logits_flat = logits.view(-1, self.vocab_size)
#         labels_flat = labels.view(-1)

#         loss = self.loss_fn(logits_flat, labels_flat)
#         output["loss"] = loss

#         return output

# class SeqDataset(Dataset):
#     def __init__(self, data_x, data_y):
#         self.data_x = torch.tensor(data_x, dtype=torch.long)
#         self.data_y = torch.tensor(data_y, dtype=torch.long)

#     def __len__(self):
#         return len(self.data_x)

#     def __getitem__(self, idx):
#         return {
#             'inputs': self.data_x[idx],
#             'labels': self.data_y[idx]
#         }

# def tokens_to_flat_onehot(token_ids, vocab_size):
#     device = token_ids.device

#     one_hot = torch.nn.functional.one_hot(token_ids, num_classes=vocab_size).to(device)
#     one_hot_flat = one_hot.view(token_ids.size(0), -1).float()

#     return one_hot_flat

# class OneHotCollator:
#     def __init__(self, vocab_size):
#         self.vocab_size = vocab_size

#     def __call__(self, batch):
#         inputs = torch.stack([item["inputs"] for item in batch])
#         labels = torch.stack([item["labels"] for item in batch])

#         one_hot_inputs = F.one_hot(inputs, num_classes=self.vocab_size)
#         one_hot_inputs = one_hot_inputs.view(inputs.size(0), -1).float()

#         return {
#             "inputs": one_hot_inputs,
#             "labels": labels
#         }

class TransformerDecoderModel(nn.Module):
    def __init__(self, vocab_size, hidden_dim, d_ff, num_layers, nhead):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_dim = hidden_dim
        self.seq_len = vocab_size # assume seq_len = vocab_size for our test runs

        # embed tokens in hidden dimension
        self.embedding = nn.Embedding(vocab_size, hidden_dim)
        # include positional encoding
        self.pos_encoder = PositionalEncoding(max_len=vocab_size, d_model=hidden_dim)

        # define decoder portion of transformer
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=nhead,
            dim_feedforward=d_ff,
            batch_first=True
        )
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.output_layer = nn.Linear(hidden_dim, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        x = self.pos_encoder(x)
        seq_len = self.seq_len
        # causal mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(x.device)
        logits = self.output_layer(self.transformer_decoder(x, x, tgt_mask=tgt_mask))
        return logits

# positional encoding for tokens in sequence
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.pos_embedding = nn.Embedding(max_len, d_model)

    def forward(self, x):
        _, seq_len, _ = x.size()
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0)
        pos_emb = self.pos_embedding(positions)
        return x + pos_emb

# prepares model for proper Hugging Face trainer format
class HFWrapper(nn.Module):
    def __init__(self, model, vocab_size):
        super().__init__()
        self.model = model
        self.vocab_size = vocab_size
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, inputs, labels=None):
        logits = self.model(inputs)
        output = {"logits": logits}

        logits_flat = logits.view(-1, self.vocab_size)
        labels_flat = labels.view(-1)
        loss = self.loss_fn(logits_flat, labels_flat)
        output["loss"] = loss

        return output


# prepares data into proper Hugging Face trainer format
class SeqDataset(Dataset):
    def __init__(self, data_x, data_y):
        self.data_x = torch.tensor(data_x, dtype=torch.long)
        self.data_y = torch.tensor(data_y, dtype=torch.long)

    def __len__(self):
        return len(self.data_x)

    def __getitem__(self, idx):
        return {
            'inputs': self.data_x[idx],
            'labels': self.data_y[idx]
        }

# prepares data into proper batches for training
class TokenCollator:
    def __call__(self, batch):
        inputs = torch.stack([item["inputs"] for item in batch])
        labels = torch.stack([item["labels"] for item in batch])
        return {
            "inputs": inputs,
            "labels": labels
        }

# generates training set given a fixed vocab size, sequence length, and subset size (dataset size)
def generate_data(vocab_size, seq_len, subset_size):
    # generate a sequence of random tokens of length seq_len
    def generate_sequence(vocab_size, seq_len):
        return tuple(random.randint(0, vocab_size - 1) for i in range(seq_len))

    input_sequences = set(generate_sequence(vocab_size, seq_len) for i in range(subset_size))
    # check there's no overlapping input sequences
    while len(input_sequences) < subset_size:
        input_sequences = set(generate_sequence(vocab_size, seq_len) for i in range(subset_size))

    label_sequences = set(generate_sequence(vocab_size, seq_len) for i in range(subset_size))
    # check there's no overlapping label sequences
    while not input_sequences.isdisjoint(label_sequences):
        label_sequences = set(generate_sequence(vocab_size, seq_len) for i in range(subset_size))

    input_sequences = list(input_sequences)
    label_sequences = list(label_sequences)
    # randomize the index-wise mapping
    random.shuffle(label_sequences)

    return input_sequences, label_sequences

# compute the per-sequence accuracy
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    accuracy = (preds.flatten() == labels.flatten()).mean().item()
    return {"sequence_accuracy": accuracy}

# stop training when accuracy is 1 (perfect accuracy)
class StopOnPerfectAccuracyCallback(TrainerCallback):
    def on_evaluate(self, args, state: TrainerState, control: TrainerControl, metrics, **kwargs):
        acc = metrics.get("eval_sequence_accuracy")
        if acc == 1.0:
            control.should_training_stop = True
        return control

def get_stats(dataset, model, collator):

    # training arguments - adjust as needed
    training_args = TrainingArguments(
        output_dir="./results_32",
        eval_strategy="epoch",
        logging_strategy="epoch",
        save_strategy="epoch",
        save_total_limit=1,
        per_device_train_batch_size=32,
        num_train_epochs=500,
        metric_for_best_model="eval_loss",
        load_best_model_at_end=True,
        greater_is_better=False,
        report_to="none",
        weight_decay=0.01,
        learning_rate=5e-4, # use this learning rate for transformer
        # learning_rate=1e-2, # use this learning rate for MLP
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    # training stops early if perfect accuracy is reached or validation loss plateaus
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset,
        eval_dataset=dataset,
        data_collator=collator,
        callbacks=[StopOnPerfectAccuracyCallback(), EarlyStoppingCallback(
            early_stopping_patience=3,
            early_stopping_threshold=0.0001
        )],
        compute_metrics=compute_metrics
    )

    trainer.train()
    max_epoch_reached = (trainer.state.epoch == training_args.num_train_epochs)
    test_metrics = trainer.evaluate(dataset)
    print("Test Metrics:", test_metrics)
    test_accuracy = test_metrics.get("eval_sequence_accuracy")
    return test_accuracy, max_epoch_reached

def find_subset_size(hidden_size, d_ff, num_layers, num_head, VOCAB_SIZE, SEQ_LEN):
    def can_memorize(subset_size, debug_flag):
        print()
        print(f"Testing training size: {subset_size}")

        # create training set
        input_sequences, label_sequences = generate_data(VOCAB_SIZE, SEQ_LEN, subset_size)
        dataset = SeqDataset(input_sequences, label_sequences)
        collator = TokenCollator()

        # create model with input params and get its accuracy on given subset size
        base_model = TransformerDecoderModel(VOCAB_SIZE, hidden_size, d_ff, num_layers, num_head)
        model = HFWrapper(base_model, VOCAB_SIZE)
        test_accuracy, max_epoch_reached = get_stats(dataset, model, collator)

        # UNCOMMENT CODE BELOW TO USE MLP
        # input_sequences, label_sequences = generate_data(VOCAB_SIZE, SEQ_LEN, subset_size)
        # dataset = SeqDataset(input_sequences, label_sequences)
        # collator = OneHotCollator(VOCAB_SIZE)
        # base_model = MLP(VOCAB_SIZE, SEQ_LEN, num_layers, hidden_size)
        # model = HFWrapper(base_model, VOCAB_SIZE, SEQ_LEN)

        # delete model for space
        del model
        gc.collect()
        torch.cuda.empty_cache()

        # for debugging purposes to see if the maximum epoch was reached
        if debug_flag:
            return test_accuracy == 1.0, max_epoch_reached
        else:
            return test_accuracy == 1.0

    superset_size = VOCAB_SIZE ** SEQ_LEN
    max_subset_size = 2
    # find the upper bound for which the model fails to memorize
    while max_subset_size <= superset_size and can_memorize(max_subset_size, False):
        max_subset_size *= 2
    low = max_subset_size // 2
    high = min(max_subset_size, superset_size)

    # conduct binary search to find threshold
    # check if max epoch was ever reached during search
    check_flag = False
    while low < high:
        mid = (low + high) // 2
        can_mem, max_epoch_reached = can_memorize(mid, True)
        if can_mem:
            low = mid + 1
        else:
            high = mid
        check_flag = check_flag or max_epoch_reached

    threshold = low
    return threshold, check_flag


# adjust parameters (paired index-wise) as fit
BASE_NUM_LAYERS  = [1, 2, 3]
BASE_NUM_HEADS   = [1, 2, 4, 8]
BASE_DFF = [380, 128, 40]
HIDDEN_SIZE = 32

VOCAB_SIZE = 10
SEQ_LEN = 10


with open(f'transformer_results_{HIDDEN_SIZE}.csv', 'w', newline='') as f:
    writer = csv.writer(f)
    writer.writerow(['hidden_size', 'd_ff', 'num_heads','num_layers', 'max_epoch_reached', 'subset1', 'subset2', 'subset3', 'subset4', 'subset5'])
    for index in range(len(BASE_DFF)):
        for num_head in BASE_NUM_HEADS:
            hidden_size = HIDDEN_SIZE
            num_layers = BASE_NUM_LAYERS[index]
            d_ff = BASE_DFF[index]
            check_flag = False
            subset_sizes = []
            # repeat runs 5 times
            for i in range(5):
                subset_size, check = find_subset_size(hidden_size, d_ff, num_layers, num_head, VOCAB_SIZE, SEQ_LEN)
                subset_sizes.append(subset_size)
                check_flag = check or check_flag
            writer.writerow([hidden_size, d_ff, num_head, num_layers, check_flag, subset_sizes[0], subset_sizes[1], subset_sizes[2], subset_sizes[3], subset_sizes[4]])


Testing training size: 2


Epoch,Training Loss,Validation Loss,Sequence Accuracy
1,2.5273,2.409162,0.15
2,2.4603,2.335792,0.15
3,2.3826,2.265625,0.25
4,2.2848,2.19868,0.3
5,2.2244,2.134938,0.35
6,2.1187,2.073853,0.35
7,2.0897,2.015904,0.35
8,2.0324,1.960835,0.4
9,2.016,1.908605,0.45
10,1.9034,1.858921,0.5


Test Metrics: {'eval_loss': 0.7343631982803345, 'eval_sequence_accuracy': 1.0, 'eval_runtime': 0.0062, 'eval_samples_per_second': 321.378, 'eval_steps_per_second': 160.689, 'epoch': 44.0}

Testing training size: 4


KeyboardInterrupt: 