# A notebook for running experiments with generalized kernel thinning of Dwivedi and Mackey 2021 https://arxiv.org/pdf/2110.01593.pdf and standard thinning

In [2]:
import numpy as np
import numpy.random as npr
import numpy.linalg as npl

from scipy.spatial.distance import pdist

from argparse import ArgumentParser
import pickle as pkl
import pathlib
import os
import os.path

# import kernel thinning
from goodpoints import kt # kt.thin is the main thinning function; kt.split and kt.swap are other important functions
from goodpoints.util import isnotebook # Check whether this file is being executed as a script or as a notebook
from goodpoints.util import fprint  # for printing while flushing buffer
from goodpoints.tictoc import tic, toc # for timing blocks of code


# utils for generating samples, evaluating kernels, and mmds
from util_sample import sample, compute_params_p, sample_string
from util_k_mmd import kernel_eval, compute_params_k, compute_power_kernel_params_k
from util_k_mmd import p_kernel, ppn_kernel, pp_kernel, pnpn_kernel, squared_mmd, get_combined_results_filename
from util_parse import init_parser, convert_arg_flags
# for partial functions, to use kernel_eval for kernel
from functools import partial

# experiment functions
from util_experiments import run_kernel_thinning_experiment, kt_split_best, kt_split_rand
from util_experiments import run_standard_thinning_experiment, run_iid_thinning_experiment

# set things a bit when running the notebook
if isnotebook():
    # Autoreload packages that are modified
    %load_ext autoreload
    %autoreload 2
    %matplotlib inline
    %load_ext line_profiler
    # https://jakevdp.github.io/PythonDataScienceHandbook/01.07-timing-and-profiling.html

In [3]:

# for relevant parameters check util_parse file
parser = init_parser()
args, opt = parser.parse_known_args()
args = convert_arg_flags(args)

In [4]:
print(args, opt, parser)

Namespace(M=None, P='gauss', computemmd=True, computepower=True, d=2, filename='/accounts/projects/binyu/raaz.rsk/.jupyter/runtime/kernel-e4741f33-abf3-43d7-aec9-624ca3744dec.json', kernel='gauss', ktplus=True, m=6, nu=0.5, power=0.5, powerkt=True, rep0=0, repn=1, rerun=False, save_combined_results=False, stdthin=False, store_K=False, targetkt=True) [] ArgumentParser(prog='ipykernel_launcher.py', usage=None, description=None, formatter_class=<class 'argparse.HelpFormatter'>, conflict_handler='error', add_help=True)


# Set parameters for thinning experiments

