# 00 - Setup and Configuration

This notebook sets up the environment for the Git Re-Basin spurious features experiment.

## What this notebook does:
1. Validates all dependencies are installed
2. Defines the global CONFIG dictionary
3. Sets deterministic seeds for reproducibility
4. Creates necessary directories
5. Verifies GPU availability

## 1. Add src to path and validate imports

In [None]:
import sys
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path.cwd().parent
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Validate core dependencies
import importlib

dependencies = [
    ('torch', 'PyTorch'),
    ('torchvision', 'TorchVision'),
    ('numpy', 'NumPy'),
    ('scipy', 'SciPy'),
    ('sklearn', 'Scikit-learn'),
    ('matplotlib', 'Matplotlib'),
    ('seaborn', 'Seaborn'),
    ('tqdm', 'tqdm'),
    ('PIL', 'Pillow'),
]

print("Checking dependencies...\n")
all_ok = True

for module_name, display_name in dependencies:
    try:
        module = importlib.import_module(module_name)
        version = getattr(module, '__version__', 'unknown')
        print(f"  [OK] {display_name}: {version}")
    except ImportError:
        print(f"  [MISSING] {display_name}")
        all_ok = False

if all_ok:
    print("\nAll dependencies are installed!")
else:
    print("\nSome dependencies are missing. Run: pip install -r requirements.txt")

In [None]:
# Validate src module imports
print("Checking src modules...\n")

src_modules = ['config', 'data', 'models', 'train', 'rebasin', 'interp', 'metrics', 'plotting']

for module_name in src_modules:
    try:
        module = importlib.import_module(f'src.{module_name}')
        print(f"  [OK] src.{module_name}")
    except ImportError as e:
        print(f"  [ERROR] src.{module_name}: {e}")

print("\nAll src modules loaded successfully!")

## 2. Load Configuration

In [None]:
from src.config import get_config, CONFIG, set_seed, get_device, setup_directories

# Load configuration
config = get_config()

print("Global Configuration:")
print("=" * 50)
for section, values in config.items():
    print(f"\n[{section}]")
    if isinstance(values, dict):
        for key, val in values.items():
            print(f"  {key}: {val}")
    else:
        print(f"  {values}")

## 3. Set Deterministic Seeds

In [None]:
import torch
import numpy as np
import random

# Set global seed
GLOBAL_SEED = config['seeds']['global']
set_seed(GLOBAL_SEED)

print(f"Global seed set to: {GLOBAL_SEED}")
print(f"\nModel seeds:")
print(f"  Model A1 (spurious): {config['seeds']['model_A1']}")
print(f"  Model A2 (spurious): {config['seeds']['model_A2']}")
print(f"  Model R1 (robust):   {config['seeds']['model_R1']}")
print(f"  Model R2 (robust):   {config['seeds']['model_R2']}")

# Verify determinism
print(f"\nDeterminism settings:")
print(f"  torch.backends.cudnn.deterministic: {torch.backends.cudnn.deterministic}")
print(f"  torch.backends.cudnn.benchmark: {torch.backends.cudnn.benchmark}")

## 4. Create Directory Structure

In [None]:
# Create all necessary directories
dirs = setup_directories()

print("Directory structure:")
for name, path in dirs.items():
    exists = "[EXISTS]" if path.exists() else "[CREATED]"
    print(f"  {exists} {name}: {path}")

## 5. Check Device (GPU/CPU)

In [None]:
device = get_device()

print(f"Using device: {device}")

if device.type == 'cuda':
    print(f"\nCUDA Details:")
    print(f"  Device name: {torch.cuda.get_device_name(0)}")
    print(f"  CUDA version: {torch.version.cuda}")
    print(f"  Memory allocated: {torch.cuda.memory_allocated(0) / 1e9:.2f} GB")
    print(f"  Memory reserved: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB")
elif device.type == 'mps':
    print(f"\nUsing Apple Metal Performance Shaders (MPS)")
else:
    print(f"\nNo GPU available, using CPU. Training will be slower.")

## 6. Quick Sanity Check

In [None]:
# Test that we can create a model and do a forward pass
from src.models import create_model, count_parameters

model = create_model(config)
model = model.to(device)

# Create dummy input
dummy_input = torch.randn(2, 3, 32, 32).to(device)
output = model(dummy_input)

print("Model sanity check:")
print(f"  Model architecture: {config['model']['architecture']}")
print(f"  Parameters: {count_parameters(model):,}")
print(f"  Input shape: {dummy_input.shape}")
print(f"  Output shape: {output.shape}")
print(f"  Forward pass: OK")

In [None]:
# Test data loading
from src.data import create_env_a_dataset, create_no_patch_dataset

print("Testing data loading...")

# This will download CIFAR-10 if not present
env_a_train = create_env_a_dataset(train=True, config=config)
env_a_test = create_env_a_dataset(train=False, config=config)
no_patch_test = create_no_patch_dataset(train=False, config=config)

print(f"\nDataset sizes:")
print(f"  Env A train: {len(env_a_train)}")
print(f"  Env A test (ID): {len(env_a_test)}")
print(f"  No patch test (OOD): {len(no_patch_test)}")

# Verify alignment rate
alignment_rate = env_a_train.get_alignment_rate()
expected_rate = config['patch']['p_align_env_a']
print(f"\nEnv A alignment rate: {alignment_rate:.3f} (expected: {expected_rate})")

## 7. Summary

In [None]:
print("=" * 60)
print("SETUP COMPLETE")
print("=" * 60)
print(f"""
Environment:
  - Device: {device}
  - Global seed: {GLOBAL_SEED}
  - All dependencies: OK
  - All src modules: OK
  - Directory structure: OK

Configuration:
  - Dataset: {config['data']['dataset']}
  - Patch size: {config['patch']['size']}x{config['patch']['size']}
  - Env A alignment: {config['patch']['p_align_env_a']}
  - Env B alignment: {config['patch']['p_align_env_b']}
  - Model: {config['model']['architecture']}
  - Training epochs: {config['training']['num_epochs']}
  - Batch size: {config['training']['batch_size']}

Next steps:
  1. Run 01_data_spurious_envs.ipynb to visualize datasets
  2. Run 02_train_models.ipynb to train all 4 models
  3. Continue with remaining notebooks in order
""")

In [None]:
# Save config to results for reference
import json
from src.config import RESULTS_DIR

config_path = RESULTS_DIR / 'config.json'

# Convert config to JSON-serializable format
config_json = {}
for key, value in config.items():
    if isinstance(value, dict):
        config_json[key] = {}
        for k, v in value.items():
            if isinstance(v, (list, tuple)):
                config_json[key][k] = list(v)
            else:
                config_json[key][k] = v
    else:
        config_json[key] = value

with open(config_path, 'w') as f:
    json.dump(config_json, f, indent=2)

print(f"Configuration saved to: {config_path}")