In [None]:
%cd ..\src

In [None]:
from os.path import isfile, join

import wandb
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset, Subset
import numpy as np

from configs.utils import get_config, get_bool_from_config
from echovpr.datasets.utils import get_dataset, get_subset_dataset
from echovpr.models.utils import get_sparsity
from echovpr.models.single_esn import SingleESN
from echovpr.models.sparce_layer import SpaRCe

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
config = get_config("configs\\train_esn_nordland_full.ini", log=False)

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
artifact_name = 'uos_ml/echovpr/esn_9km0ic3z:v0'

In [None]:
run = wandb.init()
artifact = run.use_artifact(artifact_name, type='model')
artifact_dir = artifact.download()

model_file = join(artifact_dir, 'model.pt')
esn_model_file = join(artifact_dir, 'esn_model.pt')

all_in_one = not isfile(esn_model_file)

In [None]:
in_features=int(config['model_in_features'])
reservoir_size=int(config['model_reservoir_size'])
out_features=int(config['model_out_features'])

esn_alpha = float(config['model_esn_alpha'])
esn_gamma = float(config['model_esn_gamma'])
esn_rho = float(config['model_esn_rho'])
esn_num_connections = int(config['model_esn_num_connections'])
sparce_enabled = get_bool_from_config(config, 'model_sparce_enabled')

model = nn.ModuleDict()

esn_model = SingleESN(
  in_features, 
  reservoir_size, 
  alpha=esn_alpha, 
  gamma=esn_gamma, 
  rho=esn_rho,
  sparsity=get_sparsity(esn_num_connections, reservoir_size),
  device=device
)

if all_in_one:
  model["esn"] = esn_model

if sparce_enabled:
  model["sparce"] = SpaRCe(reservoir_size)

model["out"] = nn.Linear(in_features=reservoir_size, out_features=out_features, bias=True)

In [None]:
if not all_in_one:
  esn_model.load_state_dict(torch.load(esn_model_file))

model.load_state_dict(torch.load(model_file))

In [None]:
if not all_in_one:
  esn_model.eval().to(device)
  
model.eval().to(device)

In [None]:
summer_dataset = get_dataset(config['dataset_nordland_summer_hidden_repr_file_path'])
winter_dataset = get_dataset(config['dataset_nordland_winter_hidden_repr_file_path'])

max_n = summer_dataset.tensors[0].max()
_ = summer_dataset.tensors[0].divide_(max_n)
_ = winter_dataset.tensors[0].divide_(max_n)

In [None]:
def process(model, dataLoader, device: torch.device):
    x_processed_list = []
    y_target_list = []
    
    for x, y_target in dataLoader:
        x = x.to(device)
        x_processed = model(x)

        x_processed_list.append(x_processed.cpu())
        y_target_list.append(y_target)

    return (torch.vstack(x_processed_list), torch.vstack(y_target_list))

In [None]:
def eval_esn(model, dataLoader, sparce_enabled, quantiles, top_k = 100):
    predictions = []
    ground_truths = []

    with torch.no_grad():    
        for x, y_target in dataLoader:

            x = x.to(device)
            
            if sparce_enabled:
                x = model["sparce"](x, quantiles)

            preds = model["out"](x)

            _, predIdx = torch.topk(preds, top_k, dim=1)

            predictions.append(predIdx.cpu())
            ground_truths.append(torch.argmax(y_target, dim=1, keepdim=True))

    return (torch.vstack(predictions), torch.vstack(ground_truths))

In [None]:
dataset_size = len(winter_dataset)
dataset_tolerance = 10
n_values = [1, 5, 10, 20, 50, 100]

def get_positives(gt, dataset_tolerance, dataset_size):
    return [list(filter(lambda x: (x >= 0 and x < dataset_size), range(i.item() - dataset_tolerance, i.item() + dataset_tolerance + 1))) for i in gt]

In [None]:
def p(starting_point, sim_length, esn_model, model, sparce_enabled, winter_dataset, gt, n_values, device):
    with torch.no_grad():
        # start from a different point
        winter_dataset_subset = Subset(winter_dataset, range(starting_point, starting_point + sim_length))
        gt = gt[starting_point:]
        
        print(f"Winter dataset size: {len(winter_dataset_subset)}")

        winter_dataLoader = DataLoader(winter_dataset_subset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

        winter_esn_dataset = TensorDataset(*process(esn_model, winter_dataLoader, device))
        
        winter_dataset_quantiles = None

        if sparce_enabled:
            # Calculate Training Dataset Quantiles
            quantile = float(config['model_sparce_quantile'])
            winter_dataset_quantiles = torch.quantile(torch.abs(torch.vstack([t[0] for t in winter_esn_dataset])), quantile, dim=0).to(device)

        dataLoader = DataLoader(winter_esn_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)
        predictions = TensorDataset(*eval_esn(model, dataLoader, sparce_enabled, winter_dataset_quantiles))

        correct_at_n = np.zeros((len(predictions), len(n_values)))

        for qIx, pred in enumerate(predictions.tensors[0]):
            for i, n in enumerate(n_values):
                # if in top N then also in top NN, where NN > N
                if np.any(np.in1d(pred[:n], gt[qIx])):
                    correct_at_n[qIx, i:] += 1
                    break

        del winter_dataset_subset
        del winter_dataLoader
        del winter_esn_dataset
        del dataLoader

        torch.cuda.empty_cache()
        
        return correct_at_n

In [None]:
torch.cuda.empty_cache()

In [None]:
gt = get_positives(torch.argmax(winter_dataset.tensors[1], dim=1), dataset_tolerance, dataset_size)

# correct_at_n = p(0, esn_model, model, sparce_enabled, winter_dataset, gt, n_values, device)

In [None]:
starting_positions = np.linspace(0, dataset_size - 10000, 500).astype(int)
sim_length = 500

In [None]:
lists = []

In [None]:
for starting_point in starting_positions[160:]:
    print(f"Starting point: {starting_point}")
    correct_at_n = p(starting_point, sim_length, esn_model, model, sparce_enabled, winter_dataset, gt, n_values, device)
    lists.append(correct_at_n)

In [None]:
np.savez_compressed('results/multiple_startingpoints_340.npz', starting_positions=starting_positions[160:], sim_length=sim_length, n_values=n_values, correct_at_n=lists)

In [None]:
import numpy as np

results = np.load('results/multiple_startingpoints_full.npz')
correct_at_n_list = results['correct_at_n']

In [None]:
final_correct_at_n = np.zeros((sim_length, len(n_values)))

for i, correct_at_n in enumerate(correct_at_n_list):
    final_correct_at_n += correct_at_n

In [None]:
f_c = final_correct_at_n / len(correct_at_n_list)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.plot(f_c)

In [None]:
for i, _ in enumerate(n_values):
    plt.plot(np.convolve(f_c[:, i], np.ones(20)/20, mode='valid'))

In [None]:
plt.plot(np.convolve(f_c[:200, 5], np.ones(20)/20, mode='valid'))

In [None]:
n_values = [1, 5, 10, 20, 50, 100]

In [None]:
plt.plot(np.convolve(lists[3][:1000, 0], np.ones(20)/20, mode='valid'))

In [None]:
print(np.mean(lists[0][-10000:, 0]))
print(np.mean(lists[1][-10000:, 0]))
print(np.mean(lists[2][-10000:, 0]))
print(np.mean(lists[3][-10000:, 0]))
print(np.mean(lists[4][:, 0]))