In [1]:
from google.colab import drive 
drive.mount("/content/gdrive", force_remount=True)

Mounted at /content/gdrive


In [2]:
import os
os.chdir("/content/gdrive/MyDrive/explanations-for-computer-vision/")

import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor, fasterrcnn_resnet50_fpn
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

# our dataset has two classes only - background and person
num_classes = 2

In [3]:
# load an instance segmentation model pre-trained pre-trained on COCO
# model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True, num_classes=num_classes)
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

# get number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

# load model checkpoint
path = os.path.join(os.getcwd(), "./checkpoints/faster_rcnn_10_epochs.ckpt")
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model_state_dict'])

model.to(device) 

FasterRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(

In [4]:
# from src.sodexplainer import SODExplainer
# explainer_sodex = SODExplainer(load_from='./checkpoints/faster_rcnn_10_epochs.ckpt')
# # prob = explainer.get_class_probability(dataset_test[0])

import torch
import torchvision
import numpy as np

from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

from src.utils import jaccard

class SODExplainer:
    def __init__(
        self, detector='FasterRCNN',load_from=None
    ):
        """[summary]
        Args:
            detector (string): Object detector. Defaults to "FasterRCNN".
            load_from: Model path. Defaults to None.
        """
        self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) # pretrained model         
        # get number of input features for the classifier
        in_features = self.model.roi_heads.box_predictor.cls_score.in_features
        # replace the pre-trained head with a new one
        self.model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes=2)
        
        if load_from is not None:
            # load checkpoint 
            self.model.load_state_dict(torch.load(load_from, map_location=torch.device('cpu'))['model_state_dict'])
                   
            
    def get_class_probability(self, data_obj):
      def get_probabilities(image_as_array):
        img, target = data_obj # ground truth
          
        self.model.eval() # set the module in evaluation mode
        boxes = self.model(img.unsqueeze(0))[0]['boxes']
        scores = self.model(img.unsqueeze(0))[0]['scores']
        
        if len(boxes) == 0: 
            # if there's no object detected
            prob = 0
            print("No object detected!")
        else:
            ious = jaccard(target['boxes'],boxes) # ious with shape (num_objs, num_boxes)
            ious = ious[ious > 0.4].unsqueeze(1) 
            if len(ious) == 0: # No score above the threshold
                prob = 0
            else:
                obj_idx, box_idx = np.unravel_index(torch.argmax(ious), ious.shape) # retrieve argmax-indices in 2d
                prob = scores[box_idx] 
        probabilities = [prob, 1 - prob]      
        return np.array(probabilities)

      return get_probabilities

In [5]:
 explainer_sodex = SODExplainer(load_from='./checkpoints/faster_rcnn_10_epochs.ckpt')

In [6]:
import torch
from detection.pennfudan_dataset import PennFudanDataset, get_transform

# use our dataset and defined transformations
dataset = PennFudanDataset('./PennFudanPed', get_transform(train=True))
dataset_test = PennFudanDataset('./PennFudanPed', get_transform(train=False))
# changing to array


# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset = torch.utils.data.Subset(dataset, indices[:-50])
dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:])

In [7]:
# import cv2
# test_image = os.path.join(os.getcwd(), "./PennFudanPed/PNGImages/FudanPed00052.png")

# def image_as_array(image):
#     image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#     return image

In [8]:
!pip install lime



In [12]:
from lime.lime_image import LimeImageExplainer
import numpy as np

image_test = dataset_test[0][0].detach().numpy().reshape(dataset_test[0][0].shape[1],dataset_test[0][0].shape[2],3)
image_test = image_test.astype('double')
print(type(image_test))
# image_test = image_test.as_type('double')

def get_class_probability_explanation(test_image):
      explainer = LimeImageExplainer(verbose=True)
      # self.logger.info("Explaining object: ")
      explanation = explainer.explain_instance(
          image= image_test,
          classifier_fn=explainer_sodex.get_class_probability(dataset_test[0]),
          num_samples=100)
      return explanation

<class 'numpy.ndarray'>


In [13]:
dataset[0][0].shape
print(dataset[0][0].shape[1])

378


In [14]:
explanation = get_class_probability_explanation(image_test)

  0%|          | 0/100 [00:00<?, ?it/s]

RuntimeError: ignored