In [None]:
import wandb
from robust_detection import wandb_config
from robust_detection.models.rcnn import RCNN
from robust_detection.data_utils.rcnn_data_utils import Objects_RCNN
from torchmetrics.detection.map import MeanAveragePrecision
import pandas as pd
import pytorch_lightning as pl
import os
import torch
import numpy as np

In [None]:
gpu = 0
api = wandb.Api()

results = {}

sweep_dict = {"abdv8kfl":RCNN}
model_names = ["RCNN (DPL track)"]


#data_dict = {"MMSynthetic":SyntheticMMDataModule, "Pendulum":PendulumDataModule, "CV":CVDataModule}
#data_dict = {"molecules/mol_labels/":Objects_RCNN}#, "mnist/alldigits_2":MNISTCountDataModule,  "mnist/alldigits_5":MNISTCountDataModule} #, \
            #"mnist/alldigits_large":MNISTCountDataModule, "mnist/alldigits_2_large":MNISTCountDataModule,  "mnist/alldigits_5_large":MNISTCountDataModule,}
#data_dict = {"mnist/alldigits_5/":Objects_RCNN}#, "mnist/alldigits_2":MNISTCountDataModule,  "mnist/alldigits_5":MNISTCountDataModule} #, \
            #"mnist/alldigits_large":MNISTCountDataModule
data_dict = {"molecules/molecules_skip":Objects_RCNN}

fold_name = "fold"

In [None]:
df = pd.DataFrame()

In [None]:
for i_mod, sweep_name in enumerate(sweep_dict.keys()):

    pd_dict_acc = {"Model":model_names[i_mod] + " (Acc)"}
    pd_dict_map = {"Model":model_names[i_mod] + " (mAP)"}


    #model_cls = sweep_dict[sweep_names]
    #sweep_runs = []
    #for sweep_name in sweep_names:
    #    sweep_runs += api.sweep(f"{ENTITY}/object_detection/{sweep_name}").runs
    model_cls = sweep_dict[sweep_name]
    sweep = api.sweep(f"{ENTITY}/object_detection/{sweep_name}")
    sweep_runs = []
    sweep_runs += api.sweep(f"{ENTITY}/object_detection/{sweep_name}").runs
    print(sweep_runs)
    for ood in [False,True]:

        pd_dict_acc["Type"] = "OOD" if ood else "In-distribution"
        pd_dict_map["Type"] = "OOD" if ood else "In-distribution"

        
        for data_key in data_dict.keys():

            best_runs = []
            for fold in [0,1,2,3,4]:
                #runs_fold = [r for r in sweep_runs if (r.config.get(fold_name)==fold) and (r.config.get("target_data_path")==data_key)]
                runs_fold = [r for r in sweep_runs if (r.config.get(fold_name)==fold)]
                runs_fold_sorted = sorted(runs_fold,key = lambda run: run.summary.get("restored_val_acc"), reverse = True)
                best_runs.append(runs_fold_sorted[0])

            accuracies = []
            mAPs = []
            for run in best_runs:
                fname = [f.name for f in run.files() if "ckpt" in f.name][0]
                run.file(fname).download(replace = True, root = ".")
                model = model_cls.load_from_checkpoint(fname)
                os.remove(fname)

                hparams = dict(model.hparams)
                hparams["re_train"] = False
                hparams["data_path"]= data_key
                dataset = data_dict[data_key](**hparams)
                dataset.prepare_data()
                trainer = pl.Trainer(logger = False, gpus = 1)

                if ood:
                    preds = trainer.predict(model, dataset.test_ood_dataloader())
                else:
                    preds = trainer.predict(model, dataset.test_dataloader())
                
                Y = []
                Y_hat = []
                map_metric = MeanAveragePrecision()
                for pred in preds:
                    Y += pred["targets"]
                    Y_hat += pred["preds"]
                    
                    pred_map = [dict(boxes=pred["boxes"][i],scores=pred["scores"][i],labels=pred["preds"][i]) for i in range(len(pred["targets"]))]
                    target_map = [dict(boxes=pred["boxes_true"][i],labels=pred["targets"][i]) for i in range(len(pred["targets"]))]
                    map_metric.update(pred_map,target_map)
                
                mAP = map_metric.compute()
                accuracy = np.array([torch.equal(Y[i].sort()[0],Y_hat[i].sort()[0]) for i in range(len(Y))]).mean()
                print(accuracy)
                print(mAP)

                accuracies.append(accuracy)
                mAPs.append(mAP["map"])

            accuracies = np.array(accuracies)
            acc_mu = accuracies.mean()
            acc_std = accuracies.std()
            
            mAPs = np.array(mAPs)
            map_mu = mAPs.mean()
            map_std = mAPs.std()

            acc_str = "$ " + str(acc_mu.round(3))+ "\pm" +str(acc_std.round(3)) +" $"
            map_str = "$ " + str(map_mu.round(3))+ "\pm" +str(map_std.round(3)) +" $"


            pd_dict_acc[data_key] = acc_str
            pd_dict_acc[data_key + " mAP"] = map_str

            print(pd_dict_acc)

        df = df.append(pd_dict_acc,ignore_index =True)

In [None]:
print(df.loc[df.Model.str.contains("Acc")].to_latex(escape = False,index= False))