# Experiments: Temporal Stability of Gradient-Based Explanations

In this notebook we run all training experiments for our project.
We train CNN models on MNIST and CIFAR-10 and track gradient-based explanations
(saliency and integrated gradients) over training epochs.

We vary:
- Dataset: MNIST vs CIFAR-10  
- Model size: Small vs Deep CNN  
- Learning rate: 0.001 vs 0.01  
- Random seeds: 10 per configuration  

After each epoch, we compute explanations on a fixed set of test samples and
store both accuracy and explanations for later analysis.


In [1]:
import random
import numpy as np
import torch


def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)


Using device: cpu


# Models #

In [2]:
import torch.nn as nn
import torch.nn.functional as F

class SmallCNN(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, 16, 3)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(16, 2)

    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class DeepCNN(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, 3)
        self.conv2 = nn.Conv2d(32, 64, 3)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 2)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)



# Datatset #

In [3]:
from torchvision import datasets, transforms

def get_dataset(name):
    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    if name == "mnist":
        train = datasets.MNIST("data", train=True, download=True, transform=transform)
        test = datasets.MNIST("data", train=False, download=True, transform=transform)
        classes = [3, 8]

    elif name == "cifar":
        train = datasets.CIFAR10("data", train=True, download=True, transform=transform)
        test = datasets.CIFAR10("data", train=False, download=True, transform=transform)
        classes = [3, 5]  # cat vs dog

    def filter_binary(ds):
        idx = [i for i, (_, y) in enumerate(ds) if y in classes]
        ds.data = ds.data[idx]
        ds.targets = [0 if y == classes[0] else 1 for y in ds.targets if y in classes]
        return ds

    return filter_binary(train), filter_binary(test)


# Gradient Methods

In [4]:
def vanilla_gradients(model, inputs, targets):
    """
    inputs: Tensor [N, C, H, W]
    targets: Tensor [N] (class indices)
    returns: Tensor [N, C, H, W]
    """
    inputs = inputs.clone().detach().requires_grad_(True)
    model.zero_grad()

    outputs = model(inputs)
    selected = outputs.gather(1, targets.view(-1, 1)).sum()
    selected.backward()

    return inputs.grad.detach()


def integrated_gradients(model, inputs, targets, steps=20):
    """
    Integrated Gradients with zero baseline
    """
    baseline = torch.zeros_like(inputs)
    total_gradients = torch.zeros_like(inputs)

    for alpha in torch.linspace(0, 1, steps):
        interpolated = baseline + alpha * (inputs - baseline)
        interpolated.requires_grad_(True)

        model.zero_grad()
        outputs = model(interpolated)
        selected = outputs.gather(1, targets.view(-1, 1)).sum()
        selected.backward()

        total_gradients += interpolated.grad.detach()

    avg_gradients = total_gradients / steps
    return (inputs - baseline) * avg_gradients


# Training 

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss


def train_with_explanations(
    model,
    train_dataset,
    test_dataset,
    device,
    epochs,
    learning_rate,
    seed,
    explain_samples=32,
    explain_method="vanilla",      # "vanilla" or "integrated"
    ig_steps=20,
    batch_size=128
):
    """
    Trains a model and computes explanations after each epoch.

    Returns:
        results: dict with
            - accuracy: list[epochs]
            - explanations: list[epochs] of Tensor [N, C, H, W]
            - config: experiment metadata
    """

    # Setup
    model.to(device)
    optimizer = Adam(model.parameters(), lr=learning_rate)
    criterion = CrossEntropyLoss()

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False
    )

    # Fixed explanation samples
    rng = torch.Generator().manual_seed(seed)
    idx = torch.randperm(len(test_dataset), generator=rng)[:explain_samples]

    explain_x = torch.stack([test_dataset[i][0] for i in idx]).to(device)
    explain_y = torch.tensor(
        [test_dataset[i][1] for i in idx], device=device
    )

    accuracy_history = []
    explanation_history = []

    # Training Loop

    for epoch in range(epochs):
        # Train
        model.train()
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            loss = criterion(model(x), y)
            loss.backward()
            optimizer.step()

        # Evaluate
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for x, y in test_loader:
                x, y = x.to(device), y.to(device)
                preds = model(x).argmax(dim=1)
                correct += (preds == y).sum().item()
                total += y.size(0)

        acc = correct / total
        accuracy_history.append(acc)

        # Explain
        model.eval()
        with torch.enable_grad():
            if explain_method == "vanilla":
                expl = vanilla_gradients(
                    model, explain_x, explain_y
                )
            elif explain_method == "integrated":
                expl = integrated_gradients(
                    model, explain_x, explain_y, steps=ig_steps
                )
            else:
                raise ValueError("Unknown explanation method")

        explanation_history.append(expl.cpu())


    # Return raw results
    return {
        "accuracy": accuracy_history,                  # [epochs]
        "explanations": explanation_history,            # [epochs, N, C, H, W]
        "method": explain_method,
        "learning_rate": learning_rate,
        "seed": seed
    }


# Experiments

