# An overview of generating bounding boxes from attention maps 

We first import the required libraries and modules.

In [1]:
import os
import sys
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import torch

import cv2
from PIL import Image
from torchvision import transforms
from torchvision import models
from torchvision.datasets import VOCSegmentation, VOCDetection
from skimage.measure import label, regionprops
import torch.nn as nn
from torchvision.models import resnet50,resnet18, resnet34, resnet101, resnet152

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from activation_maps import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM
from activation_maps.utils.image import show_cam_on_image

In [2]:
imagenet_label={
 0: 'broadleaved_indigenous_hardwood',
 1: 'deciduous_hardwood',
 2: 'grose_broom',
 3: 'harvested_forest',
 4: 'herbaceous_freshwater_vege',
 5: 'high_producing_grassland',
 6: 'indigenous_forest',
 7: 'lake_pond',
 8: 'low_producing_grassland',
 9: 'manuka_kanuka',
 10: 'shortrotation_cropland',
 11: 'urban_build_up',
 12: 'urban_parkland'}

We open and display a sample image.

In [3]:
import IPython.display as display
import scipy as scipy
import random
import pandas as pd
from torchvision import datasets, models, transforms
from shutil import copyfile
from scipy.special import softmax
from sklearn.metrics import confusion_matrix


In [4]:
data_dir = "../data/aerial_imagery2/classification"
val_data=datasets.ImageFolder(os.path.join(data_dir, 'val'))

