In [4]:
from importlib import reload
import sys
sys.path.insert(0, '../')

# sys.path.remove("/home/users/yixiuz/.local/lib/python3.9/site-packages")
sys.path.append("/home/groups/swl1/yixiuz/torch_fid/lib/python3.9/site-packages")
sys.path.append("/home/groups/swl1/yixiuz/torch_fid/bin")

In [5]:
import matplotlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import ml_collections

import lib.utils.bookkeeping as bookkeeping
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import lib.utils.utils as utils
import lib.models.models as models
import lib.models.model_utils as model_utils
import lib.datasets.datasets as datasets
import lib.datasets.dataset_utils as dataset_utils
import lib.sampling.sampling as sampling
import lib.sampling.sampling_utils as sampling_utils

import config.eval.piano_hollow as piano

%matplotlib inline

eval_cfg = piano.get_config()
train_cfg = bookkeeping.load_ml_collections(Path(eval_cfg.train_config_path))

for item in eval_cfg.train_config_overrides:
    utils.set_in_nested_dict(train_cfg, item[0], item[1])

S = train_cfg.data.S
# device = torch.device(eval_cfg.device)
device = torch.device("cuda")

model = model_utils.create_model(train_cfg, device)

loaded_state = torch.load(Path(eval_cfg.checkpoint_path),
    map_location=device)

modified_model_state = utils.remove_module_from_keys(loaded_state['model'])
model.load_state_dict(modified_model_state)

model.eval()

dataset = dataset_utils.get_dataset(eval_cfg, device)
data = dataset.data
test_dataset = np.load(eval_cfg.sampler.test_dataset)
condition_dim = eval_cfg.sampler.condition_dim
descramble_key = np.loadtxt(eval_cfg.pianoroll_dataset_path + '/descramble_key.txt')
# The mask stays the same
descramble_key = np.concatenate([descramble_key, np.array([descramble_key.shape[0]])], axis=0)

def descramble(samples):
    return descramble_key[samples.flatten()].reshape(*samples.shape)

descrambled_test_dataset = descramble(test_dataset)

2024-05-21 12:18:05.548310: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-21 12:18:08.707020: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /share/software/user/open/cudnn/8.9.0.131/lib:/usr/lib64/nvidia:/share/software/user/open/cuda/12.2.0/targets/x86_64-linux/lib:/share/software/user/open/cuda/12.2.0/lib64:/share/software/user/open/cuda/12.2.0/nvvm/lib64:/share/software/user/open/cuda/12.2.0/extras/Debugger/lib64:/share/software/user/open/cuda/12.2.0/extras/CUPTI/lib64:/share/software/user/open/python/3.9.0/lib:/share/software/u

In [9]:
from eval import outliers, get_dist, hellinger, eval_mse_stats, save_results
from tqdm import tqdm

