In [1]:
from importlib import reload
import sys
sys.path.insert(0, '../')

# sys.path.remove("/home/users/yixiuz/.local/lib/python3.9/site-packages")
sys.path.append("/home/groups/swl1/yixiuz/torch_fid/lib/python3.9/site-packages")
sys.path.append("/home/groups/swl1/yixiuz/torch_fid/bin")

In [2]:
import matplotlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import ml_collections

import lib.utils.bookkeeping as bookkeeping
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import lib.utils.utils as utils
import lib.models.models as models
import lib.models.model_utils as model_utils
import lib.datasets.datasets as datasets
import lib.datasets.dataset_utils as dataset_utils
import lib.sampling.sampling as sampling
import lib.sampling.sampling_utils as sampling_utils

import config.eval.piano_hollow as piano

%matplotlib inline

eval_cfg = piano.get_config()
train_cfg = bookkeeping.load_ml_collections(Path(eval_cfg.train_config_path))

for item in eval_cfg.train_config_overrides:
    utils.set_in_nested_dict(train_cfg, item[0], item[1])

S = train_cfg.data.S
# device = torch.device(eval_cfg.device)
device = torch.device("cuda")

model = model_utils.create_model(train_cfg, device)

loaded_state = torch.load(Path(eval_cfg.checkpoint_path),
    map_location=device)

modified_model_state = utils.remove_module_from_keys(loaded_state['model'])
model.load_state_dict(modified_model_state)

model.eval()

dataset = dataset_utils.get_dataset(eval_cfg, device)
data = dataset.data
test_dataset = np.load(eval_cfg.sampler.test_dataset)
condition_dim = eval_cfg.sampler.condition_dim
descramble_key = np.loadtxt(eval_cfg.pianoroll_dataset_path + '/descramble_key.txt')
# The mask stays the same
descramble_key = np.concatenate([descramble_key, np.array([descramble_key.shape[0]])], axis=0)

def descramble(samples):
    return descramble_key[samples.flatten()].reshape(*samples.shape)

descrambled_test_dataset = descramble(test_dataset)

2024-05-20 21:43:15.691735: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-20 21:43:15.804940: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /share/software/user/open/cudnn/8.9.0.131/lib:/usr/lib64/nvidia:/share/software/user/open/cuda/12.2.0/targets/x86_64-linux/lib:/share/software/user/open/cuda/12.2.0/lib64:/share/software/user/open/cuda/12.2.0/nvvm/lib64:/share/software/user/open/cuda/12.2.0/extras/Debugger/lib64:/share/software/user/open/cuda/12.2.0/extras/CUPTI/lib64:/share/software/user/open/python/3.9.0/lib:/share/software/u

In [3]:
from eval import outliers, get_dist, hellinger, eval_mse_stats, save_results
from tqdm import tqdm

In [4]:
def run_experiments(num_repeats, sample_size, batch_size, 
                    results_file,
                   method, g_steps, tau_steps, 
                    c_steps, c_stepsize, corrector):
    
    tqdm._instances.clear()

    # Specific to each method
    eval_cfg.sampler.num_steps = tau_steps
    eval_cfg.sampler.updates_per_eval = g_steps
    
    # Generic corrector fields
    eval_cfg.sampler.num_corrector_steps = c_steps
    eval_cfg.sampler.corrector_entry_time = 0.9
    eval_cfg.sampler.corrector_step_size_multiplier = c_stepsize
    eval_cfg.sampler.balancing_function = corrector

    if method == "gillespies":
        eval_cfg.sampler.name = "ConditionalPCMultiGillespies"
    elif method == "tauleaping":
        eval_cfg.sampler.name = "ConditionalPCTauLeapingAbsorbingInformed"
    else:
        assert(False)

    sampler = sampling_utils.get_sampler(eval_cfg)

    results = []

    test_size = sample_size
    
    for _ in range(num_repeats):
        
        h_dists = []
        outlier_rates = []
        for start in range(0, test_size, batch_size):
            end = min(start + batch_size, test_size)
            size = end - start

            conditioner = torch.from_numpy(test_dataset[start:end, :condition_dim]).to(device)
            samples, out = sampler.sample(model, size, 0, conditioner)
            # !Important to descramble!
            samples = descramble(samples)

            for i in range(size):
                h = hellinger(descrambled_test_dataset[start+i, :], samples[i, :], S)
                r = outliers(descrambled_test_dataset[start+i, :], samples[i, :], S)
                h_dists.append(h)
                outlier_rates.append(r)
        # !
        D = eval_cfg.data.shape[0] - eval_cfg.sampler.condition_dim
        if method == "gillespies":
            nfe = D / eval_cfg.sampler.updates_per_eval
        elif method == "tauleaping":
            nfe = eval_cfg.sampler.num_steps
        nfe += nfe * eval_cfg.sampler.corrector_entry_time * eval_cfg.sampler.num_corrector_steps

        new_result = {
                'method': method,
                'g_steps': 0 if method != "gillespies" else eval_cfg.sampler.updates_per_eval,
                'tau_steps': 0 if method != "tauleaping" else eval_cfg.sampler.num_steps,
                'use_corrector': eval_cfg.sampler.corrector_entry_time > 0 
                             and eval_cfg.sampler.num_corrector_steps > 0,
                'corrector': eval_cfg.sampler.balancing_function,
                'c_stepsize': eval_cfg.sampler.corrector_step_size_multiplier,
                'c_steps': eval_cfg.sampler.num_corrector_steps,
                'nfe': nfe,
                'h_dist': np.mean(h_dists),
                'outlier_rate': np.mean(outlier_rates),
            }
        print(new_result)
        results.append(new_result)

    save_results(results, results_file)

