# In-Context Learning Training on Google Colab

This notebook trains transformer models for in-context learning tasks.

**Setup Instructions:**
1. Runtime → Change runtime type → GPU (T4, A100, or V100)
2. Run cells sequentially
3. Authenticate with Weights & Biases when prompted

**Note:** This notebook uses the new YAML-based configuration system (no quinine dependency).


## 1. Check GPU and Python Environment


In [1]:
# Check GPU availability
!nvidia-smi

import sys
print(f"\nPython version: {sys.version}")


Tue Nov 11 06:50:53 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   56C    P8             13W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## 2. Install Required Packages


In [2]:
# Install required packages
print("Installing packages...\n")

# Core ML packages
%pip install -q transformers>=4.30.0
%pip install -q wandb
%pip install -q xgboost
%pip install -q matplotlib seaborn tqdm
%pip install -q pyyaml
%pip install munch

# PyTorch usually comes pre-installed in Colab
try:
    import torch
    print(f"✓ PyTorch already installed: {torch.__version__}")
except ImportError:
    print("Installing PyTorch...")
    %pip install -q torch torchvision torchaudio

print("\n" + "="*60)
print("✓ All required packages installed successfully!")
print("="*60)

# Verify key packages
import torch
import transformers
import wandb
import yaml

print(f"\nPackage Versions:")
print(f"  PyTorch: {torch.__version__}")
print(f"  Transformers: {transformers.__version__}")
print(f"  Wandb: {wandb.__version__}")

print(f"\nGPU Information:")
print(f"  CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  CUDA version: {torch.version.cuda}")
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("  ⚠️  No GPU detected! Enable GPU: Runtime → Change runtime type → T4 GPU")

print("\n✓ Ready to proceed!")


Installing packages...

✓ PyTorch already installed: 2.8.0+cu126

✓ All required packages installed successfully!

Package Versions:
  PyTorch: 2.8.0+cu126
  Transformers: 4.57.1
  Wandb: 0.22.3

GPU Information:
  CUDA available: True
  CUDA version: 12.6
  GPU: NVIDIA L4
  GPU Memory: 23.80 GB

✓ Ready to proceed!


## 3. Clone/Setup Repository

Choose one of the following options:


In [3]:
# Option A: Clone from GitHub
import os
import subprocess

REPO_URL = "https://github.com/hingma/in-context-learning.git"  # UPDATE THIS!

if not os.path.exists("in-context-learning"):
    print(f"Cloning repository from {REPO_URL}...")
    result = subprocess.run(["git", "clone", REPO_URL], capture_output=True, text=True)
    if result.returncode == 0:
        print("✓ Repository cloned successfully")
    else:
        print(f"Error cloning repository: {result.stderr}")
else:
    print("✓ Repository already exists")

%cd in-context-learning


Cloning repository from https://github.com/hingma/in-context-learning.git...
✓ Repository cloned successfully
/content/in-context-learning


In [None]:
# Option B: Mount Google Drive (uncomment if needed)
# from google.colab import drive
# drive.mount('/content/drive')
# %cd /content/drive/MyDrive/in-context-learning


## 4. Setup Weights & Biases


In [20]:
import wandb

# Login to W&B (you'll need to paste your API key)
wandb.login()

print("✓ W&B authenticated")


✓ W&B authenticated


## 5. Configuration Setup

Define your training configuration. This replaces the need for external config files.


In [18]:
import yaml
import os

