### Training VAEs

Running the Train Script: 
- Script: celeba/train.py
- Launch one job: sbatch simple_launch.sbatch   
- Launch multiple: sh loop.sh

Getting the trained VAE Checkpoints: 
- Trained VAE Checkpoints: '/checkpoints_loss_checks/MN_2/' 




### Making the Metadataset

Loading the trained VAE Checkpoints: 
- Trained VAE Checkpoints: '/checkpoints_loss_checks/MN_2/< model_name >_e29.pt' 

Generate vae samples: 
- Script: celeba/generate_vae_samples.py
- Launch: sh launch_playground.sh  # uncomment example 
- Saved: /scratch/mr7401/vaes/generations/MN_2/< model_name >/samples/ < id >.pt

Calculate and Save Log Likelihoods: 
- Script: celeba/calculate_log_likelihoods.py
- Launch: sh launch_playground.sh  # uncomment example
- Saved: /scratch/mr7401/vaes/likelihoods/2/< model_name >/log_liklihood.jsonl

Make VIT encodings: 
- Script: celeba/generate_vae_encodings.py
- Launch: sh launch_playground.sh  # uncomment example
- Saved: /scratch/mr7401/vaes/generations/MN_2/< model_name >/encodings/ < id >.pt 
- Combined version: /scratch/mr7401/vaes/generations/MN_2/< model_name >/encodings/all_encodings.pt

Make the metadataset: 
- I did this a notebook. I copied the code below if you want to see it. Please change the paths before running so as to not overwrite the files. You can also run with save=False to not write any files. 
- full set of sample information: /scratch/mr7401/meta_comp_data/vaes/metadatasets/MN_2/metadataset.csv

- Saved metadataset training data: these are the paths for n_per_set = 5 and n_samples = 10,000 which is what we used for training. 
    - train: /scratch/mr7401/meta_comp_data/vaes/metadatasets/MN_2/5_100000/train_seed_0.pt, 
    - train_metadata =/scratch/mr7401/meta_comp_data/vaes/metadatasets/MN_2/5_100000/train_seed_0.pt

- Save metadataset test data: the training script looks for everything in the following folder: 
    - /scratch/mr7401/meta_comp_data/vaes/metadatasets/MN_2/5_2000/ 
    - e.g. files would be like 'test_one_ood_models_128_748_seed_2.pt'. This made it easier to run evals for particular model pairs. 
    - test_one_ood_models refers to 1 ID model vs 1 OOD model. 
    - test_both_ood_models refers to 2 OOD models. 

### Training the Meta-Model 

- Script: celeba/train_meta_model.py
- Launch: sh loop.sh, which loops over learning rates, calling simple_launch.sbatch while passing in the learning rate as an exported environment variable. 
- Trained Meta-Models: '/checkpoints_ls/MN_2/5_10000/< wandb_run > -mse_dir/ epoch_checkpoints / epoch_5.pt' 
- There is also supposed to be a predictions file and losses file saved in that directory, but they crashed when I last ran it. I assume out of memory. 


In [None]:
########### Make Meta-Datasets ##############
import os 
import pandas as pd 
import random 
import torch 
import pickle

def mean(lst): 
    return sum(lst) / len(lst)
    
def calculate_kl_diff(m1_m1_ll, m1_m2_ll, m2_m1_ll, m2_m2_ll):
    kldiff = -(mean(m1_m1_ll) + mean(m1_m2_ll) + mean(m2_m1_ll) + mean(m2_m2_ll))
    return kldiff

def get_encodings(encodings_mappings, source_model, ids = []):
    encoding_map = encodings_mappings[source_model]
    indexes = [encoding_map["ids"].index(item) for item in ids] 
    encodings = [encoding_map["embeddings"][i] for i in indexes] #[tensor]
    r = torch.stack(encodings) # tensor[batch_size, dim]
    return r

