In [None]:
import os
import torch
import torchvision
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import numpy as np
import csv
import pandas as pd
import seaborn as sns
#import matplotlib.pylab as plt
import cv2
import math
from tqdm import tqdm


from kiwissenbase.data import dataloaders
from kiwissenbase.data.datasets import CaltechPedestrian, CityPersons
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import numpy as np

In [None]:
invTrans = torchvision.transforms.Compose([ torchvision.transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                torchvision.transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

## Create data loaders

In [None]:
data_loader_args = {"root_dir": "/data/anna/data/caltech_dataset/",
                    "batch_size": 8,
                    "validation_batch_size": 4,
                    "num_workers": 0,
                    "pin_memory": True,
                    #"collate_fn": collate_fn,
                    "normal_mean": (0.5, 0.5, 0.5),
                    "normal_std": (0.5, 0.5, 0.5),
                    "different_size_target": True,
                    "subset": "annotated-pedestrians",                    
                    "target_transform":{"module": "kiwissenbase.models.object_detection",
                                        "class_name": "FasterRCNN",
                                        "method_name": "target_transform"}}

In [None]:
loader = dataloaders.CaltechPedastrianDataLoader(device="cpu", **data_loader_args)

## Functions for evaluation

In [None]:
def match_bbs(target_boxes, predicted_boxes, iou_thres=0):
    
    matched_inds = []
    
    if len(target_boxes) > 0 and len(predicted_boxes) > 0:
        target_boxes = torch.stack(target_boxes)
        predicted_boxes = torch.stack(predicted_boxes)
    
        pairwise_iou = torchvision.ops.box_iou(target_boxes,predicted_boxes)
        # only consider cases with an iou more than the given threshold
        pairwise_iou[pairwise_iou<iou_thres] = 0

        while len((pairwise_iou > 0).nonzero()):
            max_iou = torch.max(pairwise_iou)
            max_inds = (pairwise_iou==max_iou).nonzero()
            matched_inds.append({"target_index":max_inds[0][0].item(), "predicted_index":max_inds[0][1].item(), "iou":max_iou.item()})
            pairwise_iou[max_inds[0][0],:] = 0
            pairwise_iou[:,max_inds[0][1]] = 0

    all_target_inds_matched = [item["target_index"] for item in matched_inds]
    for ind in range(len(target_boxes)):
        if ind not in all_target_inds_matched:
            matched_inds.append({"target_index":ind, "predicted_index":-1, "iou": 0})
    all_predicted_inds_matched = [item["predicted_index"] for item in matched_inds]
    for ind in range(len(predicted_boxes)):
        if ind not in all_predicted_inds_matched:
            matched_inds.append({"target_index":-1, "predicted_index":ind, "iou": 0})
    return matched_inds
    
    
def evaluate_prediction(target, prediction, score_thres=0, iou_thres=0):
    
    pedestrian_target_boxes = [box for box,label in zip(target["boxes"],target["labels"]) if label==1]
    # fitler the predictions based on a minimum score threshold
    pedestrian_output_boxes = [box for box,label,score in zip(prediction["boxes"],prediction["labels"],prediction["scores"])
                               if label==1 and score>=score_thres]
    
    matched_boxes = match_bbs(pedestrian_target_boxes, pedestrian_output_boxes, iou_thres)

    tps = 0
    fps = 0
    fns = 0
    for match in matched_boxes:
        if match["target_index"] == -1:
            fps += 1
        elif match["predicted_index"] == -1:
            fns += 1
        else:
            tps += 1
    if tps+fps > 0:
        precision = tps/(tps +fps)
    else:
        precision = 0
    if tps + fns > 0:
        recall = tps/(tps+fns)
    else:
        recall = 0
    return {"tps": tps, "fps": fps, "fns": fns, "precision": precision, "recall": recall}

## Select paths for trained model and netdissect resuts (run one of the following blocks)

In [None]:
path_to_saved_model = "/data/anna/pedestrian_detection_models/faster_rcnn_tuned_Caltech.pth"

path_to_result = "/data/anna/results/pytorch_fasterrcnn_resnet50_fpn_caltech_backbone.body.layer4[1].conv2/"
layer_dissected = "backbone.body.layer4[1].conv2"

In [None]:
path_to_saved_model = "/data/anna/pedestrian_detection_models/faster_rcnn_tuned_Citypersons_all_single_person_classes_positive.pth"
path_to_result = "/data/anna/results/pytorch_fasterrcnn_resnet50_fpn_citypersons/"
layer_dissected = "backbone.body.layer4[1].conv2"

## Select dataset on which to check the results

In [None]:
dataset = "caltech" # caltech or citypersons

## setup model, netdissect results and data

### load and test saved model

In [None]:
checkpoint = torch.load(path_to_saved_model)

In [None]:
## load model architecture and adapt output
import torchvision 
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

model = torchvision.models.detection.fasterrcnn_resnet50_fpn()
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 2)

In [None]:
model.load_state_dict(checkpoint['model_state'])

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
model.eval()
model.to(device)

### register forward hooks on the dissected layer

In [None]:
def get_activation(name):
    def hook(model, input, output):
        unit_activations[name] = output.detach()
    return hook

def hook_feature(module, input, output):
    features_blobs.append(output.data.cpu().numpy())

In [None]:
unit_activations = {}
conv_unit = layer_dissected
eval(f"model.{conv_unit}.register_forward_hook")(get_activation(conv_unit))

### test model and data

In [None]:
image,targets = next(iter(loader.test.dataset))
image.to(device)
predictions = model([image.to(device)]) 

In [None]:
predictions,targets

In [None]:
unit_activations[layer_dissected].shape

### load relevant netdissect results

#### tally.csv

In [None]:
path_to_tally = os.path.join(path_to_result,"tally.csv")

In [None]:
data = [line for line in csv.reader(open(path_to_tally))]
data_all = [{"unit": int(item[0])-1, "concept":item[2], "score":item[3]} for item in data[1:]]
data_top = data_all[:30]
print("Top scoring units\n")
pd.DataFrame(data_top)

#### filter the relevant concept units

In [None]:
print("all concepts in layer")
print(set([item[2] for item in data[1:]]))

In [None]:
relevant_concepts = ["head", "hair", "arm", "wheel", "car", "sidewalk","road","neck","mouth","person","leg","back","foot"]

In [None]:
data_top_relevant = [{"unit": int(item[0])-1, "concept":item[2], "score":item[3]} for item in data[1:] if item[2] in relevant_concepts][:30]
print("Top scoring relevant units\n")
pd.DataFrame(data_top_relevant)

#### quantile.npy

In [None]:
quantile = np.load(os.path.join(path_to_result,"quantile.npy"))

In [None]:
print(quantile.shape)
print(quantile[:10])

## Collect the targets, predictions on test set

In [None]:
pedestrian_image = []
detected_pedestrian_image = []
all_targets = []
all_predictions = []
with torch.no_grad():
    for local_batch in tqdm(loader.test):
        pedestrian_image.extend([any(item["labels"]==1) for item in local_batch[1]])
        local_batch_images = torch.stack(local_batch[0]).to(device)
        all_targets.extend(local_batch[1])
        output = model(local_batch_images)
        output_cpu = [{"boxes": item["boxes"].to("cpu"), "labels": item["labels"].to("cpu"), "scores": item['scores'].to("cpu")}for item in output]
        all_predictions.extend(output_cpu)
        detected_pedestrian_image.extend([any(item["labels"]==1) for item in output])

In [None]:
print("total images", len(pedestrian_image))
print("images with at least one labelled pedestrian", sum(pedestrian_image))

In [None]:
print("images with at least one detected pedestrian(no score threshold)", sum(detected_pedestrian_image))

### check performance on different score/iou thresholds 

In [None]:
score_thresholds = [0, 0.2, 0.4, 0.6, 0.8]
iou_thresholds = [0, 0.2, 0.4, 0.6, 0.8]
for score_thres in score_thresholds:
    for iou_thres in iou_thresholds:
        avg_precision = 0
        avg_recall = 0
        for target, prediction in zip(all_targets,all_predictions):
            res = evaluate_prediction(target,prediction,score_thres=score_thres, iou_thres=iou_thres)
            avg_precision += res["precision"]
            avg_recall += res["recall"]
        avg_precision = avg_precision/len(all_targets)
        avg_recall = avg_recall/len(all_targets)
        print(f"score threshold {score_thres}, iou threshold {iou_thres}")
        print(f"average precision {avg_precision}, average recall {avg_recall}")
        print("-------------------------------")


### select thresholds and collect the performance for all images 

In [None]:
score_thres=0.6
iou_thres=0.4

In [None]:
all_precision = []
all_recall = []
for target, prediction in zip(all_targets,all_predictions):
    res = evaluate_prediction(target,prediction,score_thres=score_thres, iou_thres=iou_thres)
    all_precision.append(res["precision"])
    all_recall.append(res["recall"])



## Project activation back to image space

In [None]:
model.to("cpu")

### select image index and convolutional unit index

In [None]:
img_ind = 16
unit_ind = 427

In [None]:
img,target = loader.test.dataset.__getitem__(img_ind)

In [None]:
# we need inverse transform to display image
img_display = invTrans(img.cpu()).numpy()

### get ouput and evaluate

In [None]:
output = model(torch.stack([img]))

In [None]:
output, target

In [None]:
evaluate_prediction(target,output[0],score_thres=0.6, iou_thres=0.4)

### get activation, upsample to image dim, create mask based on quantile

In [None]:
activation = unit_activations[conv_unit][0]
activation_resized = torchvision.transforms.functional.resize(activation,img.shape[1:])
mask = activation_resized[unit_ind]>quantile[unit_ind]

### Display results

#### show image

In [None]:
plt.imshow(np.transpose(img_display, (1,2,0)))

#### show image with detections/labels

In [None]:
output_boxes = [box for box,score in zip(output[0]["boxes"],output[0]["scores"]) if score>score_thres]

In [None]:
img_cv2 = img_display.transpose(1,2,0)
img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_RGB2BGR)
img_cv2 = cv2.cvtColor(img_cv2, cv2.COLOR_RGB2BGR)