- Helpful to check first init_parser function in util_parse.py to become familiar with list of arguments
- Code allows args.P supported by compute_params_p in util_sample, currently {mog, gauss, mcmc}
- Code allows args.kernel supported by compute_params_k in kernel_eval function, currently {gauss, bspline
- - For args.P="gauss", the only degree of freedom is in setting args.d (arbitrary d allowed), for "mog", its only args.M (supports only M = 4, 6, 8), and for "mcmc" its args.filename
- - For args.kernel, var_k is computed automatically based on P, set to 2d for gauss/mog, and median BW^2 for mcmc; It is equal to sigma^2 for Gauss, Laplace kernel in the notation of https://arxiv.org/pdf/2110.01593.pdf, and gamma^2 for IMQ/Matern kernels, 1/theta^2 for sinc, and it scales the distance in bspline kernel. args.nu is another parameter used only for IMQ/Matern/Bspline kernels. It denotes the nu parameter for IMQ/Matern kernel in the notation of https://arxiv.org/pdf/2110.01593.pdf, and the beta paraemter for the bspline kernel.
- - For args.power, there is a check in the code to ensure power kernel is valid based on the theory related to table 3 of the paper
- One can allow for general P and K by making changes in the functions listed above, and also checking feasilibility of MMD computations, and making changes to get_combined_results_filename

In [5]:
#
# Choose sample and kernel parameters
#

if isnotebook():
    args.d = 2
    args.M = 4
    args.P = "mog"
    args.kernel = "bspline"
    args.nu = 2
    args.computepower = True
    args.power = 2/3.
    # for bspline nu is same as beta, and with even nu, power should be (nu+2)/(2*nu+2)
    args.rep0 = 0 # starting rep index
    args.repn = 2 # number of reps
    args.m = 5 # size of input is n = 4^m, and output size is sqrt(n) = 2^m 
    
    # args.filename = 'Hinch_P_seed_1_temp_1_scaled'
    # collection of all MCMC filenames
    # ['Goodwin_RW', 'Goodwin_ADA-RW', 'Goodwin_MALA', 'Goodwin_PRECOND-MALA', 
    # 'Lotka_RW', 'Lotka_ADA-RW', 'Lotka_MALA', 'Lotka_PRECOND-MALA', 
    # 'Hinch_P_seed_1_temp_1', 'Hinch_P_seed_2_temp_1', 
    # 'Hinch_TP_seed_1_temp_8', 'Hinch_TP_seed_2_temp_8',
    # 'Hinch_P_seed_1_temp_1_scaled', 'Hinch_P_seed_2_temp_1_scaled', 
    # 'Hinch_TP_seed_1_temp_8_scaled', 'Hinch_TP_seed_2_temp_8_scaled']

d, params_p, var_k = compute_params_p(args)
# d can change from args when mcmc filename is specified
    
args.d = d
params_k, params_k_power = compute_params_k(args, var_k, power_kernel=args.computepower,power=args.power)

if args.ktplus: # if running KT+, need to define the KT+ kernel called as params_k_combo
    assert(args.power is not None)
    params_k_combo = dict()
    params_k_combo["name"] = "combo_"  + params_k["name"] + f"_{args.power}"
    params_k_combo["k"] = params_k.copy()
    params_k_combo["kpower"] = params_k_power.copy()
    params_k_combo["var"] = params_k["var"]
    params_k_combo["d"] = args.d

# if isnotebook():
print("p", params_p)
print("k", params_k)
print("kpower", params_k_power)
if args.ktplus:
    print("combo", params_k_combo)

p {'name': 'diag_mog', 'weights': array([0.25, 0.25, 0.25, 0.25]), 'means': array([[ 3.,  3.],
       [-3.,  3.],
       [-3., -3.],
       [ 3., -3.]]), 'covs': array([1., 1., 1., 1.]), 'd': 2, 'mean_sqdist': 40.0, 'saved_samples': False, 'flip_Pnmax': False}
k {'name': 'bspline', 'var': 4.0, 'd': 2, 'nu': 2}
kpower {'name': 'bspline0.6666666666666666_rt', 'd': 2, 'var': 4.0, 'nu': 1}
combo {'name': 'combo_bspline_0.6666666666666666', 'k': {'name': 'bspline', 'var': 4.0, 'd': 2, 'nu': 2}, 'kpower': {'name': 'bspline0.6666666666666666_rt', 'd': 2, 'var': 4.0, 'nu': 1}, 'var': 4.0, 'd': 2}


In [6]:
#
# Choose experiment parameters
#

# List of replicate ID numbers
rep_ids = np.arange(args.rep0, args.rep0+args.repn)

# List of halving round numbers m to evaluate
ms = range(args.m)

# Failure probability
delta = .5

if isnotebook():
    args.rerun = False
    rep_ids = range(10)

In [7]:
# initialize result matrices
# by default our code returns mmd(P, Pout), mmd(Pin, Pout), Pf-Poutf, Pinf-Poutf for f = k(0, .)
# if k is not Gauss, Pf-Poutf is set equal to Pinf-Pout f in the run_X_experiments function

if args.stdthin: #
    mmds_st = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P
    mmds_st_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin
    fun_diff_st  = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from P
    fun_diff_st_sin = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from Sin
    
if args.targetkt: 
    mmds_kt = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P
    mmds_kt_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin
    fun_diff_kt = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from P
    fun_diff_kt_sin = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from Sin

if args.powerkt: 
    mmds_kt_krt = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P
    mmds_kt_krt_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin
    fun_diff_kt_krt = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from P
    fun_diff_kt_krt_sin = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from Sin


if args.ktplus: 
    mmds_ktplus = np.zeros((max(ms)+1, len(rep_ids))) # mmds from P
    mmds_ktplus_sin = np.zeros((max(ms)+1, len(rep_ids))) # mmds from Sin
    fun_diff_ktplus = np.zeros((max(ms)+1, len(rep_ids)))# fun diff from P
    fun_diff_ktplus_sin = np.zeros((max(ms)+1, len(rep_ids))) # fun diff from Sin

# Deploy thinning experiments

In [8]:
print(f"Exp setting: k = {params_k},  P = {params_p}, m = {ms}")       
tic()

# print(args.rerun, args)
for m in ms:
    #
    # Run experiments and store quality of the 2^m thinned coreset
    #
    if args.stdthin:
        mmd_st, mmd_st_sin, fd_st, fd_st_sin = run_standard_thinning_experiment(m, params_p=params_p, rerun=args.rerun,
                                                                                params_k_mmd=params_k, rep_ids=rep_ids,
                                                                               compute_mmds=args.computemmd)
        mmds_st[m, :] = mmd_st[m, :]
        mmds_st_sin[m, :] = mmd_st_sin[m, :]
        fun_diff_st[m, :] = fd_st[m, :]
        fun_diff_st_sin[m, :] = fd_st_sin[m, :]
        
    if args.targetkt: 
        mmd_kt, mmd_kt_sin, fd_kt, fd_kt_sin = run_kernel_thinning_experiment(m, thin_fun=kt.thin, thin_str="", params_p=params_p, rerun=args.rerun,
                                                            params_k_split=params_k, params_k_swap=params_k, rep_ids=rep_ids, 
                                                            delta=delta, store_K=args.store_K,
                                                                              compute_mmds=args.computemmd
                                                                             )
        mmds_kt[m, :] = mmd_kt[m, :]
        mmds_kt_sin[m, :] = mmd_kt_sin[m, :]
        fun_diff_kt[m, :] = fd_kt[m, :]
        fun_diff_kt_sin[m, :] = fd_kt_sin[m, :]
        

        
    if args.powerkt: 
        mmd_kt_krt, mmd_kt_krt_sin, fd_kt_krt, fd_kt_krt_sin = run_kernel_thinning_experiment(m, thin_fun=kt.thin, thin_str="", params_p=params_p, rerun=args.rerun,
                                                            params_k_split=params_k_power, params_k_swap=params_k, rep_ids=rep_ids, 
                                                            delta=delta, store_K=args.store_K,
                                                                compute_mmds=args.computemmd)
        mmds_kt_krt[m, :] = mmd_kt_krt[m, :]
        mmds_kt_krt_sin[m, :] = mmd_kt_krt_sin[m, :]
        fun_diff_kt_krt[m, :] = fd_kt_krt[m, :]
        fun_diff_kt_krt_sin[m, :] = fd_kt_krt_sin[m, :]
    
    if args.ktplus:
        mmd_ktplus, mmd_ktplus_sin, fd_ktplus, fd_ktplus_sin = run_kernel_thinning_experiment(m, thin_fun=kt.thin, thin_str="-plus", params_p=params_p, rerun=args.rerun,
                                                            params_k_split=params_k_combo, params_k_swap=params_k, rep_ids=rep_ids, 
                                                            delta=delta, store_K=args.store_K,
                                                                              compute_mmds=args.computemmd
                                                                             )
        mmds_ktplus[m, :] = mmd_ktplus[m, :]
        mmds_ktplus_sin[m, :] = mmd_ktplus_sin[m, :]
        fun_diff_ktplus[m, :] = fd_ktplus[m, :]
        fun_diff_ktplus_sin[m, :] = fd_ktplus_sin[m, :]
    
    if args.targetkt:
        print('mmd target_kt', mmds_kt)
toc()

Exp setting: k = {'name': 'bspline', 'var': 4.0, 'd': 2, 'nu': 2},  P = {'name': 'diag_mog', 'weights': array([0.25, 0.25, 0.25, 0.25]), 'means': array([[ 3.,  3.],
       [-3.,  3.],
       [-3., -3.],
       [ 3., -3.]]), 'covs': array([1., 1., 1., 1.]), 'd': 2, 'mean_sqdist': 40.0, 'saved_samples': False, 'flip_Pnmax': False}, m = range(0, 5)
Running kernel thinning  experiment with template results_new/kt-coresets-diag_mog_comp4_seed1234567-splitbspline_var4.000_seed9876543-swapbspline_var4.000-d2-m0-delta0.5-rep{}.pkl.....
-elapsed time: 0.156 (s)
-elapsed time: 0.106 (s)
-elapsed time: 0.0917 (s)
-elapsed time: 0.0763 (s)
-elapsed time: 0.123 (s)
-elapsed time: 0.116 (s)
-elapsed time: 0.0707 (s)
-elapsed time: 0.0823 (s)
-elapsed time: 0.0784 (s)
-elapsed time: 0.0981 (s)
-elapsed time: 1.01 (s)
Running kernel thinning  experiment with template results_new/kt-coresets-diag_mog_comp4_seed1234567-splitbspline0.6666666666666666_rt_var4.000_seed9876543-swapbspline_var4.000-d2-m0-del

In [9]:
if isnotebook():
    print(mmds_kt.mean(1), mmds_kt_krt.mean(1), mmds_ktplus.mean(1))

[0.         0.46281486 0.32777102 0.21444014 0.13400115] [0.         0.22355672 0.12696227 0.07392604 0.03372447] [0.         0.22355672 0.12827152 0.07413888 0.03622624]


# Save MMD and fun diff results

In [33]:
#
# Save all combined results
#
if isnotebook():
    # change this code to save results manually when running notebook
    save_combined_results = True #True if args is None else args.save_combined_results
else:
    save_combined_results = False if args is None else args.save_combined_results

generic_prefixes = ["-combinedmmd-", "-sin-combinedmmd-", "-combinedfundiff-", "-sin-combinedfundiff-"]

if save_combined_results:
    
    if args.stdthin:
        prefixes = ["mc" + prefix for prefix in generic_prefixes]
        data_arrays = [mmds_st, mmds_st_sin, fun_diff_st, fun_diff_st_sin]
        for prefix, data_array in zip(prefixes, data_arrays):
            filename = get_combined_results_filename(prefix, ms, params_p, params_k, params_k, rep_ids, delta)
            with open(filename, 'wb') as file:
                print(f"Saving {prefix} to {filename}")
                pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)
    

    if args.targetkt:
        prefixes = ["kt" + prefix for prefix in generic_prefixes]
        data_arrays = [mmds_kt, mmds_kt_sin, fun_diff_kt, fun_diff_kt_sin]
        for prefix, data_array in zip(prefixes, data_arrays):
            filename = get_combined_results_filename(prefix, ms, params_p, params_k_split=params_k, params_k_swap=params_k, rep_ids=rep_ids, delta=delta)
            with open(filename, 'wb') as file:
                print(f"Saving {prefix} to {filename}")
                pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)


    if args.powerkt:
        temp = "kt_krt" if args.power == 0.5 else f"kt_power{args.power}"
        prefixes = [temp + prefix for prefix in generic_prefixes]
        data_arrays = [mmds_kt_krt, mmds_kt_krt_sin, fun_diff_kt_krt, fun_diff_kt_krt_sin]
        for prefix, data_array in zip(prefixes, data_arrays):
            filename = get_combined_results_filename(prefix, ms, params_p, params_k_split=params_k_power, params_k_swap=params_k, rep_ids=rep_ids, delta=delta)
            with open(filename, 'wb') as file:
                print(f"Saving {prefix} to {filename}")
                pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)
                
    if args.ktplus:
        prefixes = [f"kt-plus{args.power}" + prefix for prefix in generic_prefixes]
        data_arrays = [mmds_ktplus, mmds_ktplus_sin, fun_diff_ktplus, fun_diff_ktplus_sin]
        for prefix, data_array in zip(prefixes, data_arrays):
            filename = get_combined_results_filename(prefix, ms, params_p, params_k_split=params_k_combo, params_k_swap=params_k, rep_ids=rep_ids, delta=delta)
            with open(filename, 'wb') as file:
                print(f"Saving {prefix} to {filename}")
                pkl.dump(data_array, file, protocol=pkl.HIGHEST_PROTOCOL)
                
         

Saving kt-combinedmmd- to results_new/combined/kt-combinedmmd--gauss_var1.0_seed1234567--split_imq_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl
Saving kt-sin-combinedmmd- to results_new/combined/kt-sin-combinedmmd--gauss_var1.0_seed1234567--split_imq_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl
Saving kt-combinedfundiff- to results_new/combined/kt-combinedfundiff--gauss_var1.0_seed1234567--split_imq_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl
Saving kt-sin-combinedfundiff- to results_new/combined/kt-sin-combinedfundiff--gauss_var1.0_seed1234567--split_imq_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl
Saving kt_krt-combinedmmd- to results_new/combined/kt_krt-combinedmmd--gauss_var1.0_seed1234567--split_imq_rt_var4.0_seed9876543_nu0.5-swap_imq_var4.0_nu0.5-d2-m4-delta0.5-rep2.pkl
Saving kt_krt-sin-combinedmmd- to results_new/combined/kt_krt-sin-combinedmmd--gauss_var1.0_seed1234567--sp