In [None]:
import argparse
import copy
import random
from argparse import Namespace
from datetime import datetime
from os.path import expanduser
import pickle
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import wandb
import numpy as np
import yaml
import os
import importlib

from maskedvae.model.models import ModelGLVM
from maskedvae.model.networks import GLVM_VAE
from maskedvae.utils.utils import ensure_directory, save_pickle, save_yaml
from maskedvae.model.masks import MultipleDimGauss


# Set up configuration and directories

In [None]:
config = {
    "seed": 42, # random seed
    "epochs": 100, # number of epochs
    "beta": 1.0, # beta value for beta-vae
    "all_obs": 0, # 1 for all observed data
    "visualise": 0, # save visualisations during training
    "shorttest": 0, # load shorter dataset for testing run setup
    "all": 1, # run all methods in the list 
    "method": 0, # specify which method if all=0
    "one_impute": 0, # impute 1 for all masked values
    "val_impute": 1, # impute a specific value 
    "mean_impute": 0, # impute a respective mean 
    "random_impute": 0, # impute a random value mean 
    "list_len": 1, # only one data condition C, noise
    "task_name": "glvm", # task name
    "exp": "glvm", # experiment name wandb
    "offline": 1, #1 sets wandb offline
    "cross_loss": 0, # alternative training loss off
    "full_loss": 0, # alternative training loss on all data off
    "dropout": 0.0, # dropout to compare to simple dropout off
}
if config["offline"]:
    os.environ["WANDB_MODE"] = "offline"


In [None]:

data_directory = "../data/"
run_directory = "../runs/"
ensure_directory(run_directory)

with open('../configs/glvm/fit_config.yml', "r") as f:
    data_conf = yaml.load(f, Loader=yaml.Loader)

# update config with data_conf
# the second overwrites the first - comman line will beat config
args_dict = {**data_conf.__dict__, **config}
args = Namespace(**args_dict)

assert np.abs(args.fraction_full - 1 / (args.unique_masks + 1)) <= 0.05


# Set seeds to ensure reproducibility