# Define your configuration
config = {
    'model': {
        'family': 'gpt2',  # Options: 'gpt2' or 'lstm'
        'n_positions': 256,  # Maximum context length
        'n_dims': 20,  # Latent dimension
        'n_embd': 256,  # Embedding dimension
        'n_layer': 12,  # Number of layers
        'n_head': 8,  # Number of attention heads
    },
    'training': {
        'task': 'linear_regression',  # Task type
        # Options: linear_regression, sparse_linear_regression,
        #          linear_classification, relu_2nn_regression, decision_tree
        'task_kwargs': {},  # Task-specific arguments
        'num_tasks': None,  # Number of tasks (None = unlimited)
        'num_training_examples': None,  # Training examples (None = unlimited)
        'data': 'gaussian',  # Data distribution
        'batch_size': 64,  # Batch size
        'learning_rate': 3e-4,  # Learning rate
        'train_steps': 10000,  # Total training steps
        'save_every_steps': 1000,  # Checkpoint frequency
        'keep_every_steps': -1,  # Permanent checkpoint frequency (-1 = disabled)
        'resume_id': None,  # Resume from run ID (None = new run)
        'curriculum': {
            'dims': {
                'start': 5,  # Initial dimensions
                'end': 20,  # Final dimensions
                'inc': 1,  # Increment per update
                'interval': 100,  # Update every N steps
            },
            'points': {
                'start': 10,  # Initial points
                'end': 41,  # Final points
                'inc': 1,  # Increment per update
                'interval': 100,  # Update every N steps
            },
        },
    },
    'wandb': {
    'project': 'in-context-training',
    'entity': None,  # ← Use your default entity
    'notes': 'Training run from Colab',
    'name': None,
    'log_every_steps': 10,
    },
    # 'wandb': {
    #     'project': 'in-context-training',  # W&B project name
    #     'entity': 'moxintang',  # W&B entity/team name - UPDATE THIS!
    #     'notes': 'Training run from Colab',  # Run notes
    #     'name': None,  # Run name (None = auto-generated)
    #     'log_every_steps': 10,  # Logging frequency
    # },
}

# Save config to file
config_path = 'train_config.yaml'
with open(config_path, 'w') as f:
    yaml.dump(config, f, default_flow_style=False)

print("Configuration saved to:", config_path)
print("\nConfiguration:")
print(yaml.dump(config, default_flow_style=False))


Configuration saved to: train_config.yaml

Configuration:
model:
  family: gpt2
  n_dims: 20
  n_embd: 256
  n_head: 8
  n_layer: 12
  n_positions: 256
training:
  batch_size: 64
  curriculum:
    dims:
      end: 20
      inc: 1
      interval: 100
      start: 5
    points:
      end: 41
      inc: 1
      interval: 100
      start: 10
  data: gaussian
  keep_every_steps: -1
  learning_rate: 0.0003
  num_tasks: null
  num_training_examples: null
  resume_id: null
  save_every_steps: 1000
  task: linear_regression
  task_kwargs: {}
  train_steps: 10000
wandb:
  entity: null
  log_every_steps: 10
  name: null
  notes: Training run from Colab
  project: in-context-training



## 6. Import Training Modules


In [14]:
# Add src to path
import sys
sys.path.insert(0, './src')

# Import required modules
import torch
from random import randint
import uuid
from tqdm import tqdm

from eval import get_run_metrics
from tasks import get_task_sampler
from samplers import get_data_sampler
from curriculum import Curriculum
from models import build_model
from config import ConfigDict, validate_config, set_defaults

torch.backends.cudnn.benchmark = True

print("✓ All modules imported successfully")


✓ All modules imported successfully


## 7. Define Training Functions


In [21]:
def train_step(model, xs, ys, optimizer, loss_func):
    """Execute a single training step."""
    optimizer.zero_grad()
    output = model(xs, ys)
    loss = loss_func(output, ys)
    loss.backward()
    optimizer.step()
    return loss.detach().item(), output.detach()


def sample_seeds(total_seeds, count):
    """Sample random seeds for reproducible training examples."""
    seeds = set()
    while len(seeds) < count:
        seeds.add(randint(0, total_seeds - 1))
    return seeds


