In [None]:
%cd ..\src
!python setup.py develop

In [None]:
from collections import OrderedDict

import torch.nn as nn
from torch.utils.data import DataLoader

from echovpr.configs.utils import get_config, get_int_from_config, get_bool_from_config
from echovpr.trainer.classification_task import ClassificationTask
from echovpr.datasets.utils import get_dataset

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

import wandb
import os
import logging

In [None]:
os.environ["WANDB_NOTEBOOK_NAME"] = "notebooks/train_nordland_hidden_layer.ipynb"
wandb.login()

wandb_logger = WandbLogger(project="echovpr_nordland_hl", log_model=True)

logging.basicConfig(level=logging.INFO)

In [None]:
config = get_config("configs\\train_mlp_nordland_full.ini")

In [None]:
# Init MLP and Lightning Modules
in_features=int(config['model_in_features'])
hidden_features=int(config['model_hidden_features'])
out_features=int(config['model_out_features'])

layers = []

if hidden_features > 0:
  layers.append(('hl', nn.Linear(in_features=in_features, out_features=hidden_features, bias=True)))
  out_layer_in_features = hidden_features
else:
  out_layer_in_features = in_features

layers.append(('out', nn.Linear(in_features=out_layer_in_features, out_features=out_features, bias=True)))

model = nn.Sequential(OrderedDict(layers))

pl_model = ClassificationTask(model, config)

# Watch Model
wandb_logger.watch(pl_model, log="all", log_graph=True)

In [None]:
# Prepare Datasets

train_dataset = get_dataset(config['dataset_nordland_summer_netvlad_repr_file_path'])
print(f"Train dataset size: {len(train_dataset)}")
train_dataLoader = DataLoader(train_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=True)

val_dataset = get_dataset(config['dataset_nordland_winter_netvlad_repr_file_path'], config['dataset_nordland_winter_val_limit_indices_file_path'])
print(f"Validation dataset size: {len(val_dataset)}")
val_dataLoader = DataLoader(val_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

test_dataset = get_dataset(config['dataset_nordland_winter_netvlad_repr_file_path'], config['dataset_nordland_winter_test_limit_indices_file_path'])
print(f"Test dataset size: {len(test_dataset)}")
test_dataLoader = DataLoader(test_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

In [None]:
# Create PL Trainer

callbacks = []

checkpoint_callback = ModelCheckpoint(
    monitor="val_recall@1", 
    mode="max",
    auto_insert_metric_name=True,
    filename='checkpoint_{epoch:02d}-{val_recall@1:.4f}',
)
callbacks.append(checkpoint_callback)

if get_bool_from_config(config, 'early_stopping_enabled'):
    patience = int(config['early_stopping_patience'])
    min_delta = float(config['early_stopping_min_delta'])
    
    early_stopping = EarlyStopping(
        monitor='val_recall@1',
        mode='max',
        patience=patience,
        min_delta=min_delta,
        check_on_train_epoch_end=False,
    )
    callbacks.append(early_stopping)

trainer = pl.Trainer(
    gpus=1, 
    max_epochs=get_int_from_config(config, 'train_max_epochs', None), 
    min_epochs=get_int_from_config(config, 'train_min_epochs', None),
    max_steps=get_int_from_config(config, 'train_max_steps', -1),
    min_steps=get_int_from_config(config, 'train_min_steps', None),
    logger=wandb_logger, 
    callbacks=callbacks,
)

In [None]:
trainer.fit(pl_model, train_dataLoader, val_dataLoader)

In [None]:
# trainer.save_checkpoint("../checkpoint/nordland_pittsburgh_WPCA4096_mlp_01.ckpt")

In [None]:
# trainer.test(pl_model, test_dataLoader, ckpt_path="../checkpoint/nordland_pittsburgh_WPCA4096_mlp_01.ckpt")
trainer.test(pl_model, test_dataLoader)

In [None]:
wandb.finish()