In [2]:
import wandb
from robust_detection.wandb_config import ENTITY
from robust_detection.data_utils.baselines_data_utils import ObjectsCountDataModule
from robust_detection.baselines.cnn_model import CNN
#from robust_detection.baselines.detection_cnn import objects_detection_cnn
from torchmetrics.detection.map import MeanAveragePrecision
import pandas as pd
import pytorch_lightning as pl
import os
import torch
import numpy as np
from multiprocessing import Pool

from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
import tqdm

In [3]:
gpu = 0
api = wandb.Api()

results = {}

sweep_dict = {"wps6fgrx":CNN}
model_names = ["CNN"]


#data_dict = {"MMSynthetic":SyntheticMMDataModule, "Pendulum":PendulumDataModule, "CV":CVDataModule}
#data_dict = {"molecules/mol_labels":ObjectsCountDataModule}#, "mnist/alldigits_2":MNISTCountDataModule,  "mnist/alldigits_5":MNISTCountDataModule} #, \
            #"mnist/alldigits_large":MNISTCountDataModule, "mnist/alldigits_2_large":MNISTCountDataModule,  "mnist/alldigits_5_large":MNISTCountDataModule,}
#data_dict = {"Pendulum":PendulumDataModule}
data_dict = {"clevr/clevr_all":ObjectsCountDataModule}
fold_name = "fold"
pre_trained = True

In [4]:
import cv2

def find_threshold_number_islands(cam_,nobj):
    
    MAX_ITERATIONS = 25
    sigma = 0
    fixed_mean = 0.5
    thresh = fixed_mean
    
    best_thresh = (fixed_mean,0)
    
    #print(f"Threshold = :{thresh}")
    try:
        num_islands = get_number_islands(cam_,fixed_mean)
    except RecursionError:
        num_islands = 0
        
    iter_idx = 0
    while num_islands != nobj:
        if sigma<0.94:
            sigma += 0.05
        #thresh+=0.1
    
        thresh = fixed_mean + np.random.rand()*sigma - sigma/2
        try:
            num_islands = get_number_islands(cam_,thresh)
        except RecursionError:
            num_islands = num_islands
        #print("Num Islands")
        #print(num_islands)
        
        if np.abs(num_islands-nobj)< np.abs(best_thresh[1]-nobj):
            best_thresh = (thresh,num_islands)
        
        if iter_idx>MAX_ITERATIONS:
            break
            
        iter_idx += 1
                
    if num_islands == nobj:
        return thresh
    else:
        return best_thresh[0]

def get_boxes_from_contour(contours, img):
    
    boxes = []
    scores = []
    for item in range(len(contours)):
        cnt = contours[item]
        if len(cnt)>5:
            #print(len(cnt))
            x,y,w,h = cv2.boundingRect(cnt) # x, y is the top left corner, and w, h are the width and height respectively
            poly_coords = [cnt] # polygon coordinates are based on contours
            
            boxes.append(torch.Tensor([x,y,x+w,y+h]))
            
            scores.append(img[:,y:y+h,x:x+w].max())

        else: print("contour error (too small)")
    return boxes, scores

def get_number_islands(cam_array, threshold):
    graph = np.zeros_like(cam_array)
    graph[cam_array > threshold] = 1
    graph[cam_array <= threshold] = 0

    class Graph:

        def __init__(self, row, col, g):
            self.ROW = row
            self.COL = col
            self.graph = g

        # A function to check if a given cell
        # (row, col) can be included in DFS
        def isSafe(self, i, j, visited):
            # row number is in range, column number
            # is in range and value is 1
            # and not yet visited
            return (i >= 0 and i < self.ROW and
                    j >= 0 and j < self.COL and
                    not visited[i][j] and self.graph[i][j])


        # A utility function to do DFS for a 2D
        # boolean matrix. It only considers
        # the 8 neighbours as adjacent vertices
        def DFS(self, i, j, visited):

            # These arrays are used to get row and
            # column numbers of 8 neighbours
            # of a given cell
            rowNbr = [-1, -1, -1,  0, 0,  1, 1, 1];
            colNbr = [-1,  0,  1, -1, 1, -1, 0, 1];

            # Mark this cell as visited
            visited[i][j] = True

            # Recur for all connected neighbours
            for k in range(8):
                if self.isSafe(i + rowNbr[k], j + colNbr[k], visited):
                    self.DFS(i + rowNbr[k], j + colNbr[k], visited)


        # The main function that returns
        # count of islands in a given boolean
        # 2D matrix
        def countIslands(self):
            # Make a bool array to mark visited cells.
            # Initially all cells are unvisited
            visited = [[False for j in range(self.COL)]for i in range(self.ROW)]

            # Initialize count as 0 and traverse
            # through the all cells of
            # given matrix
            count = 0
            for i in range(self.ROW):
                for j in range(self.COL):
                    # If a cell with value 1 is not visited yet,
                    # then new island found
                    if visited[i][j] == False and self.graph[i][j] == 1:
                        # Visit all cells in this island
                        # and increment island count
                        self.DFS(i, j, visited)
                        count += 1

            return count

    row = len(graph)
    col = len(graph[0])

    g = Graph(row, col, graph)

    return(g.countIslands())


