# Testing EagerPy implementation of H-SMC on toy model

## 1. Imports

In [5]:

import scipy.stats as stat
import numpy as np
import eagerpy as ep
from tqdm import tqdm
from time import time
from datetime import datetime
import os
import matplotlib.pyplot as plt
import torch
import GPUtil
import cpuinfo
import pandas as pd
import argparse
from stat_reliability_measure.dev.utils import str2bool,str2floatList,str2intList,float_to_file_float,dichotomic_search
from scipy.special import betainc
from importlib import reload
from stat_reliability_measure.home import ROOT_DIR

In [6]:
import stat_reliability_measure.dev.smc.smc_ep as smc_ep
import stat_reliability_measure.dev.ep_utils as e_u
import stat_reliability_measure.dev.smc.smc_pyt as smc_pyt
import stat_reliability_measure.dev.torch_utils as t_u

In [7]:
reload(smc_ep)
reload(e_u)

<module 'stat_reliability_measure.dev.ep_utils' from '/home/karim-tito/stat_reliability_measure/dev/ep_utils.py'>

## 2. Config

In [8]:
method_name="smc_ep"

In [10]:
class config:
    N=100
    N_range=[]
    T=1
    T_range=[]
    L=1
    L_range=[]
    min_rate=0.2
    
    alpha=0.2
    alpha_range=[]
    ess_alpha=0.8
    e_range=[]
    p_range=[]
    p_t=1e-6
    n_rep=10
    
    save_config=False 
    print_config=True
    d=1024
    verbose=0
    log_dir=ROOT_DIR+'/logs/linear_gaussian_tests'
    aggr_res_path = None
    update_agg_res=True
    sigma=1
    v1_kernel=True
    torch_seed=None
    gpu_name=None
    cpu_name=None
    cores_number=None
    track_gpu=False
    track_cpu=False
    device=None
    n_max=10000 
    allow_multi_gpu=False
    tqdm_opt=True
    allow_zero_est=True
    track_accept=True
    track_calls=False
    mh_opt=False
    adapt_dt=False
    adapt_dt_mcmc=False
    target_accept=0.574
    accept_spread=0.1
    dt_decay=0.999
    dt_gain=None
    dt_min=1e-3
    dt_max=0.5
    v_min_opt=False
    ess_opt=False
    only_duplicated=False
    np_seed=None
    lambda_0=0.5
    test2=False

    s_opt=False
    s=1
    clip_s=True
    s_min=1e-3
    s_max=3
    s_decay=0.95
    s_gain=1.0001

    track_dt=False
    mult_last=True
    linear=True

    track_ess=True
    track_beta=True
    track_dt=True
    track_v_means=True
    track_ratios=False

    kappa_opt=True

    adapt_func='ESS'
    M_opt = False
    adapt_step=True
    FT=True
    sig_dt=0.02
    L_min=1
    skip_mh=False
    GV_opt=False
   

In [18]:

if config.adapt_func.lower()=='simp_ess':
    adapt_func = lambda beta,v : smc_pyt.nextBetaSimpESS(beta_old=beta,v=v,lambda_0=config.lambda_0,max_beta=1e6)
prblm_str='linear_gaussian' if config.linear else 'gaussian'
if not config.linear:
    config.log_dir=config.log_dir.replace('linear_gaussian','gaussian')
if len(config.p_range)==0:
    config.p_range= [config.p_t]

if len(config.e_range)==0:
    config.e_range= [config.ess_alpha]


if len(config.N_range)==0:
    config.N_range= [config.N]


if len(config.T_range)==0:
    config.T_range= [config.T]

if len(config.L_range)==0:
    config.L_range= [config.L]
if len(config.alpha_range)==0:
    config.alpha_range= [config.alpha]


if not config.allow_multi_gpu:
    os.environ["CUDA_VISIBLE_DEVICES"]="0"



if config.torch_seed is None:
    config.torch_seed=int(time())
torch.manual_seed(seed=config.torch_seed)

if config.np_seed is None:
    config.np_seed=int(time())
torch.manual_seed(seed=config.np_seed)



if config.track_gpu:
    gpus=GPUtil.getGPUs()
    if len(gpus)>1:
        print("Multi gpus detected, only the first GPU will be tracked.")
    config.gpu_name=gpus[0].name

if config.track_cpu:
    config.cpu_name=cpuinfo.get_cpu_info()[[key for key in cpuinfo.get_cpu_info().keys() if 'brand' in key][0]]
    config.cores_number=os.cpu_count()


if config.device is None:
    config.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
    if config.verbose>=5:
        print(config.device)
    device=config.device
else:
    device=config.device

d=config.d
#epsilon=config.epsilon


if not os.path.exists(ROOT_DIR+'/logs'):
    os.mkdir(ROOT_DIR+'/logs')
if not os.path.exists(config.log_dir):
    os.mkdir(config.log_dir)

results_path=ROOT_DIR+'/logs/'+ prblm_str+'_tests/results.csv' 
if os.path.exists(results_path):
    results_g=pd.read_csv(results_path)
