In [1]:
from importlib import reload
import sys
sys.path.insert(0, '../')

# sys.path.remove("/home/users/yixiuz/.local/lib/python3.9/site-packages")
sys.path.append("/home/groups/swl1/yixiuz/torch_fid/lib/python3.9/site-packages")
sys.path.append("/home/groups/swl1/yixiuz/torch_fid/bin")

In [2]:
import matplotlib
import torch
import torch.nn as nn
import torch.nn.functional as F
import ml_collections

import lib.utils.bookkeeping as bookkeeping
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

import lib.utils.utils as utils
import lib.models.models as models
import lib.models.model_utils as model_utils
import lib.datasets.datasets as datasets
import lib.datasets.dataset_utils as dataset_utils
import lib.sampling.sampling as sampling
import lib.sampling.sampling_utils as sampling_utils

import config.eval.piano_hollow as piano

%matplotlib inline

eval_cfg = piano.get_config()
train_cfg = bookkeeping.load_ml_collections(Path(eval_cfg.train_config_path))

for item in eval_cfg.train_config_overrides:
    utils.set_in_nested_dict(train_cfg, item[0], item[1])

S = train_cfg.data.S
# device = torch.device(eval_cfg.device)
device = torch.device("cuda")

model = model_utils.create_model(train_cfg, device)

loaded_state = torch.load(Path(eval_cfg.checkpoint_path),
    map_location=device)

modified_model_state = utils.remove_module_from_keys(loaded_state['model'])
model.load_state_dict(modified_model_state)

model.eval()

dataset = dataset_utils.get_dataset(eval_cfg, device)
data = dataset.data
test_dataset = np.load(eval_cfg.sampler.test_dataset)
condition_dim = eval_cfg.sampler.condition_dim
descramble_key = np.loadtxt(eval_cfg.pianoroll_dataset_path + '/descramble_key.txt')
# The mask stays the same
descramble_key = np.concatenate([descramble_key, np.array([descramble_key.shape[0]])], axis=0)

def descramble(samples):
    return descramble_key[samples.flatten()].reshape(*samples.shape)

descrambled_test_dataset = descramble(test_dataset)

2024-05-20 21:43:15.691735: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-05-20 21:43:15.804940: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /share/software/user/open/cudnn/8.9.0.131/lib:/usr/lib64/nvidia:/share/software/user/open/cuda/12.2.0/targets/x86_64-linux/lib:/share/software/user/open/cuda/12.2.0/lib64:/share/software/user/open/cuda/12.2.0/nvvm/lib64:/share/software/user/open/cuda/12.2.0/extras/Debugger/lib64:/share/software/user/open/cuda/12.2.0/extras/CUPTI/lib64:/share/software/user/open/python/3.9.0/lib:/share/software/u

In [3]:
from eval import outliers, get_dist, hellinger, eval_mse_stats, save_results
from tqdm import tqdm

In [4]:
def run_experiments(num_repeats, sample_size, batch_size, 
                    results_file,
                   method, g_steps, tau_steps, 
                    c_steps, c_stepsize, corrector):
    
    tqdm._instances.clear()

    # Specific to each method
    eval_cfg.sampler.num_steps = tau_steps
    eval_cfg.sampler.updates_per_eval = g_steps
    
    # Generic corrector fields
    eval_cfg.sampler.num_corrector_steps = c_steps
    eval_cfg.sampler.corrector_entry_time = 0.9
    eval_cfg.sampler.corrector_step_size_multiplier = c_stepsize
    eval_cfg.sampler.balancing_function = corrector

    if method == "gillespies":
        eval_cfg.sampler.name = "ConditionalPCMultiGillespies"
    elif method == "tauleaping":
        eval_cfg.sampler.name = "ConditionalPCTauLeapingAbsorbingInformed"
    else:
        assert(False)

    sampler = sampling_utils.get_sampler(eval_cfg)

    results = []

    test_size = sample_size
    
    for _ in range(num_repeats):
        
        h_dists = []
        outlier_rates = []
        for start in range(0, test_size, batch_size):
            end = min(start + batch_size, test_size)
            size = end - start

            conditioner = torch.from_numpy(test_dataset[start:end, :condition_dim]).to(device)
            samples, out = sampler.sample(model, size, 0, conditioner)
            # !Important to descramble!
            samples = descramble(samples)

            for i in range(size):
                h = hellinger(descrambled_test_dataset[start+i, :], samples[i, :], S)
                r = outliers(descrambled_test_dataset[start+i, :], samples[i, :], S)
                h_dists.append(h)
                outlier_rates.append(r)
        # !
        D = eval_cfg.data.shape[0] - eval_cfg.sampler.condition_dim
        if method == "gillespies":
            nfe = D / eval_cfg.sampler.updates_per_eval
        elif method == "tauleaping":
            nfe = eval_cfg.sampler.num_steps
        nfe += nfe * eval_cfg.sampler.corrector_entry_time * eval_cfg.sampler.num_corrector_steps

        new_result = {
                'method': method,
                'g_steps': 0 if method != "gillespies" else eval_cfg.sampler.updates_per_eval,
                'tau_steps': 0 if method != "tauleaping" else eval_cfg.sampler.num_steps,
                'use_corrector': eval_cfg.sampler.corrector_entry_time > 0 
                             and eval_cfg.sampler.num_corrector_steps > 0,
                'corrector': eval_cfg.sampler.balancing_function,
                'c_stepsize': eval_cfg.sampler.corrector_step_size_multiplier,
                'c_steps': eval_cfg.sampler.num_corrector_steps,
                'nfe': nfe,
                'h_dist': np.mean(h_dists),
                'outlier_rate': np.mean(outlier_rates),
            }
        print(new_result)
        results.append(new_result)

    save_results(results, results_file)