In [5]:
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet18":
        """ Resnet18
        """
        model_ft = models.resnet18(pretrained=use_pretrained)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet34":
        """ Resnet34
        """
        model_ft = models.resnet34(pretrained=use_pretrained)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet50":
        """ Resnet50
        """
        model_ft = models.resnet50(pretrained=use_pretrained)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet101":
        """ Resnet101
        """
        model_ft = models.resnet101(pretrained=use_pretrained)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224
        
    elif model_name == "resnet152":
        """ Resnet152
        """
        model_ft = models.resnet152(pretrained=use_pretrained)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224    
    
    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(pretrained=use_pretrained)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes) 
        input_size = 224

    elif model_name == "inception":
        """ Inception v3 
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(pretrained=use_pretrained)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299
    elif model_name == "vits8":
        vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
        model_ft = nn.Sequential(vits8,nn.Linear(384,num_classes))
        input_size=224
        
    else:
        print("Invalid model name, exiting...")
        exit()
    
    return model_ft, input_size


In [6]:
def cam_from_model(model_path):
    model0 = resnet18(pretrained=True)
    num_ftrs = model0.fc.in_features
    model0.fc = nn.Linear(num_ftrs, 13)
    model0.load_state_dict(torch.load(model_path))
    model0.eval()
    target_layer = model0.layer4[-1]
    cam = GradCAM(model=model0, target_layer=target_layer,use_cuda=True)
    return cam

In [7]:
def initialize_cam(cam_name, model, target_layer, use_cuda=False):
    if cam_name == "GradCAM":
        cam=GradCAM(model=model, target_layer=target_layer,use_cuda=use_cuda)
    elif cam_name == "ScoreCAM":
        cam=ScoreCAM(model=model, target_layer=target_layer,use_cuda=use_cuda)
    elif cam_name =="GradCAMPlusPlus":
        cam=GradCAMPlusPlus(model=model, target_layer=target_layer,use_cuda=use_cuda)
    elif cam_name == "AblationCAM":
        cam=AblationCAM(model=model, target_layer=target_layer,use_cuda=use_cuda)
    elif cam_name == "XGradCAM":
        cam=XGradCAM(model=model, target_layer=target_layer,use_cuda=use_cuda)
    else:
        print("Invalid CAM method, exiting...")
        exit()
    return cam

In [8]:
def gscam_calc(camList,img,mylabel,mode=3):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    inv_normalize = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
        std=[1/0.229, 1/0.224, 1/0.255])
    ])
 
    img200=IOps.expand(img.resize((200,200)),12)
    img200_0=ImChOps.offset(img200,12,12)
    img200_1=ImChOps.offset(img200,0,12)
    img200_2=ImChOps.offset(img200,-12,12)
    img200_3=ImChOps.offset(img200,12,0)
    img200_4=ImChOps.offset(img200,-12,0)
    img200_5=ImChOps.offset(img200,12,-12)
    img200_6=ImChOps.offset(img200,0,-12)
    img200_7=ImChOps.offset(img200,-12,-12)

    x = transform(img).unsqueeze(0)
    x_0 =transform(img200_0).unsqueeze(0)
    x_1 =transform(img200_1).unsqueeze(0)
    x_2 =transform(img200_2).unsqueeze(0)
    x_3 =transform(img200_3).unsqueeze(0)
    x_4 =transform(img200_4).unsqueeze(0)
    x_5 =transform(img200_5).unsqueeze(0)
    x_6 =transform(img200_6).unsqueeze(0)
    x_7 =transform(img200_7).unsqueeze(0)

    x_y=torch.flip(x,[3])
    x_x=torch.flip(x,[2])
    x_xy=torch.flip(x,[2,3])
    x_rot90=torch.rot90(x,1,[2,3])
    x_rot270=torch.rot90(x,3,[2,3])
    x_x_rot90=torch.rot90(x_x,1,[2,3])
    x_y_rot90=torch.rot90(x_y,1,[2,3])
    counter=0
    grayscale_cam=np.zeros((224,224))
    for thisCam in camList:
        tempCam=thisCam(input_tensor=x,target_category=mylabel)
        grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
        counter=counter+1
        if mode>0:
            tempCam = thisCam(input_tensor=x_x,target_category=mylabel)
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_y,target_category=mylabel)
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_xy,target_category=mylabel)
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            counter=counter+3
        if mode>1:
            tempCam = thisCam(input_tensor=x_rot90,target_category=mylabel)
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_rot270,target_category=mylabel)
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_x_rot90,target_category=mylabel)
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_y_rot90,target_category=mylabel)
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            counter=counter+4
        if mode>2:
            tempCam = thisCam(input_tensor=x_0,target_category=mylabel)
            tempCam = tempCam[24:224,24:224]
            tempCam = sktransform.resize(tempCam,(224,224))    
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_1,target_category=mylabel)
            tempCam = tempCam[24:224,24:224]
            tempCam = sktransform.resize(tempCam,(224,224))    
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_2,target_category=mylabel)
            tempCam = tempCam[24:224,24:224]
            tempCam = sktransform.resize(tempCam,(224,224))    
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_3,target_category=mylabel)
            tempCam = tempCam[24:224,24:224]
            tempCam = sktransform.resize(tempCam,(224,224))    
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_4,target_category=mylabel)
            tempCam = tempCam[24:224,24:224]
            tempCam = sktransform.resize(tempCam,(224,224))    
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_5,target_category=mylabel)
            tempCam = tempCam[24:224,24:224]
            tempCam = sktransform.resize(tempCam,(224,224))    
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_6,target_category=mylabel)
            tempCam = tempCam[24:224,24:224]
            tempCam = sktransform.resize(tempCam,(224,224))    
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            tempCam = thisCam(input_tensor=x_7,target_category=mylabel)
            tempCam = tempCam[24:224,24:224]
            tempCam = sktransform.resize(tempCam,(224,224))    
            grayscale_cam=np.sum([grayscale_cam,tempCam],axis=0)
            counter=counter+8
    grayscale_cam=grayscale_cam/counter
    return grayscale_cam

In [9]:
cam0 = cam_from_model('aerial_small_auged_new_resnet18_13.pth')
cam1 = cam_from_model('aerial_small_auged_new_resnet18_1.pth')
cam2 = cam_from_model('aerial_small_auged_new_resnet18_29.pth')
cam3 = cam_from_model('aerial_small_auged_new_resnet18_8.pth')
cam4 = cam_from_model('aerial_small_auged_new_resnet18_6.pth')
cam5 = cam_from_model('aerial_small_auged_new_resnet18_23.pth')
cam6 = cam_from_model('aerial_small_auged_new_resnet18_24.pth')
cam7 = cam_from_model('aerial_small_auged_new_resnet18_28.pth')
cam8 = cam_from_model('aerial_small_auged_new_resnet18_12.pth')
cam9 = cam_from_model('aerial_small_auged_new_resnet18_21.pth')
cam10 = cam_from_model('aerial_small_auged_new_resnet18_14.pth')
cam11 = cam_from_model('aerial_small_auged_new_resnet18_20.pth')
cam12 = cam_from_model('aerial_small_auged_new_resnet18_22.pth')
cam13 = cam_from_model('aerial_small_auged_new_resnet18_4.pth')
cam14 = cam_from_model('aerial_small_auged_new_resnet18_0.pth')
cam15 = cam_from_model('aerial_small_auged_new_resnet18_16.pth')
cam16 = cam_from_model('aerial_small_auged_new_resnet18_2.pth')
cam17 = cam_from_model('aerial_small_auged_new_resnet18_7.pth')
cam18 = cam_from_model('aerial_small_auged_new_resnet18_5.pth')
cam19 = cam_from_model('aerial_small_auged_new_resnet18_25.pth')
cam20 = cam_from_model('aerial_small_auged_new_resnet18_17.pth')
cam21 = cam_from_model('aerial_small_auged_new_resnet18_27.pth')
cam22 = cam_from_model('aerial_small_auged_new_resnet18_19.pth')
cam23 = cam_from_model('aerial_small_auged_new_resnet18_15.pth')
cam24 = cam_from_model('aerial_small_auged_new_resnet18_3.pth')
cam25 = cam_from_model('aerial_small_auged_new_resnet18_26.pth')
cam26 = cam_from_model('aerial_small_auged_new_resnet18_9.pth')
cam27 = cam_from_model('aerial_small_auged_new_resnet18_10.pth')
cam28 = cam_from_model('aerial_small_auged_new_resnet18_18.pth')
cam29 = cam_from_model('aerial_small_auged_new_resnet18_11.pth')
masterCamList=[cam0,cam1,cam2,cam3,cam4,cam5,cam6,cam7,cam8,cam9,
               cam10,cam11,cam12,cam13,cam14,cam15,cam16,cam17,cam18,cam19,
               cam20,cam21,cam22,cam23,cam24,cam25,cam26,cam27,cam28,cam29]

In [10]:
import skimage.transform as sktransform 
import PIL.ImageOps as IOps
import PIL.ImageChops as ImChOps
for ensSize in range(1,31):
    for mode in range(0,4):
        folder_name = f"../results/aerial-imagery-resnet18-GradCam-batch-all-mode_{mode}-ens_{ensSize}"
        Path(folder_name).mkdir(parents=True, exist_ok=True)
        print(folder_name)
        for i in range(0,4342):
            imgNo=i
            oriPath=val_data.imgs[imgNo][0]
            maskPath=val_data.imgs[imgNo][0].replace('classification','segmentation').replace('/val','')
            (img,mylabel)=val_data.__getitem__(imgNo)
            img_copy=np.array(img)
            grayscale_cam=gscam_calc(masterCamList[0:ensSize],img,mylabel,mode)
            tpr_fpr_list=[]
            precision_recall_list=[]
            gtmaskMap=cv2.imread(maskPath)
            gtmaskMap = cv2.cvtColor(gtmaskMap, cv2.COLOR_BGR2GRAY)
            gtmaskMap[0:,0]=0
            gtmaskMap[0:,223]=0
            gtmaskMap[0,0:]=0
            gtmaskMap[223,0:]=0
            gtmaskMap[gtmaskMap>1] = 1
            gtbbox=np.zeros(gtmaskMap.shape)
            gtbox_candidates = regionprops(label(gtmaskMap))
            for ind_segMask in gtbox_candidates:
                (y1,x1,y2,x2)=ind_segMask.bbox
                gtbbox[y1:y2,x1:x2]=1
            for idx, th in enumerate(np.arange(1.00, -0.01, -0.01), 1):
                thresholded_cam = grayscale_cam.copy()
                thresholded_cam[grayscale_cam < th] = 0
                cam_on_image = show_cam_on_image(np.array(img_copy) / 255, thresholded_cam)
                mask = thresholded_cam.copy()
                mask[mask > 0] = 1
                box_candidates = regionprops(label(mask))
                cambbox=np.zeros(mask.shape)
                cambbox_list=[]
                for box_candidate in box_candidates:
                    (y1,x1,y2,x2)=box_candidate.bbox
                    cambbox[y1:y2,x1:x2]=1
                tp=np.sum(np.logical_and(gtmaskMap==1,mask==1))
                fn=np.sum(gtmaskMap==1)-tp
                fp=np.sum(mask==1)-tp
                tn=np.sum(np.logical_and(gtmaskMap==0,mask==0))
                recall=tp/(tp+fn)
                tpr=recall
                precision=0
                if((tp+fp)>0):
                    precision=tp/(tp+fp)
                fpr=0
                if((fp+tn)>0):
                    fpr=fp/(fp+tn)
                precision_recall_list.append([precision,recall])
                tpr_fpr_list.append([tpr,fpr])
                cambbox_list=[]
            pr_df=pd.DataFrame(precision_recall_list,columns=['Precision','Recall'])
            tf_df=pd.DataFrame(tpr_fpr_list,columns=['TPR','FPR'])
            csv_file=f"{imgNo:04d}_{mylabel:02d}"
            pr_df.to_csv(f"{folder_name}/pr_{csv_file}.csv", index=False)
            tf_df.to_csv(f"{folder_name}/tf_{csv_file}.csv", index=False)   
        


../results/aerial-imagery-resnet18-GradCam-batch-all-mode_0-ens_1


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
  cam = cam / np.max(cam)


../results/aerial-imagery-resnet18-GradCam-batch-all-mode_1-ens_1
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_2-ens_1
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_3-ens_1
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_0-ens_2
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_1-ens_2
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_2-ens_2
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_3-ens_2
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_0-ens_3
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_1-ens_3
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_2-ens_3
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_3-ens_3
../results/aerial-imagery-resnet18-GradCam-batch-all-mode_0-ens_4


KeyboardInterrupt: 

In [None]:
grayscale_cam.shape

In [None]:
#x.shape

In [None]:
print(img)