In [None]:
%cd ..\src

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset
from torchmetrics import MetricCollection

from configs.utils import get_config, get_int_from_config, get_float_from_config, get_bool_from_config
from echovpr.datasets.utils import get_dataset, get_subset_dataset, save_tensor
from echovpr.models.utils import get_sparsity
from echovpr.models.single_esn import SingleESN
from echovpr.models.hier_esn import HierESN
from echovpr.models.sparce_layer import SpaRCe
from echovpr.trainer.metrics.recall_top_k_metric import RecallTopKMetric

import logging

logging.basicConfig(level=logging.INFO)

In [None]:
config = get_config("configs\\train_esn_nordland_full.ini")
# config = get_config("configs\\train_esn_nordland_1k.ini")

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
# Init ESN and Lightning Modules

in_features=int(config['model_in_features'])
out_features=int(config['model_out_features'])
esn_num_connections = int(config['model_esn_num_connections'])
esn_hier = get_bool_from_config(config, 'model_esn_hier', False)
sparce_enabled = get_bool_from_config(config, 'model_sparce_enabled')

if not esn_hier:
  reservoir_size=int(config['model_reservoir_size'])
  esn_alpha = float(config['model_esn_alpha'])
  esn_gamma = float(config['model_esn_gamma'])
  esn_rho = float(config['model_esn_rho'])

  model_esn = SingleESN(
    in_features, 
    reservoir_size, 
    alpha=esn_alpha, 
    gamma=esn_gamma, 
    rho=esn_rho,
    sparsity=get_sparsity(esn_num_connections, reservoir_size),
    device=device
  )
else:
  reservoir1_size=int(config['model_reservoir1_size'])
  reservoir2_size=int(config['model_reservoir2_size'])

  esn1_alpha = float(config['model_esn1_alpha'])
  esn1_gamma = float(config['model_esn1_gamma'])
  esn1_rho = float(config['model_esn1_rho'])

  esn2_alpha = float(config['model_esn2_alpha'])
  esn2_gamma = float(config['model_esn2_gamma'])
  esn2_rho = float(config['model_esn2_rho'])

  model_esn = HierESN(
    in_features,
    nReservoir1=reservoir1_size,
    nReservoir2=reservoir2_size,
    alpha1=esn1_alpha,
    alpha2=esn2_alpha,
    gamma1=esn1_gamma,
    gamma2=esn2_gamma,
    rho1=esn1_rho,
    rho2=esn2_rho,
    sparsity1=get_sparsity(esn_num_connections, reservoir1_size),
    sparsity2=get_sparsity(esn_num_connections, reservoir2_size),
    device=device
  )
  
reservoir_output_size = reservoir1_size + reservoir2_size if esn_hier else reservoir_size

model_esn.to(device)

In [None]:
summer_dataset = get_dataset(config['dataset_nordland_summer_hidden_repr_file_path'])
winter_dataset = get_dataset(config['dataset_nordland_winter_hidden_repr_file_path'])

max_n = summer_dataset.tensors[0].max()
_ = summer_dataset.tensors[0].divide_(max_n)
_ = winter_dataset.tensors[0].divide_(max_n)

In [None]:
def process(model, dataLoader, device: torch.device):
    x_processed_list = []
    y_target_list = []
    
    for x, y_target in dataLoader:
        x = x.to(device)
        x_processed = model(x)

        x_processed_list.append(x_processed.cpu())
        y_target_list.append(y_target)

    return (torch.vstack(x_processed_list), torch.vstack(y_target_list))

