In [None]:
%cd ..\src
!python setup.py develop

In [None]:
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import TensorDataset, Subset

from configs.utils import get_config_wandb, get_int_from_config
from echovpr.trainer.metrics.recall import compute_recall
from echovpr.datasets.utils import load_np_file, get_1_hot_encode
from echovpr.trainer.eval import eval

import wandb
import os
import logging

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "notebooks/train_oxford_hidden_layer.ipynb"
wandb.login()

logging.basicConfig(level=logging.INFO)
log = logging.getLogger(__name__)

In [None]:
run, config = get_config_wandb("configs\\train_mlp_oxford.ini", log, project="echovpr_oxford_hl")

In [None]:
# Prepare Datasets
day_dataset_info = load_np_file(config['dataset_oxford_day_dataset_file_path'])

day_gt = day_dataset_info['ground_truth_indices']

day_image_idx = torch.from_numpy(day_dataset_info['image_indices'])
image_1_hot = torch.from_numpy(get_1_hot_encode(day_dataset_info['image_indices'], len(day_dataset_info['image_indices']))).type(torch.float)
netvlad_repr = torch.from_numpy(load_np_file(config['dataset_oxford_day_netvlad_repr_file_path']))

train_dataset = TensorDataset(netvlad_repr, image_1_hot, day_image_idx)
print(f"Train dataset size: {len(train_dataset)}")
train_dataLoader = DataLoader(train_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=True)

val_test_splits = load_np_file(config['dataset_oxford_night_val_test_splits_indices_file_path'])
night_dataset_info = load_np_file(config['dataset_oxford_night_dataset_file_path'])

night_gt = night_dataset_info['ground_truth_indices']
netvlad_repr = torch.from_numpy(load_np_file(config['dataset_oxford_night_netvlad_repr_file_path']))
night_image_idx = torch.from_numpy(night_dataset_info['image_indices'])

night_dataset = TensorDataset(netvlad_repr, night_image_idx)

val_dataset = Subset(night_dataset, val_test_splits['val_indices'])
print(f"Val dataset size: {len(val_dataset)}")
val_dataLoader = DataLoader(val_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=True)

test_dataset = Subset(night_dataset, val_test_splits['test_indices'])
print(f"Test dataset size: {len(test_dataset)}")
test_dataLoader = DataLoader(test_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=True)

In [None]:
# Init MLP and Lightning Modules
in_features=int(config['model_in_features'])
hidden_features=int(config['model_hidden_features'])
out_features=int(config['model_out_features'])

layers = []

if hidden_features > 0:
  layers.append(('hl', nn.Linear(in_features=in_features, out_features=hidden_features, bias=True)))
  out_layer_in_features = hidden_features
else:
  out_layer_in_features = in_features

layers.append(('out', nn.Linear(in_features=out_layer_in_features, out_features=out_features, bias=True)))

model = nn.Sequential(OrderedDict(layers))
model.to(device)

lr = float(config['train_lr'])    

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.BCEWithLogitsLoss(reduction='mean').to(device)

# Watch Model
wandb.watch(model, criterion=criterion, log="all", idx=1, log_graph=True)

In [None]:
n_values = [1, 5, 10, 20, 50, 100]
top_k = max(n_values)

In [None]:
max_epochs = get_int_from_config(config, 'train_max_epochs', 1)
num_batches = len(train_dataLoader)

steps = 0
best_val_recall_at_1 = 0
save_best_checkpoint = True

model_path = os.path.join(wandb.run.dir, 'model.pt')

for epoch in range(1, max_epochs + 1):
    epoch_loss = 0.0

    predictions = []

    for x, y_target, y_idx in train_dataLoader:
        steps += 1

        x = x.to(device)
        y_target = y_target.to(device)

        optimizer.zero_grad()

        y = model(x)

        with torch.no_grad():
            _, predIdx = torch.topk(y, top_k)
            predictions += zip(y_idx.numpy(), predIdx.cpu().numpy())

        loss = criterion(y, y_target)
            
        loss.backward()
        optimizer.step()

        batch_loss = loss.item()
        
        epoch_loss += batch_loss

    avg_loss = epoch_loss / num_batches
    train_recalls = compute_recall(day_gt, predictions, len(predictions), n_values)
    
    print(f"Epoch: {epoch}, Loss: {loss.item()}, Train Metrics: {train_recalls}")

    with torch.no_grad():
        val_recalls = eval(model, val_dataLoader, night_gt, n_values, top_k)
        print(f"Epoch: {epoch}, Val Metric: {val_recalls}")

        current_val_recall_at_1 = val_recalls[1]

        is_better = current_val_recall_at_1 > best_val_recall_at_1

        if is_better:
            best_val_recall_at_1 = current_val_recall_at_1
            test_recalls = eval(model, test_dataLoader, night_gt, n_values, top_k)
            print(f"Epoch: {epoch}, Test Metric: {test_recalls}")
            
            if save_best_checkpoint:
                torch.save(model.state_dict(), model_path)
        else:
            test_recalls = {}

    log_dic = {'train_loss': avg_loss, "epoch": epoch}

    for k, v in train_recalls.items():
        log_dic[f"train_recall@{k}"] = v

    for k, v in val_recalls.items():
        log_dic[f"val_recall@{k}"] = v

    for k, v in test_recalls.items():
        log_dic[f"test_recall@{k}"] = v

    wandb.log(log_dic, step=steps)

In [None]:
model_artifact = wandb.Artifact(f'hl_model_{run.id}', "model", metadata=config)
model_artifact.add_file(model_path)
wandb.log_artifact(model_artifact, aliases=["best"]) 

In [None]:
run.finish()