In [None]:
!unzip -q Visualization.zip -d vizdata # download the zip file from email
# !pip install --upgrade pip setuptools
# !pip install --use-pep517 torch torchvision pytorch-gradcam matplotlib lime timm grad-cam

In [None]:
import os

from utils import download, suppress_stdout_stderr

model_files = [
    "maianet_nirmal.pth",
    "soyatrans_nirmal.pth",
    "tswinf_nirmalsankana.pth",
    "maianet_pungliya.pth",
    "soyatrans_pungliya.pth",
    "tswinf_pungliyavithika.pth",
    "maianet_mendeley.pth",
    "soyatrans_mendeley.pth",
    "tswinf_mendeley.pth",
]

download(model_files)


In [None]:
from torchvision import datasets, transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# nirmalsankana
# {'Healthy': 0, 'Mosaic': 1, 'RedRot': 2, 'Rust': 3, 'Yellow': 4}

# pungliyavithika
# {'Healthy': 0, 'RedRot': 1, 'RedRust': 2}

vdata = datasets.ImageFolder("vizdata/Visualization/pungliyavithika", transform=transform)
ndata = datasets.ImageFolder("vizdata/Visualization/nirmalsankana", transform=transform)
mdata = datasets.ImageFolder("vizdata/Visualization/mendeley", transform=transform)


def match_dataset(model_name):
    if "nirmal" in model_name:
        dataset = ndata
    elif "mendeley" in model_name:
        dataset = mdata
    else:
        dataset = vdata
    return dataset

In [None]:
import torch
from build import build_model
from config import get_config

from maianet import MaiaNet
from soyatrans import SoyaTrans

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_objects = [MaiaNet(5), SoyaTrans(5), build_model(get_config(), 5), MaiaNet(3), SoyaTrans(3), build_model(get_config(), 5), MaiaNet(11), SoyaTrans(11), build_model(get_config(), 11)]


def instance(model, file):
    checkpoint = torch.load(f"models/{file}")
    model.load_state_dict(checkpoint)
    model.eval()
    model.to(device)
    print(f"{file} loaded")

    if "maianet" in file:
        target_layer = [model.maia_4.conv3[0]]
    elif "soyatrans" in file:
        target_layer = [model.stage1.downsample]
    elif "tswinf" in file:
        # target_layer = [model.stage4[0].attns[0].get_v]
        target_layer = [model.LCA.conv1[0]]
    return model, target_layer


models = {file: instance(model, file) for file, model in zip(model_files, model_objects)}

  from .autonotebook import tqdm as notebook_tqdm


maianet_mendeley.pth loaded


In [None]:
import gc

import numpy as np
import torch
from PIL import Image
from pytorch_grad_cam import (
    GradCAM,
    GradCAMPlusPlus,
    ScoreCAM,  # Needed for isinstance check
)
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from tqdm import tqdm


def tensor_to_rgb_image(tensor):
    img = tensor.clone().detach().cpu()
    img = img * 0.5 + 0.5  # reverse normalization
    img = img.permute(1, 2, 0).numpy()
    img = np.clip(img, 0, 1)
    return img


def plot(model, image, image_path, class_index, cams, output_dir="data/cam_outputs"):
    base_name = os.path.splitext(os.path.basename(image_path))[0]
    os.makedirs(output_dir, exist_ok=True)

    device = next(model.parameters()).device
    input_tensor = image.unsqueeze(0).to(device)
    rgb_img = tensor_to_rgb_image(image)

    target = [ClassifierOutputTarget(class_index)]

    # Save original image
    original_path = os.path.join(output_dir, f"{base_name}_original.jpg")
    Image.fromarray((rgb_img * 255).astype(np.uint8)).save(original_path)

    for name, cam_method in cams.items():
        if isinstance(cam_method, ScoreCAM):
            with torch.no_grad():
                grayscale_cam = cam_method(input_tensor=input_tensor, targets=target)[0]
        else:
            grayscale_cam = cam_method(input_tensor=input_tensor, targets=target)[0]

        cam_image = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
        save_path = os.path.join(output_dir, f"{base_name}_{name}.jpg")
        Image.fromarray(cam_image).save(save_path)


In [None]:

import torch

for model_name, model_item in models.items():

    model, target_layers = model_item
    dataset = match_dataset(model_name)

    model = model.cuda() if torch.cuda.is_available() else model.cpu()
    model.eval()

    # Create CAM methods only once per model
    cams = {
        "Grad-CAM": GradCAM(model=model, target_layers=target_layers),
        "Grad-CAM++": GradCAMPlusPlus(model=model, target_layers=target_layers),
        "Score-CAM": ScoreCAM(model=model, target_layers=target_layers),
    }

    with tqdm(total=len(dataset), desc=f"{model_name}", leave=True) as pbar:
        for idx in range(len(dataset)):
            image, label = dataset[idx]
            image_path, _ = dataset.imgs[idx]

            # Suppress plot outputs
            with suppress_stdout_stderr():
                base_name = os.path.splitext(os.path.basename(image_path))[0]
                output_dir = f"output/{model_name}"
                original_path = os.path.join(output_dir, f"{base_name}_original.jpg")

                if not os.path.isfile(original_path):
                    # Plot only if original image is not already saved
                    plot(model, image, image_path, label, cams, output_dir=f"output/{model_name}")

            pbar.update(1)

    # Clean up
    del model, cams
    torch.cuda.empty_cache()
    gc.collect()


Starting: maianet_mendeley.pth


maianet_mendeley.pth: 100%|██████████| 165/165 [02:21<00:00,  1.17it/s] 