else:
    results_g=pd.DataFrame(columns=['p_t','mean_est','mean_time','mean_err','stdtime','std_est','T','N','rho','alpha','n_rep','min_rate','method'])
    results_g.to_csv(results_path,index=False)
raw_logs = os.path.join(config.log_dir,'raw_logs/')
if not os.path.exists(raw_logs):
    os.mkdir(raw_logs)
raw_logs_path=os.path.join(config.log_dir,'raw_logs/'+method_name)
if not os.path.exists(raw_logs_path):
    os.mkdir(raw_logs_path)

loc_time= datetime.today().isoformat().split('.')[0]
log_name=method_name+'_'+'_'+loc_time
exp_log_path=os.path.join(raw_logs_path,log_name)
if os.path.exists(exp_log_path):
    exp_log_path = exp_log_path +'_'+ str(np.random.randint(low=0,high=9))
os.mkdir(path=exp_log_path)

# if config.aggr_res_path is None:
#     aggr_res_path=os.path.join(config.log_dir,'aggr_res.csv')
# else:
#     aggr_res_path=config.aggr_res_path

if config.dt_gain is None:
    config.dt_gain=1/config.dt_decay
#if config.print_config:
    #print(config.json)

param_ranges = [config.N_range,config.T_range,config.alpha_range,config.p_range,config.L_range,config.e_range]
param_lens=np.array([len(l) for l in param_ranges])
nb_runs= np.prod(param_lens)

mh_str="adjusted" 
method=method_name+'_'+mh_str
save_every = 1
#adapt_func= smc_pyt.ESSAdaptBetaPyt if config.ess_opt else smc_pyt.SimpAdaptBetaPyt

kernel_str='v1_kernel' if config.v1_kernel else 'v2_kernel'


run_nb=0
iterator= tqdm(range(config.n_rep))
exp_res=[]

  0%|          | 0/10 [00:00<?, ?it/s]

### 3. Test H-SMC on linear toy model

In [19]:
assert config.linear

In [20]:
for p_t in config.p_range:
    if config.linear:
        
        get_c_norm= lambda p:stat.norm.isf(p)
        c=get_c_norm(p_t)
        if config.verbose>=1.:
            print(f'c:{c}')
        e_1= torch.Tensor([1]+[0]*(d-1)).to(device)
        V = lambda X: torch.clamp(input=c-X[:,0], min=0, max=None)
        V_ep = lambda X: (c-X[:,0]).clip(min_=0, max_ = None)
        
        gradV= lambda X: -((X[:,0]<c)[:,None]*e_1)
        gradV_ep = lambda X: -((X[:,0]<c)[:,None]*e_1)
        
        norm_gen = lambda N: torch.randn(size=(N,d)).to(device)
       
    else:
        epsilon=1
        p_target_f=lambda h: 0.5*betainc(0.5*(d-1),0.5,(2*epsilon*h-h**2)/(epsilon**2))
        h,P_target = dichotomic_search(f=p_target_f,a=0,b=epsilon,thresh=p_t,n_max=100)
        c=epsilon-h
        print(f'c:{c}',f'P_target:{P_target}')
        e_1= torch.Tensor([1]+[0]*(d-1)).to(device)
        V = lambda X: torch.clamp(input=torch.norm(X,p=2,dim=-1)*c-X[:,0], min=0, max=None)
        V_ep = lambda X: (ep.norms.l2(X,p=2,dim=-1)*c-X[:,0]).clip(min_=0, max_=None)
        gradV= lambda X: (c*X/torch.norm(X,p=2,dim=-1)[:,None] -e_1[None,:])*(X[:,0]<c*torch.norm(X,p=2,dim=1))[:,None]
        
        norm_gen = lambda N: torch.randn(size=(N,d)).to(device)

In [44]:
reload(e_u)
reload(smc_ep)

<module 'stat_reliability_measure.dev.smc.smc_ep' from '/home/karim-tito/stat_reliability_measure/dev/smc/smc_ep.py'>