In [None]:
fix, axs = plt.subplots(ncols= 1, squeeze=False, figsize=(20, 10))

for box_ind,box in enumerate(target["boxes"]):
    box = [int(i.item()) for i in box]
    cv2.rectangle(img_cv2, (box[0],box[1], box[2]-box[0],box[3]-box[1]), color=(0, 255, 0),thickness=2)
    cv2.putText(img_cv2,f"t_{box_ind}", (box[0],box[1]-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,255,0),thickness=2)
for box_ind,box in enumerate(output_boxes):
    box = [int(i.item()) for i in box]
    cv2.rectangle(img_cv2, (box[0],box[1], box[2]-box[0],box[3]-box[1]), color=(255, 0, 0),thickness=2)
    cv2.putText(img_cv2,f"p_{box_ind}", (box[0],box[1]-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,0,0),thickness=2)
axs[0, 0].imshow(img_cv2)

#### show activation

In [None]:
fix, axs = plt.subplots(ncols= 2,nrows=2, squeeze=False, figsize=(20, 10))
axs[0, 0].imshow(np.transpose(img_display, (1,2,0)))
axs[0, 0].set_title("image")
axs[0, 1].imshow(activation[unit_ind].cpu().numpy(), cmap='gray')
axs[0, 1].set_title("raw activation")
axs[1, 0].imshow(activation_resized[unit_ind].cpu().numpy(), cmap='gray')
axs[1, 0].set_title("activation interpolated to image dimensions")
axs[1, 1].imshow(mask.cpu().numpy())
axs[1, 1].set_title("masked activation (significant)")
plt.show()

