In [None]:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os
from itertools import cycle

# ==============================================================================
# Configuration & Hyperparameters
# ==============================================================================
class Config:
    # Data Paths
    DATA_DIR = r"/kaggle/input/cmnistneo1"
    TRAIN_FILE = "train_data_rg95z.npz"
    TEST_FILE = "test_data_gr95z.npz"
    
    # Training Hyperparameters
    BATCH_SIZE = 128
    EPOCHS = 20
    LR = 1e-3
    WEIGHT_DECAY = 1e-4
    
    # REx Specifics
    REX_PENALTY_WEIGHT = 100.0
    REX_ANNEAL_EPOCHS = 5  # Epochs before penalty kicks in
    
    # System
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    SEED = 42

# ==============================================================================
# Model Architecture
# ==============================================================================
class CNN3Layer(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN3Layer, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(64 * 3 * 3, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = x.reshape(x.size(0), -1)
        x = self.classifier(x)
        return x

# ==============================================================================
# Utils: REx Penalty & Data Loading
# ==============================================================================
def get_dominant_color(img_tensor):
    """Returns dominant channel index (0=R, 1=G, 2=B) for a single image tensor."""
    means = torch.mean(img_tensor, dim=(1, 2))
    return torch.argmax(means).item()

def load_data(config):
    """Loads data and splits training set into two REx environments."""
    print(f"\n[Data] Loading from {config.DATA_DIR}...")
    
    train_path = os.path.join(config.DATA_DIR, config.TRAIN_FILE)
    if not os.path.exists(train_path):
        raise FileNotFoundError(f"Train file not found: {train_path}")

    # Load Train (Biased)
    train_data = np.load(train_path)
    X_train = torch.tensor(train_data['images'].astype('float32') / 255.0).permute(0, 3, 1, 2)
    y_train = torch.tensor(train_data['labels']).long()

    # 1. Identify Spurious Correlation (Color Bias)
    print("[Data] Analyzing bias...")
    digit_bias_color = {}
    for d in range(10):
        indices = (y_train == d).nonzero(as_tuple=True)[0][:100]
        colors = [get_dominant_color(X_train[i]) for i in indices]
        majority = max(set(colors), key=colors.count)
        digit_bias_color[d] = majority
    
    # 2. Split into Environments (Aligned vs Conflict)
    img_means = torch.mean(X_train, dim=(2, 3))
    img_colors = torch.argmax(img_means, dim=1)
    
    expected_colors = torch.tensor([digit_bias_color[y.item()] for y in y_train], device=X_train.device)
    aligned_mask = (img_colors == expected_colors.cpu())
    
    env1_idx = torch.nonzero(aligned_mask, as_tuple=True)[0]
    env2_idx = torch.nonzero(~aligned_mask, as_tuple=True)[0]
    
    print(f"  Env 1 (Aligned/Biased): {len(env1_idx)} samples")
    print(f"  Env 2 (Conflict/OOD):   {len(env2_idx)} samples")

    # Load Test (Hard OOD)
    test_path = os.path.join(config.DATA_DIR, config.TEST_FILE)
    test_data = np.load(test_path)
    X_test = torch.tensor(test_data['images'].astype('float32') / 255.0).permute(0, 3, 1, 2)
    y_test = torch.tensor(test_data['labels']).long()
    
    # Create Datasets
    ds_env1 = TensorDataset(X_train[env1_idx], y_train[env1_idx])
    ds_env2 = TensorDataset(X_train[env2_idx], y_train[env2_idx])
    ds_test = TensorDataset(X_test, y_test)
    
    return ds_env1, ds_env2, ds_test

# ==============================================================================
# Training Loop with REx
# ==============================================================================
def train(config):
    torch.manual_seed(config.SEED)
    
    # Load Data
    ds_env1, ds_env2, ds_test = load_data(config)
    
    # Loaders
    loader1 = DataLoader(ds_env1, batch_size=config.BATCH_SIZE, shuffle=True)
    loader2 = DataLoader(ds_env2, batch_size=config.BATCH_SIZE, shuffle=True)
    loader_test = DataLoader(ds_test, batch_size=config.BATCH_SIZE, shuffle=False)
    
    # Env 2 is smaller (5%), so we cycle it to match Env 1 iterations
    iter2 = cycle(loader2)
    
    # Model & Optim
    model = CNN3Layer().to(config.DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
    
    print(f"\n[Train] Starting REx Training on {config.DEVICE}...")
    print(f"  Penalty: {config.REX_PENALTY_WEIGHT} (Annealed for {config.REX_ANNEAL_EPOCHS} epochs)")
    
    # Track history for plotting
    history = {'loss': [], 'penalty': [], 'test_acc': []}
    
    for epoch in range(config.EPOCHS):
        model.train()
        total_loss = 0
        total_penalty = 0
        correct = 0
        total = 0
        
        for x1, y1 in loader1:
            x2, y2 = next(iter2)
            
            x1, y1 = x1.to(config.DEVICE), y1.to(config.DEVICE)
            x2, y2 = x2.to(config.DEVICE), y2.to(config.DEVICE)
            
            # Forward
            logits1 = model(x1)
            logits2 = model(x2)
            
            # 1. Compute losses per environment
            loss1 = nn.CrossEntropyLoss()(logits1, y1)
            loss2 = nn.CrossEntropyLoss()(logits2, y2)
            
            # 2. REx: Mean + Variance penalty
            mean_loss = (loss1 + loss2) / 2
            
            # Variance of losses across environments
            penalty = torch.tensor(0.).to(config.DEVICE)
            if epoch >= config.REX_ANNEAL_EPOCHS:
                variance = ((loss1 - mean_loss) ** 2 + (loss2 - mean_loss) ** 2) / 2
                penalty = variance
            
            # Total Loss: Mean(Losses) + beta * Variance(Losses)
            weight = config.REX_PENALTY_WEIGHT if epoch >= config.REX_ANNEAL_EPOCHS else 0.0
            total_loss_batch = mean_loss + weight * penalty
            
            optimizer.zero_grad()
            total_loss_batch.backward()
            optimizer.step()
            
            # Stats
            total_loss += mean_loss.item()
            total_penalty += penalty.item()
            _, preds = torch.max(logits1, 1)
            correct += (preds == y1).sum().item()
            total += y1.size(0)
            
        # Logging
        avg_risk = total_loss / len(loader1)
        avg_penalty = total_penalty / len(loader1)
        train_acc = 100 * correct / total
        
        # Evaluate on test set
        val_acc = evaluate(model, loader_test, config.DEVICE)
        
        # Store history
        history['loss'].append(avg_risk)
        history['penalty'].append(avg_penalty)
        history['test_acc'].append(val_acc)
        
        print(f"Epoch [{epoch+1}/{config.EPOCHS}] Risk: {avg_risk:.4f} | Variance: {avg_penalty:.6f} | Train Acc: {train_acc:.1f}% | Test Acc: {val_acc:.2f}%")

    # Final Save
    save_path = "task4_rex_model.pth"
    torch.save(model.state_dict(), save_path)
    print(f"\n[Done] Model saved to {save_path}")
    
    # ========== VISUALIZATIONS ==========
    print("\n[Visualizations] Generating plots...")
    
    # 1. Training History Plots
    plt.figure(figsize=(15, 4))
    
    plt.subplot(1, 3, 1)
    plt.plot(history['loss'], label='CE Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss (REx)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 2)
    plt.plot(history['penalty'], label='Variance Penalty', color='orange', linewidth=2)
    plt.axvline(config.REX_ANNEAL_EPOCHS, color='r', linestyle='--', alpha=0.5, label='Penalty Start')
    plt.xlabel('Epoch')
    plt.ylabel('Variance')
    plt.title('REx Variance Penalty')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 3)
    plt.plot(history['test_acc'], label='Hard Test Acc', color='green', linewidth=2)
    plt.axhline(70, color='r', linestyle='--', alpha=0.7, label='Target (70%)')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Performance on Hard Test Set')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('task4_rex_training.png', dpi=150, bbox_inches='tight')
    print("  Saved: task4_rex_training.png")
    plt.show()
    
    # 2. Confusion Matrix
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for x, y in loader_test:
            x, y = x.to(config.DEVICE), y.to(config.DEVICE)
            out = model(x)
            preds = torch.max(out, 1)[1]
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=range(10), yticklabels=range(10))
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix - REx on Hard Test Set')
    plt.savefig('task4_rex_confusion.png', dpi=150, bbox_inches='tight')
    print("  Saved: task4_rex_confusion.png")
    plt.show()
    
    # 3. Sample Predictions
    test_data = np.load(os.path.join(config.DATA_DIR, config.TEST_FILE))
    X_test_np = test_data['images']
    y_test_np = test_data['labels']
    
    model.eval()
    sample_indices = np.random.choice(len(X_test_np), 10, replace=False)
    
    plt.figure(figsize=(15, 6))
    for i, idx in enumerate(sample_indices):
        img_tensor = torch.FloatTensor(X_test_np[idx:idx+1] / 255.0).permute(0, 3, 1, 2).to(config.DEVICE)
        with torch.no_grad():
            pred = torch.max(model(img_tensor), 1)[1].item()
        true_label = y_test_np[idx]
        
        plt.subplot(2, 5, i+1)
        plt.imshow(X_test_np[idx])
        plt.title(f'True: {true_label}\nPred: {pred}', 
                 color='green' if pred == true_label else 'red')
        plt.axis('off')
    
    plt.suptitle('REx: Sample Predictions on Hard Test Set', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig('task4_rex_samples.png', dpi=150, bbox_inches='tight')
    print("  Saved: task4_rex_samples.png")
    plt.show()
    
    # 4. Final Summary
    final_acc = history['test_acc'][-1]
    print(f"\n{'='*60}")
    print(f"FINAL RESULTS - Risk Extrapolation (REx)")
    print(f"{'='*60}")
    print(f"Final Test Accuracy: {final_acc:.2f}%")
    print(f"Target Achieved: {'✓ YES' if final_acc >= 70 else '✗ NO'}")
    print(f"Best Test Accuracy: {max(history['test_acc']):.2f}% (Epoch {np.argmax(history['test_acc'])+1})")
    print(f"{'='*60}")

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            _, preds = torch.max(out, 1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

if __name__ == "__main__":
    train(Config)
