In [1]:
import time
import torch
import pickle
import os
import pytorch_lightning as pl
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

from pathlib import Path
from utils import tools, callbacks, metrics, supported_preprocessing_transforms
from modules.cae_base_module import CAEBaseModule
from modules.aae_base_module import AAEBaseModule
from modules.vae_base_module import VAEBaseModule
from datasets import supported_datamodules
from models import supported_models
from functools import reduce

# Landscape Detection with Generic Reconstructions

In this experiment, the ability for various reconstruction-based algorithms to rank landscapes containing novel content is investigated.

In [3]:
# Import configurations and paths to logged models
root = Path.cwd() / '..'
log_path = root / 'logs' / 'LunarAnalogueDataModule'
paths_to_archived_models = list(Path(log_path).glob('**/archive_v*'))

print('Found archived models:\n------')
print('\n'.join([f'{p.parent.name}/{p.name}' for p in paths_to_archived_models]))

Found archived models:
------
BaselineVAE/archive_v2_2021-05-06
BaselineVAE/archive_v1_2021-05-06
BaselineCAE/archive_v3_2021-05-06
BaselineCAE/archive_v2_2021-05-06
BaselineCAE/archive_v1_2021-04-12


In [5]:
# Load the checkpoints for all the training modules and save them in a dictionary
module_catalog = {}

for pth in paths_to_archived_models:
    config = tools.load_config(pth / 'configuration.yaml', silent=True)
    model_type = pth.parent.name
    model_name = pth.name
    ckpt_path = next(iter((pth / 'checkpoints').glob('val_*')))
    
    # Unsupervising region proposal is called implicity in this line, see utils/preprocessing.py
    preprocessing_transforms = supported_preprocessing_transforms[config['data-parameters']['preprocessing']]
    
    datamodule = supported_datamodules[config['experiment-parameters']['datamodule']](
        data_transforms=preprocessing_transforms,
        **config['data-parameters'])
    datamodule.setup('test')

    # Handle the various model instantiations
    if 'AAE' in model_type:
        model = supported_models[model_type](
            in_nodes=reduce(lambda x, y: x*y, datamodule.data_shape),
            latent_nodes=config['module-parameters']['latent_nodes'])
        module = AAEBaseModule(model, **config['module-parameters'])
    elif 'VAE' in model_type:
        model = supported_models[model_type](
            in_shape=datamodule.data_shape,
            latent_nodes=config['module-parameters']['latent_nodes'])
        module = VAEBaseModule(model, **config['module-parameters'])
    elif 'CAE' in model_type:
        model = supported_models[model_type](in_shape=datamodule.data_shape)
        module = CAEBaseModule(model, **config['module-parameters'])
    else:
        raise ValueError(f'Model substring not found, got {model_type}')
        
    # Load the state_dict into the module architecture
    checkpoint = torch.load(ckpt_path)
    module.load_state_dict(checkpoint['state_dict'])
    
    if model_type not in module_catalog:
        module_catalog[model_type] = {}
    module_catalog[model_type][model_name] = module
    print(f'Loaded state dict for: {model_type}/{model_name}')

Loaded state dict for: BaselineVAE/archive_v2_2021-05-06
Loaded state dict for: BaselineVAE/archive_v1_2021-05-06
Loaded state dict for: BaselineCAE/archive_v3_2021-05-06
Loaded state dict for: BaselineCAE/archive_v2_2021-05-06
Loaded state dict for: BaselineCAE/archive_v1_2021-04-12