def get_boxes_single_image(X,y_pred,cam,n_classes = 10):
    """
    cam is the class activation map for all classes
    """
    boxes_img = []
    labels_img = []
    scores_img = []
    for class_i in range(n_classes):
        #print(class_i)
        
        #targets = [ClassifierOutputTarget(class_i)]
        #cam_ = cam(input_tensor=X[None,...], targets=targets)
        cam_ = cam[class_i]
        nobj_pred = y_pred.round()[class_i].detach().long().numpy()

        if nobj_pred>0:
            
            #plt.figure()
            #plt.imshow(cam_[0][...,None])
            #plt.show()
            
            thresh = find_threshold_number_islands(cam_[0],nobj_pred)
            cam_thresh = np.zeros_like(cam_)
            cam_thresh[cam_>thresh]=255
            

            contours,hierarchy = cv2.findContours(cam_thresh[0].astype(np.uint8), 1, 2)

            bbox_coords, scores = get_boxes_from_contour(contours, img = cam_)#, img = cam_thresh.astype(np.uint8))
            boxes_img += bbox_coords
            labels_img += [class_i]*len(bbox_coords)
            scores_img += scores
            
    return torch.stack(boxes_img), labels_img, scores_img

def f(i,X,y_pred,cam_dict,n_classes):
    boxes, labels, scores = get_boxes_single_image(X[i].cpu(),y_pred[i].cpu(),cam = [cam_dict[class_i][i][None,...] for class_i in range(n_classes)], n_classes = n_classes)
    return(boxes,labels, scores)
    
def get_cnn_boxes(X,y_pred,cam,n_classes= 10):
    #boxes_list = []
    #labels_list = []
    #scores_list = []
    
    #print("Computing CAM...")
    cam_dict = {class_i:cam(input_tensor=X, targets=[ClassifierOutputTarget(class_i)]*len(X)) for class_i in range(n_classes)}
    #print("Done")
    
    #print("Computing boxes...")
    
    
    X = X.detach().cpu()
    y_pred = y_pred.detach().cpu()

    from multiprocessing import Pool
        
    with Pool(15) as p:
        res = p.starmap(f,zip([i for i in range(len(X))],[X for _ in range(len(X))],[y_pred for _ in range(len(X))],[cam_dict for _ in range(len(X))],[n_classes for _ in range(len(X))]))
        
    boxes_list = [b[0] for b in res]
    labels_list = [b[1] for b in res]
    scores_list = [b[2] for b in res]

    #for i in range(len(X)):
    #    boxes, labels, scores = get_boxes_single_image(X[i].cpu(),y_pred[i].cpu(),cam = [cam_dict[class_i][i][None,...] for class_i in range(n_classes)], n_classes = n_classes)
    #    boxes_list.append(boxes)
    #    labels_list.append(labels)
    #    scores_list.append(scores)
    return boxes_list, labels_list, scores_list


def objects_detection_cnn(model,dataloader, single_batch = False):

    target_layers = [model.model[1].layer4[-1]]
    cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
    boxes_list = []
    labels_list = []
    scores_list = []
    for i,batch in tqdm.tqdm(enumerate(dataloader)):
        X,y,_ = batch
        y_pred = model(X.to(model.device))

        boxes, labels, scores = get_cnn_boxes(X.to(model.device),y_pred,n_classes = y_pred.shape[1], cam = cam)

        boxes_list.append(boxes)
        labels_list.append(labels)
        scores_list.append(scores)
        
        if single_batch:
            break
    
    return boxes_list, labels_list, scores_list

