In [2]:
import sys
import os
repo_dir = os.path.dirname(os.getcwd())
sys.path.append(repo_dir)

In [3]:
from utils.bins_samplers import GaussianQMCSampler
from models.cm import ContinuousMixture
from torch.utils.data import DataLoader
from utils.datasets import load_debd
import numpy as np
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
gpus = None if device == 'cpu' else 1
print(device)

cuda


### Specify the datasets to evaluate

In [24]:
DEBD_DATASETS = [
    # 'nltcs',
    # 'msnbc',
    # 'kdd',
    # 'plants',
    # 'baudio',
    # 'jester',
    # 'bnetflix',
    # 'accidents',
    'tretail',
    # 'pumsb_star',
    'dna',
    # 'kosarek',
    # 'msweb',
    # 'book',
    # 'tmovie',
    # 'cwebkb',
    # 'cr52',
    # 'c20ng',
    # 'bbc',
    # 'ad',
]
print(DEBD_DATASETS)

['tretail', 'dna']


### Number of integration points (bins) to evaluate

In [25]:
n_bins_list = [2**7, 2**8, 2**9, 2**10, 2**11, 2**12, 2**13]
# n_bins_list = [2**13]
print(n_bins_list)

[128, 256, 512, 1024, 2048, 4096, 8192]


### Set clt to False (True) for CM of factorisations (of CLTs)

In [3]:
clt = False
log_dir = repo_dir + ('/logs/debd/cm_clt/' if clt else '/logs/debd/cm_fact/')
print(log_dir)

/scratch/s3313093/cm-tpm-main/logs/debd/cm_fact/


## Evaluate

In [4]:
def evaluate_lls_dict(lls_dict):
    for n_bins in lls_dict.keys():
        avg_lls_per_run = [np.mean(ll) for ll in lls_dict[n_bins]]
        avg_ll = np.mean(avg_lls_per_run)
        std_ll = np.std(avg_lls_per_run)
        print('Evaluating using ' + str(n_bins) + ' bins..')
        print('AVG LL: %f ' % avg_ll + ' STD LL: %f ' % std_ll)
        print('Latex string: %.2f$\\pm$%.2f' % (avg_ll, std_ll))

In [28]:
# if you run OOM you can tweak n_chunks and batch_size
only_test = True
n_chunks = None
batch_size = 32

for dataset_name in DEBD_DATASETS:
    
    _, valid, test = load_debd(dataset_name)
    valid_loader = DataLoader(valid, batch_size=batch_size)
    test_loader = DataLoader(test, batch_size=batch_size)
    print('Evaluating ' + dataset_name + '..')

    if not only_test:
        bmv_valid_lls_dict = {n_bins: [] for n_bins in n_bins_list}
    bmv_test_lls_dict = {n_bins: [] for n_bins in n_bins_list}
        
    exp_runs = 0
    folder_tree = list(os.walk(log_dir + dataset_name))
    for folder in folder_tree:
        
        if 'checkpoints' in folder[0]:
            exp_runs += 1
            for ckpt in folder[2]:
                model = ContinuousMixture.load_from_checkpoint(folder[0] + '/' + ckpt).to(device)
                model.n_chunks = n_chunks
                model.missing = False
                for n_bins in n_bins_list:
                    test_sampler = GaussianQMCSampler(latent_dim=4, n_bins=n_bins)
                    z, log_w = test_sampler(seed=42)
                    if 'best_model_valid' in ckpt:
                        if not only_test:
                            bmv_valid_lls_dict[n_bins].append(
                                model.eval_loader(valid_loader, z, log_w, device=device).cpu().numpy())
                        bmv_test_lls_dict[n_bins].append(
                            model.eval_loader(test_loader, z, log_w, device=device).cpu().numpy())

    if not only_test:
        print('\n --- BMV on VALID ---')
        evaluate_lls_dict(bmv_valid_lls_dict)
    print('\n --- BMV on TEST ---')
    evaluate_lls_dict(bmv_test_lls_dict)
    
    print('\n' + str(exp_runs) + ' runs found and evaluated for ' + dataset_name + '\n\n')
    print('---------------------------------------------------------------------------\n')

Evaluating tretail..

 --- BMV on TEST ---
Evaluating using 128 bins..
AVG LL: -10.958385  STD LL: 0.029157 
Latex string: -10.96$\pm$0.03
Evaluating using 256 bins..
AVG LL: -10.899335  STD LL: 0.006448 
Latex string: -10.90$\pm$0.01
Evaluating using 512 bins..
AVG LL: -10.861258  STD LL: 0.006889 
Latex string: -10.86$\pm$0.01
Evaluating using 1024 bins..
AVG LL: -10.850134  STD LL: 0.004201 
Latex string: -10.85$\pm$0.00
Evaluating using 2048 bins..
AVG LL: -10.847236  STD LL: 0.004806 
Latex string: -10.85$\pm$0.00
Evaluating using 4096 bins..
AVG LL: -10.846115  STD LL: 0.005021 
Latex string: -10.85$\pm$0.01
Evaluating using 8192 bins..
AVG LL: -10.845361  STD LL: 0.005040 
Latex string: -10.85$\pm$0.01

5 runs found and evaluated for tretail


---------------------------------------------------------------------------

Evaluating dna..

 --- BMV on TEST ---
Evaluating using 128 bins..
AVG LL: -98.789055  STD LL: 0.520151 
Latex string: -98.79$\pm$0.52
Evaluating using 256 bins..