# Memory-Augmented Transformer Training (A100)

**Quick start for Vertex AI Workbench or Colab Enterprise**

This notebook trains the Memory-Augmented Transformer on A100 GPUs.

## Setup
1. Create A100 instance in Vertex AI Workbench
2. Clone this repo
3. Run cells below

In [None]:
# Check GPU
!nvidia-smi

In [None]:
# Install dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install -q transformers datasets accelerate wandb tokenizers einops safetensors pyyaml tqdm

In [None]:
# Clone repo (if not already done)
import os
if not os.path.exists('memory_transformer'):
    # Replace with your repo URL
    !git clone https://github.com/YOUR_USERNAME/memory-transformer.git
    %cd memory-transformer

In [None]:
# Login to wandb (optional but recommended)
import wandb
wandb.login()

In [None]:
# Verify setup
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"BF16 support: {torch.cuda.is_bf16_supported()}")

## Training Options

Choose your experiment:

In [None]:
# Option 1: Quick validation (1 hour, ~$4)
# Train small model to verify everything works
!python scripts/train_cloud.py \
    --model-config configs/tiny_full.yaml \
    --batch-size 32 \
    --max-steps 5000 \
    --dataset c4 \
    --output-dir outputs/quick-test \
    --run-name "quick-validation"

In [None]:
# Option 2: Medium model training (4-6 hours, ~$20)
# 180M parameter model on C4
!python scripts/train_cloud.py \
    --model-config configs/medium_a100.yaml \
    --batch-size 32 \
    --max-steps 50000 \
    --dataset c4 \
    --output-dir outputs/medium-c4 \
    --run-name "mat-180m-c4"

In [None]:
# Option 3: Large model training (10-15 hours, ~$50)
# 500M parameter model
!python scripts/train_cloud.py \
    --model-config configs/large_a100.yaml \
    --batch-size 16 \
    --max-steps 100000 \
    --dataset c4 \
    --output-dir outputs/large-c4 \
    --run-name "mat-500m-c4"

In [None]:
# Option 4: Hyperparameter sweep (parallel, ~$100)
# Run multiple experiments to find optimal config
!wandb sweep configs/sweep_config.yaml

## Multi-GPU Training (8x A100)

If you have access to 8x A100 (~$30/hr), use distributed training for 10x speedup:

In [None]:
# 8x A100 distributed training
!accelerate launch --num_processes=8 scripts/train_distributed.py \
    --model-config configs/large_a100.yaml \
    --batch-size 8 \
    --max-steps 50000 \
    --run-name "mat-500m-8gpu"

## Evaluation

In [None]:
# Evaluate trained model
!python scripts/evaluate.py \
    --checkpoint outputs/medium-c4/checkpoint-best.pt \
    --benchmark recall

## Generate Text

In [None]:
# Generate text from trained model
!python scripts/generate.py \
    --checkpoint outputs/medium-c4/checkpoint-best.pt \
    --prompt "The Memory-Augmented Transformer works by" \
    --max-tokens 100

## Save to GCS

Save checkpoints to Google Cloud Storage to persist after VM shutdown:

In [None]:
# Upload checkpoints to GCS
BUCKET = "your-bucket-name"  # Change this
!gsutil -m cp -r outputs/* gs://{BUCKET}/mat-checkpoints/