# Testing EagerPy implementation of AMLS on toy model

## 1. Imports

In [1]:
import numpy as np 
from time import time
import scipy.stats as stat
import matplotlib.pyplot as plt
import pandas as pd
import os
import argparse
from tqdm import tqdm
from importlib import reload
import torch
from stat_reliability_measure.home import ROOT_DIR
from datetime import datetime
from stat_reliability_measure.dev.utils import  float_to_file_float,str2bool,str2intList,str2floatList

In [2]:
from stat_reliability_measure.dev.amls import amls_ep
from stat_reliability_measure.dev.amls import amls_pyt

In [3]:
reload(amls_ep)

<module 'stat_reliability_measure.dev.amls.amls_ep' from '/home/karim-tito/stat_reliability_measure/dev/amls/amls_ep.py'>

## 2. Config

In [16]:
method_name = "amls_ep"
class config:
    n_rep=200
    verbose=0
    min_rate=0.40
    clip_s=True
    s_min=8e-3
    s_max=3
    n_max=2000
    decay=0.95
    gain_rate=1.0001
    allow_zero_est=True
    
    N=100
    N_range=[]

    T=10
    T_range=[]

    ratio=0.5
    ratio_range=[]

    s=1
    s_range= []

    p_t=1e-7
    p_range=[]
    
    d = 1024
    epsilon = 1
    
    
    tqdm_opt=True
    save_config = True
    print_config=True
    update_aggr_res=False
    aggr_res_path = None

    track_accept=False
    track_finish=True
    device = None

    torch_seed=0
    np_seed=0

    log_dir=ROOT_DIR+"/logs/linear_gaussian_tests"
    batch_opt=True
    allow_multi_gpu=True
    track_gpu=False
    track_cpu=False
    core_numbers=None
    gpu_name=None 
    cpu_name=None
    cores_number=None
    correct_T=False
    last_particle=False

In [17]:
nb_runs=1
if len(config.N_range)==0:
    config.N_range=[config.N]
nb_runs*=len(config.N_range)
if len(config.T_range)==0:
    config.T_range=[config.T]
nb_runs*=len(config.T_range)
if len(config.ratio_range)==0:
    config.ratio_range=[config.ratio]
nb_runs*=len(config.ratio_range)
if len(config.s_range)==0:
    config.s_range=[config.s]
nb_runs*=len(config.s_range)
if len(config.p_range)==0:
    config.p_range=[config.p_t]
nb_runs*=len(config.p_range)

if config.device is None:
    device= 'cuda:0' if torch.cuda.is_available() else 'cpu'


if not config.allow_multi_gpu:
    os.environ["CUDA_VISIBLE_DEVICES"]="0"

if config.track_gpu:
    import GPUtil
    gpus=GPUtil.getGPUs()
    if len(gpus)>1:
        print("Multi gpus detected, only the first GPU will be tracked.")
    config.gpu_name=gpus[0].name

if config.track_cpu:
    import cpuinfo
    config.cpu_name=cpuinfo.get_cpu_info()[[key for key in cpuinfo.get_cpu_info().keys() if 'brand' in key][0]]
    config.cores_number=os.cpu_count()


epsilon=config.epsilon
d=config.d

if not os.path.exists(ROOT_DIR+'/logs'):
    os.mkdir(ROOT_DIR+'/logs')
    os.mkdir(config.log_dir)
elif not os.path.exists(config.log_dir):
    os.mkdir(config.log_dir)
raw_logs=os.path.join(config.log_dir,'raw_logs/')
if not os.path.exists(raw_logs):
    os.mkdir(raw_logs)
raw_logs_path=os.path.join(raw_logs,method_name)
if not os.path.exists(raw_logs_path):
    os.mkdir(raw_logs_path)

loc_time= datetime.today().isoformat().split('.')[0]

exp_log_path=os.path.join(config.log_dir,method_name+'_t_'+loc_time.split('_')[0])
os.mkdir(exp_log_path)
exp_res = []


## 3. Test of AMLS

In [22]:
epsilon=config.epsilon
e_1 = torch.Tensor([1]+[0]*(d-1),device=config.device)
get_c_norm= lambda p:stat.norm.isf(p)
i_run=0
config.n_rep = 200


