In [None]:
from models.model import Student, Teacher
import torchvision
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import cv2

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

In [None]:
student = Student(num_classes=2, num_layers=[4, 6, 8, 10], growth_rate=16)
student.load_state_dict(torch.load("weights.pth"))

teacher = Teacher(2)
teacher.load_state_dict(torch.load("weights.pth"))

In [None]:
class End(nn.Module):
    
    def __init__(self, model):
        super(End, self).__init__()
        self.conv = model.branch.conv
        self.downsample = model.branch.downsample
        self.batchnorm = model.branch.avgpool[0]
        self.relu = model.branch.avgpool[1]

    def forward(self, x):
        tmp = x
        x = self.conv(x)
        tmp = self.downsample(tmp)
        x += tmp
        x = self.relu(self.batchnorm(x))

        return x

class StudentAuxiliary(nn.Module):
    def __init__(self, model):
        super(StudentAuxiliary, self).__init__()

        self.features = nn.Sequential(
            model.stem,
            model.layer1,
            model.transition1,
            model.layer2,
            model.transition2,
            model.layer3,
            model.transition3,
            End(model)
        )

        self.avgpool = model.branch.avgpool[2]
        self.fc = model.branch.fc

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


class StudentMain(nn.Module):
    def __init__(self, model):
        super(StudentMain, self).__init__()

        self.features = nn.Sequential(
                        model.stem,
                        model.layer1,
                        model.transition1,
                        model.layer2,
                        model.transition2,
                        model.layer3,
                        model.transition3,
                        model.layer4,
                        model.avgpool[0],
                        model.avgpool[1],
                    )

        self.avgpool = model.avgpool[2]
        self.fc = model.fc

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        
        x = self.fc(x)
        
        return x


class TeacherAuxiliary(nn.Module):
    def __init__(self, model):
        super(TeacherAuxiliary, self).__init__()
        
        self.features = nn.Sequential(model.stem,
                                      model.layer1,
                                      model.layer2,
                                      model.layer3, 
                                      End(model))
        
        self.avgpool = model.branch.avgpool[2]
        self.fc = model.branch.fc

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        
        x = self.fc(x)
        return x

class TeacherMain(nn.Module):
    def __init__(self, model):
        super(TeacherMain, self).__init__()

        self.features = nn.Sequential(
                            model.stem,
                            model.layer1,
                            model.layer2,
                            model.layer3,
                            model.layer4)
        
        self.avgpool = model.avgpool
        self.fc = model.fc
    
    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.hook_handles = []
        self.hook_handles.append(target_layer.register_forward_hook(self._save_activation))
        self.hook_handles.append(target_layer.register_full_backward_hook(self._save_gradient))

    def _save_activation(self, module, input, output):
        self.activations = output.detach()

    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def __call__(self, input_tensor, target_category=None):
        self.model.zero_grad()
        output = self.model(input_tensor)
        
        if target_category is None:
            target_category = torch.argmax(output, dim=1).item()
            
        one_hot = torch.zeros_like(output).to(input_tensor.device)
        one_hot[:, target_category] = 1
        
        target_score = (output * one_hot).sum()
        target_score.backward()
    
        gradients = self.gradients
        activations = self.activations
        
        alpha = torch.mean(gradients, dim=(2, 3), keepdim=True)
        cam = (alpha * activations).sum(dim=1, keepdim=True)
        
        cam = F.relu(cam)
        cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False)
        cam = cam.squeeze(0)
        
        cam = cam - cam.min()
        if cam.max() > 1e-6:
            cam = cam / cam.max()
        else:
            cam = torch.zeros_like(cam)
            
        return cam.squeeze().cpu().numpy()

    def __del__(self):
        for handle in self.hook_handles:
            handle.remove()

def show_cam_on_image(img: np.ndarray, mask: np.ndarray) -> np.ndarray:
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    superimposed_img = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
    return superimposed_img

def deprocess_image(img_tensor, mean, std):
    mean_tensor = torch.tensor(mean).view(3, 1, 1).to(img_tensor.device)
    std_tensor = torch.tensor(std).view(3, 1, 1).to(img_tensor.device)
    img = img_tensor * std_tensor + mean_tensor
    img = img.permute(1, 2, 0).cpu().numpy()
    img = np.clip(img * 255, 0, 255).astype(np.uint8)
    return img