In [5]:
results_file = 'piano_results_test.csv'

num_repeats = 2
sample_size = 10
batch_size = 10

methods = ["tauleaping", 
           "gillespies"]

updates_per_evals = [3, 2, 1]
num_sample_steps = [50, 100, 200]

correctors = ["mpf", "barker", "birthdeath"]
corrector_stepsizes = [1.0, 1.0, 0.1]
corrector_steps = 2

for method in methods:
    for i in range(len(correctors)):
        corrector = correctors[i]
        c_stepsize = corrector_stepsizes[i]
        g_steps = updates_per_evals[0]
        tau_steps = num_sample_steps[0]
        c_steps = corrector_steps
        run_experiments(num_repeats=num_repeats, 
                        sample_size=sample_size, batch_size=batch_size,
                        results_file=results_file,
                        method=method, 
                        g_steps=g_steps, 
                        tau_steps=tau_steps, 
                        c_steps=c_steps, 
                        c_stepsize=c_stepsize, 
                        corrector=corrector)

49it [00:01, 33.77it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.2668313174703481, 'outlier_rate': 0.083203125}


49it [00:01, 37.76it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.2600486274351373, 'outlier_rate': 0.082421875}
Experiment results saved to  piano_results_test.csv


49it [00:01, 37.70it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.2857392213506149, 'outlier_rate': 0.09921875}


49it [00:01, 37.61it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.29041439957098325, 'outlier_rate': 0.095703125}
Experiment results saved to  piano_results_test.csv


49it [00:01, 37.72it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'birthdeath', 'c_stepsize': 0.1, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.3186536161480142, 'outlier_rate': 0.1296875}


49it [00:01, 37.74it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'birthdeath', 'c_stepsize': 0.1, 'c_steps': 2, 'nfe': 140.0, 'h_dist': 0.28446260233681525, 'outlier_rate': 0.10859375}
Experiment results saved to  piano_results_test.csv


225it [00:02, 106.60it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.42027543564354497, 'outlier_rate': 0.0859375}


225it [00:02, 108.07it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.3994891652684006, 'outlier_rate': 0.08984375}
Experiment results saved to  piano_results_test.csv


225it [00:02, 108.01it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.29192670713085234, 'outlier_rate': 0.091796875}


225it [00:02, 107.84it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.289521658416117, 'outlier_rate': 0.087890625}
Experiment results saved to  piano_results_test.csv


225it [00:02, 107.72it/s]                         


{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'birthdeath', 'c_stepsize': 0.1, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.2763986583731081, 'outlier_rate': 0.093359375}


225it [00:02, 108.37it/s]                         

{'method': 'gillespies', 'g_steps': 3, 'tau_steps': 0, 'use_corrector': True, 'corrector': 'birthdeath', 'c_stepsize': 0.1, 'c_steps': 2, 'nfe': 209.06666666666666, 'h_dist': 0.2865092655555353, 'outlier_rate': 0.087890625}
Experiment results saved to  piano_results_test.csv





In [10]:
results_file = 'piano_results.csv'

num_repeats = 3
sample_size = test_dataset.shape[0]
batch_size = 200

methods = ["tauleaping", 
           "gillespies"]

updates_per_evals = [3, 2, 1]
num_sample_steps = [50, 100, 200]

correctors = ["mpf", "barker", "birthdeath"]
corrector_stepsizes = [1.0, 1.0, 0.1]
corrector_steps = 1

