# Task 4: The Intervention - Method 1: IRM (Gradient Penalty)

This notebook implements **Invariant Risk Minimization (IRM)** style gradient penalty to force the model to ignore spurious color correlations and focus on digit shapes.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch_directml
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
import os

device = torch_directml.device()
print(f"Using DirectML device: {device}")
print(f"Device name: {torch_directml.device_name(0)}")

## 1. Load Biased Data

In [None]:
TRAIN_PATH = '../mydata/dataset/train_data_rg.npz'
TEST_PATH = '../mydata/dataset/test_data_rg.npz'

def load_npz(path):
    data = np.load(path)
    X = data['images'].astype('float32') / 255.0
    y = data['labels']
    return X, y

X_train, y_train = load_npz(TRAIN_PATH)
X_test, y_test = load_npz(TEST_PATH)

# Convert to tensors and permute to (N, C, H, W)
X_train_tensor = torch.FloatTensor(X_train).permute(0, 3, 1, 2)
y_train_tensor = torch.LongTensor(y_train)
X_test_tensor = torch.FloatTensor(X_test).permute(0, 3, 1, 2)
y_test_tensor = torch.LongTensor(y_test)

print(f"Train shape: {X_train_tensor.shape}")
print(f"Test shape: {X_test_tensor.shape}")

## 2. Define Model Architecture
Using the same 3-layer CNN as Task 1 for a fair comparison.

In [None]:
class CNN3Layer(nn.Module):
    def __init__(self):
        super(CNN3Layer, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 3 * 3, 128)
        self.fc2 = nn.Linear(128, 10)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

model = CNN3Layer().to(device)

## 3. Custom IRM Loss Implementation

The IRM penalty is calculated as the squared gradient of the loss with respect to a dummy scalar classifier $w=1.0$ applied to the model outputs. This encourages the model to have a representation where the optimal classifier is stable across potentially heterogeneous samples.

In [None]:
def compute_irm_penalty(logits, y):
    scale = torch.tensor(1.).to(device).requires_grad_(True)
    loss = F.cross_entropy(logits * scale, y)
    grad = torch.autograd.grad(loss, [scale], create_graph=True)[0]
    return torch.sum(grad**2)

def train_irm(model, loader, optimizer, penalty_weight, device):
    model.train()
    total_loss = 0
    total_penalty = 0
    
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        
        logits = model(x)
        ce_loss = F.cross_entropy(logits, y)
        penalty = compute_irm_penalty(logits, y)
        
        loss = ce_loss + penalty_weight * penalty
        
        loss.backward()
        optimizer.step()
        
        total_loss += ce_loss.item()
        total_penalty += penalty.item()
        
    return total_loss / len(loader), total_penalty / len(loader)

## 4. Training with Increasing Penalty
Often in IRM, we start with a standard CE loss and gradually ramp up the penalty to ensure the model first learns something useful before being constrained by invariance.

In [None]:
batch_size = 128
train_loader = DataLoader(TensorDataset(X_train_tensor, y_train_tensor), batch_size=batch_size, shuffle=True)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

epochs = 20
penalty_anneal_iters = 5 # Number of epochs to wait before applying penalty
penalty_weight = 100.0

history = {'loss': [], 'penalty': [], 'test_acc': []}

for epoch in range(epochs):
    weight = penalty_weight if epoch >= penalty_anneal_iters else 1.0
    avg_loss, avg_penalty = train_irm(model, train_loader, optimizer, weight, device)
    
    # Evaluate
    model.eval()
    correct = 0
    with torch.no_grad():
        logits = model(X_test_tensor.to(device))
        preds = logits.argmax(dim=1)
        correct = (preds == y_test_tensor.to(device)).sum().item()
    
    acc = correct / len(y_test_tensor)
    
    history['loss'].append(avg_loss)
    history['penalty'].append(avg_penalty)
    history['test_acc'].append(acc)
    
    print(f"Epoch {epoch+1}/{epochs} | Loss: {avg_loss:.4f} | Penalty: {avg_penalty:.6f} | Test Acc: {acc:.4%}")

torch.save(model.state_dict(), 'task4_method1.pth')
print("Model saved as task4_method1.pth")

## 5. Result Analysis
Compare with the baseline performance from Task 1.

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history['loss'], label='CE Loss')
plt.title('Training Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history['test_acc'], label='Hard Test Acc')
plt.axhline(0.7, color='r', linestyle='--', label='Target')
plt.title('Performance on Hard Set')
plt.legend()
plt.show()