In [None]:
import wandb
from robust_detection.wandb_config import ENTITY
from robust_detection.data_utils.baselines_data_utils import ObjectsCountDataModule
from robust_detection.baselines.cnn_model import CNN
import pandas as pd
import pytorch_lightning as pl
import os
import torch
import numpy as np

In [None]:
gpu = 0
api = wandb.Api()

results = {}

sweep_dict = {"dub96i86":CNN}
model_names = ["CNN"]


#data_dict = {"MMSynthetic":SyntheticMMDataModule, "Pendulum":PendulumDataModule, "CV":CVDataModule}
data_dict = {"molecules/molecules_skip":ObjectsCountDataModule} #, \
            #"mnist/alldigits_large":MNISTCountDataModule, "mnist/alldigits_2_large":MNISTCountDataModule,  "mnist/alldigits_5_large":MNISTCountDataModule,}
#data_dict = {"Pendulum":PendulumDataModule}

fold_name = "fold"
pre_trained = True

In [None]:
df = pd.DataFrame()

In [None]:
for i_mod, sweep_name in enumerate(sweep_dict.keys()):
    print(sweep_name)
    pd_dict_acc = {"Model":model_names[i_mod] + " (Acc)"}
    pd_dict_mse = {"Model":model_names[i_mod] + " (mse)"}


    model_cls = sweep_dict[sweep_name]
    sweep = api.sweep(f"{ENTITY}/object_detection/{sweep_name}")

    for ood in [False,True]:

        pd_dict_acc["Type"] = "OOD" if ood else "In-distribution"
        pd_dict_mse["Type"] = "OOD" if ood else "In-distribution"

        
        for data_key in data_dict.keys():

            best_runs = []
            for fold in [0,1,2,3,4]:
                #runs_fold = [r for r in sweep.runs if (r.config.get(fold_name)==fold) and (r.config.get("data_path")==data_key)]
                runs_fold = [r for r in sweep.runs if (r.config.get(fold_name)==fold) and (r.config.get("pre_trained")==pre_trained)]
                runs_fold_sorted = sorted(runs_fold,key = lambda run: run.summary.get("restored_val_acc"), reverse = True)
                best_runs.append(runs_fold_sorted[0])

            mses = []
            accuracies = []
            for run in best_runs:
                fname = [f.name for f in run.files() if "ckpt" in f.name][0]
                run.file(fname).download(replace = True, root = ".")
                model = model_cls.load_from_checkpoint(fname)
                os.remove(fname)

                hparams = dict(model.hparams)
                hparams["data_dir"]= data_key
                dataset = data_dict[data_key](**hparams)
                dataset.prepare_data()
                trainer = pl.Trainer(logger = False, gpus = 1)

                if ood:
                    preds = trainer.predict(model, dataset.test_ood_dataloader())
                else:
                    preds = trainer.predict(model, dataset.test_dataloader())
                Y = torch.cat([pred["Y"] for pred in preds]).cpu()
                Y_hat = torch.cat([pred["Y_pred"] for pred in preds]).cpu()
                M = torch.cat([pred["M"] for pred in preds]).cpu()

                #mse = model.compute_mse(Y_hat,Y,M)
                accuracy = model.compute_accuracy(Y,Y_hat,M)
                #mses.append(mse)
                accuracies.append(accuracy)

            #mses = np.array(mses)
            #mse_mu = mses.mean()
            #mse_std = mses.std()

            accuracies = np.array(accuracies)
            acc_mu = accuracies.mean()
            acc_std = accuracies.std()

            #mse_str = "$ " + str(mse_mu.round(3))+ "\pm" +str(mse_std.round(3)) +" $"
            acc_str = "$ " + str(acc_mu.round(3))+ "\pm" +str(acc_std.round(3)) +" $"

            #pd_dict_mse[data_key] = mse_str
            pd_dict_acc[data_key] = acc_str

        #df = df.append(pd_dict_mse,ignore_index =True)
        df = df.append(pd_dict_acc,ignore_index =True)

In [None]:
print(df.loc[df.Model.str.contains("Acc")].to_latex(escape = False,index= False))