def train(model, args):
    """Main training loop."""
    optimizer = torch.optim.Adam(model.parameters(), lr=args.training.learning_rate)
    curriculum = Curriculum(args.training.curriculum)

    starting_step = 0
    state_path = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_path):
        state = torch.load(state_path)
        model.load_state_dict(state["model_state_dict"])
        optimizer.load_state_dict(state["optimizer_state_dict"])
        starting_step = state["train_step"]
        for i in range(state["train_step"] + 1):
            curriculum.update()
        print(f"✓ Resumed from step {starting_step}")

    n_dims = model.n_dims
    bsize = args.training.batch_size
    data_sampler = get_data_sampler(args.training.data, n_dims=n_dims)
    task_sampler = get_task_sampler(
        args.training.task,
        n_dims,
        bsize,
        num_tasks=args.training.num_tasks,
        **args.training.task_kwargs,
    )
    pbar = tqdm(range(starting_step, args.training.train_steps))

    num_training_examples = args.training.num_training_examples

    for i in pbar:
        data_sampler_args = {}
        task_sampler_args = {}

        if "sparse" in args.training.task:
            task_sampler_args["valid_coords"] = curriculum.n_dims_truncated
        if num_training_examples is not None:
            assert num_training_examples >= bsize
            seeds = sample_seeds(num_training_examples, bsize)
            data_sampler_args["seeds"] = seeds
            task_sampler_args["seeds"] = [s + 1 for s in seeds]

        xs = data_sampler.sample_xs(
            curriculum.n_points,
            bsize,
            curriculum.n_dims_truncated,
            **data_sampler_args,
        )
        task = task_sampler(**task_sampler_args)
        ys = task.evaluate(xs)

        loss_func = task.get_training_metric()

        loss, output = train_step(model, xs.cuda(), ys.cuda(), optimizer, loss_func)

        point_wise_tags = list(range(curriculum.n_points))
        point_wise_loss_func = task.get_metric()
        point_wise_loss = point_wise_loss_func(output, ys.cuda()).mean(dim=0)

        baseline_loss = (
            sum(
                max(curriculum.n_dims_truncated - ii, 0)
                for ii in range(curriculum.n_points)
            )
            / curriculum.n_points
        )

        if i % args.wandb.log_every_steps == 0 and not args.test_run:
            wandb.log(
                {
                    "overall_loss": loss,
                    "excess_loss": loss / baseline_loss,
                    "pointwise/loss": dict(
                        zip(point_wise_tags, point_wise_loss.cpu().numpy())
                    ),
                    "n_points": curriculum.n_points,
                    "n_dims": curriculum.n_dims_truncated,
                },
                step=i,
            )

        curriculum.update()

        pbar.set_description(f"loss {loss:.4f}")
        if i % args.training.save_every_steps == 0 and not args.test_run:
            training_state = {
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "train_step": i,
            }
            torch.save(training_state, state_path)

        if (
            args.training.keep_every_steps > 0
            and i % args.training.keep_every_steps == 0
            and not args.test_run
            and i > 0
        ):
            torch.save(model.state_dict(), os.path.join(args.out_dir, f"model_{i}.pt"))

    print("\n✓ Training completed!")


def main(args):
    """Main training function."""
    if args.test_run:
        curriculum_args = args.training.curriculum
        curriculum_args['points']['start'] = curriculum_args['points']['end']
        curriculum_args['dims']['start'] = curriculum_args['dims']['end']
        args.training.train_steps = 100
        print("Running in test mode (100 steps)")
    else:
        wandb.init(
            project=args.wandb.project,
            config=dict(args),
            notes=args.wandb.notes,
            name=args.wandb.name,
        )
        # wandb.init(
        #     dir=args.out_dir,
        #     project=args.wandb.project,
        #     entity=args.wandb.entity,
        #     config=dict(args),
        #     notes=args.wandb.notes,
        #     name=args.wandb.name,
        #     resume=True,
        # )
        print(f"✓ W&B run initialized: {wandb.run.name}")

    model = build_model(args.model)
    model.cuda()
    model.train()

    print(f"\n{'='*60}")
    print(f"Model: {args.model.family}")
    print(f"Task: {args.training.task}")
    print(f"Training steps: {args.training.train_steps}")
    print(f"Batch size: {args.training.batch_size}")
    print(f"Learning rate: {args.training.learning_rate}")
    print(f"{'='*60}\n")

    train(model, args)

    if not args.test_run:
        _ = get_run_metrics(args.out_dir)  # Precompute metrics for eval
        print("✓ Metrics computed")


