In [5]:
import sys
import os

# Add the root directory to the path so Python can find 'src'
project_root = os.path.abspath(os.path.join(os.path.dirname(os.getcwd())))
if project_root not in sys.path:
    sys.path.append(project_root)

In [6]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import numpy as np
from torchcam.methods import SmoothGradCAMpp
from torchcam.utils import overlay_mask
from src.model import TrashNetClassifier
from src import config

In [None]:
model = TrashNetClassifier()
model.load_state_dict(torch.load(
    config.MODEL_SAVE_PATH, map_location=config.DEVICE))
model.eval().to(config.DEVICE)


def load_image(image_path, image_size):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])
    image = Image.open(image_path).convert("RGB")
    return image, transform(image).unsqueeze(0)


img_path = "data/inference_test/colabottle.png"
pil_img, input_tensor = load_image(img_path, config.IMAGE_SIZE)
input_tensor = input_tensor.to(config.DEVICE)


cam_extractor = SmoothGradCAMpp(model.backbone)


with torch.no_grad():
    output = model(input_tensor)
    pred_class = torch.argmax(output).item()

activation_map = cam_extractor(pred_class, output)[0].cpu().numpy()


to_pil = transforms.ToPILImage()
resized_img = pil_img.resize(
    (activation_map.shape[1], activation_map.shape[0]))
heatmap = overlay_mask(resized_img, Image.fromarray(
    (activation_map * 255).astype(np.uint8), mode='L'), alpha=0.6)


plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(resized_img)
plt.axis("off")

plt.subplot(1, 2, 2)
plt.title("Grad-CAM Overlay")
plt.imshow(heatmap)
plt.axis("off")

plt.suptitle(
    f"Predicted: {pred_class} | Image: {img_path.split('/')[-1]}", fontsize=14)
plt.tight_layout()
plt.show()