# SSIDVAE - Example of usage

- In the following lines, SSIDVAE is trained on Cars3D.
- The model is trained with the set of parameters specified in the config dictionary.
- You might rely on libraries like ray tune to tune the parameters (https://docs.ray.io/en/latest/tune/index.html).
- After training, the model is stored in the stored_models directory, logs are eventually stored in the logs directory.

In [None]:
import os
import torch
import random
import time
import json
import numpy as np

In [None]:
from disentanglement.data_models.cars3d import Cars3D
from disentanglement.models.ssidvae import SSIDVAE, ConditionalPrior, ssidvae_train
from disentanglement.models.utils import weights_init

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
config = {
    "model": "ssidvae", 
    "stored_model_path": "../stored_models", # path of the dir where trained models are stored
    "log_path": "../logs", # path of the dir were logs are stored
    "dataset": "cars3d", # or any other dataset
    "data_path": "../data/cars3d", # path taken as input by the dataset class (it can be a dir or a filename)
    "num_channels": 3, # number of channels of the dataset above
    "batch_size": 64, 
    "labeled_percentage": .01, # percentage of labeled instances (.01 = 1%)
    "seed": 17,
    "u_dim": 3, # to be changed according to the selected dataset 
    "u_idx": [0, 1, 2], # to be changed according to the selected dataset
    "z_dim": 3, # to be changed according to the selected dataset
    "hidden_dim": 256, 
    "c_hidden_dim": 1000, 
    "training_steps": 10, 
    "beta": 1, 
    "gamma": 1, 
    "m_l_rate": 1e-4, 
    "m_eps": 1e-8, 
    "m_beta_1": .9, 
    "m_beta_2": .999, 
    "c_l_rate": 1e-4, 
    "c_eps": 1e-8,
    "c_beta_1": .5, 
    "c_beta_2": .9, 
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    "print_every": 2
}

In [None]:
def write_log(output_file_path, training_time, m_loss_list, l_loss_list, u_loss_list, c_loss_list, round_digits=2):  
    res = {
        'training_time' : training_time, 
        'loss' : [round(l, 2) for l in m_loss_list], 
        'l_loss' : [round(l, 2) for l in l_loss_list], 
        'u_loss' : [round(l, 2) for l in u_loss_list], 
        'c_loss' : [round(l, 2) for l in c_loss_list]
    }
    with open(output_file_path, 'w') as fp:
        json.dump(res, fp)

In [None]:
def run(dataset, config, log=True):
    torch.manual_seed(config["seed"])
    np.random.seed(config["seed"])
    random.seed(config["seed"])
    
    model = SSIDVAE(
        num_channels=config["num_channels"], 
        x_dim=64*64,
        hidden_dim=config["hidden_dim"], 
        z_dim=config["z_dim"],
        u_dim=config["u_dim"]
    ).apply(weights_init)

    m_optimizer = torch.optim.Adam(
        list(model.parameters()), 
        lr=config["m_l_rate"],
        betas=(config["m_beta_1"], config["m_beta_2"]), 
        eps=config["m_eps"]
    )

    conditional_prior = ConditionalPrior(
        u_dim=config["u_dim"], 
        hidden_dim=config["c_hidden_dim"], 
        z_dim=config["z_dim"]
    ).apply(weights_init)

    c_optimizer = torch.optim.Adam(
        list(conditional_prior.parameters()), 
        lr=config["c_l_rate"],
        betas=(config["c_beta_1"], config["c_beta_2"]), 
        eps=config["c_eps"]
    )

    # Load dataset
    data = dataset(path=config["data_path"])
    dataloader = data.get_dataloader(batch_size=config["batch_size"])
    
    # Train
    start_time = time.time()
    train_loss_list, l_loss_list, u_loss_list, c_loss_list = ssidvae_train(
        model, 
        m_optimizer,
        conditional_prior, 
        c_optimizer,
        dataloader,  
        config["u_idx"], 
        config["device"], 
        beta=config["beta"],
        gamma=config["gamma"],
        labeled_percentage=config["labeled_percentage"],
        training_steps=config["training_steps"],
        print_every=config["print_every"]
    )
    end_time = time.time()
    
    # Define some filenames
    base_filename = base_filename = "{}_{}_seed_{}_beta_{}_gamma_{}_uidx_{}_labeled_{}".format(
        config["model"], 
        config["dataset"], 
        str(config["seed"]).zfill(2), 
        str(config["beta"]).zfill(2), 
        str(config["gamma"]).zfill(2), 
        str(config["u_idx"]),
        str(config["labeled_percentage"])
    )
    log_filename = base_filename + "_losses.json"
    model_filename = base_filename + ".pth"
    conditional_prior_filename = base_filename + "_c.pth"
    
    # Save the model
    torch.save(model.state_dict(), os.path.join(config["stored_model_path"], model_filename))
    torch.save(conditional_prior.state_dict(), os.path.join(config["stored_model_path"], conditional_prior_filename))
    
    # Save the losses
    if log:
        write_log(
            os.path.join(config["log_path"], log_filename), 
            end_time-start_time, 
            train_loss_list, 
            l_loss_list, 
            u_loss_list,  
            c_loss_list, 
            round_digits=2
        )

In [None]:
run(Cars3D, config)