In [None]:
%cd ..\src

In [None]:
from os.path import isfile, join

import wandb
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
import numpy as np

from configs.utils import get_config, get_bool_from_config
from echovpr.datasets.utils import get_dataset, get_subset_dataset
from echovpr.models.utils import get_sparsity
from echovpr.models.single_esn import SingleESN
from echovpr.models.sparce_layer import SpaRCe

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
config = get_config("configs\\train_esn_nordland_full_sweep.ini", log=False)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
artifact_name = 'uos_ml/echovpr/esn_4wkrv7z1:v0'

In [None]:
run = wandb.init()
artifact = run.use_artifact(artifact_name, type='model')
artifact_dir = artifact.download()

model_file = join(artifact_dir, 'model.pt')
esn_model_file = join(artifact_dir, 'esn_model.pt')

all_in_one = not isfile(esn_model_file)

In [None]:
in_features=int(config['model_in_features'])
reservoir_size=int(config['model_reservoir_size'])
out_features=int(config['model_out_features'])

esn_alpha = float(config['model_esn_alpha'])
esn_gamma = float(config['model_esn_gamma'])
esn_rho = float(config['model_esn_rho'])
esn_num_connections = int(config['model_esn_num_connections'])
sparce_enabled = get_bool_from_config(config, 'model_sparce_enabled')

model = nn.ModuleDict()

esn_model = SingleESN(
  in_features, 
  reservoir_size, 
  alpha=esn_alpha, 
  gamma=esn_gamma, 
  rho=esn_rho,
  sparsity=get_sparsity(esn_num_connections, reservoir_size),
  device=device
)

if all_in_one:
  model["esn"] = esn_model

if sparce_enabled:
  model["sparce"] = SpaRCe(reservoir_size)

model["out"] = nn.Linear(in_features=reservoir_size, out_features=out_features, bias=True)

In [None]:
if not all_in_one:
  esn_model.load_state_dict(torch.load(esn_model_file))

model.load_state_dict(torch.load(model_file))

In [None]:
if not all_in_one:
  esn_model.eval().to(device)
  
model.eval().to(device)

In [None]:
summer_dataset = get_dataset(config['dataset_nordland_summer_hidden_repr_file_path'])
winter_dataset = get_dataset(config['dataset_nordland_winter_hidden_repr_file_path'])

max_n = summer_dataset.tensors[0].max()
_ = summer_dataset.tensors[0].divide_(max_n)
_ = winter_dataset.tensors[0].divide_(max_n)

In [None]:
def process(model, dataLoader, device: torch.device):
    x_processed_list = []
    y_target_list = []
    
    for x, y_target in dataLoader:
        x = x.to(device)
        x_processed = model(x)

        x_processed_list.append(x_processed.cpu())
        y_target_list.append(y_target)

    return (torch.vstack(x_processed_list), torch.vstack(y_target_list))

In [None]:
print(f"Winter dataset size: {len(winter_dataset)}")
winter_dataLoader = DataLoader(winter_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

In [None]:
winter_dataset = TensorDataset(*process(esn_model, winter_dataLoader, device))
dataset_size = len(winter_dataset)

In [None]:
# Prepare Datasets

val_dataset = get_subset_dataset(winter_dataset, config['dataset_nordland_winter_val_limit_indices_file_path'])
print(f"Validation dataset size: {len(val_dataset)}")
val_dataLoader = DataLoader(val_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

test_dataset = get_subset_dataset(winter_dataset, config['dataset_nordland_winter_test_limit_indices_file_path'])
print(f"Test dataset size: {len(test_dataset)}")
test_dataLoader = DataLoader(test_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

In [None]:
val_dataset_quantiles = None

if sparce_enabled:
    # Calculate Training Dataset Quantiles
    quantile = float(config['model_sparce_quantile'])
    val_dataset_quantiles = torch.quantile(torch.abs(torch.vstack([t[0] for t in val_dataset])), quantile, dim=0).to(device)

In [None]:
def get_predictions(model, dataLoader, sparce_enabled, quantiles):
    predictions = []
    ground_truths = []

    with torch.no_grad():    
        for x, y_target in dataLoader:

            x = x.to(device)
            
            if sparce_enabled:
                x = model["sparce"](x, quantiles)

            preds = model["out"](x)

            predictions.append(preds.cpu())
            ground_truths.append(y_target)

    return (torch.vstack(predictions), torch.vstack(ground_truths))

In [None]:
val_preds = TensorDataset(*get_predictions(model, val_dataLoader, sparce_enabled, val_dataset_quantiles))
test_preds = TensorDataset(*get_predictions(model, test_dataLoader, sparce_enabled, val_dataset_quantiles))

In [None]:
def eval_top_k(predictions, top_k = 100):
    top_k_predictions = []
    ground_truths = []

    for  _, (preds, y_target) in enumerate(predictions):

        _, predIdx = torch.topk(preds, top_k)

        top_k_predictions.append(predIdx.cpu())
        ground_truths.append(torch.argmax(y_target, keepdim=True))

    return (torch.vstack(top_k_predictions), torch.vstack(ground_truths))

In [None]:
val_top_k_preds = TensorDataset(*eval_top_k(val_preds))
test_top_k_preds = TensorDataset(*eval_top_k(test_preds))

In [None]:
# memory cleanup
del model
del val_dataset
del val_dataLoader
del test_dataset
del test_dataLoader
del winter_dataset
del winter_dataLoader

torch.cuda.empty_cache()

In [None]:
def compute_recall(gt, predictions, numQ, n_values, recall_str=''):
    correct_at_n = np.zeros(len(n_values))
    for qIx, pred in enumerate(predictions):
        for i, n in enumerate(n_values):
            # if in top N then also in top NN, where NN > N
            if np.any(np.in1d(pred[:n], gt[qIx])):
                correct_at_n[i:] += 1
                break
    recall_at_n = correct_at_n / numQ
    all_recalls = {}  # make dict for output
    for i, n in enumerate(n_values):
        all_recalls[n] = recall_at_n[i]
        print("====> Recall {}@{}: {:.4f}".format(recall_str, n, recall_at_n[i]))
    return all_recalls

In [None]:
dataset_size = 27592
dataset_tolerance = 10

def get_positives(gt, dataset_tolerance, dataset_size):
    return [list(filter(lambda x: (x >= 0 and x < dataset_size), range(i.item() - dataset_tolerance, i.item() + dataset_tolerance + 1))) for i in gt]

In [None]:
n_values = [1, 5, 10, 20, 50, 100]
gt_val = get_positives(val_top_k_preds.tensors[1], dataset_tolerance, dataset_size)
val_recalls = compute_recall(gt_val, val_top_k_preds.tensors[0], len(val_top_k_preds), n_values, 'for Val EchoVPR')
gt_test = get_positives(test_top_k_preds.tensors[1], dataset_tolerance, dataset_size)
test_recalls = compute_recall(gt_test, test_top_k_preds.tensors[0], len(test_top_k_preds), n_values, 'for Test EchoVPR')