In [45]:
for ess_t in config.e_range:
        if config.adapt_func.lower()=='ess':
            adapt_func = lambda beta,v : smc_ep.nextBetaESS(beta_old=beta,v=v,ess_alpha=ess_t,max_beta=1e6)
        for T in config.T_range:
            for L in config.L_range:
                for alpha in config.alpha_range:       
                    for N in config.N_range:
                        loc_time= datetime.today().isoformat().split('.')[0]
                        log_name=method_name+f'_N_{N}_T_{T}_L_{L}_a_{float_to_file_float(alpha)}_ess_{float_to_file_float(ess_t)}'+'_'+loc_time.split('_')[0]
                        log_path=os.path.join(exp_log_path,log_name)
                        if os.path.exists(log_path):
                            log_path=log_path+'_'+str(np.random.randint(low=0,high=10))
                        
                        
                        os.mkdir(path=log_path)
                        run_nb+=1
                        print(f'Run {run_nb}/{nb_runs}')
                        times=[]
                        ests = []
                        calls=[]
                        finished_flags=[]
                        iterator= tqdm(range(config.n_rep)) if config.tqdm_opt else range(config.n_rep)
                        print(f"Starting simulations with p_t:{p_t},ess_t:{ess_t},T:{T},alpha:{alpha},N:{N},L:{L}")
                        for i in iterator:
                            t=time()
                            sampler=smc_ep.SamplerSMC
                            p_est,res_dict,=sampler(gen=norm_gen,V= V,gradV=gradV,adapt_func=adapt_func,min_rate=config.min_rate,N=N,T=T,L=L,
                            alpha=alpha,n_max=10000,
                            verbose=config.verbose, track_accept=config.track_accept,track_beta=config.track_beta,track_v_means=config.track_v_means,
                            track_ratios=config.track_ratios,track_ess=config.track_ess,kappa_opt=config.kappa_opt
                            ,gaussian =True,accept_spread=config.accept_spread, 
                            adapt_dt=config.adapt_dt, dt_decay=config.dt_decay,only_duplicated=config.only_duplicated,
                            dt_gain=config.dt_gain,dt_min=config.dt_min,dt_max=config.dt_max,
                            v_min_opt=config.v_min_opt, lambda_0= config.lambda_0,
                            track_dt=config.track_dt,M_opt=config.M_opt,adapt_step=config.adapt_step,FT=config.FT,
                            sig_dt=config.sig_dt,L_min=config.L_min,skip_mh=config.skip_mh,
                            GV_opt=config.GV_opt
                            )
                            t1=time()-t

                            #print(p_est)
                  
                            
                            times.append(t1)
                            ests.append(p_est)
                            calls.append(res_dict['calls'])

  0%|          | 0/10 [00:00<?, ?it/s]

Run 11/1
Starting simulations with p_t:1e-06,ess_t:0.8,T:1,alpha:0.2,N:100,L:1
[ True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True  True  True  True  True  True  True  True  True
  True  True  True  True]





TypeError: only integer scalar arrays can be converted to a scalar index

In [45]:
for ess_t in config.e_range:
        if config.adapt_func.lower()=='ess':
            adapt_func = lambda beta,v : smc_pyt.nextBetaESS(beta_old=beta,v=v,ess_alpha=ess_t,max_beta=1e6)
        for T in config.T_range:
            for L in config.L_range:
                for alpha in config.alpha_range:       
                    for N in config.N_range:
                        loc_time= datetime.today().isoformat().split('.')[0]
                        log_name=method_name+f'_N_{N}_T_{T}_L_{L}_a_{float_to_file_float(alpha)}_ess_{float_to_file_float(ess_t)}'+'_'+loc_time.split('_')[0]
                        log_path=os.path.join(exp_log_path,log_name)
                        if os.path.exists(log_path):
                            log_path=log_path+'_'+str(np.random.randint(low=0,high=10))
                        
                        
                        os.mkdir(path=log_path)
                        run_nb+=1
                        print(f'Run {run_nb}/{nb_runs}')
                        times=[]
                        ests = []
                        calls=[]
                        finished_flags=[]
                        iterator= tqdm(range(config.n_rep)) if config.tqdm_opt else range(config.n_rep)
                        print(f"Starting simulations with p_t:{p_t},ess_t:{ess_t},T:{T},alpha:{alpha},N:{N},L:{L}")
                        for i in iterator:
                            t=time()
                            sampler=smc_pyt.SamplerSMC
                            p_est,res_dict,=sampler(gen=norm_gen,V= V,gradV=gradV,adapt_func=adapt_func,min_rate=config.min_rate,N=N,T=T,L=L,
                            alpha=alpha,n_max=10000,
                            verbose=config.verbose, track_accept=config.track_accept,track_beta=config.track_beta,track_v_means=config.track_v_means,
                            track_ratios=config.track_ratios,track_ess=config.track_ess,kappa_opt=config.kappa_opt
                            ,gaussian =True,accept_spread=config.accept_spread, 
                            adapt_dt=config.adapt_dt, dt_decay=config.dt_decay,only_duplicated=config.only_duplicated,
                            dt_gain=config.dt_gain,dt_min=config.dt_min,dt_max=config.dt_max,
                            v_min_opt=config.v_min_opt, lambda_0= config.lambda_0,
                            track_dt=config.track_dt,M_opt=config.M_opt,adapt_step=config.adapt_step,FT=config.FT,
                            sig_dt=config.sig_dt,L_min=config.L_min,skip_mh=config.skip_mh,
                            GV_opt=config.GV_opt
                            )
                            t1=time()-t

                            #print(p_est)
                  
                            
                            times.append(t1)
                            ests.append(p_est)
                            calls.append(res_dict['calls'])



Run 3/1
Starting simulations with p_t:1e-06,ess_t:0.8,T:1,alpha:0.2,N:100,L:1


100%|██████████| 10/10 [00:01<00:00,  6.71it/s]