def generate_meta_dataset(samples, n_samples, n_per_set, save = True, seed = 42, save_dir = "/", models_to_include = [], set_name = "other", verbose = False):
    # Initialize an empty list to store the train dataset
    m1_ids = []
    m2_ids = []
    m1s = []
    m2s = []

    x1_data = []
    x2_data = []
    y_data = []

    if models_to_include:
        print(f"Filtering for only models in list: {models_to_include}")
        samples = samples[samples["gen_source_model"].isin(models_to_include)]
    source_models = list(samples['gen_source_model'].unique())
    encodings_mappings = {} # model name: {embeddings: [], ids: []}
    
    for model_name in source_models: 
        combined_encodings_path = samples[samples["gen_source_model"] == model_name]["combined_encodings_location"].iloc[0]
        encodings_mappings[model_name] = torch.load(combined_encodings_path)
    
    print(f"Starting Metadaset Generation...")

    random.seed(seed)
    # Generate n samples
    for i in range(n_samples):
        # Randomly select 2 models to compare
        if i % 100 == 0: 
            print(f"Completed: {i}", flush = True)
        
        model1, model2 = random.sample(source_models, 2)
        if verbose: 
            print(f"Selected: {model1}, {model2}")
        
        # Select 5 samples for each model
        model1_samples = samples[samples['gen_source_model'] == model1].sample(n_per_set, random_state=seed)
        model2_samples = samples[samples['gen_source_model'] == model2].sample(n_per_set, random_state=seed)
        if verbose: 
            print(f"Samples: \n{model1_samples["id"].to_list()}, \n{model2_samples["id"].to_list()}")
        
        # Get encodings 
        sample1_encodings = get_encodings(encodings_mappings = encodings_mappings, source_model = model1, ids = model1_samples["id"].tolist()) 
        sample2_encodings = get_encodings(encodings_mappings = encodings_mappings, source_model = model2, ids = model2_samples["id"].tolist()) 
        
        # Get label 
        m1_m1_ll, m1_m2_ll = model1_samples[f"{model1}_ll"].to_list(), model1_samples[f"{model2}_ll"].to_list()
        m2_m1_ll, m2_m2_ll = model2_samples[f"{model1}_ll"].to_list(), model2_samples[f"{model2}_ll"].to_list()
        kldiff = calculate_kl_diff(m1_m1_ll, m1_m2_ll, m2_m1_ll, m2_m2_ll)
    
        # add to sets for saving 
        m1_ids.append(model1_samples["id"].tolist())
        m2_ids.append(model2_samples["id"].tolist())
        
        m1s.append([model1] * len(model1_samples))
        m2s.append([model2] * len(model2_samples))
        
        x1_data.append(sample1_encodings)
        x2_data.append(sample2_encodings)
        y_data.append(kldiff)
    
    x1_stacked = torch.stack(x1_data)
    x2_stacked = torch.stack(x2_data)
    y_stacked = torch.tensor(y_data)
    sample_indexes = torch.arange(n_samples)

    # Make objects
    dataset = {"x1": x1_stacked, "x2": x2_stacked, "y": y_stacked, "sample_index": sample_indexes} # can move to/ from GPU, shuffled.
    metadata = {"sample_index": sample_indexes, "m1_ids": m1_ids, "m2_ids": m2_ids, "m1s": m1s, "m2s": m2s} # use the sample_indexes to get the metadata. 

    if save: 
        os.makedirs(f"{save_dir}/{n_per_set}_{n_samples}", exist_ok = True)
        dataset_save_path = f"{save_dir}/{n_per_set}_{n_samples}/{set_name}_seed_{seed}.pt"
        metadata_save_path = f"{save_dir}/{n_per_set}_{n_samples}/{set_name}_metadata_seed_{seed}.pkl"
        
        torch.save(dataset, dataset_save_path)
        with open(metadata_save_path, "wb") as f:
            pickle.dump(metadata, f)
        print(f"Saved to {dataset_save_path} and {metadata_save_path}")
        
    return dataset, metadata 

test_models = ['vae_ldim_4', 'vae_ldim_1384', 'vae_ldim_748',  'vae_ldim_2048']
    
train_models = [
 'vae_ldim_8',
 'vae_ldim_16',
 'vae_ldim_32',
 'vae_ldim_64',
 'vae_ldim_128',
 'vae_ldim_256',
 'vae_ldim_512',
 'vae_ldim_718',
 'vae_ldim_1024'
]


In [None]:
######## Make Train Datasets ##########

# This is all of the samples, including the encoding location, ids, and log likelihoods. 
samples = pd.read_csv(f"/scratch/mr7401/meta_comp_data/vaes/metadatasets/MN_2/metadataset.csv", index_col = 0)

for n_per_set in [5,10,20,30]:
    for seed in [0,1,2]:
        data, metadata = generate_meta_dataset(
            samples, 
            n_samples=10000,
            n_per_set = n_per_set, 
            seed = seed, 
            models_to_include = train_models, 
            save = False , 
            set_name = "train", 
            save_dir = "/scratch/mr7401/meta_comp_data/vaes/metadatasets/MN_2", 
            verbose = False
        )

In [None]:
# ############ Generate test datasets ##################
from itertools import combinations

train_models = [
 'vae_ldim_8',
 'vae_ldim_16',
 'vae_ldim_32',
 'vae_ldim_64',
 'vae_ldim_128',
 'vae_ldim_256',
 'vae_ldim_512',
 'vae_ldim_718',
 'vae_ldim_1024'
]

test_models = ['vae_ldim_4', 'vae_ldim_1384', 'vae_ldim_748',  'vae_ldim_2048']

# ### OOD: test vs test 
for model1, model2 in combinations(test_models, 2):
    print(f"Pairing: {model1} vs {model2}")
    dim1 = model1.split("_")[-1]
    dim2 = model2.split("_")[-1]
    for n_per_set in [5]: #,10,20,30]:
        for seed in [2]:
            data, metadata = generate_meta_dataset(
                samples, 
                n_samples=2000,
                n_per_set = n_per_set, 
                seed = seed, 
                models_to_include = [model1, model2], 
                save = False, 
                set_name = f"test_both_ood_models_{dim1}_{dim2}", 
                save_dir = "/scratch/mr7401/meta_comp_data/vaes/metadatasets/MN_2", 
                verbose = False
            )

# ### OOD mixed: train vs test 
for model1 in train_models: 
    for model2 in test_models: 
        print(f"Pairing: {model1} vs {model2}")
        dim1 = model1.split("_")[-1]
        dim2 = model2.split("_")[-1]
        for n_per_set in [5]: #,10,20,30]:
            for seed in [2]:
                data, metadata = generate_meta_dataset(
                    samples, 
                    n_samples=2000,
                    n_per_set = n_per_set, 
                    seed = seed, 
                    models_to_include = [model1, model2], 
                    save = False, 
                    set_name = f"test_one_ood_models_{dim1}_{dim2}", 
                    save_dir = "/scratch/mr7401/meta_comp_data/vaes/metadatasets/MN_2", 
                    verbose = False
                )