print("✓ Training functions defined")


✓ Training functions defined


In [22]:
# Load and prepare configuration
with open(config_path, 'r') as f:
    config_dict = yaml.safe_load(f)

# Set defaults and validate
set_defaults(config_dict)
validate_config(config_dict)

# Add required fields
config_dict['out_dir'] = './outputs'  # Output directory
config_dict['test_run'] = False  # Set to True for quick test run (100 steps)

# Convert to ConfigDict for attribute access
args = ConfigDict(config_dict)

# Verify model family
assert args.model.family in ["gpt2", "lstm"], f"Invalid model family: {args.model.family}"

# Create output directory with unique run ID
if not args.test_run:
    run_id = args.training.resume_id
    if run_id is None:
        run_id = str(uuid.uuid4())

    out_dir = os.path.join(args.out_dir, run_id)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    args.out_dir = out_dir

    # Save config to output directory
    with open(os.path.join(out_dir, "config.yaml"), "w") as yaml_file:
        yaml.dump(dict(args), yaml_file, default_flow_style=False)

    print(f"✓ Output directory: {out_dir}")
    print(f"✓ Run ID: {run_id}")
else:
    print("Running in test mode (no output saved)")

print("\n✓ Configuration prepared")


✓ Output directory: ./outputs/27bbcdfc-aa77-4072-9f50-32ac973d01c2
✓ Run ID: 27bbcdfc-aa77-4072-9f50-32ac973d01c2

✓ Configuration prepared


## 9. Start Training


In [23]:
# Run training
print("Starting training...\n")
main(args)


Starting training...



✓ W&B run initialized: zesty-grass-1

Model: gpt2
Task: linear_regression
Training steps: 10000
Batch size: 64
Learning rate: 0.0003



loss 19.9302:  21%|██        | 2061/10000 [01:41<06:29, 20.37it/s]


KeyboardInterrupt: 

## 10. Monitor Training

You can monitor training progress in real-time using Weights & Biases:
- Click on the W&B run link printed above
- View loss curves, metrics, and system stats
- Compare with other runs


## 11. Save/Download Model

After training completes, you can download the model checkpoints:


In [None]:
# Download trained model to local machine
from google.colab import files

# Download the final state
if not args.test_run:
    state_file = os.path.join(args.out_dir, "state.pt")
    if os.path.exists(state_file):
        print(f"Downloading {state_file}...")
        files.download(state_file)

    # Also download the config
    config_file = os.path.join(args.out_dir, "config.yaml")
    if os.path.exists(config_file):
        print(f"Downloading {config_file}...")
        files.download(config_file)

    print("✓ Model files downloaded")
else:
    print("Test run - no files to download")


## 12. Optional: Run Quick Test

Run a quick test with just 100 steps to verify everything works:


In [None]:
# Quick test run (100 steps, no logging)
test_config = config_dict.copy()
test_config['test_run'] = True
test_args = ConfigDict(test_config)

print("Running quick test (100 steps)...\n")
main(test_args)


## Configuration Options Reference

### Model Configuration
- `family`: 'gpt2' or 'lstm'
- `n_positions`: Maximum sequence length
- `n_dims`: Latent dimension size
- `n_embd`: Embedding dimension
- `n_layer`: Number of transformer/LSTM layers
- `n_head`: Number of attention heads (for GPT-2)

### Training Tasks
Available tasks:
1. `linear_regression`: Linear regression
2. `sparse_linear_regression`: Sparse linear regression
3. `linear_classification`: Linear classification
4. `relu_2nn_regression`: 2-layer ReLU neural network regression
5. `decision_tree`: Decision tree learning

### Curriculum Learning
- `start`: Initial value
- `end`: Final value
- `inc`: Increment per update
- `interval`: Update frequency (steps)

### Training Parameters
- `batch_size`: Batch size (default: 64)
- `learning_rate`: Learning rate (default: 3e-4)
- `train_steps`: Total training steps
- `save_every_steps`: Checkpoint frequency
- `keep_every_steps`: Permanent checkpoint frequency (-1 to disable)
