In [None]:
%load_ext autoreload
%autoreload 2

import glob
import json
from omegaconf import OmegaConf
import os
import numpy as np
import pickle
from sklearn.metrics import confusion_matrix
import torch
from tqdm import tqdm
from pathlib import Path

import sys
sys.path.append("../")
from classification.train.train_classification import MultiPartitioningClassifier, load_yaml

In [None]:
# Change working directory so that relative paths are based from `ROOT/g3/` and not
# `ROOT/g3/notebooks/`. This is necessary for MultiPartitioningClassifier to work:
os.chdir(Path("../"))

In [None]:
device = "cuda"
def run_val(model):
    i = 0
    predictions = []
    attentions = []
    labels = []
    batch_ids = []
    outputs = []

    for j, batch in enumerate(tqdm(model.val_dataloader())):
        with torch.no_grad():
            images, target, lats, lngs, ids = batch
            images, target = images.to(device), target.to(device)
            output = model((images, ids))
            predictions.append(output["output"][i])
            if "attn" in output:
                attentions.append(output["attn"]["attn_scores"])
            if type(target) is list:
                labels.append(target[i])
            else:
                labels.append(target)
            batch_ids.extend(ids)
    predictions = torch.cat(predictions)
    labels = torch.cat(labels)
    if attentions:
        attentions = torch.cat(attentions)
    return predictions, labels, attentions, batch_ids

def get_class_accuracies(y_true, y_pred, labels):
    cm = confusion_matrix(y_true.flatten(), y_pred.flatten(), labels=labels)
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    return np.diagonal(cm)

def validate(predictions, labels):
    for k in [1, 5, 10]:
        chosen_predictions = predictions.topk(k=k, dim=-1).indices
        correct = torch.any(chosen_predictions == labels.unsqueeze(dim=-1), dim=-1).sum()
        correct = correct.item() / len(labels)
        print(f"top-{k} acc:", correct)
    labels = labels.detach().cpu().numpy()
    final_predictions = predictions.argmax(dim=-1).detach().cpu().numpy()
    class_accs = get_class_accuracies(labels, final_predictions, range(249))
    print("avg class acc:", np.nanmean(class_accs))
        
def save(name, config, predictions, labels, attentions, batch_ids):
    anns = []
    for i in range(labels.shape[0]):
        ann = {}
        ann["label"] = labels[i].item()
        ann["predictions"] = predictions[i].cpu().numpy()
        if attentions != []:
            ann["attn"] = attentions[i].cpu().numpy()
        ann["id"] = batch_ids[i]
        anns.append(ann)

    folder = os.path.dirname(config.model_params.weights).replace("/ckpts", "")
    pickle.dump(anns, open(f"{folder}/{name}", "wb"))

In [None]:
out_dir = Path("../g3/weights").resolve()
assert out_dir.exists(), str(out_dir)

# Evaluate trained weights
# configs = {
#     f"{out_dir}/resnet50_image": "resnet50_image.yml",
#     f"{out_dir}/resnet50_image_and_clip": "resnet50_image_and_clip.yml",
#     f"{out_dir}/resnet50_image_and_clues": "resnet50_image_and_clues.yml",
#     f"{out_dir}/resnet50_image_clip_clues": "resnet50_image_clip_clues.yml",
# }

# Evaluate publicly shared weights
configs = {
    f"{out_dir / 'g3'}": "resnet50_image_clip_clues.yml",
}

In [None]:
eval_test = True
save_predictions = False

In [None]:
for folder, config_name in configs.items():
    for ckpt in glob.glob(f"{folder}/*/ckpts/last.ckpt"):
        config = load_yaml(f"./config/{config_name}")
        config.model_params.weights = ckpt

        if eval_test:
            config.model_params.msgpack_val_dir = "${data_dir}/dataset/test/msgpack"
            config.model_params.val_meta_path = "${data_dir}/dataset/test/test.csv"
            config.model_params.val_label_mapping = "${data_dir}/dataset/test/label_mapping/countries.json"
            name = "predictions_test.json"
        else:
            name = "predictions_val.json"

        model = MultiPartitioningClassifier(config["model_params"], None)
        model = model.to(device)
        model = model.eval()

        predictions, labels, attentions, batch_ids = run_val(model)
        validate(predictions, labels)

        if save_predictions:
            save(name, config, predictions, labels, attentions, batch_ids)