In [1]:
import os, sys
os.chdir(os.path.abspath('..'))

In [2]:
import json
import pandas as pd
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()

In [3]:
%load_ext autoreload
%autoreload 2

from src.vae_models import CVAE
import src.utils as utils
import src.preprocess_lib as preprocess_lib
import src.testing_lib as testing_lib

In [4]:
config_dir = 'runs/sweep_runs_corrected'
config_file = 'config.json'

In [5]:
device = "cuda:0"
NUM_REC_SAMPLES = 100
NUM_IMP_SAMPLES = {"prior": 20, "likelihood": 5}
IMPUTATION_STYLE = "mean" ## "samples" or "mean"

In [6]:
folders = os.listdir(config_dir)

In [7]:
for i, folder in enumerate(folders):
    # if os.path.exists(os.path.join(config_dir, folder, "test_results.json")):
    #     print(f"Test results already exist for {folder}. Skipping...")
    #     continue

    test_results = {}
    # Load config file
    print(f"Loading config file for {folder}...")
    with open(os.path.join(config_dir, folder, config_file), 'r') as f: config = json.load(f)
    
    trainset, valset, conditioner, user_ids, condition_set, X_test, num_missing_days, nonzero_mean, nonzero_std = preprocess_lib.prepare_data(config["data"])
    # num_missing_days = [x.shape[0] for x in X_test["list"]]
    num_users = len(num_missing_days)
    
    # Load model
    model = CVAE(input_dim=valset.inputs.shape[1], conditioner=conditioner, **config["model"])
    model.load(os.path.join(config_dir, folder))
    
    log_space = config["data"]["scaling"]["log_space"]
    zero_id = config["data"]["scaling"]["zero_id"]
    shift = config["data"]["scaling"]["shift"]
    
    print(f"Preparing test data for {folder}...")
    x_test = utils.zero_preserved_log_normalize(X_test*1.0, nonzero_mean, nonzero_std, log_output=log_space, zero_id=zero_id, shift=shift)
    x_test = torch.tensor(x_test).float()
    conditions_test =  torch.tensor(conditioner.transform(condition_set["test"].copy())).float()
    
    print("Reconstructing...")
    x_rec, z_rec = testing_lib.mass_reconstruction(model, x_test, conditions_test, num_mc_samples=NUM_REC_SAMPLES, batch_size=20000, device=device)

    print("Calculating probabilistic metrics...")
    test_results["prob_metrics"] = testing_lib.get_probabilistic_metrics(model, x_test, x_rec, z_rec, aggregate=True, device=device)

    print("Imputing...")
    x_imp = testing_lib.mass_imputation(model, conditions_test, num_mc_samples_prior=NUM_IMP_SAMPLES["prior"], num_mc_samples_likelihood=NUM_IMP_SAMPLES["likelihood"], batch_size=20000, device=device)

    x_imp_denormalized = testing_lib.mass_denormalization(model=model, x_imp=x_imp, nonzero_mean=nonzero_mean, nonzero_std=nonzero_std, zero_id=zero_id, shift=shift, log_space=log_space, deviation=2, device=device)

    print("Calculating sample metrics...")
    test_results["sample_metrics"] = {}
    for imputation_style in ["samples", "mean"]:
        test_results["sample_metrics"][imputation_style] = {}
        test_results["sample_metrics"][imputation_style] = testing_lib.get_sample_metrics(x_test, x_imp_denormalized, imputation_style=imputation_style, aggregate=True)

    print("Saving results...")
    with open(os.path.join(config_dir, folder, "test_results.json"), 'w') as f:
        json.dump(test_results, f, indent=4)

    print(f"Finished testing {folder} ({i+1}/{len(folders)})")

Loading config file for sweep_Aug12_04-53-27...
Dataset: goi4_dp_full_Gipuzkoa
Loaded 2522880 consumption profiles from 365 dates and 6912 users.
Loaded metadata for 1 provinces
Uniqe provinces are: ['Gipuzkoa']
Removing 15 users with all-zero consumption profiles
Removing 67 users with any-negative consumption profiles
Number of (subsampled/filtered) users....6830
Number of (subsampled) days...............365
Number of (aggregated) features............24
Mean of enrolments: 10.06
Number of Training Points: 2173908
Number of Validation Points: 249295
Number of Testing Points: 69747
USING SIGMA_LIM!
USING SIGMA_LIM!
Preparing test data for sweep_Aug12_04-53-27...
Reconstructing...
Calculating probabilistic metrics...
Imputing...
Calculating sample metrics...
Saving results...
Finished testing sweep_Aug12_04-53-27 (1/161)
Test results already exist for sweep_Aug09_19-03-46. Skipping...
Test results already exist for sweep_Aug07_01-49-53. Skipping...
Test results already exist for sweep_A

KeyboardInterrupt: 