In [None]:
%cd ..\src

In [None]:
import argparse
import logging
import os

import torch
from torch import nn

from configs.utils import (get_bool_from_config, get_config, get_config_wandb,
                           get_value_from_namespace,
                           get_value_from_namespace_or_raise,
                           update_config_wandb)
from echovpr.models.single_esn import SingleESN
from echovpr.models.sparce_layer import SpaRCe
from echovpr.models.utils import get_sparsity
from echovpr.trainer.eval import run_eval
from echovpr.trainer.metrics.recall import compute_recall
from echovpr.trainer.prepare_esn_datasets import prepare_esn_datasets
from echovpr.trainer.prepare_final_datasets import (get_dataset_infos,
                                                    prepare_final_datasets)
from echovpr.trainer.process_patchnetvlad import local_matcher

logging.basicConfig(level=logging.INFO)
log = logging.getLogger()

env_torch_device = os.environ.get("TORCH_DEVICE")
if env_torch_device is not None:
    device = torch.device(env_torch_device)
    log.info(f'Setting device set by environment to {env_torch_device}')
else:
    device = torch.device(
        "cuda") if torch.cuda.is_available() else torch.device("cpu")
    log.info('Setting default available device')

os.environ["WANDB_SILENT"] = "true"

In [None]:
def load_model(run, artifact_name: str, model_name: str) -> torch.Tensor:
    model_artifact = run.use_artifact(artifact_name, type='model')
    model_dir = model_artifact.download()
    return torch.load(os.path.join(model_dir, model_name))

def load_mode_from_dir(model_dir: str, model_name: str) -> torch.Tensor:
    return torch.load(os.path.join(model_dir, model_name))

def model_forward(model, x, kwargs):
    if kwargs['sparce_enabled']:
        x = model["sparce"](x, kwargs['dataset_quantiles'])
    
    y = model["out"](x)
    return y

def set_summary_props(run, recall_dic, key_prefix):
    for key in recall_dic:
        run.summary[f"{key_prefix}@{key}"] = recall_dic[key]


In [None]:
config_file ='configs\\train_esn_nordland.ini'
project = 'echovpr_nordland'
entity = 'uos_ml'

model_dir = 'pretrained_models\\nordland_netvlad_esn'
# model_dir = 'pretrained_models\\nordland_netvlad_esn_sparse'
# model_dir = 'pretrained_models\\oxford_netvlad_esn'
# model_dir = 'pretrained_models\\oxford_netvlad_esn_sparse'

In [None]:
#  Setup config
run, config = get_config_wandb(config_file, project=project, entity=entity, logger=log, log=False)

# Setup ESN
in_features=int(config['model_in_features'])
reservoir_size=int(config['model_reservoir_size'])
out_features=int(config['model_out_features'])

esn_alpha = float(config['model_esn_alpha'])
esn_gamma = float(config['model_esn_gamma'])
esn_rho = float(config['model_esn_rho'])
esn_num_connections = int(config['model_esn_num_connections'])
sparce_enabled = get_bool_from_config(config, 'model_sparce_enabled')

model_esn = SingleESN(
    in_features, 
    reservoir_size, 
    alpha=esn_alpha, 
    gamma=esn_gamma, 
    rho=esn_rho,
    sparsity=get_sparsity(esn_num_connections, reservoir_size),
    device=device
)

esn_model_tensor = load_mode_from_dir(model_dir, 'esn_model.pt')
model_esn.load_state_dict(esn_model_tensor)

model = nn.ModuleDict()

if sparce_enabled:
    model["sparce"] = SpaRCe(reservoir_size)

model["out"] = nn.Linear(in_features=reservoir_size, out_features=out_features, bias=True)

model_tensor = load_mode_from_dir(model_dir, 'model.pt')
model.load_state_dict(model_tensor, strict=False)

# Move to device
model_esn.eval().to(device)
model.eval().to(device)

# Load datasets, normalize and process through ESN

esn_descriptors = prepare_esn_datasets(model_esn, config, device, log, eval_only=True)

del model_esn

torch.cuda.empty_cache()

_, _, val_dataset, val_dataLoader, test_dataLoader, _, eval_gt = prepare_final_datasets(esn_descriptors, config, eval_only=True)

