<a href="https://colab.research.google.com/github/nlim-uow/aerial_segmentation/blob/main/aerial_imagery_explainability.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Import Libraries and install pytorch-grad-cam

In [11]:
from __future__ import print_function
from __future__ import division
import sys
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)
from torchvision.transforms import ToTensor
from scipy.special import softmax
import cv2
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
from skimage.measure import label, regionprops

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
!pip install grad-cam==1.3.7
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

PyTorch Version:  1.12.1+cu113
Torchvision Version:  0.13.1+cu113
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


#Helper Functions
In the following blocks we have the helper functions for creating, training and visualizing the GradCam explanations for the classifications

In [12]:
def train_model(model, dataloaders, criterion, optimizer, num_epochs=25, is_inception=False, patience=5):
    since = time.time()

    val_acc_history = []
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    early_stopping = 0
    best_loss = np.Inf
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    if is_inception and phase == 'train':
                        # From https://discuss.pytorch.org/t/how-to-optimize-inception-model-with-auxiliary-classifiers/7958
                        outputs, aux_outputs = model(inputs)
                        loss1 = criterion(outputs, labels)
                        loss2 = criterion(aux_outputs, labels)
                        loss = loss1 + 0.4*loss2
                    else:
                        outputs = model(inputs)
                        loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            
            

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
              best_acc = epoch_acc
              best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'val':
              val_acc_history.append(epoch_acc)
            
            if phase == 'val': 
              if epoch_loss > best_loss:
                early_stopping = early_stopping + 1
              else:
                best_loss=epoch_loss
                early_stopping = 0
        if early_stopping>patience:
            break

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, val_acc_history

In [13]:
def set_parameter_requires_grad(model, feature_extracting):
    if feature_extracting:
      for param in model.parameters():
        param.requires_grad = False
    else:
      for param in model.parameters():
        param.requires_grad = True
    

