In [None]:
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import timm
import torch
import torch.nn as nn

import torch.nn.functional as F
import torchvision
from PIL import Image
from sklearn.preprocessing import LabelEncoder
from torchvision import transforms, models

In [None]:
base_folder = r"/kaggle/input/manhmeo/GG/xzyx7pbr3w-1"
df = pd.read_csv(r"/kaggle/input/data-split/meta_data_64_16_20.csv")
df["Path"] = df.apply(lambda row: os.path.join(base_folder, row["Fish Name"], row["Path"]), axis=1)

In [None]:
label_encoder = LabelEncoder()
df["Label"] = label_encoder.fit_transform(df["Label"])

classes = df["Label"].unique()
num_classes = len(classes)
print(classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

In [None]:
class FishClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super(FishClassifier, self).__init__()
        self.desnet = models.densenet121(weights=None)
        self.desnet.classifier = nn.Sequential(
            nn.Linear(self.desnet.classifier.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
    def forward(self, x):
        return self.desnet(x)
        
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loaded_model = FishClassifier(num_classes=num_classes)

model_save_path = '/kaggle/input/dense-fish/DenseNet/result_DenseNet121/fish_classifier_DenseNet121.pth'
state_dict = torch.load(model_save_path, map_location=device)

new_state_dict = {}
for k, v in state_dict.items():
    new_k = k.replace("module.", "")
    new_state_dict[new_k] = v

loaded_model.load_state_dict(new_state_dict)
loaded_model = loaded_model.to(device)
loaded_model.eval()

print("Load model Dense thành công theo đúng cấu trúc!")

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self.hook_handles.append(target_layer.register_forward_hook(self._save_activation))
        self.hook_handles.append(target_layer.register_full_backward_hook(self._save_gradient))

    def _save_activation(self, module, input, output):
        self.activations = output.detach()

    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, input_tensor, target_category=None):
        self.model.zero_grad()
        output = self.model(input_tensor)
        
        if target_category is None:
            target_category = torch.argmax(output, dim=1).item()
            
        one_hot = torch.zeros_like(output).to(input_tensor.device)
        one_hot[:, target_category] = 1
        
        target_score = (output * one_hot).sum()
        target_score.backward()
    
        gradients = self.gradients
        activations = self.activations
        
        alpha = torch.mean(gradients, dim=(2, 3), keepdim=True)
        cam = (alpha * activations).sum(dim=1, keepdim=True)
        
        cam = F.relu(cam)
        cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False)
        cam = cam.squeeze(0)
        
        cam = cam - cam.min()
        if cam.max() > 1e-6:
            cam = cam / cam.max()
        else:
            cam = torch.zeros_like(cam)
            
        return cam.squeeze().cpu().numpy()

    def __del__(self):
        for handle in self.hook_handles:
            handle.remove()

def show_cam_on_image(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    superimposed_img = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
    return superimposed_img

def deprocess_image(img_tensor, mean, std):
    mean_tensor = torch.tensor(mean).view(3, 1, 1).to(img_tensor.device)
    std_tensor = torch.tensor(std).view(3, 1, 1).to(img_tensor.device)
    img = img_tensor * std_tensor + mean_tensor
    img = img.permute(1, 2, 0).cpu().numpy()
    img = np.clip(img * 255, 0, 255).astype(np.uint8)
    return img

In [None]:
single_image_path = "/kaggle/input/manhmeo/GG/xzyx7pbr3w-1/Chanos Chanos - Highly Fresh/IMG_20190930_070337.jpg"

target_layer = loaded_model.desnet.features.denseblock4
grad_cam = GradCAM(loaded_model, target_layer)

original_img_pil = Image.open(single_image_path).convert("RGB")
width, height = original_img_pil.size
original_img_np = np.array(original_img_pil)

data_transform = transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])

img_tensor = data_transform(original_img_pil).unsqueeze(0).to(device)

heatmap = grad_cam(img_tensor)
denormalized_img_np = deprocess_image(img_tensor.squeeze(0), mean, std)
overlayed_img = show_cam_on_image(denormalized_img_np, heatmap)

heatmap = Image.fromarray(np.uint8(255 * heatmap)).resize((width, height), resample=Image.BILINEAR)
overlayed_img = Image.fromarray(overlayed_img).resize((width, height), resample=Image.BILINEAR)

fig, axs = plt.subplots(1, 3, figsize=(12, 4))

axs[0].imshow(original_img_np)
axs[0].set_title("Original Image")
axs[0].axis('off')

axs[1].imshow(heatmap, cmap='jet')
axs[1].set_title("Grad-CAM Heatmap")
axs[1].axis('off')

axs[2].imshow(overlayed_img)
axs[2].set_title("Overlay")
axs[2].axis('off')

plt.show()

del grad_cam