val_dataset_quantiles = None
if sparce_enabled:
    # Calculate Training Dataset Quantiles
    quantile = float(config['model_sparce_quantile'])
    val_dataset_quantiles = torch.quantile(torch.abs(torch.vstack([t[0] for t in val_dataset])), quantile, dim=0).to(device)


In [None]:
n_values = [1, 5, 10, 20, 50, 100]
top_k = 27592

# if options.validation:
#     _, val_predictions = run_eval(model, val_dataLoader, eval_gt, n_values, top_k, device, model_forward=model_forward, sparce_enabled=sparce_enabled, dataset_quantiles=val_dataset_quantiles)
#     val_recalls = compute_recall(eval_gt, val_predictions, len(val_predictions), n_values, print_recall=True, recall_str='Eval on Validation Set')
#     set_summary_props(run, val_recalls, 'best_val_recall')

_, test_predictions = run_eval(model, test_dataLoader, eval_gt, n_values, top_k, device, model_forward=model_forward, sparce_enabled=sparce_enabled, dataset_quantiles=val_dataset_quantiles)
# test_recalls = compute_recall(eval_gt, test_predictions, len(test_predictions), n_values, print_recall=True, recall_str='Eval on Test Set')
# set_summary_props(run, test_recalls, 'best_test_recall')

# if 'patchnetvlad_config_file' in options:
#     patchnetvlad_config = get_config(options.patchnetvlad_config_file, log)
#     train_dataset_info, val_test_dataset_info = get_dataset_infos(config)

#     input_index_local_features_prefix = os.path.join(get_value_from_namespace_or_raise(options, 'index_input_features_dir'), 'patchfeats')

#     if options.validation:
#         input_val_local_features_prefix = os.path.join(get_value_from_namespace_or_raise(options, 'val_input_features_dir'), 'patchfeats')
        
#         reranked_val_predictions = local_matcher(val_predictions, patchnetvlad_config, input_val_local_features_prefix, input_index_local_features_prefix, val_test_dataset_info, train_dataset_info, device)
#         val_patch_recalls = compute_recall(eval_gt, reranked_val_predictions, len(val_predictions), n_values, print_recall=True, recall_str='PatchNetVLAD Eval on Validation Set')
#         set_summary_props(run, val_patch_recalls, 'best_val_patch_recall')    

#     input_test_local_features_prefix = os.path.join(get_value_from_namespace_or_raise(options, 'test_input_features_dir'), 'patchfeats')

#     reranked_test_predictions = local_matcher(test_predictions, patchnetvlad_config, input_test_local_features_prefix, input_index_local_features_prefix, val_test_dataset_info, train_dataset_info, device)
#     test_patch_recalls = compute_recall(eval_gt, reranked_test_predictions, len(test_predictions), n_values, print_recall=True, recall_str='PatchNetVLAD Eval on Test Set')
#     set_summary_props(run, test_patch_recalls, 'best_test_patch_recall')

In [None]:
import numpy as np
import scipy
from sklearn.metrics import precision_recall_curve, auc

In [None]:
testy = []
lr_probs = []
len_probs = len(test_predictions)

for qIx, pred in test_predictions:
    sm = scipy.special.softmax(pred, axis=None)
    # print(f"ID: {qIx}, Pred: {np.argmax(pred)}, GT: {np.any(np.in1d([np.argmax(pred)], eval_gt[qIx]))}, Raw: {pred[np.argmax(pred)]}, Prod: {sm[np.argmax(sm)]}")
    # lr_probs.append(pred[np.argmax(pred)])
    
    lr_probs.append(sm[np.argmax(sm)])
    testy.append(np.any(np.in1d([np.argmax(pred)], eval_gt[qIx])))

lr_probs = np.stack(lr_probs)
testy = np.stack(testy)

In [None]:
lr_precision, lr_recall, _ = precision_recall_curve(testy, lr_probs)
AUC  = auc(lr_recall, lr_precision)

In [None]:
AUC

In [None]:
from sklearn.metrics import PrecisionRecallDisplay
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 200
display = PrecisionRecallDisplay.from_predictions(testy, lr_probs, name="EchoVPR nordland")
_ = display.ax_.set_title("2-class Precision-Recall curve")