In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "02bfa0a8",
   "metadata": {},
   "source": [
    "# Comparing Explanation Methods with deit_tiny_finetuned\n",
    "\n",
    "This notebook loads a finetuned deit_tiny_patch16_224 model (with a custom head and weights), then runs several explanation methods on each image from a dataset (e.g., COVID-Q). It creates a composite image comparing the different explanation masks.\n",
    "\n",
    "Make sure that all your functions (such as `deit_tiny_finetuned`, `run_explanation`, `show_mask_on_image`, etc.) are available (e.g., in your PYTHONPATH) before running this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7ac3ad3b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import cv2\n",
    "import json\n",
    "import numpy as np\n",
    "import torch\n",
    "from PIL import Image\n",
    "from torchvision import transforms\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "# Import your explanation functions and model loading functions\n",
    "# Adjust these imports to point to your modules\n",
    "from vit_rollout import VITAttentionRollout\n",
    "from vit_grad_rollout import VITAttentionGradRollout\n",
    "from vit_explainability import VITTransformerExplainability\n",
    "from vit_LRPmimic import VITTransformerLRPmimic\n",
    "from vit_LRPexact import load_model_LRP, VITTransformerLRPexact\n",
    "from load_deit import load_deit  # if used in your fine-tuning function\n",
    "\n",
    "# Custom finetuned model loader\n",
    "def deit_tiny_finetuned(head_weights_path, pretrained=True, **kwargs):\n",
    "    from vit_explainability import VisionTransformer, _cfg  # adjust as needed\n",
    "    import torch.nn as nn\n",
    "\n",
    "    # Define a custom linear layer with relprop\n",
    "    class LinearRelprop(nn.Linear):\n",
    "        def forward(self, x):\n",
    "            self.input = x\n",
    "            return super().forward(x)\n",
    "        def relprop(self, cam, epsilon=1e-6, **kwargs):\n",
    "            z = torch.matmul(self.input, self.weight.t()) + epsilon\n",
    "            s = cam / z\n",
    "            c = torch.matmul(s, self.weight)\n",
    "            return self.input * c\n",
    "\n",
    "    model = VisionTransformer(\n",
    "        patch_size=16,\n",
    "        embed_dim=192,\n",
    "        depth=12,\n",
    "        num_heads=3,\n",
    "        mlp_ratio=4,\n",
    "        qkv_bias=True,\n",
    "        **kwargs\n",
    "    )\n",
    "    model.default_cfg = _cfg()\n",
    "    if pretrained:\n",
    "        checkpoint = torch.hub.load_state_dict_from_url(\n",
    "            url=\"https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth\",\n",
    "            map_location=\"cpu\",\n",
    "            check_hash=True\n",
    "        )\n",
    "        model.load_state_dict(checkpoint[\"model\"])\n",
    "    in_features = model.head.in_features\n",
    "    model.head = LinearRelprop(in_features, 2)\n",
    "    head_state_dict = torch.load(head_weights_path, map_location=torch.device('cpu'))\n",
    "    model.head.load_state_dict(head_state_dict)\n",
    "    return model\n",
    "\n",
    "# Function to load model based on the model name\n",
    "def load_model(model_name, parameters):\n",
    "    print(\"Loading model:\", model_name)\n",
    "    if model_name == \"deit_tiny_finetuned\":\n",
    "        model = deit_tiny_finetuned(parameters[\"weights_path\"], pretrained=parameters[\"pretrained\"])\n",
    "    elif model_name == \"deit_tiny\":\n",
    "        model = torch.hub.load('facebookresearch/deit:main', \"deit_tiny_patch16_224\", parameters[\"pretrained\"])\n",
    "    else:\n",
    "        raise ValueError(\"Unknown model: {}\".format(model_name))\n",
    "    return model\n",
    "\n",
    "# A simple function to overlay a heatmap on an image\n",
    "def show_mask_on_image(img, mask):\n",
    "    img = np.float32(img) / 255\n",
    "    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)\n",
    "    heatmap = np.float32(heatmap) / 255\n",
    "    cam = heatmap + np.float32(img)\n",
    "    cam = cam / np.max(cam)\n",
    "    return np.uint8(255 * cam)\n",
    "\n",
    "# A helper function to create a composite image from the original and masks\n",
    "def create_composite(original_img, results_dict):\n",
    "    orig = np.array(original_img)  # in RGB\n",
    "    orig = cv2.cvtColor(orig, cv2.COLOR_RGB2BGR)\n",
    "    orig = cv2.resize(orig, (224, 224))\n",
    "    composite_images = [orig]\n",
    "    for method, mask in results_dict.items():\n",
    "        m = cv2.resize(mask, (orig.shape[1], orig.shape[0]))\n",
    "        overlay = show_mask_on_image(orig, m)\n",
    "        cv2.putText(overlay, method, (10, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)\n",
    "        composite_images.append(overlay)\n",
    "    comp = np.hstack(composite_images)\n",
    "    return comp\n",
    "\n",
    "# Define a dictionary of methods to compare\n",
    "methods = [\"attention\", \"gradient\", \"explainability\", \"LRP_mimic\", \"LRP_exact\"]\n",
    "\n",
    "# Parameters for the model, as a dictionary (this was originally passed as a JSON string)\n",
    "model_parameters = {\"pretrained\": True, \"weights_path\": \"weights/deit_tiny_head_weights.pth\"}\n",
    "\n",
    "# Set device\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "# Load the finetuned model\n",
    "model = load_model(\"deit_tiny_finetuned\", model_parameters)\n",
    "model = model.to(device)\n",
    "model.eval()\n",
    "\n",
    "# Transformation for the input images\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),\n",
    "])\n",
    "\n",
    "# Create a dummy args dictionary (simulate command-line arguments)\n",
    "args = {\n",
    "    \"category_index\": None,\n",
    "    \"discard_ratio\": 0.9,\n",
    "    \"head_fusion\": \"max\",\n",
    "    \"attention_layer_name\": \"attn_drop\",\n",
    "    \"model_name\": \"deit_tiny_finetuned\"\n",
    "}\n",
    "\n",
    "# %% [markdown]\n",
    "# ## Process Images and Compare Explanation Methods\n",
    "# Loop over a subset of images from the dataset and generate composite images.\n",
    "\n",
    "# Get list of images from dataset directory (adjust extension if necessary)\n",
    "dataset_dir = \"/path/to/covidqu\"  # update to your dataset directory\n",
    "output_dir = \"comparisons\"\n",
    "os.makedirs(output_dir, exist_ok=True)\n",
    "image_files = glob.glob(os.path.join(dataset_dir, \"*.[jp][pn]g\"))\n",
    "\n",
    "print(f\"Found {len(image_files)} images.\")\n",
    "\n",
    "# Process only a few images for demonstration (e.g., first 5 images)\n",
    "for img_path in image_files[:5]:\n",
    "    print(\"Processing\", img_path)\n",
    "    img = Image.open(img_path).convert(\"RGB\")\n",
    "    input_tensor = transform(img).unsqueeze(0).to(device)\n",
    "    results = {}\n",
    "    for method in methods:\n",
    "        try:\n",
    "            # Here, we assume run_explanation is defined and imported from your module\n",
    "            from vit_explain_modulable import run_explanation\n",
    "            mask, _ = run_explanation(method, model, input_tensor, args)\n",
    "            results[method] = mask\n",
    "        except Exception as e:\n",
    "            print(f\"Method {method} failed on {img_path}: {e}\")\n",
    "    composite = create_composite(img, results)\n",
    "    base_name = os.path.basename(img_path)\n",
    "    out_path = os.path.join(output_dir, f\"composite_{base_name}\")\n",
    "    cv2.imwrite(out_path, composite)\n",
    "    print(\"Saved composite image to\", out_path)\n",
    "    plt.figure(figsize=(12,6))\n",
    "    plt.imshow(cv2.cvtColor(composite, cv2.COLOR_BGR2RGB))\n",
    "    plt.axis(\"off\")\n",
    "    plt.title(f\"Comparison for {base_name}\")\n",
    "    plt.show()\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.x"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