# 1 corrector step
c_steps = 1
for method in methods:
    for i in range(len(correctors)):
        
        corrector = correctors[i]
        c_stepsize = corrector_stepsizes[i]
        
        for j in range(len(updates_per_evals)):
        
            g_steps = updates_per_evals[j]
            tau_steps = num_sample_steps[j]
            
            run_experiments(num_repeats=num_repeats, 
                            sample_size=sample_size, batch_size=batch_size,
                            results_file=results_file,
                            method=method, 
                            g_steps=g_steps, 
                            tau_steps=tau_steps, 
                            c_steps=c_steps, 
                            c_stepsize=c_stepsize, 
                            corrector=corrector)

49it [00:06,  8.10it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.11it/s]
49it [00:05,  9.24it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.3794871595882483, 'outlier_rate': 0.12499197070914697}


49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:05,  9.24it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.3780450319462578, 'outlier_rate': 0.12188062050359712}


49it [00:06,  8.11it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.11it/s]
49it [00:05,  9.25it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.3792639253193196, 'outlier_rate': 0.12286420863309352}
Experiment results saved to  piano_results.csv


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.05it/s]
99it [00:10,  9.19it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 100, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 190.0, 'h_dist': 0.38032061559315516, 'outlier_rate': 0.12338209789311408}


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:10,  9.19it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 100, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 190.0, 'h_dist': 0.3800044233128485, 'outlier_rate': 0.12243464157245632}


99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.06it/s]
99it [00:12,  8.05it/s]
99it [00:10,  9.19it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 100, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 190.0, 'h_dist': 0.37835053119944406, 'outlier_rate': 0.12143900950668036}
Experiment results saved to  piano_results.csv


199it [00:24,  8.04it/s]
199it [00:24,  8.04it/s]
199it [00:24,  8.04it/s]
199it [00:24,  8.04it/s]
199it [00:21,  9.16it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 200, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 380.0, 'h_dist': 0.37845285473918566, 'outlier_rate': 0.12135068730729702}


199it [00:24,  8.04it/s]
199it [00:24,  8.04it/s]
199it [00:24,  8.04it/s]
199it [00:24,  8.04it/s]
199it [00:21,  9.17it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 200, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 380.0, 'h_dist': 0.37843752512551715, 'outlier_rate': 0.12156346351490237}


199it [00:24,  8.04it/s]
199it [00:24,  8.04it/s]
199it [00:24,  8.04it/s]
199it [00:24,  8.04it/s]
199it [00:21,  9.16it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 200, 'use_corrector': True, 'corrector': 'mpf', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 380.0, 'h_dist': 0.3785529563636009, 'outlier_rate': 0.12169193216855087}
Experiment results saved to  piano_results.csv


49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
49it [00:06,  8.11it/s]
49it [00:06,  8.10it/s]
49it [00:05,  9.23it/s]


{'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.39283093610583686, 'outlier_rate': 0.14403343396711202}


28it [00:03,  8.18it/s]


ValueError: Expected parameter rate (Tensor of shape (200, 224, 130)) of distribution Poisson(rate: torch.Size([200, 224, 130])) to satisfy the constraint GreaterThanEq(lower_bound=0.0), but found invalid values:
tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4803e-02],
         [1.9946e-21, 5.8667e-20, 1.3536e-24,  ..., 5.2780e-26,
          9.9655e-18, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.5329e-02],
         ...,
         [2.7588e-27, 6.3831e-25, 2.2352e-23,  ..., 4.5161e-27,
          8.8701e-22, 0.0000e+00],
         [9.7245e-28, 9.7345e-26, 3.3701e-25,  ..., 5.2593e-29,
          6.8246e-23, 0.0000e+00],
         [2.8975e-26, 1.6522e-28, 2.9208e-24,  ..., 1.1348e-28,
          1.5967e-24, 0.0000e+00]],

        [[2.8028e-24, 1.2704e-23, 1.0938e-24,  ..., 1.8976e-20,
          1.2571e-23, 0.0000e+00],
         [9.9945e-20, 3.9277e-21, 2.2082e-23,  ..., 3.9252e-20,
          1.7560e-21, 0.0000e+00],
         [5.7163e-21, 3.5796e-21, 1.0110e-23,  ..., 1.3082e-18,
          7.6886e-20, 0.0000e+00],
         ...,
         [6.1273e-23, 1.8870e-24, 3.9121e-25,  ..., 3.7225e-23,
          7.2968e-24, 0.0000e+00],
         [2.2943e-23, 1.1735e-24, 5.6519e-29,  ..., 8.4486e-22,
          9.0745e-24, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4649e-02]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 3.4128e-02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4647e-02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4705e-02],
         ...,
         [7.0503e-22, 4.1223e-19, 6.3445e-18,  ..., 3.4697e-22,
          4.1230e-19, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 2.5906e-02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4817e-02]],

        ...,

        [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4647e-02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4647e-02],
         [3.5123e-31, 2.4257e-29, 1.0306e-31,  ..., 1.3900e-27,
          1.5538e-31, 0.0000e+00],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4647e-02],
         [1.5888e-24, 2.0520e-27, 6.3884e-33,  ..., 7.7668e-34,
          1.7116e-28, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4647e-02]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4836e-02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4828e-02],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4830e-02],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4735e-02],
         [1.9268e-17, 2.5908e-17, 4.7962e-19,  ..., 9.9308e-11,
          1.2739e-07, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.5022e-02]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4647e-02],
         [8.1822e-29, 1.0511e-31, 1.0836e-32,  ..., 2.9833e-34,
          2.0394e-33, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4647e-02],
         ...,
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 1.4647e-02],
         [7.0793e-23, 1.5494e-24, 2.2797e-22,  ..., 3.7854e-26,
          2.6089e-20, 0.0000e+00],
         [8.6550e-26, 2.4491e-26, 1.5469e-25,  ..., 3.5657e-28,
          1.2517e-23, 0.0000e+00]]], device='cuda:0')

