In [None]:
%cd ..\src
!python setup.py develop

In [None]:
import wandb
import os
import numpy as np
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from configs.utils import get_config
from echovpr.datasets.utils import get_dataset, save_tensor


In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [None]:
config = get_config("configs\\train_mlp_nordland_full.ini")

In [None]:
run = wandb.init(project="echovpr_nordland_hl")

In [None]:
def load_model(artifact_name: str, model_name: str) -> str:
    model_artifact = run.use_artifact(artifact_name, type='model')
    model_dir = model_artifact.download()
    return os.path.join(model_dir, model_name)

In [None]:
model_path = load_model('mscerri/echovpr_nordland_hl/model-qsp4802p:v0', 'model.ckpt')

In [None]:
# Init MLP and Lightning Modules
in_features=int(config['model_in_features'])
hidden_features=int(config['model_hidden_features'])
out_features=int(config['model_out_features'])

model = nn.Sequential(OrderedDict([
          ('hl', nn.Linear(in_features=in_features, out_features=hidden_features, bias=True)),
          ('out', nn.Linear(in_features=hidden_features, out_features=out_features, bias=True))
        ]))

pl_model = ClassificationTask.load_from_checkpoint(model_path, map_location={'cuda:0':'cuda:0'}, model=model, config=config)

In [None]:
pl_model.eval()
pl_model.freeze()

pl_model = pl_model.cuda()

In [None]:
def process(model, dataLoader):
    x_processed_list = []
    y_target_list = []
    
    for x, y_target in dataLoader:
        x = x.cuda()
        x_processed = model(x)

        x_processed_list.append(x_processed.cpu())
        y_target_list.append(y_target)

    return (torch.vstack(x_processed_list), torch.vstack(y_target_list))

In [None]:
# Prepare Datasets

summer_dataset = get_dataset(config['dataset_nordland_summer_netvlad_repr_file_path'])
print(f"Summer dataset size: {len(summer_dataset)}")
summer_dataLoader = DataLoader(summer_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

winter_dataset = get_dataset(config['dataset_nordland_winter_netvlad_repr_file_path'])
print(f"Winter dataset size: {len(winter_dataset)}")
winter_dataLoader = DataLoader(winter_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

spring_dataset = get_dataset(config['dataset_nordland_spring_netvlad_repr_file_path'])
print(f"Spring dataset size: {len(spring_dataset)}")
spring_dataLoader = DataLoader(spring_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

fall_dataset = get_dataset(config['dataset_nordland_fall_netvlad_repr_file_path'])
print(f"Fall dataset size: {len(fall_dataset)}")
fall_dataLoader = DataLoader(fall_dataset, num_workers=int(config['dataloader_threads']), batch_size=int(config['train_batchsize']), shuffle=False)

In [None]:
encoder = model.get_submodule('hl')

In [None]:
nordland_summer_repr = process(encoder, summer_dataLoader)
save_tensor(nordland_summer_repr, config['dataset_nordland_summer_hidden_repr_file_path'])

nordland_winter_repr = process(encoder, winter_dataLoader)
save_tensor(nordland_winter_repr, config['dataset_nordland_winter_hidden_repr_file_path'])

nordland_spring_repr = process(encoder, spring_dataLoader)
save_tensor(nordland_spring_repr, config['dataset_nordland_spring_hidden_repr_file_path'])

nordland_fall_repr = process(encoder, fall_dataLoader)
save_tensor(nordland_fall_repr, config['dataset_nordland_fall_hidden_repr_file_path'])

In [None]:
wandb.finish()