In [None]:
student_aux = StudentAuxiliary(student).to(device).eval()
student_main = StudentMain(student).to(device).eval()

teacher_aux = TeacherAuxiliary(teacher).to(device).eval()
teacher_main = TeacherMain(teacher).to(device).eval()

In [None]:
single_image_path = ""

target_layer = student_main.features
grad_cam = GradCAM(student_main, target_layer)

original_img_pil = Image.open(single_image_path).convert("RGB")
width, height = original_img_pil.size
original_img_np = np.array(original_img_pil)

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

data_transform = transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])

img_tensor = data_transform(original_img_pil).unsqueeze(0).to(device)

heatmap = grad_cam(img_tensor)
denormalized_img_np = deprocess_image(img_tensor.squeeze(0), mean, std)
overlayed_img = show_cam_on_image(denormalized_img_np, heatmap)

heatmap = Image.fromarray(np.uint8(255 * heatmap)).resize((width, height), resample=Image.BILINEAR)
overlayed_img = Image.fromarray(overlayed_img).resize((width, height), resample=Image.BILINEAR)

plt.figure(figsize=(6, 6))
plt.imshow(overlayed_img)
plt.axis('off')
plt.tight_layout()

plt.show()

In [None]:
single_image_path = ""

target_layer = student_aux.features
grad_cam = GradCAM(student_aux, target_layer)

original_img_pil = Image.open(single_image_path).convert("RGB")
width, height = original_img_pil.size
original_img_np = np.array(original_img_pil)

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

data_transform = transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])

img_tensor = data_transform(original_img_pil).unsqueeze(0).to(device)

heatmap = grad_cam(img_tensor)
denormalized_img_np = deprocess_image(img_tensor.squeeze(0), mean, std)
overlayed_img = show_cam_on_image(denormalized_img_np, heatmap)

heatmap = Image.fromarray(np.uint8(255 * heatmap)).resize((width, height), resample=Image.BILINEAR)
overlayed_img = Image.fromarray(overlayed_img).resize((width, height), resample=Image.BILINEAR)

plt.figure(figsize=(6, 6))
plt.imshow(overlayed_img)
plt.axis('off')
plt.tight_layout()

plt.show()

In [None]:
single_image_path = ""

target_layer = teacher_main.features
grad_cam = GradCAM(teacher_main, target_layer)

original_img_pil = Image.open(single_image_path).convert("RGB")
width, height = original_img_pil.size
original_img_np = np.array(original_img_pil)

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

data_transform = transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])

img_tensor = data_transform(original_img_pil).unsqueeze(0).to(device)

heatmap = grad_cam(img_tensor)
denormalized_img_np = deprocess_image(img_tensor.squeeze(0), mean, std)
overlayed_img = show_cam_on_image(denormalized_img_np, heatmap)

heatmap = Image.fromarray(np.uint8(255 * heatmap)).resize((width, height), resample=Image.BILINEAR)
overlayed_img = Image.fromarray(overlayed_img).resize((width, height), resample=Image.BILINEAR)

plt.figure(figsize=(6, 6))
plt.imshow(overlayed_img)
plt.axis('off')
plt.tight_layout()

plt.show()

In [None]:
single_image_path = ""

target_layer = teacher_aux.features
grad_cam = GradCAM(teacher_aux, target_layer)

original_img_pil = Image.open(single_image_path).convert("RGB")
width, height = original_img_pil.size
original_img_np = np.array(original_img_pil)

mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

data_transform = transforms.Compose([
        transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.LANCZOS),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)])

img_tensor = data_transform(original_img_pil).unsqueeze(0).to(device)

heatmap = grad_cam(img_tensor)
denormalized_img_np = deprocess_image(img_tensor.squeeze(0), mean, std)
overlayed_img = show_cam_on_image(denormalized_img_np, heatmap)

heatmap = Image.fromarray(np.uint8(255 * heatmap)).resize((width, height), resample=Image.BILINEAR)
overlayed_img = Image.fromarray(overlayed_img).resize((width, height), resample=Image.BILINEAR)

plt.figure(figsize=(6, 6))
plt.imshow(overlayed_img)
plt.axis('off')
plt.tight_layout()

plt.show()