In [None]:
print(f"Summer dataset size: {len(summer_dataset)}")
summer_dataLoader = DataLoader(summer_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

print(f"Winter dataset size: {len(winter_dataset)}")
winter_dataLoader = DataLoader(winter_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

In [None]:
nordland_summer_repr_x, nordland_summer_repr_y_target  = process(model_esn, summer_dataLoader, device)
nordland_winter_repr_x, nordland_winter_repr_y_target = process(model_esn, winter_dataLoader, device)

nordland_summer_repr_x_cpu = nordland_summer_repr_x.cpu()
nordland_summer_repr_y_target_cpu = nordland_summer_repr_y_target.cpu()
nordland_winter_repr_x_cpu = nordland_winter_repr_x.cpu()
nordland_winter_repr_y_target_cpu = nordland_winter_repr_y_target.cpu()

del nordland_summer_repr_x
del nordland_summer_repr_y_target
del nordland_winter_repr_x
del nordland_winter_repr_y_target
del summer_dataset
del summer_dataLoader
del winter_dataset
del winter_dataLoader

torch.cuda.empty_cache()

summer_dataset = TensorDataset(nordland_summer_repr_x_cpu, nordland_summer_repr_y_target_cpu)
winter_dataset = TensorDataset(nordland_winter_repr_x_cpu, nordland_winter_repr_y_target_cpu)

In [None]:
# Prepare Datasets

train_dataset = summer_dataset
print(f"Train dataset size: {len(train_dataset)}")
train_dataLoader = DataLoader(train_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=True)

val_dataset = get_subset_dataset(winter_dataset, config['dataset_nordland_winter_val_limit_indices_file_path'])
print(f"Validation dataset size: {len(val_dataset)}")
val_dataLoader = DataLoader(val_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

test_dataset = get_subset_dataset(winter_dataset, config['dataset_nordland_winter_test_limit_indices_file_path'])
print(f"Test dataset size: {len(test_dataset)}")
test_dataLoader = DataLoader(test_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

In [None]:
model = nn.ModuleDict()

if sparce_enabled:
  model["sparce"] = SpaRCe(reservoir_output_size)

model["out"] = nn.Linear(in_features=reservoir_output_size, out_features=out_features, bias=True)

model.to(device)

In [None]:
train_dataset_quantiles = None
val_dataset_quantiles = None

if sparce_enabled:
    # Calculate Training Dataset Quantiles
    quantile = float(config['model_sparce_quantile'])
    train_dataset_quantiles = torch.quantile(torch.abs(train_dataset.tensors[0]), quantile, dim=0).to(device)
    val_dataset_quantiles = torch.quantile(torch.abs(torch.vstack([t[0] for t in val_dataset])), quantile, dim=0).to(device)

In [None]:
optimizer_params = []

lr = float(config['train_lr'])

if sparce_enabled:
    lr_sparce = lr / get_int_from_config(config, 'train_lr_sparce_divide_by', 1000)
    optimizer_params.append({'params': model["sparce"].parameters(), 'lr': lr_sparce})

optimizer_params.append({'params': model["out"].parameters()})

if config['train_optimizer'] == 'SGD':
  optimizer = torch.optim.SGD(optimizer_params, lr=lr, momentum=float(config['train_momentum']), weight_decay=float(config['train_weight_decay']))
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(config['train_lr_step']), gamma=float(config['train_lr_gamma']))
else:
  optimizer = torch.optim.Adam(optimizer_params, lr=lr)

In [None]:
criterion = nn.BCEWithLogitsLoss(reduction='mean').to(device)

In [None]:
def eval(model, dataLoader, metrics, sparce_enabled, quantiles):
    tolerance = 10
    correct = 0
    total = 0
    
    for x, y_target in dataLoader:

        x = x.to(device)
        y_target = y_target.to(device)
        
        if sparce_enabled:
            x = model["sparce"](x, quantiles)

        preds = model["out"](x)

        _, indices = torch.topk(preds, 100, dim=1)

        distances = torch.abs(indices - torch.argmax(y_target, dim=1, keepdim=True))

        correct += torch.sum(torch.sum(distances <= tolerance, dim=1) > 0)
        total += x.shape[0]

    return correct / total

In [None]:
ds_tolerance = get_int_from_config(config, 'dataset_tolerance', 10)

train_metrics = MetricCollection(
{
    'recall@1': RecallTopKMetric(top_k=1, tolerance=ds_tolerance),
    'recall@5': RecallTopKMetric(top_k=5, tolerance=ds_tolerance),
    'recall@10': RecallTopKMetric(top_k=10, tolerance=ds_tolerance),
    'recall@20': RecallTopKMetric(top_k=20, tolerance=ds_tolerance),
    'recall@50': RecallTopKMetric(top_k=50, tolerance=ds_tolerance),
    'recall@100': RecallTopKMetric(top_k=100, tolerance=ds_tolerance)
}, prefix='train_').to(device)
val_metrics = train_metrics.clone(prefix='val_').to(device)
test_metrics = train_metrics.clone(prefix='test_').to(device)

best_val_recall_at_1 = 0
save_best_checkpoint = True

run_id = '8000_0.0005_1'

for epoch in range(40):

    train_metrics.reset()

    for x, y_target in train_dataLoader:

        x = x.to(device)
        y_target = y_target.to(device)

        optimizer.zero_grad()

        if sparce_enabled:
            x = model["sparce"](x, train_dataset_quantiles)
        
        y = model["out"](x)

        loss = criterion(y, y_target)
        
        loss.backward()
        optimizer.step()
        
        train_metrics.update(y, y_target.int())
    
    if config['train_optimizer'] == 'SGD':
        scheduler.step()

    print(f"Epoch: {epoch}, Loss: {loss.item()}, Train Metrics: {train_metrics.compute()}")
    
    with torch.no_grad():
        val_metrics_dic = eval(model, val_dataLoader, val_metrics, sparce_enabled, val_dataset_quantiles)
        print(f"Epoch: {epoch}, Val Metric: {val_metrics_dic}")

        current_val_recall_at_1 = val_metrics_dic['val_recall@1']

        is_better = current_val_recall_at_1 > best_val_recall_at_1

        if is_better:
            test_metrics_dic = eval(model, test_dataLoader, test_metrics, sparce_enabled, val_dataset_quantiles)
            print(f"Epoch: {epoch}, Test Metric: {test_metrics_dic}")
            
            if save_best_checkpoint:
                save_tensor(model.state_dict(), f'checkpoints\\checkpoint_{run_id}.pt')