<a href="https://colab.research.google.com/github/newmantic/LLM_transfer_training/blob/main/LLM_transfer_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Install


In [1]:
!pip3 install transformers datasets numpy jaxtyping torch wandb

Collecting datasets
  Downloading datasets-2.21.0-py3-none-any.whl.metadata (21 kB)
Collecting jaxtyping
  Downloading jaxtyping-0.2.33-py3-none-any.whl.metadata (6.4 kB)
Collecting wandb
  Downloading wandb-0.17.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (10 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)
Collecting typeguard==2.13.3 (from jaxtyping)
  Downloading typeguard-2.13.3-py3-none-any.whl.metadata (3.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manyl

In [2]:
import copy
import random
import re
import time
import numpy as np
from typing import Any, Literal, TypedDict
from datasets import Dataset, DatasetDict
from jaxtyping import Int
import wandb

import torch
from torch import FloatTensor, Tensor, nn
from torch.optim import Adam
from tqdm.auto import tqdm
from torch.utils.data import DataLoader

import transformers
from transformers import (
    GenerationConfig,
    GPTNeoConfig,
    GPTNeoForCausalLM,
    PreTrainedModel,
    PreTrainedTokenizerBase,
    AdamW,
    AutoTokenizer,
)
from transformers.modeling_outputs import (
    CausalLMOutputWithCrossAttentions,
)
from pathlib import Path

In [3]:
Path("./data/").mkdir(parents=True, exist_ok=True)

In [4]:
# Load the TinyStories model
tinystories_model = transformers.AutoModelForCausalLM.from_pretrained(
    "roneneldan/TinyStories-1M", revision="8cd14d5", cache_dir="./data/"
)

# Create a random version of this model (by re-calling the i)
random_init_model = transformers.AutoModelForCausalLM.from_pretrained(
    "roneneldan/TinyStories-1M", revision="8cd14d5", cache_dir="./data/"
)
random_init_model.apply(random_init_model._init_weights)  # noqa: SLF001

tokenizer = transformers.AutoTokenizer.from_pretrained(
    "roneneldan/TinyStories-1M",
    revision="8cd14d5",
    cache_dir="./data/",
    padding_side="left",  # Left padding so generate works
)
tokenizer.pad_token = tokenizer.eos_token

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/48.6M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/722 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/798k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

In [15]:
def generate_unary_counting_sequence(number: int) -> str:
    """
    Generates a unary counting sequence up to the given number.
    Ex: if the number is 4, it returns ' 1 11 111 1111'.
    """
    return ' ' + ' '.join(['1' * i for i in range(1, number + 1)])


In [7]:
generate_unary_counting_sequence(4)

'1 11 111 1111'

In [8]:
def convert_to_unary(number: int) -> str:
    """
    Converts an integer to its unary representation.

    For example, convert_to_unary(3) returns '111'.
    """
    return '1' * number

def create_dataset(n_data = 100) -> DatasetDict:
    """Create the training dataset.

    Each item should have keys "prompt" and "completion".
    """
    # Initialize lists to hold prompts and completions for both train and test sets
    train_prompts = []
    train_completions = []
    test_prompts = []
    test_completions = []

    # Loop over the numbers 1 through 100
    for i in range(1, n_data + 1):
        # Generate the unary counting sequence for the current number
        completion = generate_unary_counting_sequence(i)

        # Create the input prompt based on the unary sequence
        unary_stop = convert_to_unary(i)
        prompt = f"Please count up in unary, starting at 1 and stopping at {unary_stop}:"

        # Depending on whether the number ends in 1, 3, or 7 modulo 10, assign it to the test set
        if i % 10 in {1, 3, 7}:
            test_prompts.append(prompt)
            test_completions.append(completion)
        else:
            # Otherwise, assign it to the training set
            train_prompts.append(prompt)
            train_completions.append(completion)

    # Combine the prompts and completions into pairs
    train_data = list(zip(train_prompts, train_completions))
    test_data = list(zip(test_prompts, test_completions))

    # Shuffle the data to ensure randomness
    random.shuffle(train_data)
    random.shuffle(test_data)

    # Select the first 70 examples for the training set
    train_data = train_data[:70]
    # Select the first 30 examples for the testing set
    test_data = test_data[:30]

    # Unzip the pairs back into separate lists for prompts and completions
    train_prompts, train_completions = zip(*train_data)
    test_prompts, test_completions = zip(*test_data)

    # Create the train and test datasets from the prompts and completions
    train_dataset = Dataset.from_dict({"prompt": list(train_prompts), "completion": list(train_completions)})
    test_dataset = Dataset.from_dict({"prompt": list(test_prompts), "completion": list(test_completions)})

    # Combine the train and test datasets into a DatasetDict
    dataset = DatasetDict({"train": train_dataset, "test": test_dataset})

    # Return the final dataset
    return dataset



In [None]:
# Generate the dataset using the `create_dataset` function
n_data = 100
dataset = create_dataset(n_data)


In [18]:
# Test your implementation
example_train_prompt = "Please count up in unary, starting at 1 and stopping at 11:"
example_train_completion = " 1 11"
train_match = [i for i in dataset["train"] if i["prompt"] == example_train_prompt]  # type: ignore
assert len(train_match) == 1
print(train_match[0]["completion"], example_train_completion)
assert train_match[0]["completion"] == example_train_completion  # type: ignore

example_test_prompt = "Please count up in unary, starting at 1 and stopping at 111:"
unary_test_completion = " 1 11 111"
test_match = [i for i in dataset["test"] if i["prompt"] == example_test_prompt]  # type: ignore
assert len(test_match) == 1
assert test_match[0]["completion"] == unary_test_completion  # type: ignore

 1 11  1 11


In [19]:
def evaluate_model(
    model: PreTrainedModel,
    dataset: Dataset,
    pre_trained_tokenizer: PreTrainedTokenizerBase,
    batch_size: int = 8,
    context_window_size: int = 1536,
) -> float:
    """Evaluate the model.

    This function should return the accuracy of the model at temperature = 0, where accuracy is
    calculated using exact match on the whole of the target completion. i.e. for the prompt "Please
    count up in unary, starting at 1 and stopping at 11:", the model is correct if it replies with "
    1 11 111", and false otherwise.
    """
    # ensure the model is in evaluation mode
    # This disables certain layers like dropout, making the model deterministic and suitable for evaluation.
    model.eval()

    # initialize variables to track the number of correct predictions and the total number of examples
    correct_predictions = 0
    total_examples = len(dataset)

    # disable gradient computation for evaluation, with the following consequences:
    # a) it stays deterministic; b) it reduces memory usage and speeds up computation.
    with torch.no_grad():
        # iterate over the dataset in batches
        for i in range(0, total_examples, batch_size):
            # get a batch of examples from the dataset using slicing
            batch = dataset.select(range(i, min(i + batch_size, total_examples)))

            # extract prompts and the corresponding correct completions from the batch
            prompts = [example['prompt'] for example in batch]
            correct_completions = [example['completion'] for example in batch]
            # print(i, correct_completions) # double check correct_completions

            # tokenize the input prompts into input_ids and attention masks
            # this converts the prompts from strings to tensors that the model can process.
            inputs = pre_trained_tokenizer(prompts, return_tensors='pt', padding=True, truncation=True)
            input_ids = inputs['input_ids']
            attention_mask = inputs['attention_mask']

            # get the model's output logits for the input tokens, which are the raw & unnormalized predictions
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            # get the predicted token ids by selecting the token with the highest probability (argmax)
            # which is equal to generating the sequence with a temperature of 0.
            predicted_token_ids = logits.argmax(dim=-1)

            # decode the predicted token ids back into strings
            predicted_completions = [
                pre_trained_tokenizer.decode(predicted_token_ids[j], skip_special_tokens=False)
                for j in range(len(prompts))
            ]

            # compare the predicted completions with the correct completions and increment the correct counts
            # print out erroneous predictions of the first 1% of sample
            for predicted, correct in zip(predicted_completions, correct_completions):
                if predicted == correct:
                    correct_predictions += 1
                elif i <= np.ceil(total_examples*0.01): # show error prediction for 1% of samples
                    print('sampled generation: ', i, predicted[:-100], '\n')

    # Calculate accuracy as the ratio of correct predictions to total examples
    accuracy = correct_predictions / total_examples
    return accuracy

In [22]:
# Create an appropriate dataset
n_full_data = 100
full_text_dataset = create_dataset(n_full_data)

In [23]:
def train_model(
    model: PreTrainedModel,
    text_dataset: DatasetDict,
    learning_rate: float = 0.001,
    num_epochs: int = 3,
    batch_size: int = 1,
    pre_trained_tokenizer: PreTrainedTokenizerBase = None,
) -> None:
    """Train the model with the given dataset and learning rate.

    Args:
        model: The pre-trained language model to fine-tune.
        text_dataset: A DatasetDict containing 'train' and 'test' datasets.
        learning_rate: The learning rate for the optimizer.
        num_epochs: The number of epochs to train the model.
        batch_size: The number of samples per batch.
        pre_trained_tokenizer: The tokenizer associated with the pre-trained model.

    Returns:
        None: The function trains the model and logs metrics, but does not return any values.
    """

    # Set the model to training mode, enabling gradient calculations
    model.train()

    # Prepare the optimizer with the specified learning rate
    optimizer = AdamW(model.parameters(), lr=learning_rate)

    # Create data loaders for the training and validation datasets
    train_loader = DataLoader(text_dataset['train'], batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(text_dataset['test'], batch_size=batch_size)

    # Start the training loop, iterating over the number of epochs
    for epoch in range(num_epochs):
        total_loss = 0  # Initialize the total loss for the epoch
        model.train()  # Ensure the model is in training mode

        # Iterate over each batch in the training data
        for batch in train_loader:
            optimizer.zero_grad()  # Zero the gradients from the previous step

            # Tokenize the input prompts and target completions
            inputs = pre_trained_tokenizer(batch['prompt'], return_tensors='pt', padding=True, truncation=True, max_length=1024)
            labels = pre_trained_tokenizer(batch['completion'], return_tensors='pt', padding=True, truncation=True, max_length=1024).input_ids

            # Ensure that both input_ids and labels are truncated to the same length
            min_len = min(inputs['input_ids'].shape[1], labels.shape[1])
            input_ids = inputs['input_ids'][:, :min_len]
            labels = labels[:, :min_len]

            # Set padding tokens in labels to -100 so they are ignored in the loss calculation
            labels[labels == pre_trained_tokenizer.pad_token_id] = -100

            # Perform a forward pass to compute the model's output and loss
            outputs = model(input_ids=input_ids, attention_mask=inputs['attention_mask'][:, :min_len], labels=labels)
            loss = outputs.loss

            # Perform a backward pass to compute gradients
            loss.backward()

            # Update model parameters based on the gradients
            optimizer.step()

            # Accumulate the loss for this batch
            total_loss += loss.item()

        # Compute the average loss over all batches in the epoch
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")

        # Validation step: Evaluate the model on the validation dataset
        model.eval()  # Set the model to evaluation mode (disables dropout layers, etc.)
        correct_predictions = 0  # Initialize a counter for correct predictions
        total_predictions = 0  # Initialize a counter for total predictions

        # Disable gradient calculations for validation (improves efficiency)
        with torch.no_grad():
            for batch in val_loader:
                # Tokenize the validation input prompts and completions
                inputs = pre_trained_tokenizer(batch['prompt'], return_tensors='pt', padding=True, truncation=True, max_length=128)
                labels = pre_trained_tokenizer(batch['completion'], return_tensors='pt', padding=True, truncation=True, max_length=128).input_ids

                # Ensure that both input_ids and labels are truncated to the same length
                min_len = min(inputs['input_ids'].shape[1], labels.shape[1])
                input_ids = inputs['input_ids'][:, :min_len]
                labels = labels[:, :min_len]

                # Set padding tokens in labels to -100 so they are ignored in the loss calculation
                labels[labels == pre_trained_tokenizer.pad_token_id] = -100

                # Perform a forward pass to get model predictions
                outputs = model(input_ids=input_ids, attention_mask=inputs['attention_mask'][:, :min_len])
                logits = outputs.logits

                # Truncate logits to match the labels' sequence length
                logits = logits[:, :labels.size(1), :]

                # Calculate predictions and accuracy
                predictions = logits.argmax(dim=-1)  # Get the token with the highest probability for each position

                for pred, label in zip(predictions, labels):
                    # Filter out padding tokens (-100) before decoding
                    pred = pred[pred != -100]
                    label = label[label != -100]

                    if len(pred) > 0 and len(label) > 0:  # Ensure there are tokens left to decode
                        pred_text = pre_trained_tokenizer.decode(pred, skip_special_tokens=False)
                        label_text = pre_trained_tokenizer.decode(label, skip_special_tokens=False)
                        if pred_text == label_text:
                            correct_predictions += 1
                        total_predictions += 1

        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
        print(f"Validation Accuracy after Epoch {epoch + 1}: {accuracy:.4f}")

    print("Training completed.")


In [24]:
t1 = time.time()

# Train the randomly initialized model
train_model(
    model=random_init_model,
    text_dataset=full_text_dataset,
    learning_rate=0.001,  # A reasonable learning rate for quick testing
    num_epochs=3,  # Limited epochs to ensure the training completes quickly
    batch_size=1,  # Small batch size to keep training time under 5 minutes
    pre_trained_tokenizer=tokenizer
)

# Train the TinyStories model (which is already pre-trained)
train_model(
    model=tinystories_model,
    text_dataset=full_text_dataset,
    learning_rate=0.0001,  # Lower learning rate to fine-tune the pre-trained model gently
    num_epochs=3,  # Same number of epochs for consistency
    batch_size=1,  # Same batch size to keep the comparison fair
    pre_trained_tokenizer=tokenizer
)

t2 = time.time()

print(f"Model training time: {int((t2-t1)/60)} minutes")




Epoch 1/3, Loss: 6.8525
Validation Accuracy after Epoch 1: 0.0000
Epoch 2/3, Loss: 0.8240
Validation Accuracy after Epoch 2: 0.0000
Epoch 3/3, Loss: 0.2115
Validation Accuracy after Epoch 3: 0.0000
Training completed.
Epoch 1/3, Loss: 8.0821
Validation Accuracy after Epoch 1: 0.0000
Epoch 2/3, Loss: 5.0144
Validation Accuracy after Epoch 2: 0.0000
Epoch 3/3, Loss: 2.8975
Validation Accuracy after Epoch 3: 0.0000
Training completed.
Model training time: 0 minutes
