# Library

In [1]:
import cv2
import matplotlib.pyplot as plt
import matplotlib as mpl 
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.optim
import torchvision.models as models
import torchvision.transforms as transforms
import warnings
from PIL import Image
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix, ConfusionMatrixDisplay
from torch.utils.data import DataLoader, TensorDataset
from torchvision.transforms.functional import InterpolationMode
from torch.nn.parallel import DataParallel
from tqdm import tqdm

warnings.filterwarnings("ignore")
device = torch.device("cuda")
if torch.cuda.device_count() > 1:
    print("Using", torch.cuda.device_count(), "GPUs!")
    parallel = True
else:
    parallel = False

In [2]:
train_path = 'C:/crisis_vision_benchmarks/tasks/damage_severity/consolidated/consolidated_damage_train_final.tsv'
test_path = 'C:/crisis_vision_benchmarks/tasks/damage_severity/consolidated/consolidated_damage_test_final.tsv'
dev_path = 'C:/crisis_vision_benchmarks/tasks/damage_severity/consolidated/consolidated_damage_dev_final.tsv'

train_label = pd.read_table(train_path)
test_label = pd.read_table(test_path)
dev_label = pd.read_table(dev_path)

print(train_label.shape)
print(test_label.shape)
print(dev_label.shape)

(28319, 4)
(3865, 4)
(2712, 4)


# Preprocessing

In [3]:
class_label_map = {"severe": 2, "mild": 1, "little_or_none": 0}

def get_data(pd_label, class_label_map):
    image_path = pd_label['image_path']
    class_label = pd_label['class_label']

    X, y, valid_indices = [], [], []

    tfms = transforms.Compose([
        transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    for i in tqdm(range(len(image_path))):
        path = os.path.join('C:/crisis_vision_benchmarks/', image_path[i])
        try:
            img = Image.open(path)
            if img is None:
                print(f"Error reading image: {path}")
                continue
            img = img.convert("RGB")
            img = tfms(img) 
            X.append(img)
            valid_indices.append(i)
        except Exception as e:
            print(f"Error opening image: {path} - {str(e)}")
            continue

    print(X[0].shape)
    X = np.stack(X)
    print(X.shape)

    for idx in tqdm(valid_indices):
        label = class_label[idx]
        if label not in class_label_map:
            print(f"Error: Unknown class label: {label}")
            continue
        y.append(class_label_map[label])

    y = np.array(y, dtype=np.int64)
    print(y.shape)

    return X, y

X_test, y_test = get_data(test_label, class_label_map)

100%|██████████| 3865/3865 [00:41<00:00, 92.06it/s] 


torch.Size([3, 256, 256])
(3865, 3, 256, 256)


100%|██████████| 3865/3865 [00:00<00:00, 773166.64it/s]

(3865,)





In [4]:
X_test = torch.tensor(X_test, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.int64)
testset = TensorDataset(X_test, y_test)
testloader = DataLoader(testset, batch_size=32, shuffle=False, num_workers=4)

In [5]:
model = models.efficientnet_b1(weights=models.EfficientNet_B1_Weights.IMAGENET1K_V1)
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_ftrs, len(class_label_map))
model = model.to(device)

# GradCAM

In [6]:
class GradCamModel(nn.Module): 
    def __init__(self): 
        super().__init__() 
        self.gradients = None 
        self.tensorhook = [] 
        self.layerhook = [] 
        self.selected_out = None 
 
        #PRETRAINED MODEL 
        self.pretrained = model
        self.layerhook.append(self.pretrained.features[-1].register_forward_hook(self.forward_hook())) 
 
        for p in self.pretrained.parameters(): 
            p.requires_grad = True 
 
    def activations_hook(self,grad): 
        self.gradients = grad 
 
    def get_act_grads(self): 
        return self.gradients 
 
    def forward_hook(self): 
        def hook(module, inp, out): 
            self.selected_out = out 
            self.tensorhook.append(out.register_hook(self.activations_hook)) 
        return hook 
 
    def forward(self,x): 
        out = self.pretrained(x) 
        return out, self.selected_out
    
gcmodel = GradCamModel().to(device)

# Inference

In [7]:
output_folder = 'D:/Research/Exploration/crisis_vision_benchmarks/gradcam'
image_index = 1

for images, labels in testloader:
    images = images.to(device)
    labels = labels.to(device)
    preds = model(images)

    _, predicted = torch.max(preds, 1)
    out, acts = gcmodel(images) 
    acts = acts.detach().cpu() 
    loss = nn.CrossEntropyLoss()(out, labels).to(device)
    loss.backward() 
    grads = gcmodel.get_act_grads().detach().cpu() 
    pooled_grads = torch.mean(grads, dim=[0,2,3]).detach().cpu()

    predicted = predicted.detach().cpu().numpy()
    labels = labels.detach().cpu().numpy()
    predicted_class = [k for k, v in class_label_map.items() if v == int(predicted[0])][0]
    true_class = [k for k, v in class_label_map.items() if v == int(labels[0])][0]
    #print(f"Predicted: {predicted_class}, Ground Truth: {true_class}")

    for i in range(acts.shape[1]):
        acts[:,i,:,:] += pooled_grads[i]
        
    heatmap_j = torch.mean(acts, dim = 1).squeeze()
    heatmap_j_max = heatmap_j.max(axis = 0)[0]
    heatmap_j /= heatmap_j_max
    heatmap_j = heatmap_j.detach().cpu().numpy()

    for i in range(len(heatmap_j)):
        heatmap_j_i = cv2.resize(heatmap_j[i],(224, 224))
        cmap = mpl.colormaps['inferno']
        heatmap_j2 = cmap(heatmap_j_i, alpha = 0.5)
        print(f"Predicted: {predicted_class}, Ground Truth: {true_class}")
        fig, axs = plt.subplots(1,1,figsize = (5,5))
        axs.axis('off')
        image_np = np.transpose(images.detach().cpu().numpy()[i], [1, 2, 0])
        image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)
        image_np = (image_np - np.min(image_np)) / (np.max(image_np) - np.min(image_np))
        axs.imshow(image_np)
        axs.imshow(heatmap_j2)
        plt.savefig(os.path.join(output_folder, f'GradCAM_{image_index}.jpg'))
        image_index += 1
        plt.show()
        plt.close()