In [23]:
for p_t in config.p_range:
    c=get_c_norm(p_t)
    P_target=stat.norm.sf(c)
    if config.verbose>=5:
        print(f"P_target:{P_target}")
    arbitrary_thresh=40 #pretty useless a priori but should not hurt results
    def v_batch_pyt(X,c=c):
        return torch.clamp(input=c-X[:,0],min=-arbitrary_thresh, max = None)
    def v_batch_ep(X, c=c):
        return (c-X[:,0]).clip(min_ = -arbitrary_thresh, max_ = None)
    amls_gen = lambda N: torch.randn(size=(N,d),device=config.device)
    normal_kernel =  lambda x,s : (x + s*torch.randn(size = x.shape,device=config.device))/np.sqrt(1+s**2) #normal law kernel, appliable to vectors 
    normal_kernel_ep = lambda x,s : (x + s*x.normal(shape=x.shape))/np.sqrt(1+s**2)
    h_V_batch_pyt= lambda x: -v_batch_pyt(x).reshape((x.shape[0],1))
    h_V_batch_ep= lambda x: -v_batch_ep(x).reshape((x.shape[0],1))



In [24]:
for T in config.T_range:
        for N in config.N_range: 
            for s in config.s_range:
                for ratio in config.ratio_range: 
                    loc_time= datetime.today().isoformat().split('.')[0]
                    log_name=method_name+f'_N_{N}_T_{T}_s_{float_to_file_float(s)}_r_{float_to_file_float(ratio)}_t_'+'_'+loc_time.split('_')[0]
                    log_path=os.path.join(exp_log_path,log_name)
                    os.mkdir(path=log_path)
                    i_run+=1
                    
                    
                    K=int(N*ratio) if not config.last_particle else N-1
                    print(f"Starting {method_name} run {i_run}/{nb_runs}, with p_t= {p_t},N={N},K={K},T={T},s={s}")
                    if config.verbose>3:
                        print(f"K/N:{K/N}")
                    times= []
                    rel_error= []
                    ests = [] 
                    calls=[]
                    if config.track_finish:
                        finish_flags=[]
                    for i in tqdm(range(config.n_rep)):
                        t=time()
                        if config.batch_opt:
                            amls_res=amls_ep.ImportanceSplittingPytBatch(amls_gen, normal_kernel_ep,K=K, N=N,s=s,  h=h_V_batch_ep, 
                        tau=1e-15 , n_max=config.n_max,clip_s=config.clip_s , T=T,
                        s_min= config.s_min, s_max =config.s_max,verbose= config.verbose,
                        device=config.device,track_accept=config.track_accept)

                        else:
                            amls_res = amls_ep.ImportanceSplittingPyt(amls_gen, normal_kernel_ep,K=K, N=N,s=s,  h=h_V_batch_ep, 
                        tau=0 , n_max=config.n_max,clip_s=config.clip_s , T=T,
                        s_min= config.s_min, s_max =config.s_max,verbose= config.verbose,
                        device=config.device, )
                        t=time()-t
                        est=amls_res[0]
                        
                        dict_out=amls_res[1]
                        times.append(t)
                        ests.append(est)
                        calls.append(dict_out['Count_h'])
                    times=np.array(times)  
                    ests=np.array(ests)
                    abs_errors=np.abs(ests-p_t)
                    rel_errors=abs_errors/p_t
                    bias=np.mean(ests)-p_t
                    calls=np.array(calls)
                    errs=np.abs(ests-p_t)
                    q_1,med_est,q_3=np.quantile(a=ests,q=[0.25,0.5,0.75])
                    mean_calls=calls.mean()
                    std_calls=calls.std()
                    MSE=np.mean(abs_errors**2)
                    MSE_adj=MSE*mean_calls
                    MSE_rel=MSE/p_t**2
                    MSE_rel_adj=MSE_rel*mean_calls
                    
                    print(f"mean est:{ests.mean()}, std est:{ests.std()}")
                    print(f"mean rel error:{rel_errors.mean()}")
                    print(f"MSE rel:{MSE/p_t**2}")
                    print(f"MSE adj.:{MSE_adj}")
                    print(f"MSE rel. adj.:{MSE_rel_adj}")
                    print(f"mean calls:{calls.mean()}")

  0%|          | 1/200 [00:00<00:30,  6.55it/s]

