In [None]:
"""
Task 4: The Intervention - Method 4: Gradient Reversal Layer (GRL)

Adversarial approach that forces the network to learn color-invariant features
by adding a "color predictor" head with reversed gradients.
"""

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import os

# ==============================================================================
# Configuration & Hyperparameters
# ==============================================================================
class Config:
    # Data Paths
    DATA_DIR = r"/kaggle/input/cmnistneo1"
    TRAIN_FILE = "train_data_rg95z.npz"
    TEST_FILE = "test_data_gr95z.npz"
    # Training Hyperparameters
    BATCH_SIZE = 128
    EPOCHS = 20
    LR = 1e-3
    WEIGHT_DECAY = 1e-4
    
    # GRL Specifics
    GRL_LAMBDA = 1.0  # Strength of gradient reversal
    
    # System
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    SEED = 42

# ==============================================================================
# Gradient Reversal Layer
# ==============================================================================
class GradientReversalFunction(torch.autograd.Function):
    """
    During forward pass: identity (no change)
    During backward pass: flip the gradient sign
    """
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.view_as(x)
    
    @staticmethod
    def backward(ctx, grad_output):
        return grad_output.neg() * ctx.lambda_, None

class GradientReversalLayer(nn.Module):
    def __init__(self, lambda_=1.0):
        super(GradientReversalLayer, self).__init__()
        self.lambda_ = lambda_
    
    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)