In [14]:
def run_experiments(num_repeats, sample_size, batch_size, 
                    results_file,
                   method, g_steps, tau_steps, 
                    c_steps, c_stepsize, corrector):
    
    tqdm._instances.clear()

    # Specific to each method
    eval_cfg.sampler.num_steps = tau_steps
    eval_cfg.sampler.updates_per_eval = g_steps
    
    # Generic corrector fields
    eval_cfg.sampler.num_corrector_steps = c_steps
    eval_cfg.sampler.corrector_entry_time = 0.9
    eval_cfg.sampler.corrector_step_size_multiplier = c_stepsize
    eval_cfg.sampler.balancing_function = corrector

    if method == "gillespies":
        eval_cfg.sampler.name = "ConditionalPCMultiGillespies"
    elif method == "tauleaping":
        eval_cfg.sampler.name = "ConditionalPCTauLeapingAbsorbingInformed"
    else:
        assert(False)

    sampler = sampling_utils.get_sampler(eval_cfg)

    results = []

    test_size = sample_size
    repeat = 0
    while repeat < num_repeats:
        print("Repeat:", repeat)
        try:
            h_dists = []
            outlier_rates = []
            for start in range(0, test_size, batch_size):
                end = min(start + batch_size, test_size)
                size = end - start

                conditioner = torch.from_numpy(test_dataset[start:end, :condition_dim]).to(device)
                samples, out = sampler.sample(model, size, 0, conditioner)
                # !Important to descramble!
                samples = descramble(samples)

                for i in range(size):
                    h = hellinger(descrambled_test_dataset[start+i, :], samples[i, :], S)
                    r = outliers(descrambled_test_dataset[start+i, :], samples[i, :], S)
                    h_dists.append(h)
                    outlier_rates.append(r)
            # !
            D = eval_cfg.data.shape[0] - eval_cfg.sampler.condition_dim
            if method == "gillespies":
                nfe = D / eval_cfg.sampler.updates_per_eval
            elif method == "tauleaping":
                nfe = eval_cfg.sampler.num_steps
            nfe += nfe * eval_cfg.sampler.corrector_entry_time * eval_cfg.sampler.num_corrector_steps

            new_result = {
                    'method': method,
                    'g_steps': 0 if method != "gillespies" else eval_cfg.sampler.updates_per_eval,
                    'tau_steps': 0 if method != "tauleaping" else eval_cfg.sampler.num_steps,
                    'use_corrector': eval_cfg.sampler.corrector_entry_time > 0 
                                 and eval_cfg.sampler.num_corrector_steps > 0,
                    'corrector': eval_cfg.sampler.balancing_function,
                    'c_stepsize': eval_cfg.sampler.corrector_step_size_multiplier,
                    'c_steps': eval_cfg.sampler.num_corrector_steps,
                    'nfe': nfe,
                    'h_dist': np.mean(h_dists),
                    'outlier_rate': np.mean(outlier_rates),
                }
            print(new_result)
            results.append(new_result)
            repeat += 1
        except Exception as e:
            # Handle the error and print the error message
            print("NaNed out! Retrying")

    save_results(results, results_file)

In [16]:
results_file = 'piano_results.csv'

num_repeats = 3
sample_size = test_dataset.shape[0]
batch_size = 200

methods = ["tauleaping", 
           "gillespies"]

updates_per_evals = [3, 2, 1]
num_sample_steps = [50, 100, 200]

# Remove Barker for now because it keeps NaNing out for some reason...
correctors = ["barker"]
corrector_stepsizes = [.05]
corrector_steps = 1

start_flag = False
start_method = "tauleaping"
start_i = 1
start_j = 2

# 1 corrector step
c_steps = 1
for method in methods:
    
    if start_flag:
        method = start_method
        print("Set method to ", method)
    
    for i in range(len(correctors)):
        
        if start_flag:
            i = start_i
            print("Set i to ", i)
        
        corrector = correctors[i]
        c_stepsize = corrector_stepsizes[i]
        
        for j in range(len(updates_per_evals)):
        
            if start_flag:
                j = start_j
                print("Set j to ", j)
                start_flag = False
        
            g_steps = updates_per_evals[j]
            tau_steps = num_sample_steps[j]
            
            run_experiments(num_repeats=num_repeats, 
                            sample_size=sample_size, batch_size=batch_size,
                            results_file=results_file,
                            method=method, 
                            g_steps=g_steps, 
                            tau_steps=tau_steps, 
                            c_steps=c_steps, 
                            c_stepsize=c_stepsize, 
                            corrector=corrector)

Repeat: 0


49it [00:06,  8.10it/s]
22it [00:02,  8.37it/s]


NaNed out! Retrying
Repeat: 0


49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
7it [00:00, 10.58it/s]


NaNed out! Retrying
Repeat: 0


49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:05,  9.26it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.05, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.39763854498071954, 'outlier_rate': 0.15198644655704008}
Repeat: 1


49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:05,  9.25it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.05, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.3987160968902745, 'outlier_rate': 0.15435508735868447}
Repeat: 2


39it [00:04,  8.02it/s]


NaNed out! Retrying
Repeat: 2


41it [00:05,  8.00it/s]


NaNed out! Retrying
Repeat: 2


49it [00:06,  8.11it/s]
11it [00:01,  9.28it/s]


NaNed out! Retrying
Repeat: 2


49it [00:06,  8.09it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.09it/s]
49it [00:05,  9.23it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.05, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.3969865136074619, 'outlier_rate': 0.1530222250770812}
Experiment results saved to  piano_results.csv
Repeat: 0