In [14]:
def initialize_model(model_name, num_classes, feature_extract, use_pretrained=True, visualization=False):
    # Initialize these variables which will be set in this if statement. Each of these
    #   variables is model specific.
    model_ft = None
    input_size = 0

    if model_name == "resnet18":
        """ Resnet18
        """
        model_ft = models.resnet18(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet34":
        """ Resnet34
        """
        model_ft = models.resnet34(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet50":
        """ Resnet50
        """
        model_ft = models.resnet50(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224

    elif model_name == "resnet101":
        """ Resnet101
        """
        model_ft = models.resnet101(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224
        
    elif model_name == "resnet152":
        """ Resnet152
        """
        model_ft = models.resnet152(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_classes)
        input_size = 224    
    
    elif model_name == "alexnet":
        """ Alexnet
        """
        model_ft = models.alexnet(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "vgg":
        """ VGG11_bn
        """
        model_ft = models.vgg11_bn(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs,num_classes)
        input_size = 224

    elif model_name == "squeezenet":
        """ Squeezenet
        """
        model_ft = models.squeezenet1_0(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1,1), stride=(1,1))
        model_ft.num_classes = num_classes
        input_size = 224

    elif model_name == "densenet":
        """ Densenet
        """
        model_ft = models.densenet121(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_classes) 
        input_size = 224

    elif model_name == "inception":
        """ Inception v3 
        Be careful, expects (299,299) sized images and has auxiliary output
        """
        model_ft = models.inception_v3(weights='IMAGENET1K_V1')
        set_parameter_requires_grad(model_ft, feature_extract)
        # Handle the auxilary net
        num_ftrs = model_ft.AuxLogits.fc.in_features
        model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
        # Handle the primary net
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs,num_classes)
        input_size = 299

    else:
        print("Invalid model name, exiting...")
        exit()
    
    return model_ft, input_size

In [15]:
def display_heatmap(img,model,cam,cls_list,true_label=0,th=0.75,oracle=False):
    img=img.resize((224,224))
    img_copy=np.array(img)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    inv_normalize = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.255],
        std=[1/0.229, 1/0.224, 1/0.255])
    ])
    x = transform(img)
    x = x.unsqueeze(0)
    x_y=torch.flip(x,[3])
    x_x=torch.flip(x,[2])
    x_xy=torch.flip(x,[2,3])
    x_rot90=torch.rot90(x,1,[2,3])
    x_rot270=torch.rot90(x,3,[2,3])
    x_x_rot90=torch.rot90(x_x,1,[2,3])
    x_y_rot90=torch.rot90(x_y,1,[2,3])
    predict=model.cuda()(x.cuda()).detach().cpu().numpy()
    pred=np.argmax(predict)
    if oracle==True:
      predictedClass=[ClassifierOutputTarget(pred)]
      predCls=cls_list[pred]
      predScore=softmax(predict)[0,pred]
    else:
      predictedClass=[ClassifierOutputTarget(true_label)]
      predCls=cls_list[true_label]
      predScore=softmax(predict)[0,true_label]
    if predScore<0.1:
       print(f"The confidence score of the classifier is less than 0.1 ({predScore:0.4}). It is possible that the network was unable to find the land feature in the image. ")
       print(f"The following heatmap and bounding box assumes that the network correctly identified the land feature.")
    grayscale_cam = cam(input_tensor=x,targets=predictedClass)[0]
    grayscale_cam_x = cam(input_tensor=x_x,targets=predictedClass)[0]
    grayscale_cam_y = cam(input_tensor=x_y,targets=predictedClass)[0]
    grayscale_cam_xy = cam(input_tensor=x_xy,targets=predictedClass)[0]
    grayscale_cam_x_rot90 = cam(input_tensor=x_rot90,targets=predictedClass)[0]
    grayscale_cam_x_rot270 = cam(input_tensor=x_rot270,targets=predictedClass)[0]
    grayscale_cam_x_x_rot90 = cam(input_tensor=x_x_rot90,targets=predictedClass)[0]
    grayscale_cam_x_y_rot90 = cam(input_tensor=x_y_rot90,targets=predictedClass)[0]
    grayscale_cam_ori = grayscale_cam
    grayscale_cam_x=np.flip(grayscale_cam_x,0)
    grayscale_cam_y=np.flip(grayscale_cam_y,1)
    grayscale_cam_xy=np.flip(grayscale_cam_xy,[0,1])
    grayscale_cam_x_rot90 = np.rot90(grayscale_cam_x_rot90,3,[0,1])
    grayscale_cam_x_rot270 = np.rot90(grayscale_cam_x_rot270,1,[0,1])
    grayscale_cam_x_x_rot90 = np.flip(np.rot90(grayscale_cam_x_x_rot90,3,[0,1]),0)
    grayscale_cam_x_y_rot90 = np.flip(np.rot90(grayscale_cam_x_y_rot90,3,[0,1]),1)
    grayscale_cam_max=np.maximum(grayscale_cam,grayscale_cam_x)
    grayscale_cam_max=np.maximum(grayscale_cam_max,grayscale_cam_y)
    grayscale_cam_max=np.maximum(grayscale_cam_max,grayscale_cam_xy)
    grayscale_cam_max=np.maximum(grayscale_cam_max,grayscale_cam_x_rot90)
    grayscale_cam_max=np.maximum(grayscale_cam_max,grayscale_cam_x_rot270)
    grayscale_cam_max=np.maximum(grayscale_cam_max,grayscale_cam_x_x_rot90)
    grayscale_cam_max=np.maximum(grayscale_cam_max,grayscale_cam_x_y_rot90)
    grayscale_cam_avg=np.sum([grayscale_cam_ori,grayscale_cam_x,grayscale_cam_y,grayscale_cam_xy,grayscale_cam_x_rot90,grayscale_cam_x_rot270,grayscale_cam_x_x_rot90,grayscale_cam_x_y_rot90],axis=0)/8
    thresholded_cam = grayscale_cam_max.copy()
    cam_on_image = show_cam_on_image(np.array(img_copy) / 255, thresholded_cam,True,cv2.COLORMAP_HOT)
    mask = thresholded_cam.copy()
    mask[mask > th] = 1
    box_candidates = regionprops(label(mask))
    cambbox=np.zeros(mask.shape)
    cambbox_list=[]
    for box_candidate in box_candidates:
        (y1,x1,y2,x2)=box_candidate.bbox
        cambbox[y1:y2,x1:x2]=1
    for box_candidate in box_candidates:
        (y1,x1,y2,x2)=box_candidate.bbox
        cv2.rectangle(cam_on_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
        cv2.putText(cam_on_image,f"{cls_list[pred]}-{predScore:0.4}", (x1+2, y1+25), cv2.FONT_HERSHEY_SIMPLEX, 0.33, (255, 255, 255), 1)
    visualization = Image.fromarray(cam_on_image.astype(np.uint8))
    new_im = Image.new('RGB', (448,224))
    new_im.paste(img, (0,0))
    new_im.paste(visualization, (224,0))
    return new_im

# Downloading the dataset
In the following blocks we will download the dataset from https://datasets.cms.waikato.ac.nz/taiao/data/waikato_aerial_imagery_2017 
The tar file contains 13 classes of land types as defined by LCDB4.0, there are 666 images for each class in the training set and 334 images in the validation/hold-out set.

The names of the class are taken from the folder names

In [16]:
#Download Dataset from dataset.cms.waikato.ac.nz
!mkdir results
!mkdir -p data/aerial_imagery/classification
!wget https://datasets.cms.waikato.ac.nz/taiao/data/waikato_aerial_imagery_2017/classification.tar


--2022-11-30 21:02:57--  https://datasets.cms.waikato.ac.nz/taiao/data/waikato_aerial_imagery_2017/classification.tar
Resolving datasets.cms.waikato.ac.nz (datasets.cms.waikato.ac.nz)... 130.217.218.32
Connecting to datasets.cms.waikato.ac.nz (datasets.cms.waikato.ac.nz)|130.217.218.32|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1105899520 (1.0G) [application/x-tar]
Saving to: ‘classification.tar’


2022-11-30 21:04:16 (13.6 MB/s) - ‘classification.tar’ saved [1105899520/1105899520]



In [17]:
!tar -xf classification.tar --directory /content/data/aerial_imagery/

In [18]:
%cd /content/results

/content/results


#Training the model
The following code initializes the model and does transfer learning 

In [19]:
data_dir = "/content/data/aerial_imagery/classification"
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_names=['resnet18']

batch_size = 32
input_size = 224
pretrain_epochs = 10 # change to a larger number if the model accuracy is insufficient 
finetune_epochs = 30 # change to a larger number if the model accuracy is insufficient
runCount=1
feature_extract = True # Sets up transfer learning

# Detect if we have a GPU available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


# Data augmentation and normalization for training
# Just normalization for validation
#        transforms.RandomApply([transforms.RandomRotation((90, 90))], p=0.5),

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

print("Initializing Datasets and Dataloaders...")

# Create training and validation datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# Create training and validation dataloaders
dataloaders_dict = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=4) for x in ['train', 'val']}

