In [1]:

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os
import argparse
from itertools import cycle

# ==============================================================================
# Configuration & Hyperparameters
# ==============================================================================
class Config:
    # Data Paths
    DATA_DIR = r"/kaggle/input/cmnistneo1"
    TRAIN_FILE = "train_data_rg95z.npz"
    TEST_FILE = "test_data_gr95z.npz"
    
    # Training Hyperparameters
    BATCH_SIZE = 256
    EPOCHS = 50
    LR = 1e-3
    WEIGHT_DECAY = 1e-4
    
    # IRM Specifics
    IRM_PENALTY_WEIGHT = 10000.0
    IRM_ANNEAL_EPOCHS = 10  # Epochs before penalty kicks in
    
    # System
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    SEED = 42

# ==============================================================================
# Model Architecture
# ==============================================================================
class CNN3Layer(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN3Layer, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(64 * 3 * 3, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

# ==============================================================================
# Utils: IRM Penalty & Data Loading
# ==============================================================================
def irm_penalty(logits, y):
    """
    Computes the IRM v1 penalty: gradient norm of the loss w.r.t a fixed scalar 1.0.
    This effectively asks: "If I multiplied the classifier output by a scalar 'w',
    would the optimal 'w' result in 0 gradient at w=1.0 across all environments?"
    """
    scale = torch.tensor(1.).to(logits.device).requires_grad_()
    loss = nn.CrossEntropyLoss()(logits * scale, y)
    grad = torch.autograd.grad(loss, [scale], create_graph=True)[0]
    return torch.sum(grad**2)

def get_dominant_color(img_tensor):
    """Returns dominant channel index (0=R, 1=G, 2=B) for a single image tensor."""
    means = torch.mean(img_tensor, dim=(1, 2))
    return torch.argmax(means).item()

def load_data(config):
    """Loads data and splits training set into two IRM environments."""
    print(f"\n[Data] Loading from {config.DATA_DIR}...")
    
    train_path = os.path.join(config.DATA_DIR, config.TRAIN_FILE)
    if not os.path.exists(train_path):
        raise FileNotFoundError(f"Train file not found: {train_path}")

    # Load Train (Biased)
    train_data = np.load(train_path)
    X_train = torch.tensor(train_data['images'].astype('float32') / 255.0).permute(0, 3, 1, 2)
    y_train = torch.tensor(train_data['labels']).long()

    # 1. Identify Spurious Correlation (Color Bias)
    print("[Data] Analyzing bias...")
    digit_bias_color = {}
    for d in range(10):
        # Sample subset to determine majority color
        indices = (y_train == d).nonzero(as_tuple=True)[0][:100]
        colors = [get_dominant_color(X_train[i]) for i in indices]
        majority = max(set(colors), key=colors.count)
        digit_bias_color[d] = majority
    
    # 2. Split into Environments (Aligned vs Conflict)
    # Calculate dominant color for entire batch efficiently
    img_means = torch.mean(X_train, dim=(2, 3)) # (N, 3)
    img_colors = torch.argmax(img_means, dim=1) # (N,)
    
    # Expected color based on label
    expected_colors = torch.tensor([digit_bias_color[y.item()] for y in y_train], device=X_train.device)
    
    # Mask: True if image follows bias (Environment 1), False if conflict (Environment 2)
    aligned_mask = (img_colors == expected_colors.cpu())
    
    env1_idx = torch.nonzero(aligned_mask, as_tuple=True)[0]
    env2_idx = torch.nonzero(~aligned_mask, as_tuple=True)[0]
    
    print(f"  Env 1 (Aligned/Biased): {len(env1_idx)} samples")
    print(f"  Env 2 (Conflict/OOD):   {len(env2_idx)} samples")

    # Load Test (Hard OOD)
    test_path = os.path.join(config.DATA_DIR, config.TEST_FILE)
    test_data = np.load(test_path)
    X_test = torch.tensor(test_data['images'].astype('float32') / 255.0).permute(0, 3, 1, 2)
    y_test = torch.tensor(test_data['labels']).long()
    
    # Create Datasets
    ds_env1 = TensorDataset(X_train[env1_idx], y_train[env1_idx])
    ds_env2 = TensorDataset(X_train[env2_idx], y_train[env2_idx])
    ds_test = TensorDataset(X_test, y_test)
    
    return ds_env1, ds_env2, ds_test

# ==============================================================================
# Training Loop
# ==============================================================================
def train(config):
    torch.manual_seed(config.SEED)
    
    # Load Data
    ds_env1, ds_env2, ds_test = load_data(config)
    
    # Loaders
    loader1 = DataLoader(ds_env1, batch_size=config.BATCH_SIZE, shuffle=True)
    loader2 = DataLoader(ds_env2, batch_size=config.BATCH_SIZE, shuffle=True)
    loader_test = DataLoader(ds_test, batch_size=config.BATCH_SIZE, shuffle=False)
    
    # Env 2 is smaller (5%), so we cycle it to match Env 1 iterations
    iter2 = cycle(loader2)
    
    # Model & Optim
    model = CNN3Layer().to(config.DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
    
    print(f"\n[Train] Starting IRM Training on {config.DEVICE}...")
    print(f"  Penalty: {config.IRM_PENALTY_WEIGHT} (Annealed for {config.IRM_ANNEAL_EPOCHS} epochs)")
    
    for epoch in range(config.EPOCHS):
        model.train()
        total_loss = 0
        total_penalty = 0
        correct = 0
        total = 0
        
        for x1, y1 in loader1:
            x2, y2 = next(iter2)
            
            x1, y1 = x1.to(config.DEVICE), y1.to(config.DEVICE)
            x2, y2 = x2.to(config.DEVICE), y2.to(config.DEVICE)
            
            # Forward
            logits1 = model(x1)
            logits2 = model(x2)
            
            # 1. Standard ERM Risk (Cross Entropy)
            nll1 = nn.CrossEntropyLoss()(logits1, y1)
            nll2 = nn.CrossEntropyLoss()(logits2, y2)
            risk = (nll1 + nll2) / 2
            
            # 2. Invariance Penalty
            penalty = torch.tensor(0.).to(config.DEVICE)
            if epoch >= config.IRM_ANNEAL_EPOCHS:
                p1 = irm_penalty(logits1, y1)
                p2 = irm_penalty(logits2, y2)
                penalty = (p1 + p2) / 2
            
            # Total Loss
            loss = risk + (config.IRM_PENALTY_WEIGHT if epoch >= config.IRM_ANNEAL_EPOCHS else 1.0) * penalty
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            # Stats
            total_loss += risk.item()
            total_penalty += penalty.item()
            _, preds = torch.max(logits1, 1)
            correct += (preds == y1).sum().item()
            total += y1.size(0)
            
        # Logging
        avg_risk = total_loss / len(loader1)
        avg_penalty = total_penalty / len(loader1)
        train_acc = 100 * correct / total
        
        print(f"Epoch [{epoch+1}/{config.EPOCHS}] Risk: {avg_risk:.4f} | Penalty: {avg_penalty:.6f} | Train Acc (Env1): {train_acc:.1f}%")
        
        # Validation
        if (epoch + 1) % 5 == 0:
            val_acc = evaluate(model, loader_test, config.DEVICE)
            print(f"  >>> Hard Test Acc (OOD): {val_acc:.2f}%")

    # Final Save
    save_path = "task4_irm_model.pth"
    torch.save(model.state_dict(), save_path)
    print(f"\n[Done] Model saved to {save_path}")

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            _, preds = torch.max(out, 1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

if __name__ == "__main__":
    train(Config)



[Data] Loading from /kaggle/input/cmnistneo1...
[Data] Analyzing bias...
  Env 1 (Aligned/Biased): 57948 samples
  Env 2 (Conflict/OOD):   2052 samples

[Train] Starting IRM Training on cuda...
  Penalty: 10000.0 (Annealed for 10 epochs)
Epoch [1/50] Risk: 0.5146 | Penalty: 0.000000 | Train Acc (Env1): 88.1%
Epoch [2/50] Risk: 0.0805 | Penalty: 0.000000 | Train Acc (Env1): 98.1%
Epoch [3/50] Risk: 0.0382 | Penalty: 0.000000 | Train Acc (Env1): 98.9%
Epoch [4/50] Risk: 0.0225 | Penalty: 0.000000 | Train Acc (Env1): 99.3%
Epoch [5/50] Risk: 0.0161 | Penalty: 0.000000 | Train Acc (Env1): 99.4%
  >>> Hard Test Acc (OOD): 94.48%
Epoch [6/50] Risk: 0.0130 | Penalty: 0.000000 | Train Acc (Env1): 99.5%
Epoch [7/50] Risk: 0.0100 | Penalty: 0.000000 | Train Acc (Env1): 99.6%
Epoch [8/50] Risk: 0.0089 | Penalty: 0.000000 | Train Acc (Env1): 99.7%
Epoch [9/50] Risk: 0.0075 | Penalty: 0.000000 | Train Acc (Env1): 99.7%
Epoch [10/50] Risk: 0.0071 | Penalty: 0.000000 | Train Acc (Env1): 99.7%
  >>> 