In [1]:
import matplotlib.pyplot as plt
import sys
import os
import torch
import pickle
sys.path.append("/home/alec/latent-space-localization")
import auto_localization.dataset_management.data_manager_factory as data_manager_factory
import auto_localization.models.model_factory as model_factory
from auto_localization.localization.noise_model_selector import NoiseModelSelector
import numpy as np
%matplotlib inline

In [2]:
def load_localizers(wandb_path):
    """
        Gives the localizers and localizer_metrics for a given wandb_path
    """
    # get the directory
    run_directory = os.path.join(os.environ["LATENT_PATH"], "auto_localization", "logs", wandb_path)
    # load a single set of localizers
    localizers = None
    with open(run_directory+"/localizers.pkl", "rb") as f:
        localizers = pickle.load(f)
    # metrics path
    metrics_path = os.path.join(run_directory, "localizer_metrics.pkl")
    with open(metrics_path, "rb") as f:
        metrics = pickle.load(f)
    
    return localizers, metrics

def load_experiment_config(wandb_path):
    # load up the experiment parameters from the given run_name
    run_directory = os.path.join(os.environ["LATENT_PATH"], "auto_localization", "logs", wandb_path)
    # load the expeirment config
    params_path = run_directory + "/params.pkl"
    with open(params_path, "rb") as f:
        experiment_config = pickle.load(f)
    
    return experiment_config

def load_model(wandb_path):
    # get the directory
    run_directory = os.path.join(os.environ["LATENT_PATH"], "auto_localization", "logs", wandb_path)
    # get the model config
    experiment_config = load_experiment_config(wandb_path)
    model_config = experiment_config["model_config"]
    # get the model path
    model_weight_path = os.path.join(run_directory, "best_model.pkl")
    # use the model factory to load the model
    model = model_factory.get_model_from_config(model_config["model_type"], model_config)
    # load the model weights
    model.load_state_dict(torch.load(model_weight_path))
    model.eval()
    return model


In [3]:
def plot_final_localization_images(localizers, model, data_manager, title="Title"):
    sample_size = 100
    image_indices = np.random.choice(len(data_manager.image_test), size=(sample_size))
    similarity_vectors = []
    for index in image_indices:
        image = data_manager.image_test[index].cuda()
        _, _, similarity_vector, _ = model.forward(image)
        similarity_vectors.append(similarity_vector)

    similarity_vectors = torch.stack(similarity_vectors).cuda()

    def get_closest_image(similarity_vector):
        distances = torch.norm(similarity_vector - similarity_vectors, dim=-1)
        closest_index = torch.argmin(distances)
        closest_image = data_manager.image_test[image_indices[closest_index]]
        return closest_image
    
    num_rows = min(len(localizers), 6)
    fig, axs = plt.subplots(num_rows, 3, figsize=(3, num_rows))
    plt.title(title)
    for i in range(num_rows):
        localizer = localizers[i]
        last_mean = torch.Tensor(localizer.posterior_means[-1]).unsqueeze(0)
        print(last_mean)
        # get the final estimate and decode it
        decoded_mean = model.decode(last_mean.to("cuda")).cpu().detach().numpy().squeeze()
        axs[i, 0].imshow(decoded_mean)
        # get the final nearest neighbor
        nearest_neighbor = get_closest_image(last_mean.cuda()).detach().cpu().numpy().squeeze()
        axs[i, 1].imshow(nearest_neighbor)
        # get the reference image
        reference_image = localizer.reference_data
        axs[i, 2].imshow(reference_image)

In [4]:

dataset_config = {
    "component_weighting": [1.0, 1.0, 1.0, 1.0, 1.0, 1.0], # weights slant and thickness
    "attribute_return": True,
    "which_digits": [1],
    "one_two_ratio": 0.0,
    "batch_size": 256,
    "indexed": True,
    "num_workers": 6,
    "single_feature_triplet": False,
    "inject_triplet_noise": 0.0,
    "dataset_name": "MorphoMNIST",
    "input_shape": (32, 32),
    "triplet_batch_size": 256
}
data_manager, localization_metadata_oracle = data_manager_factory.construct_morpho_mnist(dataset_config)

inject
0.0
Setting up data


In [5]:
# load the localizers
# for each localizer
# get the reference image
# decode the predicted ideal point
# plot the nearest neighbor of the estimate
model_path = "polar-hill-1041"
model = load_model(model_path).cuda()
wandb_path = "splendid-mountain-1050"
localizers, metrics = load_localizers(wandb_path)
plot_final_localization_images(localizers, model, data_manager)

RuntimeError: Error(s) in loading state_dict for IsolatedVAE:
	Missing key(s) in state_dict: "uncertainty_linear.weight", "uncertainty_linear.bias", "similarity_batchnorm.running_mean", "similarity_batchnorm.running_var". 