In [5]:
def process_results():
    
    df = pd.DataFrame()
    for i_mod, sweep_name in enumerate(sweep_dict.keys()):

        pd_dict = {"Model":model_names[i_mod] + " (MSE)"}
        pd_dict_acc = {"Model":model_names[i_mod] + " (Acc)"}

        model_cls = sweep_dict[sweep_name]
        sweep = api.sweep(f"{ENTITY}/object_detection/{sweep_name}")

        for ood in [False,True]:

            pd_dict["Type"] = "OOD" if ood else "In-distribution"
            pd_dict_acc["Type"] = "OOD" if ood else "In-distribution"

            for data_key in data_dict.keys():

                best_runs = []
                for fold in [0,1,2,3,4]:
                    runs_fold = [r for r in sweep.runs if (r.config.get(fold_name)==fold) and (r.config.get("data_dir")==data_key) and (r.config.get("pre_trained")==pre_trained)]
                    runs_fold_sorted = sorted(runs_fold,key = lambda run: run.summary.get("restored_val_loss"), reverse = False)
                    best_runs.append(runs_fold_sorted[0])

                mses = []
                accuracies = []
                mAPs = []
                for run in best_runs:
                    fname = [f.name for f in run.files() if "ckpt" in f.name][0]
                    run.file(fname).download(replace = True, root = ".")
                    model = model_cls.load_from_checkpoint(fname)
                    os.remove(fname)

                    hparams = dict(model.hparams)

                    dataset = data_dict[data_key](**hparams)
                    dataset.prepare_data()
                    trainer = pl.Trainer(logger = False, gpus = 1)

                    if ood:
                        preds = trainer.predict(model, dataset.test_ood_dataloader())
                    else:
                        preds = trainer.predict(model, dataset.test_dataloader())
                    Y = torch.cat([pred["Y"] for pred in preds]).cpu()
                    Y_hat = torch.cat([pred["Y_pred"] for pred in preds]).cpu()
                    M = torch.cat([pred["M"] for pred in preds]).cpu()

                    mse = model.compute_mse(Y_hat,Y,M)
                    accuracy = model.compute_accuracy(Y,Y_hat,M)
                    mses.append(mse)
                    accuracies.append(accuracy)

                    map_metric = MeanAveragePrecision()

                    if ood:
                        boxes, labels, scores = objects_detection_cnn(model,dataset.test_ood_dataloader())
                    else:
                        boxes, labels, scores = objects_detection_cnn(model,dataset.test_dataloader())#, single_batch = True)

                    #pred_map_full = [dict(boxes=boxes[i],scores=scores[i],labels=labels[i]) for i in range(len(boxes))]

                    #base_idx = 0
                    for i_pred,pred in enumerate(preds):
                        target_map = [dict(boxes=pred["boxes_true"][i],labels=pred["targets"][i]) for i in range(len(pred["targets"]))]
                        pred_map = [dict(boxes=boxes[i_pred][i],scores=torch.Tensor(scores[i_pred][i]),labels=torch.Tensor(labels[i_pred][i])) for i in range(len(boxes[i_pred]))]

                        map_metric.update(pred_map,target_map)
                        #return map_metric, boxes, labels, scores, pred, target_map, pred_map
                        

                    mAP = map_metric.compute()
                    mAPs.append(mAP["map"])
                    print(mAP)


                mses = np.array(mses)
                mse_mu = mses.mean()
                mse_std = mses.std()

                accuracies = np.array(accuracies)
                acc_mu = accuracies.mean()
                acc_std = accuracies.std()

                mAPs = np.array(mAPs)
                map_mu = mAPs.mean()
                map_std = mAPs.std()

                mse_str = "$ " + str(mse_mu.round(3))+ "\pm" +str(mse_std.round(3)) +" $"
                acc_str = "$ " + str(acc_mu.round(3))+ "\pm" +str(acc_std.round(3)) +" $"
                map_str = "$ " + str(map_mu.round(3))+ "\pm" +str(map_std.round(3)) +" $"

                pd_dict[data_key] = mse_str
                pd_dict_acc[data_key] = acc_str
                pd_dict_acc[data_key + " mAP"] = map_str

            df = df.append(pd_dict,ignore_index =True)
            df = df.append(pd_dict_acc,ignore_index =True)
    return df