In [None]:
import os
import torch
# Run a single training experiment for one specific configuration.
#
# This function:
# 1) Sets the random seed for reproducibility,
# 2) Loads the chosen dataset,
# 3) Initializes the specified model,
# 4) Trains the model for a fixed number of epochs while
#    computing explanations after each epoch,
# 5) Saves accuracy, explanations, and all metadata to a .pt file.


def run_single_experiment(
    dataset_name="mnist",
    model_type="small",          # "small" or "deep"
    learning_rate=0.001,
    seed=0,
    explain_method="vanilla",    # "vanilla" or "integrated"
    epochs=50,
    device="cuda",
    save_dir="results"
):

    set_seed(seed)
    os.makedirs(save_dir, exist_ok=True)

    train_ds, test_ds = get_dataset(dataset_name)

    in_channels = 1 if dataset_name == "mnist" else 3
    if model_type == "small":
        model = SmallCNN(in_channels)
    elif model_type == "deep":
        model = DeepCNN(in_channels)
    else:
        raise ValueError("Unknown model type")


    results = train_with_explanations(
        model=model,
        train_dataset=train_ds,
        test_dataset=test_ds,
        device=device,
        epochs=epochs,
        learning_rate=learning_rate,
        seed=seed,
        explain_method=explain_method
    )


    filename = (
        f"{dataset_name}_"
        f"{model_type}_"
        f"lr{learning_rate}_"
        f"seed{seed}_"
        f"{explain_method}.pt"
    )

    save_path = os.path.join(save_dir, filename)

    torch.save({
        "dataset": dataset_name,
        "model": model_type,
        "learning_rate": learning_rate,
        "seed": seed,
        "explain_method": explain_method,
        "epochs": epochs,
        "accuracy": results["accuracy"],           # list[epochs]
        "explanations": results["explanations"]     # list[epochs, N, C, H, W]
    }, save_path)

    print(f"Saved results to: {save_path}")



*One Experiment*

In [20]:
run_single_experiment(
    dataset_name="cifar",
    model_type="deep",
    learning_rate=0.01,
    seed=6,
    explain_method="integrated",
    device="cuda" if torch.cuda.is_available() else "cpu"
)

Saved results to: results\cifar_deep_lr0.01_seed6_integrated.pt


*All at once*

In [None]:
# Run the full experiment grid over all configurations:
# - Dataset (MNIST, CIFAR)
# - Model size (small, deep)
# - Learning rate (0.001, 0.01)
# - Random seeds (10 per setting)
# - Explanation method (saliency / vanilla gradients, integrated gradients)
#
# For each configuration, exactly one training run is executed and saved to disk.

DATASETS = ["mnist", "cifar"]
MODELS = ["small", "deep"]
LEARNING_RATES = [0.001, 0.01]
SEEDS = list(range(10))
EXPL_METHODS = ["vanilla", "integrated"]

EPOCHS = 50
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SAVE_DIR = "results"


def run_all():
    total_runs = (
        len(DATASETS)
        * len(MODELS)
        * len(LEARNING_RATES)
        * len(SEEDS)
        * len(EXPL_METHODS)
    )

    run_id = 1
    print(f"Total runs: {total_runs}")

    for dataset in DATASETS:
        for model in MODELS:
            for lr in LEARNING_RATES:
                for seed in SEEDS:
                    for expl in EXPL_METHODS:
                        print(
                            f"[{run_id}/{total_runs}] "
                            f"dataset={dataset}, model={model}, "
                            f"lr={lr}, seed={seed}, expl={expl}"
                        )

                        run_single_experiment(
                            dataset_name=dataset,
                            model_type=model,
                            learning_rate=lr,
                            seed=seed,
                            explain_method=expl,
                            epochs=EPOCHS,
                            device=DEVICE,
                            save_dir=SAVE_DIR
                        )

                        run_id += 1


run_all()


Total runs: 80
[1/80] dataset=cifar, model=small, lr=0.001, seed=0, expl=vanilla
Saved results to: results\cifar_small_lr0.001_seed0_vanilla.pt
[2/80] dataset=cifar, model=small, lr=0.001, seed=0, expl=integrated
Saved results to: results\cifar_small_lr0.001_seed0_integrated.pt
[3/80] dataset=cifar, model=small, lr=0.001, seed=1, expl=vanilla
Saved results to: results\cifar_small_lr0.001_seed1_vanilla.pt
[4/80] dataset=cifar, model=small, lr=0.001, seed=1, expl=integrated
Saved results to: results\cifar_small_lr0.001_seed1_integrated.pt
[5/80] dataset=cifar, model=small, lr=0.001, seed=2, expl=vanilla
Saved results to: results\cifar_small_lr0.001_seed2_vanilla.pt
[6/80] dataset=cifar, model=small, lr=0.001, seed=2, expl=integrated
Saved results to: results\cifar_small_lr0.001_seed2_integrated.pt
[7/80] dataset=cifar, model=small, lr=0.001, seed=3, expl=vanilla
Saved results to: results\cifar_small_lr0.001_seed3_vanilla.pt
[8/80] dataset=cifar, model=small, lr=0.001, seed=3, expl=integr