# ==============================================================================
# Model Architecture with GRL
# ==============================================================================
class CNNWithGRL(nn.Module):
    def __init__(self, num_classes=10, num_colors=3, grl_lambda=1.0):
        super(CNNWithGRL, self).__init__()
        
        # Shared feature extractor
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        
        # Main classifier (for digits)
        self.digit_classifier = nn.Sequential(
            nn.Linear(64 * 3 * 3, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )
        
        # Gradient Reversal Layer
        self.grl = GradientReversalLayer(lambda_=grl_lambda)
        
        # Color discriminator (adversarial head)
        self.color_discriminator = nn.Sequential(
            nn.Linear(64 * 3 * 3, 64),
            nn.ReLU(),
            nn.Linear(64, num_colors)
        )

    def forward(self, x):
        # Shared features
        features = self.features(x)
        features_flat = features.reshape(features.size(0), -1)
        
        # Digit prediction (main task)
        digit_out = self.digit_classifier(features_flat)
        
        # Color prediction (through GRL - adversarial)
        reversed_features = self.grl(features_flat)
        color_out = self.color_discriminator(reversed_features)
        
        return digit_out, color_out

# ==============================================================================
# Utils: Data Loading
# ==============================================================================
def get_dominant_color(img_tensor):
    """Returns dominant channel index (0=R, 1=G, 2=B)."""
    means = torch.mean(img_tensor, dim=(1, 2))
    return torch.argmax(means).item()

def load_data(config):
    """Loads data with color labels."""
    print(f"\n[Data] Loading from {config.DATA_DIR}...")
    
    train_path = os.path.join(config.DATA_DIR, config.TRAIN_FILE)
    train_data = np.load(train_path)
    X_train = torch.tensor(train_data['images'].astype('float32') / 255.0).permute(0, 3, 1, 2)
    y_train = torch.tensor(train_data['labels']).long()
    
    # Extract color labels
    print("[Data] Extracting color labels...")
    color_labels = []
    for i in range(len(X_train)):
        color_labels.append(get_dominant_color(X_train[i]))
    color_train = torch.tensor(color_labels).long()
    
    # Load Test
    test_path = os.path.join(config.DATA_DIR, config.TEST_FILE)
    test_data = np.load(test_path)
    X_test = torch.tensor(test_data['images'].astype('float32') / 255.0).permute(0, 3, 1, 2)
    y_test = torch.tensor(test_data['labels']).long()
    
    print(f"  Train: {X_train.shape}, Colors: {color_train.shape}")
    print(f"  Test: {X_test.shape}")
    
    ds_train = TensorDataset(X_train, y_train, color_train)
    ds_test = TensorDataset(X_test, y_test)
    
    return ds_train, ds_test

# ==============================================================================
# Training Loop with GRL
# ==============================================================================
def train(config):
    torch.manual_seed(config.SEED)
    
    # Load Data
    ds_train, ds_test = load_data(config)
    
    # Loaders
    loader_train = DataLoader(ds_train, batch_size=config.BATCH_SIZE, shuffle=True)
    loader_test = DataLoader(ds_test, batch_size=config.BATCH_SIZE, shuffle=False)
    
    # Model & Optim
    model = CNNWithGRL(grl_lambda=config.GRL_LAMBDA).to(config.DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=config.LR, weight_decay=config.WEIGHT_DECAY)
    
    print(f"\n[Train] Starting GRL Training on {config.DEVICE}...")
    print(f"  GRL Lambda: {config.GRL_LAMBDA}")
    
    # Track history
    history = {'digit_loss': [], 'color_loss': [], 'test_acc': []}
    
    for epoch in range(config.EPOCHS):
        model.train()
        total_digit_loss = 0
        total_color_loss = 0
        correct = 0
        total = 0
        
        for x, y_digit, y_color in loader_train:
            x = x.to(config.DEVICE)
            y_digit = y_digit.to(config.DEVICE)
            y_color = y_color.to(config.DEVICE)
            
            optimizer.zero_grad()
            
            # Forward
            digit_out, color_out = model(x)
            
            # Losses
            digit_loss = nn.CrossEntropyLoss()(digit_out, y_digit)
            color_loss = nn.CrossEntropyLoss()(color_out, y_color)
            
            # Total loss (color_loss gradient will be reversed by GRL)
            total_loss = digit_loss + color_loss
            
            total_loss.backward()
            optimizer.step()
            
            # Stats
            total_digit_loss += digit_loss.item()
            total_color_loss += color_loss.item()
            _, preds = torch.max(digit_out, 1)
            correct += (preds == y_digit).sum().item()
            total += y_digit.size(0)
        
        # Logging
        avg_digit_loss = total_digit_loss / len(loader_train)
        avg_color_loss = total_color_loss / len(loader_train)
        train_acc = 100 * correct / total
        
        # Evaluate
        val_acc = evaluate(model, loader_test, config.DEVICE)
        
        # Store history
        history['digit_loss'].append(avg_digit_loss)
        history['color_loss'].append(avg_color_loss)
        history['test_acc'].append(val_acc)
        
        print(f"Epoch [{epoch+1}/{config.EPOCHS}] Digit Loss: {avg_digit_loss:.4f} | Color Loss: {avg_color_loss:.4f} | Train Acc: {train_acc:.1f}% | Test Acc: {val_acc:.2f}%")

    # Final Save
    save_path = "task4_grl_model.pth"
    torch.save(model.state_dict(), save_path)
    print(f"\n[Done] Model saved to {save_path}")
    
    # ========== VISUALIZATIONS ==========
    print("\n[Visualizations] Generating plots...")
    
    # 1. Training History
    plt.figure(figsize=(15, 4))
    
    plt.subplot(1, 3, 1)
    plt.plot(history['digit_loss'], label='Digit Loss', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Digit Classification Loss (GRL)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 2)
    plt.plot(history['color_loss'], label='Color Loss (Adversarial)', color='orange', linewidth=2)
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Color Discriminator Loss (Reversed)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 3, 3)
    plt.plot(history['test_acc'], label='Hard Test Acc', color='green', linewidth=2)
    plt.axhline(70, color='r', linestyle='--', alpha=0.7, label='Target (70%)')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Performance on Hard Test Set')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('task4_grl_training.png', dpi=150, bbox_inches='tight')
    print("  Saved: task4_grl_training.png")
    plt.show()
    
    # 2. Confusion Matrix
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for x, y in loader_test:
            x, y = x.to(config.DEVICE), y.to(config.DEVICE)
            digit_out, _ = model(x)
            preds = torch.max(digit_out, 1)[1]
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    
    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Greens', xticklabels=range(10), yticklabels=range(10))
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix - GRL on Hard Test Set')
    plt.savefig('task4_grl_confusion.png', dpi=150, bbox_inches='tight')
    print("  Saved: task4_grl_confusion.png")
    plt.show()
    
    # 3. Sample Predictions
    test_data = np.load(os.path.join(config.DATA_DIR, config.TEST_FILE))
    X_test_np = test_data['images']
    y_test_np = test_data['labels']
    
    sample_indices = np.random.choice(len(X_test_np), 10, replace=False)
    
    plt.figure(figsize=(15, 6))
    for i, idx in enumerate(sample_indices):
        img_tensor = torch.FloatTensor(X_test_np[idx:idx+1] / 255.0).permute(0, 3, 1, 2).to(config.DEVICE)
        with torch.no_grad():
            digit_out, color_out = model(img_tensor)
            pred = torch.max(digit_out, 1)[1].item()
            color_pred = torch.max(color_out, 1)[1].item()
        true_label = y_test_np[idx]
        
        plt.subplot(2, 5, i+1)
        plt.imshow(X_test_np[idx])
        color_names = ['Red', 'Green', 'Blue']
        plt.title(f'True: {true_label}, Pred: {pred}\nColor: {color_names[color_pred]}', 
                 color='green' if pred == true_label else 'red', fontsize=9)
        plt.axis('off')
    
    plt.suptitle('GRL: Sample Predictions (showing color confusion)', fontsize=14, y=1.02)
    plt.tight_layout()
    plt.savefig('task4_grl_samples.png', dpi=150, bbox_inches='tight')
    print("  Saved: task4_grl_samples.png")
    plt.show()
    
    # 4. Final Summary
    final_acc = history['test_acc'][-1]
    print(f"\n{'='*60}")
    print(f"FINAL RESULTS - Gradient Reversal Layer (GRL)")
    print(f"{'='*60}")
    print(f"Final Test Accuracy: {final_acc:.2f}%")
    print(f"Target Achieved: {'✓ YES' if final_acc >= 70 else '✗ NO'}")
    print(f"Best Test Accuracy: {max(history['test_acc']):.2f}% (Epoch {np.argmax(history['test_acc'])+1})")
    print(f"{'='*60}")

def evaluate(model, loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            digit_out, _ = model(x)
            _, preds = torch.max(digit_out, 1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return 100 * correct / total

if __name__ == "__main__":
    train(Config)