In [5]:
results_file = 'piano_results_test.csv'

num_repeats = 2
sample_size = 10
batch_size = 10

methods = ["tauleaping", 
           "gillespies"]

updates_per_evals = [3, 2, 1]
num_sample_steps = [50, 100, 200]

correctors = ["mpf", "barker", "birthdeath"]
corrector_stepsizes = [1.0, 1.0, 0.1]
corrector_steps = 2

for method in methods:
    for i in range(len(correctors)):
        corrector = correctors[i]
        c_stepsize = corrector_stepsizes[i]
        g_steps = updates_per_evals[0]
        tau_steps = num_sample_steps[0]
        c_steps = corrector_steps
        run_experiments(num_repeats=num_repeats, 
                        sample_size=sample_size, batch_size=batch_size,
                        results_file=results_file,
                        method=method, 
                        g_steps=g_steps, 
                        tau_steps=tau_steps, 
                        c_steps=c_steps, 
                        c_stepsize=c_stepsize, 
                        corrector=corrector)

49it [00:01, 33.77it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.2668313174703481, 'outlier_rate': 0.083203125}


49it [00:01, 37.76it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.2600486274351373, 'outlier_rate': 0.082421875}
Experiment results saved to  piano_results_test.csv


49it [00:01, 37.70it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.2857392213506149, 'outlier_rate': 0.09921875}


49it [00:01, 37.61it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.29041439957098325, 'outlier_rate': 0.095703125}
Experiment results saved to  piano_results_test.csv


49it [00:01, 37.72it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'birthdeath', 'c_stepsize': 0.1, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.3186536161480142, 'outlier_rate': 0.1296875}


49it [00:01, 37.74it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'birthdeath', 'c_stepsize': 0.1, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.28446260233681525, 'outlier_rate': 0.10859375}
Experiment results saved to  piano_results_test.csv


225it [00:02, 106.60it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.42027543564354497, 'outlier_rate': 0.0859375}


225it [00:02, 108.07it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.3994891652684006, 'outlier_rate': 0.08984375}
Experiment results saved to  piano_results_test.csv


225it [00:02, 108.01it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.29192670713085234, 'outlier_rate': 0.091796875}


225it [00:02, 107.84it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.289521658416117, 'outlier_rate': 0.087890625}
Experiment results saved to  piano_results_test.csv


225it [00:02, 107.72it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'birthdeath', 'c_stepsize': 0.1, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.2763986583731081, 'outlier_rate': 0.093359375}


225it [00:02, 108.37it/s]                         

{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'birthdeath', 'c_stepsize': 0.1, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.2865092655555353, 'outlier_rate': 0.087890625}
Experiment results saved to  piano_results_test.csv





In [None]:
results_file = 'piano_results.csv'

num_repeats = 3
sample_size = test_dataset.shape[0]
batch_size = 200

methods = ["tauleaping", 
           "gillespies"]

updates_per_evals = [3, 2, 1]
num_sample_steps = [50, 100, 200]

correctors = ["mpf", "barker", "birthdeath"]
corrector_stepsizes = [1.0, 1.0, 0.1]
corrector_steps = 1

# 1 corrector step
c_steps = 1
for method in methods:
    for i in range(len(correctors)):
        
        corrector = correctors[i]
        c_stepsize = corrector_stepsizes[i]
        
        for j in range(len(updates_per_evals)):
        
            g_steps = updates_per_evals[j]
            tau_steps = num_sample_steps[j]
            
            run_experiments(num_repeats=num_repeats, 
                            sample_size=sample_size, batch_size=batch_size,
                            results_file=results_file,
                            method=method, 
                            g_steps=g_steps, 
                            tau_steps=tau_steps, 
                            c_steps=c_steps, 
                            c_stepsize=c_stepsize, 
                            corrector=corrector)

49it [00:06,  8.10it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.11it/s]
49it [00:05,  9.24it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.3794871595882483, 'outlier_rate': 0.12499197070914697}


49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:05,  9.24it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.3780450319462578, 'outlier_rate': 0.12188062050359712}


49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.11it/s]
49it [00:05,  9.25it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.3792639253193196, 'outlier_rate': 0.12286420863309352}
Experiment results saved to  piano_results.csv


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.05it/s]
99it [00:10,  9.19it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 100, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 190.0, 'h_dist': 0.38032061559315516, 'outlier_rate': 0.12338209789311408}


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:10,  9.19it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 100, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 190.0, 'h_dist': 0.3800044233128485, 'outlier_rate': 0.12243464157245632}


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
53it [00:06,  7.61it/s]

In [None]:
# Try different rates
correctors = ["mpf", "barker"]
corrector_stepsizes = [0.1, 0.1]

for method in methods:
    for i in range(len(correctors)):
        
        corrector = correctors[i]
        c_stepsize = corrector_stepsizes[i]
        
        for j in range(len(updates_per_evals)):
        
            g_steps = updates_per_evals[j]
            tau_steps = num_sample_steps[j]
            
            run_experiments(num_repeats=num_repeats, 
                            sample_size=sample_size, 
                            results_file=results_file,
                            method=method, 
                            g_steps=g_steps, 
                            tau_steps=tau_steps, 
                            c_steps=c_steps, 
                            c_stepsize=c_stepsize, 
                            corrector=corrector)