#### show image, overlayed with masked activation

In [None]:
fix, axs = plt.subplots(ncols= 1, squeeze=False, figsize=(15, 15))
axs[0, 0].imshow(np.transpose(img_display, (1,2,0)))
#plt.imshow(act_resized[510].cpu().numpy(), cmap='gray', alpha=0.6)
axs[0, 0].imshow(mask.cpu().numpy(), cmap='gray', alpha=0.4)
axs[0,0].set_title(f"image with activation of unit {unit_ind} ({[item['concept'] for item in data_top_relevant if item['unit']==unit_ind][0]})")
plt.show()

### Plot multiple units/detectors

In [None]:
detectors = data_top_relevant

In [None]:
ncols = 4
nrows = math.ceil(len(detectors)/ncols)

In [None]:
fix, axs = plt.subplots(ncols= ncols, nrows = nrows, squeeze=False, figsize=(30, 55))
ind = 0
for i in range(nrows):
    for j in range(ncols):
        if ind>=len(detectors):
            break
        mask = activation_resized[detectors[ind]["unit"]]>quantile[detectors[ind]["unit"]]
        axs[i, j].imshow(np.transpose(img_display, (1,2,0)))
        axs[i, j].imshow(mask.cpu().numpy(), cmap='gray', alpha=0.4)
        axs[i, j].set_title(f"{detectors[ind]['unit']}_{detectors[ind]['concept']}")
        ind +=1
plt.show()