Starting {method_name} run 1/1, with p_t= 1e-07,N=100,K=50,T=10,s=1


100%|██████████| 200/200 [00:26<00:00,  7.43it/s]

mean est:9.85465943813324e-08, std est:6.679141662058736e-08
mean rel error:0.5135169434547424
MSE rel:0.4463205722077249
MSE adj.:5.2040978719420716e-11
MSE rel. adj.:5204.097871942073
mean calls:11660.0





In [25]:
for T in config.T_range:
        for N in config.N_range: 
            for s in config.s_range:
                for ratio in config.ratio_range: 
                    loc_time= datetime.today().isoformat().split('.')[0]
                    log_name=method_name+f'_N_{N}_T_{T}_s_{float_to_file_float(s)}_r_{float_to_file_float(ratio)}_t_'+'_'+loc_time.split('_')[0]
                    log_path=os.path.join(exp_log_path,log_name)
                    os.mkdir(path=log_path)
                    i_run+=1
                    
                    
                    K=int(N*ratio) if not config.last_particle else N-1
                    print(f"Starting {method_name} run {i_run}/{nb_runs}, with p_t= {p_t},N={N},K={K},T={T},s={s}")
                    if config.verbose>3:
                        print(f"K/N:{K/N}")
                    times= []
                    rel_error= []
                    ests = [] 
                    calls=[]
                    if config.track_finish:
                        finish_flags=[]
                    for i in tqdm(range(config.n_rep)):
                        t=time()
                        if config.batch_opt:
                            amls_res=amls_pyt.ImportanceSplittingPytBatch(amls_gen, normal_kernel,K=K, N=N,s=s,  h=h_V_batch_pyt, 
                        tau=1e-15 , n_max=config.n_max,clip_s=config.clip_s , T=T,
                        s_min= config.s_min, s_max =config.s_max,verbose= config.verbose,
                        device=config.device,track_accept=config.track_accept)

                        else:
                            amls_res = amls_pyt.ImportanceSplittingPyt(amls_gen, normal_kernel,K=K, N=N,s=s,  h=h_V_batch_pyt, 
                        tau=0 , n_max=config.n_max,clip_s=config.clip_s , T=T,
                        s_min= config.s_min, s_max =config.s_max,verbose= config.verbose,
                        device=config.device, )
                        t=time()-t
                        est=amls_res[0]
                        
                        dict_out=amls_res[1]
                        times.append(t)
                        ests.append(est)
                        calls.append(dict_out['Count_h'])
                    times=np.array(times)  
                    ests=np.array(ests)
                    abs_errors=np.abs(ests-p_t)
                    rel_errors=abs_errors/p_t
                    bias=np.mean(ests)-p_t
                    calls=np.array(calls)
                    errs=np.abs(ests-p_t)
                    q_1,med_est,q_3=np.quantile(a=ests,q=[0.25,0.5,0.75])
                    mean_calls=calls.mean()
                    std_calls=calls.std()
                    MSE=np.mean(abs_errors**2)
                    MSE_adj=MSE*mean_calls
                    MSE_rel=MSE/p_t**2
                    MSE_rel_adj=MSE_rel*mean_calls
                    
                    print(f"mean est:{ests.mean()}, std est:{ests.std()}")
                    print(f"mean rel error:{rel_errors.mean()}")
                    print(f"MSE rel:{MSE/p_t**2}")
                    print(f"MSE adj.:{MSE_adj}")
                    print(f"MSE rel. adj.:{MSE_rel_adj}")
                    print(f"mean calls:{calls.mean()}")

  0%|          | 1/200 [00:00<00:33,  5.89it/s]

Starting {method_name} run 2/1, with p_t= 1e-07,N=100,K=50,T=10,s=1


100%|██████████| 200/200 [00:23<00:00,  8.51it/s]

mean est:9.848549962043761e-08, std est:6.642025233599491e-08
mean rel error:0.5180104899406434
MSE rel:0.44139436317769315
MSE adj.:5.14886524646779e-11
MSE rel. adj.:5148.865246467791
mean calls:11665.0