In [None]:
%debug

> [0;32m/home/groups/swl1/yixiuz/torch_fid/lib/python3.9/site-packages/torch/distributions/distribution.py[0m(68)[0;36m__init__[0;34m()[0m
[0;32m     66 [0;31m                [0mvalid[0m [0;34m=[0m [0mconstraint[0m[0;34m.[0m[0mcheck[0m[0;34m([0m[0mvalue[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     67 [0;31m                [0;32mif[0m [0;32mnot[0m [0mvalid[0m[0;34m.[0m[0mall[0m[0;34m([0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 68 [0;31m                    raise ValueError(
[0m[0;32m     69 [0;31m                        [0;34mf"Expected parameter {param} "[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     70 [0;31m                        [0;34mf"({type(value).__name__} of shape {tuple(value.shape)}) "[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> u
> [0;32m/home/groups/swl1/yixiuz/torch_fid/lib/python3.9/site-packages/torch/distributions/poisson.py[0m(51)[0;36m__init__[0;34m()[0m
[0;32m     49 [0;31m        [0;

ipdb> u
> [0;32m/home/groups/swl1/yixiuz/torch_fid/tauLDR/lib/sampling/sampling.py[0m(1288)[0;36msample[0;34m()[0m
[0;32m   1286 [0;31m                            [0mc_rate_hist[0m[0;34m.[0m[0mappend[0m[0;34m([0m[0mcorrector_rate[0m[0;34m.[0m[0mdetach[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0mnumpy[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1287 [0;31m[0;34m[0m[0m
[0m[0;32m-> 1288 [0;31m                        x = take_poisson_step(x, corrector_rate, 
[0m[0;32m   1289 [0;31m                            corrector_step_size_multiplier * h)
[0m[0;32m   1290 [0;31m                [0;32melif[0m [0mt[0m [0;32min[0m [0msave_ts[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> ll
[1;32m   1160 [0m    [0;32mdef[0m [0msample[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0mN[0m[0;34m,[0m [0mnum_intermediates[0m[0;34m,[0m [0mconditioner[

ipdb> score
*** NameError: name 'score' is not defined
ipdb> d
> [0;32m/home/groups/swl1/yixiuz/torch_fid/tauLDR/lib/sampling/sampling.py[0m(1251)[0;36mtake_poisson_step[0;34m()[0m
[0;32m   1249 [0;31m                [0;32mdef[0m [0mtake_poisson_step[0m[0;34m([0m[0min_x[0m[0;34m,[0m [0min_reverse_rates[0m[0;34m,[0m [0min_h[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m   1250 [0;31m                    [0mdiffs[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0marange[0m[0;34m([0m[0mS[0m[0;34m,[0m [0mdevice[0m[0;34m=[0m[0mdevice[0m[0;34m)[0m[0;34m.[0m[0mview[0m[0;34m([0m[0;36m1[0m[0;34m,[0m[0;36m1[0m[0;34m,[0m[0mS[0m[0;34m)[0m [0;34m-[0m [0min_x[0m[0;34m.[0m[0mview[0m[0;34m([0m[0mN[0m[0;34m,[0m[0msample_D[0m[0;34m,[0m[0;36m1[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m-> 1251 [0;31m                    [0mpoisson_dist[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mdistributions[0m[0;34m.

ipdb> scores 
tensor([[[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 7.5449e-01],
         [5.7945e-20, 1.7043e-18, 3.9322e-23,  ..., 1.5333e-24,
          2.8950e-16, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
          0.0000e+00, 8.0280e-01],
         ...,
         [8.0146e-26, 1.8543e-23, 6.4933e-22,  ..., 1.3120e-25,
          2.5768e-20, 0.0000e+00],
         [2.8250e-26, 2.8279e-24, 9.7903e-24,  ..., 1.5278e-27,
          1.9826e-21, 0.0000e+00],
         [8.4174e-25, 4.7998e-27, 8.4850e-23,  ..., 3.2967e-27,
          4.6386e-23, 0.0000e+00]],

        [[8.1423e-23, 3.6905e-22, 3.1777e-23,  ..., 5.5128e-19,
          3.6519e-22, 0.0000e+00],
         [2.9034e-18, 1.1410e-19, 6.4149e-22,  ..., 1.1403e-18,
          5.1014e-20, 0.0000e+00],
         [1.6606e-19, 1.0399e-19, 2.9371e-22,  ..., 3.8004e-17,
          2.2336e-18, 0.0000e+00],
         ...,
         [1.7800e-21, 5.4817e-23, 1.1365e-23,  ..., 1.0814e-21,
 

ipdb> ll
[1;32m   1160 [0m    [0;32mdef[0m [0msample[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mmodel[0m[0;34m,[0m [0mN[0m[0;34m,[0m [0mnum_intermediates[0m[0;34m,[0m [0mconditioner[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1161 [0m        [0;32massert[0m [0mconditioner[0m[0;34m.[0m[0mshape[0m[0;34m[[0m[0;36m0[0m[0;34m][0m [0;34m==[0m [0mN[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1162 [0m[0;34m[0m[0m
[1;32m   1163 [0m        [0mt[0m [0;34m=[0m [0;36m1.0[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1164 [0m        [0mcondition_dim[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0mcfg[0m[0;34m.[0m[0msampler[0m[0;34m.[0m[0mcondition_dim[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1165 [0m        [0mtotal_D[0m [0;34m=[0m [0mnp[0m[0;34m.[0m[0mprod[0m[0;34m([0m[0mself[0m[0;34m.[0m[0mcfg[0m[0;34m.[0m[0mdata[0m[0;34m.[0m[0mshape[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[1;32m   1166 [0m        [0msamp

ipdb> t
0.44571428571428573
ipdb> u
> [0;32m/tmp/ipykernel_21529/1918133229.py[0m(40)[0;36mrun_experiments[0;34m()[0m
[0;32m     38 [0;31m[0;34m[0m[0m
[0m[0;32m     39 [0;31m            [0mconditioner[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mfrom_numpy[0m[0;34m([0m[0mtest_dataset[0m[0;34m[[0m[0mstart[0m[0;34m:[0m[0mend[0m[0;34m,[0m [0;34m:[0m[0mcondition_dim[0m[0;34m][0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 40 [0;31m            [0msamples[0m[0;34m,[0m [0mout[0m [0;34m=[0m [0msampler[0m[0;34m.[0m[0msample[0m[0;34m([0m[0mmodel[0m[0;34m,[0m [0msize[0m[0;34m,[0m [0;36m0[0m[0;34m,[0m [0mconditioner[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     41 [0;31m            [0;31m# !Important to descramble![0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     42 [0;31m            [0msamples[0m [0;34m=[0m [0mdescramble[0m[0;34m([0

In [None]:
# {'method': 'tauleaping', 'g_steps': 0, 'tau_steps': 50, 'use_corrector': True, 'corrector': 'barker', 'c_stepsize': 1.0, 'c_steps': 1, 'nfe': 95.0, 'h_dist': 0.39283093610583686, 'outlier_rate': 0.14403343396711202}

In [None]:
# Try different rates
correctors = ["mpf", "barker"]
corrector_stepsizes = [0.1, 0.1]

for method in methods:
    for i in range(len(correctors)):
        
        corrector = correctors[i]
        c_stepsize = corrector_stepsizes[i]
        
        for j in range(len(updates_per_evals)):
        
            g_steps = updates_per_evals[j]
            tau_steps = num_sample_steps[j]
            
            run_experiments(num_repeats=num_repeats, 
                            sample_size=sample_size, 
                            results_file=results_file,
                            method=method, 
                            g_steps=g_steps, 
                            tau_steps=tau_steps, 
                            c_steps=c_steps, 
                            c_stepsize=c_stepsize, 
                            corrector=corrector)