99it [00:12,  8.00it/s]
68it [00:08,  8.13it/s]


KeyboardInterrupt: 

In [None]:
# {'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.39283093610583686, 'outlier_rate': 0.14403343396711202}

In [17]:
# Try different rates
correctors = ["barker"]
corrector_stepsizes = [0.1]

methods = ["tauleaping"]

for method in methods:
    for i in range(len(correctors)):
        
        corrector = correctors[i]
        c_stepsize = corrector_stepsizes[i]
        
        for j in range(len(updates_per_evals)):
        
            g_steps = updates_per_evals[j]
            tau_steps = num_sample_steps[j]
            
            run_experiments(num_repeats=num_repeats, 
                            sample_size=sample_size, batch_size=batch_size,
                            results_file=results_file,
                            method=method, 
                            g_steps=g_steps, 
                            tau_steps=tau_steps, 
                            c_steps=c_steps, 
                            c_stepsize=c_stepsize, 
                            corrector=corrector)

Repeat: 0


49it [00:06,  8.08it/s]
49it [00:06,  8.09it/s]
39it [00:04,  8.01it/s]


NaNed out! Retrying
Repeat: 0


49it [00:06,  8.10it/s]
29it [00:03,  8.16it/s]


NaNed out! Retrying
Repeat: 0


49it [00:06,  8.11it/s]
28it [00:03,  8.19it/s]


NaNed out! Retrying
Repeat: 0


49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
27it [00:03,  8.11it/s]


NaNed out! Retrying
Repeat: 0


49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
44it [00:04,  9.08it/s]


NaNed out! Retrying
Repeat: 0


49it [00:06,  8.10it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.09it/s]
49it [00:06,  8.10it/s]
49it [00:05,  9.25it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.1, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.3980528411952264, 'outlier_rate': 0.1532550745118191}
Repeat: 1


49it [00:06,  8.10it/s]
33it [00:04,  8.08it/s]


NaNed out! Retrying
Repeat: 1


49it [00:06,  8.09it/s]
49it [00:06,  8.09it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.07it/s]
49it [00:05,  9.24it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.1, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.39810894279452225, 'outlier_rate': 0.15282952209660844}
Repeat: 2


49it [00:06,  8.10it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.08it/s]
49it [00:06,  8.08it/s]
49it [00:05,  9.23it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.1, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.39933662275745935, 'outlier_rate': 0.1537247880267215}
Experiment results saved to  piano_results.csv
Repeat: 0


86it [00:10,  8.00it/s]


NaNed out! Retrying
Repeat: 0


95it [00:11,  7.97it/s]


NaNed out! Retrying
Repeat: 0


99it [00:12,  8.00it/s]
57it [00:06,  8.25it/s]


NaNed out! Retrying
Repeat: 0


89it [00:11,  8.01it/s]


NaNed out! Retrying
Repeat: 0


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.05it/s]
77it [00:09,  8.08it/s]


NaNed out! Retrying
Repeat: 0


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
26it [00:02,  9.20it/s]


NaNed out! Retrying
Repeat: 0


99it [00:12,  8.06it/s]
36it [00:04,  8.70it/s]


NaNed out! Retrying
Repeat: 0


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
78it [00:09,  8.08it/s]


NaNed out! Retrying
Repeat: 0


99it [00:12,  8.06it/s]
30it [00:03,  8.95it/s]


NaNed out! Retrying
Repeat: 0


99it [00:12,  8.06it/s]
97it [00:12,  7.98it/s]


NaNed out! Retrying
Repeat: 0


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.05it/s]
99it [00:12,  8.06it/s]
99it [00:10,  9.20it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 100, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.1, 'c_steps': 1, 'nfe': 190.0, 'h_dist': 0.39061603200047323, 'outlier_rate': 0.14056478031860226}
Repeat: 1


99it [00:12,  8.06it/s]
96it [00:12,  7.96it/s]


NaNed out! Retrying
Repeat: 1


55it [00:06,  8.26it/s]


NaNed out! Retrying
Repeat: 1


99it [00:12,  8.05it/s]
99it [00:12,  8.01it/s]
24it [00:02,  9.29it/s]


NaNed out! Retrying
Repeat: 1