In [14]:
#map_metric, boxes, labels, scores, pred, target_map, pred_map = process_results()


#map_metric = MeanAveragePrecision()#iou_thresholds = [0.5,0.75])
#map_metric.update(pred_map[:2],target_map[:2])
#mAP = map_metric.compute()
#print(mAP)

#import torchvision
#import matplotlib.pyplot as plt
#img_idx = 3
#img_with_boxes = torchvision.utils.draw_bounding_boxes(torch.Tensor(pred["X"][img_idx]).permute(1,2,0).to(torch.uint8).permute(2,0,1),boxes[0][img_idx])
#img_with_true_boxes = torchvision.utils.draw_bounding_boxes(torch.Tensor(pred["X"][img_idx]).permute(1,2,0).to(torch.uint8).permute(2,0,1),pred["boxes_true"][img_idx], colors = "red")



#fig, ax = plt.subplots(figsize=(10, 10))
#ax.imshow(pred["X"][img_idx].permute(1,2,0) + img_with_boxes.permute(1,2,0).numpy() + img_with_true_boxes.permute(1,2,0).numpy(), interpolation='nearest')
#plt.tight_layout()

In [None]:
df = process_results()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

1it [00:28, 28.81s/it]

contour error (too small)


3it [01:31, 30.41s/it]

contour error (too small)


4it [01:59, 29.43s/it]

contour error (too small)


6it [02:55, 28.48s/it]

contour error (too small)


11it [05:20, 29.25s/it]

contour error (too small)


26it [12:36, 29.44s/it]

contour error (too small)


32it [15:00, 28.13s/it]


{'map': tensor(0.0162), 'map_50': tensor(0.0969), 'map_75': tensor(0.0002), 'map_small': tensor(-1.), 'map_medium': tensor(0.0162), 'map_large': tensor(-1.), 'mar_1': tensor(0.0544), 'mar_10': tensor(0.0592), 'mar_100': tensor(0.0592), 'mar_small': tensor(-1.), 'mar_medium': tensor(0.0592), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.)}


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

14it [06:56, 30.44s/it]

contour error (too small)


25it [12:14, 29.84s/it]

contour error (too small)


30it [14:39, 29.03s/it]

contour error (too small)


32it [15:20, 28.76s/it]


{'map': tensor(0.0348), 'map_50': tensor(0.1858), 'map_75': tensor(0.0011), 'map_small': tensor(-1.), 'map_medium': tensor(0.0349), 'map_large': tensor(-1.), 'mar_1': tensor(0.0843), 'mar_10': tensor(0.0937), 'mar_100': tensor(0.0937), 'mar_small': tensor(-1.), 'mar_medium': tensor(0.0937), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.)}


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

10it [05:03, 30.91s/it]

contour error (too small)


11it [05:32, 30.31s/it]

contour error (too small)


16it [08:02, 30.28s/it]

contour error (too small)


27it [13:26, 29.85s/it]

contour error (too small)


30it [14:51, 29.01s/it]

contour error (too small)


32it [15:27, 28.98s/it]


{'map': tensor(0.0262), 'map_50': tensor(0.1438), 'map_75': tensor(0.0015), 'map_small': tensor(-1.), 'map_medium': tensor(0.0263), 'map_large': tensor(-1.), 'mar_1': tensor(0.0696), 'mar_10': tensor(0.0786), 'mar_100': tensor(0.0786), 'mar_small': tensor(-1.), 'mar_medium': tensor(0.0786), 'mar_large': tensor(-1.), 'map_per_class': tensor(-1.), 'mar_100_per_class': tensor(-1.)}


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

0it [00:00, ?it/s]

contour error (too small)


10it [04:58, 29.45s/it]

contour error (too small)


27it [13:20, 29.74s/it]

In [None]:
print(df.loc[df.Model.str.contains("Acc")].to_latex(escape = False,index= False))

In [9]:
df

Unnamed: 0,Model,Type,clevr/clevr_all,clevr/clevr_all mAP
0,CNN (MSE),In-distribution,$ 0.003\pm0.0 $,
1,CNN (Acc),In-distribution,$ 0.97\pm0.005 $,$ 0.036\pm0.014 $
2,CNN (MSE),OOD,$ 0.016\pm0.001 $,
3,CNN (Acc),OOD,$ 0.831\pm0.016 $,$ 0.029\pm0.01 $