In [None]:
# setup torch and seeds
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_num_threads(1)
torch.manual_seed(args.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
method_handle = copy.copy(args.method)

# Load the training, validation and test dataset 

In [None]:
c_id = 0 
dataset_train = torch.load(
    f"{data_directory:s}/glvm/20_dim_1_dim/data_train_{c_id:d}.pt"
)
dataset_valid = torch.load(
    f"{data_directory:s}/glvm/20_dim_1_dim/data_valid_{c_id:d}.pt"
)
dataset_test = torch.load(
    f"{data_directory:s}/glvm/20_dim_1_dim/data_test_{c_id:d}.pt"
)

args.C = dataset_train.C
args.d = dataset_train.d
args.z_prior = np.stack((np.zeros(args.z_dim), np.ones(args.z_dim)))
args.noise = dataset_train.noise

print("train ", len(dataset_train), "valid ", len(dataset_valid), "test ", len(dataset_test))

# Specify the running modes: masked vs naive

In [None]:
methods = [
    "zero_imputation_mask_concatenated_encoder_only", # masked with all_obs = 0
    "zero_imputation", # naive with all_obs = 1
]
# short names for WandB
method_short = [
    "enc-mask-",  # masekd
    "zero-",    # naive
]
meth_labels = ['masked', 'naive']

# Logging of losses

In [None]:

logs = {
    method: {
        "elbo": [],
        "kl": [],
        "rec_loss": [],
        "elbo_val": [],
        "kl_val": [],
        "rec_loss_val": [],
        "masked_rec_loss": [],
        "masked_rec_loss_val": [],
        "observed_mse": [],
        "masked_mse": [],
        "time_stamp_val": [],
    }
    for method in methods
}


# Specify the Mask generator + run paramerters

In [None]:
args.ts = datetime.now().strftime("%d%m%Y_%H%M" "%S")  # more easily readable

# Specify the masking generator
args.generator = MultipleDimGauss
# Specify which non linearity to use for the networks
nonlin_mapping = {0: nn.Identity(), 1: nn.ELU(), 2: nn.ReLU(), 3: nn.Sigmoid()}
args.nonlin_fn = nonlin_mapping.get(args.nonlin, nn.Identity())
    
# ensure either one or mean imputation if both zero standard zero imputation
args.mean_impute = not args.one_impute and args.mean_impute


# Start training for both naive and masked

In [None]:
for i, method in enumerate(methods):
    # pass the right logs to network
    args.method = method
    
    if args.method == "zero_imputation":
        args.all_obs = 1
        args.exp = "naive"
    elif args.method == "zero_imputation_mask_concatenated_encoder_only":
        args.all_obs = 0
        args.exp = "masked"     
           
    name_tag = f"xdims_{args.x_dim}_C_{args.C[0][0]:.2f}_sig_{args.noise[0][0]:.2f}"

    if args.freeze_decoder and args.loss_type != "regular":
        args.loss_type = "regular"
        print("The decoder is frozen -> switching loss type to regular...")
        name_tag = f"_frozen{name_tag}"
    print(name_tag)
    # Initialize the Weights and Biases (wandb) run 
    run = wandb.init(
        project=args.project_name,
        group=f"{args.exp}{name_tag}",
        name=f"{method_short[i]}{args.ts}",
        reinit=True,
        config=args,
        dir=run_directory
    )

    # Setup the directory for storing figures 
    figs_directory = os.path.join(wandb.run.dir, "figs")
    os.makedirs(figs_directory, exist_ok=True)  # os.makedirs can create directories and won't throw an error if the directory already exists
    args.fig_root = run_directory
    # Update wandb configuration with the figures directory path
    args.fig_dir = figs_directory
    wandb.config.update({"fig_dir": args.fig_dir}, allow_val_change=True)

    print("Masked model...")
    model = ModelGLVM(
        args=args,
        dataset_train=dataset_train,
        dataset_valid=dataset_valid,
        dataset_test=dataset_test,
        logs=logs[method],
        device=device,
        inference_model=GLVM_VAE,
        Generator=args.generator,
        nonlin=args.nonlin_fn,
        dropout=args.dropout,
    )

    print(" ---------- begin model fit ---------------")
    print(method)

    model.fit()
    print(" ---------- end model fit ---------------")
    print("  ")

    model_path = os.path.join(model.args.fig_root, str(model.ts), model.method)
    ensure_directory(model_path)

    # save logs and args dictionary
    save_pickle(model.logs, os.path.join(model_path, "logs.pkl"))
    save_pickle(model.args, os.path.join(model_path, "args.pkl"))
    save_yaml(model.args, os.path.join(model_path, "args.yml"))

    model.train_loader = 0
    model.test_loader = 0
    model.validation_loader = 0
    
    # save the model
    if not model.args.watchmodel:
        model_filepath = os.path.join(model_path, "model_end_of_training.pt")
        torch.save(model, model_filepath)
        torch.save(model, os.path.join(wandb.run.dir, "model_end_of_training.pt"))

    run.finish()

# ------------------------ end loop over methods ----------------------------------------

joint_directory_path = os.path.join(model.args.fig_root, str(model.ts))
ensure_directory(joint_directory_path)
save_pickle(model.logs, os.path.join(joint_directory_path, "logs.pkl"))

# Visualise the loss curves

In [None]:
from maskedvae.plotting.plotting import plot_losses
from maskedvae.plotting.plotting_utils import cm2inch

fig, axes = plt.subplots(1, 4, figsize=cm2inch((30, 6)))
plot_losses(logs, methods, ax=axes[0], log='elbo', ylabel='- elbo', xlabel='iteration')
plot_losses(logs, methods, ax=axes[1], log='elbo_val', ylabel='- elbo', xlabel='epoch')
plot_losses(logs, methods, ax=axes[2], log='kl', ylabel='KL', xlabel='iteration')
plot_losses(logs, methods, ax=axes[3], log='rec_loss', ylabel='reconstruction loss', xlabel='iteration')
plt.tight_layout()