99it [00:12,  8.05it/s]
99it [00:12,  8.03it/s]
90it [00:11,  7.99it/s]


NaNed out! Retrying
Repeat: 1


45it [00:05,  8.44it/s]


NaNed out! Retrying
Repeat: 1


99it [00:12,  7.90it/s]
99it [00:12,  8.04it/s]
99it [00:12,  8.04it/s]
99it [00:12,  8.00it/s]
99it [00:10,  9.17it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 100, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.1, 'c_steps': 1, 'nfe': 190.0, 'h_dist': 0.3921691045483312, 'outlier_rate': 0.14192173047276466}
Repeat: 2


99it [00:12,  8.04it/s]
99it [00:12,  8.04it/s]
51it [00:06,  8.33it/s]


NaNed out! Retrying
Repeat: 2


72it [00:08,  8.10it/s]


NaNed out! Retrying
Repeat: 2


92it [00:11,  7.96it/s]


NaNed out! Retrying
Repeat: 2


99it [00:12,  8.01it/s]
97it [00:12,  7.96it/s]


NaNed out! Retrying
Repeat: 2


21it [00:02,  9.61it/s]


NaNed out! Retrying
Repeat: 2


99it [00:12,  8.05it/s]
99it [00:12,  8.04it/s]
99it [00:12,  8.00it/s]
99it [00:12,  8.04it/s]
99it [00:10,  9.14it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 100, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.1, 'c_steps': 1, 'nfe': 190.0, 'h_dist': 0.3909310115923907, 'outlier_rate': 0.14262429342240493}
Experiment results saved to  piano_results.csv
Repeat: 0


199it [00:24,  8.02it/s]
199it [00:24,  8.00it/s]
199it [00:24,  8.02it/s]
199it [00:24,  8.02it/s]
199it [00:21,  9.15it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 200, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.1, 'c_steps': 1, 'nfe': 380.0, 'h_dist': 0.3858292281506808, 'outlier_rate': 0.13467529547790338}
Repeat: 1


199it [00:24,  8.01it/s]
199it [00:24,  8.02it/s]
199it [00:24,  8.03it/s]
199it [00:24,  8.03it/s]
199it [00:21,  9.17it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 200, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.1, 'c_steps': 1, 'nfe': 380.0, 'h_dist': 0.38419458643839044, 'outlier_rate': 0.13174460431654678}
Repeat: 2


199it [00:24,  8.01it/s]
199it [00:24,  8.03it/s]
199it [00:24,  8.03it/s]
199it [00:24,  8.04it/s]
199it [00:21,  9.17it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 200, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 0.1, 'c_steps': 1, 'nfe': 380.0, 'h_dist': 0.38476495294708696, 'outlier_rate': 0.13372783915724562}
Experiment results saved to  piano_results.csv


In [5]:
# results_file = 'piano_results_test.csv'

# num_repeats = 2
# sample_size = 10
# batch_size = 10

# methods = ["tauleaping", 
#            "gillespies"]

# updates_per_evals = [3, 2, 1]
# num_sample_steps = [50, 100, 200]

# correctors = ["mpf", "barker", "birthdeath"]
# corrector_stepsizes = [1.0, 1.0, 0.1]
# corrector_steps = 2

# for method in methods:
#     for i in range(len(correctors)):
#         corrector = correctors[i]
#         c_stepsize = corrector_stepsizes[i]
#         g_steps = updates_per_evals[0]
#         tau_steps = num_sample_steps[0]
#         c_steps = corrector_steps
#         run_experiments(num_repeats=num_repeats, 
#                         sample_size=sample_size, batch_size=batch_size,
#                         results_file=results_file,
#                         method=method, 
#                         g_steps=g_steps, 
#                         tau_steps=tau_steps, 
#                         c_steps=c_steps, 
#                         c_stepsize=c_stepsize, 
#                         corrector=corrector)

49it [00:01, 34.43it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.26276069890124965, 'outlier_rate': 0.077734375}


49it [00:01, 37.88it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.26873401892136234, 'outlier_rate': 0.0921875}
Experiment results saved to  piano_results_test.csv


9it [00:00, 54.73it/s]


KeyboardInterrupt: 