<a href="https://colab.research.google.com/github/chrisjmccormick/shared-subspaces/blob/main/subspace_decoder/scripts/run_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ▂▂▂▂▂▂▂▂▂▂▂▂

# Overview

This notebook demonstrates how to run the pre-training and fine-tuning scripts from a command line, and is also setup to allow you to run them from within the notebook.

**Training Arguments**

The scripts are designed such that everything is specified through `.json` config files rather than on the command line. However, there is a command line utility to define new configurations--see the "Defining a New Run" section at the end of this notebook.

The examples below run one of the existing configurations.

**Weights and Biases**

The scripts are set up to log to wandb by default. You can change the `wandb_mode` variable below to 'offline' if you don't have an account / don't want to log online.


# ▂▂▂▂▂▂▂▂▂▂▂▂

# S1. Setup

In [1]:
# Set a flag we can use to determine if we are running within a Colab instance.
is_colab = "google.colab" in str(get_ipython())

## 1.1. Clone Repository

In [2]:
!git clone https://github.com/chrisjmccormick/shared-subspaces.git

Cloning into 'shared-subspaces'...
remote: Enumerating objects: 647, done.[K
remote: Counting objects: 100% (384/384), done.[K
remote: Compressing objects: 100% (223/223), done.[K
remote: Total 647 (delta 244), reused 270 (delta 159), pack-reused 263 (from 1)[K
Receiving objects: 100% (647/647), 11.06 MiB | 9.97 MiB/s, done.
Resolving deltas: 100% (371/371), done.


Provide the full path to the subspace_decoder folder.

This will be added to the PYTHONPATH when executing the scripts so that they can import the classes from the local files.

This variable is also used to construct paths to config files and scripts.

In [3]:
# For Google Colab,
base_path = "/content/shared-subspaces/subspace_decoder"

# For Lambda / Ubuntu,
#base_path = "/home/ubuntu/shared-subspaces/subspace_decoder"

To switch to a development branch:

In [4]:
%cd shared-subspaces

# Change to branch 'gpt-2-scale'
!git checkout gpt-2-scale

/content/shared-subspaces
Branch 'gpt-2-scale' set up to track remote branch 'gpt-2-scale' from 'origin'.
Switched to a new branch 'gpt-2-scale'


## 1.2. FlashAttention on Colab

FlashAttention does not come pre-installed on Colab instances, and is very time consuming to install manually because it has to be built from source.

The below GitHub repo, however, provides pre-built wheels which make setup easy.

https://github.com/mjun0812/flash-attention-prebuild-wheels/releases

The key is just to identify the correct wheel to use from the giant list.

We need the wheel specific to our version of python, pytorch, and CUDA. So first we'll print those out:

In [5]:
import torch
import sys

print("Python:", sys.version)
print("PyTorch:", torch.__version__)
print("CUDA:", torch.version.cuda)
print("")
print("GPU:", torch.cuda.get_device_name(0))

Python: 3.12.11 (main, Jun  4 2025, 08:56:18) [GCC 11.4.0]
PyTorch: 2.8.0+cu126
CUDA: 12.6

GPU: NVIDIA A100-SXM4-80GB


It's difficult to find the correct wheel because they are all hidden underneath different releases, and searching the page doesn't work unless the releases are expanded.

With some hunting, I was able to find the correct version for Colab's current configuration:

In [6]:
# This wheel is specific to Colab.
if is_colab:
    # Define the wheel details
    WHEEL_URL = "https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu126torch2.8-cp312-cp312-linux_x86_64.whl"
    WHEEL_NAME = "flash_attn-2.8.3+cu126torch2.8-cp312-cp312-linux_x86_64.whl"

    # Download and install the wheel
    !wget {WHEEL_URL}
    !pip install {WHEEL_NAME}

    # Clean up the downloaded file
    import os
    os.remove(WHEEL_NAME)

    print("\n✅ FlashAttention 2 installed successfully!")

--2025-10-03 16:43:22--  https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.4.11/flash_attn-2.8.3+cu126torch2.8-cp312-cp312-linux_x86_64.whl
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://release-assets.githubusercontent.com/github-production-release-asset/878958395/9d318eca-33f7-4c0d-9b3f-6e90e52421f8?sp=r&sv=2018-11-09&sr=b&spr=https&se=2025-10-03T17%3A43%3A34Z&rscd=attachment%3B+filename%3Dflash_attn-2.8.3%2Bcu126torch2.8-cp312-cp312-linux_x86_64.whl&rsct=application%2Foctet-stream&skoid=96c2d410-5711-43a1-aedd-ab1947aa7ab0&sktid=398a6654-997b-47e9-b12b-9515b896b4de&skt=2025-10-03T16%3A42%3A47Z&ske=2025-10-03T17%3A43%3A34Z&sks=b&skv=2018-11-09&sig=Cw8AAOdQqzVHHx35ixkgLS9q60UQfUFb0Le6AM17ytE%3D&jwt=eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmVsZWFzZS1hc3NldHMuZ2l0aHVidXNlcmNvbnRlbn

## 1.3. Weights & Biases

To provide your wandb API key for the script:
1. You could paste it in manually on the training command lines further down.
2. Or, use the secrets panel (the key symbol on the left edge of the notebook) and:
    * Define your wandb api key as `wandb_api_key`.
    * Grant access to this notebook.
    * Run the below cell to retrieve it.

In [7]:
# Set to false if you don't want to use wandb.
# The scripts will still log using the wandb library, but to a local directory.
use_wandb = True

if use_wandb:

    # Enable Weights & Biases logging (online mode)
    wandb_mode = "online"

    if is_colab:
        # Get key from colab secrets
        from google.colab import userdata

        # Get wandb API key from Colab secrets
        wandb_key = userdata.get("wandb_api_key")
    else:
        # Get the key from the environment
        # Import python dot env
        from dotenv import load_dotenv

        # Load the environment variables from the .env file
        load_dotenv()

        wandb_key = os.getenv("wandb_api_key")

# Set to offline if you don't want to log in.
else:
    wandb_mode = "offline"

    wandb_key = ""

# ▂▂▂▂▂▂▂▂▂▂▂▂

# S2. Run a Single Config

## 2.1. Choose Config

In [12]:
import os
import json

#  Choose which config file to run
pretrain_config_path = f"{base_path}/configs/tiny_mla_wiki103.json"

# Make sure it's a valid path
if not os.path.exists(pretrain_config_path):
    raise ValueError(f"Config file {pretrain_config_path} does not exist.")

# Print it out.
with open(pretrain_config_path, "r") as f:
    pretrain_config = json.load(f)

print(f"\n======== {pretrain_config_path} ========\n")

# Print out the configuration with spacing.
json_str = json.dumps(pretrain_config, indent=4)
print(json_str)



{
    "shorthand": "model.256.lyr.6 - seqlen.128 - mla.96.64.0 - ah8.32 - rd16.16",
    "notes": "Tiny MLA baseline configuration trained on wikitext103",
    "model": {
        "hidden_size": 256,
        "num_hidden_layers": 6,
        "intermediate_size": 672,
        "vocab_size": 50257,
        "tie_word_embeddings": true,
        "max_position_embeddings": 128,
        "norm_type": "rmsnorm",
        "layer_norm_eps": 1e-12,
        "rms_norm_eps": 1e-06,
        "num_dense_layers": 0,
        "num_attention_heads": 8,
        "q_shared_dim": 96,
        "kv_shared_dim": 64,
        "o_shared_dim": null,
        "qk_private_dim": 32,
        "vo_private_dim": 32,
        "rope_dims": 16,
        "nope_dims": 16,
        "rope_theta": 10000.0,
        "rope_scaling": null,
        "attention_bias": false,
        "attention_backend": "flash_attention_2",
        "ffn_decompose": false,
        "ffn_rank": null,
        "vocab_subspace": false,
        "vocab_rank": null,
       

## 2.2. Run Pre-Training

In [None]:
print("\n======= Pre-Train ========\n")

# Construct the command line
train_command = (
    f"TRANSFORMERS_NO_TF=1 "
    f"PYTHONPATH={base_path} "
    f"WANDB_MODE={wandb_mode} "
    f'WANDB_API_KEY="{wandb_key}" '
    f"python {base_path}/scripts/train.py --config {pretrain_config_path}"
)

# Run pre-training
!{train_command}



TF available (Transformers thinks): False
Importing Packages...

PROJECT_ROOT /content/shared-subspaces/subspace_decoder
model.256.lyr.6 - seqlen.128 - mla.96.64.0 - ah8.32 - rd16.16
✓ BFloat16 is supported on this hardware
✓ torch.compile enabled:
  Backend: inductor
  Mode: default
  Note: First training step will be slower due to compilation.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mchrismccormick[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 1801350
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})
Initializing model...
config.q_shared_dim 96
config.q_shared_dim 96
config.q_shared_dim 96
config.q_shared_dim 96
config.q_shared_dim 96
config.q_shar

**Optional - Save the checkpoint to Google Drive**

In [None]:
import shutil

if False:
    # Copy whatever's in the checkpoints folder over to Google Drive.
    shutil.copytree(
        "/content/checkpoints",
        "/content/drive/MyDrive/decoder-pretrain-wiki103/checkpoints",
        dirs_exist_ok = True
    )

'/content/drive/MyDrive/encoder-pretrain-wiki103/checkpoints'

##  2.3. Fine-Tune SST-2

In [None]:
import json
import os

# =================
#    SFT Config
# =================

# Specify the path to the SFT config and the pre-trained model config
sft_config_path = f"{base_path}/configs/fine-tune-sst2.json"

# Make sure paths are valid
if not os.path.exists(sft_config_path):
    raise ValueError(f"SFT config file {sft_config_path} does not exist.")

# Print out the SFT configuration
with open(sft_config_path, "r") as f:
    sft_config = json.load(f)

print(f"\n======== SFT Config: {sft_config_path} ========\n")
json_str = json.dumps(sft_config, indent=4)
print(json_str)

# ================================
#     Pretrained Model Config
# ================================

# Retrieve the config from the pre-training result

# Verify the path still exists.
if not os.path.exists(pretrain_config_path):
    raise ValueError(f"Pre-training config file {pretrain_config_path} does not exist.")

# Load the original pre-training configuration,
with open(pretrain_config_path, "r") as f:
    pretrain_config = json.load(f)

# Then retrieve the path to the resulting `full_config.json` based on
# the training output directory.
saved_model_cfg_path = pretrain_config["pre_train"]["output_dir"] + '/full_config.json'

# Review the updated configuration.
print(f"\n======== Config From Saved Model: {saved_model_cfg_path} ========\n")
with open(saved_model_cfg_path, "r") as f:
    print(json.dumps(json.load(f), indent=4))

# ===================================
#         Run fine-tuning
# ===================================
print("\n======= Fine-Tune on SST-2 ========\n")

# Construct the command line with separate config files
train_command = (
    f"TRANSFORMERS_NO_TF=1 "
    f"PYTHONPATH={base_path} "
    f"WANDB_MODE={wandb_mode} "
    f'WANDB_API_KEY="{wandb_key}" '
    f"python {base_path}/scripts/finetune_sst2.py --sft_config {sft_config_path} --model_config {saved_model_cfg_path}"
)

# Run fine-tuning
!{train_command}

# ▂▂▂▂▂▂▂▂▂▂▂▂

# S3. Run Multiple Configs

The below cell combines the above, allowing for running multiple experiments.

> Note: My experience has been that any run that takes an hour or more seems to be impractical on Colab, even with Colab Pro. My Notebook usually gets sluggish and eventually crashes. This may be because there is too much being printed to the output cell, however?

In [None]:
pretrain_config_paths = [
    f"{base_path}/configs/gpt-2_mha.json",
    f"{base_path}/configs/gpt-2_mla_192.96.192.json",
    f"{base_path}/configs/gpt-2_mla_192.96.0.json",
    f"{base_path}/configs/gpt-2_mla_0.96.192.json",
]

# Quickly validate the fileneames.
for pretrain_config_path in pretrain_config_paths:
    if not os.path.exists(pretrain_config_path):
        raise ValueError(f"Model config file {pretrain_config_path} does not exist.")


# For each of the pre-training configurations...
for pretrain_config_path in pretrain_config_paths:

    # ===========================
    #      Run pre-training
    # ===========================
    print("\n======= Pre-Train ========\n")

    # Construct the command line
    #  > train_log.txt 2>&1"
    train_command = (
        f"TRANSFORMERS_NO_TF=1 "
        f"PYTHONPATH={base_path} "
        f"WANDB_MODE={wandb_mode} "
        f'WANDB_API_KEY="{wandb_key}" '
        f"python {base_path}/scripts/train.py --config {pretrain_config_path}"
    )

    # Run pre-training
    !{train_command}


    # ===================================
    #    Retrieve the updated config
    # ===================================
    with open(pretrain_config_path, "r") as f:
        pretrain_config = json.load(f)

    saved_model_cfg_path = pretrain_config["pre_train"]["output_dir"] + '/full_config.json'

    print(f"\n======== Config From Saved Model: {saved_model_cfg_path} ========\n")
    with open(saved_model_cfg_path, "r") as f:
        print(json.dumps(json.load(f), indent=4))


    # ===================================
    #         Run fine-tuning
    # ===================================
    print("\n======= Fine-Tune on SST-2 ========\n")

    # Construct the command line with separate config files
    train_command = (
        f"TRANSFORMERS_NO_TF=1 "
        f"PYTHONPATH={base_path} "
        f"WANDB_MODE={wandb_mode} "
        f'WANDB_API_KEY="{wandb_key}" '
        f"python {base_path}/scripts/finetune_sst2.py --sft_config {sft_config_path} --model_config {saved_model_cfg_path}"
    )

    # Run fine-tuning
    !{train_command}


# ▂▂▂▂▂▂▂▂▂▂▂▂

# Appendix

## Defining a New Run

To modify the parameters from the command line, I created a command line utiltity in `/configs/create_new_config.py` which will copy one of the existing config files and allow you to specify any parameter changes.

See the [script](https://github.com/chrisjmccormick/shared-subspaces/blob/main/subspace_encoder/configs/create_new_config.py) for documentation, check out the baseline config [here](https://github.com/chrisjmccormick/shared-subspaces/blob/main/subspace_encoder/configs/best_mla-o.json) to see all of the hyperparameters that are defined, and see the [Config](https://github.com/chrisjmccormick/shared-subspaces/blob/main/subspace_encoder/models/shared_space_config.py#L81) class for documentation of the model parameters.

Below is an example for defining a new run which increases the output latent size to 96.

In [None]:
!python {base_path}/configs/create_new_config.py \
    mla-o_baseline_o96 \
    --base {base_path}/configs/mla-o_baseline.json \
    --shorthand "rd.32 - 6.mla.64.32.96 - mlp.1024 - model.256.lyr.6 - ah.8.32" \
    --notes "Trying increasing the output subspace size from 64 to 96" \
    --set model.o_latent_dim=96

Wrote new config to /content/shared-subspaces/subspace_encoder/configs/mla-o_baseline_o96.json


In [None]:
!cat {base_path}/configs/mla-o_baseline_o96.json

{
  "shorthand": "rd.32 - 6.mla.64.32.96 - mlp.1024 - model.256.lyr.6 - ah.8.32",
  "notes": "Trying increasing the output subspace size from 64 to 96",
  "model": {
    "hidden_size": 256,
    "num_hidden_layers": 6,
    "intermediate_size": 1024,
    "hidden_dropout_prob": 0.1,
    "attention_dropout_prob": 0.1,
    "classifier_dropout": null,
    "initializer_range": 0.02,
    "layer_norm_eps": 1e-12,
    "rms_norm_eps": 1e-06,
    "vocab_size": 30522,
    "rope_theta": 10000.0,
    "rope_scaling": null,
    "max_position_embeddings": 128,
    "num_dense_layers": 0,
    "q_latent_dim": 64,
    "kv_latent_dim": 32,
    "num_attention_heads": 8,
    "head_dim": 32,
    "rope_dims": 32,
    "attention_bias": false,
    "output_subspace": true,
    "o_latent_dim": 96,
    "attention_backend": "sdpa",
    "ffn_decompose": false,
    "ffn_rank": null,
    "vocab_subspace": false,
    "vocab_rank": 128
  },
  "pre_train": {
    "output_dir": "checkpoints/mla-o_baseline_o96",
    "seed": 42

# ▂▂▂▂▂▂▂▂▂▂▂▂