#### This notebook fetches the models from Comet ML and runs them on the holdout set. Its output is a CSV file with the metrics for each model for a given dataset, either Cityscapes or NYUv2.

In [8]:
import os
import argparse
import yaml
from pathlib import Path

import pandas as pd

import torch
from comet_ml import API
%autoreload 2

In [9]:
from vision_mtl.utils.comet_utils import load_artifacts_from_comet, model_to_exp_name
from vision_mtl.lit_module import MTLModule
from vision_mtl.utils.pipeline_utils import build_model
from vision_mtl.utils.pipeline_utils import load_ckpt_model
from vision_mtl.cfg import cfg, root_dir
from vision_mtl.lit_datamodule import MTLDataModule
from vision_mtl.training_lit import predict
from vision_mtl.utils.pipeline_utils import fetch_data_cfg

In [10]:
comet_api = API(api_key=os.environ["COMET_API_KEY"])

In [11]:
dataset_name = "cityscapes"
# dataset_name = "nyuv2"

if dataset_name == "cityscapes":
    data_cfg = fetch_data_cfg("cityscapes")
else:
    data_cfg = fetch_data_cfg("nyuv2")

In [None]:
metrics_df_dict = {}

for idx, model_name in enumerate(model_to_exp_name.keys()):
    exp_name = model_to_exp_name[model_name][dataset_name]
    artifacts_dir = f"{root_dir}/artifacts/{exp_name}"
    args_name_no_ext = "train_args"
    model_artifact_name = "model"
    artifacts = load_artifacts_from_comet(
        exp_name=exp_name,
        local_artifacts_dir=artifacts_dir,
        model_artifact_name=model_artifact_name,
        args_name_no_ext=args_name_no_ext,
        api=comet_api,
    )
    args = argparse.Namespace(
        **yaml.safe_load(
            open(
                artifacts["args_path"],
                "r",
            )
        )["args"]
    )

    model = build_model(args, data_cfg)
    module = MTLModule(model=model, num_classes=data_cfg.num_classes)

    model_ckpt = load_ckpt_model(Path(artifacts["checkpoint_path"]).parent)
    module.load_state_dict(model_ckpt["model"])

    datamodule = MTLDataModule(
        dataset_name=dataset_name,
        batch_size=args.batch_size,
        do_overfit=args.do_overfit,
        train_transform=data_cfg.train_transform,
        test_transform=data_cfg.test_transform,
        num_workers=0,
    )
    datamodule.setup()

    preds, predict_metrics = predict(
        datamodule.predict_dataloader(),
        module,
        cfg.device,
        do_plot_preds=False,
        do_show_preds=True,
    )
    torch.save(preds, os.path.join(artifacts_dir, "preds.pt"))
    metrics_df_dict[model_name] = predict_metrics
    print(predict_metrics)
    print(f'Done with {model_name}')

In [None]:
metrics_df = pd.DataFrame(metrics_df_dict).round(3)
metrics_df