clsList,clsIdx=image_datasets['train'].find_classes(f'{data_dir}/train')
num_classes=len(clsList)

for runNo in range(runCount):
    for model_name in model_names:
        model_ft, input_size = initialize_model(model_name, num_classes, feature_extract, use_pretrained=True)
        model_ft = model_ft.to(device)

        params_to_update = model_ft.parameters()
        print("Params to learn:")
        if feature_extract:
            params_to_update = []
            for name,param in model_ft.named_parameters():
                if param.requires_grad == True:
                    params_to_update.append(param)
                    print("\t",name)
        else:
            for name,param in model_ft.named_parameters():
                if param.requires_grad == True:
                    print("\t",name)

        optimizer_pt = optim.Adam(params_to_update, lr=0.0003)
        optimizer_ft = optim.Adam(model_ft.parameters(), lr=0.00001)

        criterion = nn.CrossEntropyLoss()
        model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_pt, num_epochs=pretrain_epochs, is_inception=(model_name=="inception"))
        set_parameter_requires_grad(model_ft, False)
        model_ft, hist = train_model(model_ft, dataloaders_dict, criterion, optimizer_ft, num_epochs=finetune_epochs, is_inception=(model_name=="inception"))

        torch.save(model_ft.state_dict(), f"aerial_small_auged_noshift_norot_{model_name}_{runNo}.pth")

Initializing Datasets and Dataloaders...
Params to learn:
	 fc.weight
	 fc.bias
Epoch 0/9
----------
train Loss: 2.2517 Acc: 0.2471
val Loss: 2.0784 Acc: 0.2950

Epoch 1/9
----------
train Loss: 1.9597 Acc: 0.3399
val Loss: 1.9455 Acc: 0.3349

Epoch 2/9
----------
train Loss: 1.8461 Acc: 0.3831
val Loss: 1.8988 Acc: 0.3526

Epoch 3/9
----------
train Loss: 1.7940 Acc: 0.3928
val Loss: 1.8606 Acc: 0.3690

Epoch 4/9
----------
train Loss: 1.7626 Acc: 0.4029
val Loss: 1.8498 Acc: 0.3694

Epoch 5/9
----------
train Loss: 1.7344 Acc: 0.4161
val Loss: 1.8322 Acc: 0.3809

Epoch 6/9
----------
train Loss: 1.7198 Acc: 0.4166
val Loss: 1.8245 Acc: 0.3816

Epoch 7/9
----------
train Loss: 1.7012 Acc: 0.4271
val Loss: 1.8243 Acc: 0.3819

Epoch 8/9
----------
train Loss: 1.6928 Acc: 0.4353
val Loss: 1.8035 Acc: 0.3943

Epoch 9/9
----------
train Loss: 1.6747 Acc: 0.4351
val Loss: 1.8097 Acc: 0.3867

Training complete in 7m 22s
Best val Acc: 0.394288
Epoch 0/29
----------
train Loss: 1.6220 Acc: 0.4

#Visualize the gradcam explanation
The following block of code visualizes the gradcam activation on validation/holdout set.  

In [None]:
data_dir = "/content/data/aerial_imagery/classification"
# Models to choose from [resnet, alexnet, vgg, squeezenet, densenet, inception]
model_names=['resnet18']
runCount=1
val_data=datasets.ImageFolder(os.path.join(data_dir, 'val'))
train_datasets = datasets.ImageFolder(os.path.join(data_dir, 'train'))

clsList,clsIdx=train_datasets.find_classes(os.path.join(data_dir, 'train'))
num_classes=len(clsList)
for model_name in model_names:
  for runNo in range(runCount):
    model_infer,_ = initialize_model(model_name, num_classes, feature_extract=False, use_pretrained=True)
    model_infer.load_state_dict(torch.load(f'aerial_small_auged_noshift_norot_{model_name}_{runNo}.pth'))
    model_infer.eval()
    folder_name = f"/content/results/aerial-imagery-visualization_{runNo}" ## Output Directory
    Path(folder_name).mkdir(parents=True, exist_ok=True)
    target_layer = [model_infer.layer4[-1]]
    val_data=datasets.ImageFolder(os.path.join(data_dir, 'val'))
    cam = GradCAM(model=model_infer, target_layers=target_layer,use_cuda=True)
    for classes in clsList:
      Path(os.path.join(folder_name, classes)).mkdir(parents=True, exist_ok=True)
  
    for i in range(0,4342):
      imgNo=i
      (img,true_label)=val_data.__getitem__(imgNo)
      visualization_im=display_heatmap(img,model_infer,cam,clsList,true_label,th=0.75,oracle=True) ## Oracle assumes we know the true label rather than using the predicted label to report the GradCam Activation 
      visualization_im.save(f'{folder_name}/{clsList[true_label]}/{i}.png')