In [1]:
import os

import numpy as np
import pandas as pd
import scipy as sp
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_theme(style='ticks', font_scale=1.5)
mpl.rcParams["text.usetex"] = True
mpl.rcParams["mathtext.fontset"] = 'cm'
mpl.rcParams['font.family'] = ['sans-serif']

if not os.path.exists('../figures'):
    os.makedirs('../figures')

In [2]:
solvers = ['Nelder-Mead', 'L-BFGS-B', 'TNC', 'SLSQP', 'Powell', 'trust-constr', 'COBYLA', 'COBYQA']

## 2-arm

In [3]:
values = np.load('../data/2arm/values.npy')
pis = np.exp(values) / np.sum(np.exp(values), axis=-1, keepdims=True)

for vtag in ['cvx', 'cvx_truc', 'mc']:
    htvalues = np.load(f'../outputs/2arm/htvalues_{vtag}.npy')
    htpis = np.exp(htvalues) / np.sum(np.exp(htvalues), axis=-1, keepdims=True)
    kl = np.mean(np.sum(sp.special.kl_div(pis, htpis), axis=-1), axis=-1)
    print(f'{vtag}: {np.median(kl)} ({np.quantile(kl, 0.25)}-{np.quantile(kl, 0.75)})')
print('\n')

for solver in solvers:
    print(solver)
    s_tag = ''.join(solver.split('-')).lower()
    for vtag in [f'cvx_{s_tag}', f'cvx_truc_{s_tag}', f'direct_{s_tag}']:
        htvalues = np.load(f'../outputs/2arm/htvalues_{vtag}.npy')
        htpis = np.exp(htvalues) / np.sum(np.exp(htvalues), axis=-1, keepdims=True)
        kl = np.mean(np.sum(sp.special.kl_div(pis, htpis), axis=-1), axis=-1)
        print(f'{vtag}: {np.median(kl)} ({np.quantile(kl, 0.25)}-{np.quantile(kl, 0.75)})')

cvx: 0.013767522743656956 (0.008119858485275498-0.024228614944488337)
cvx_truc: 0.011329655825298244 (0.00655965427227584-0.023375244213780587)
mc: 0.006152892848866383 (0.003200409338230091-0.010610909819992482)


Nelder-Mead
cvx_neldermead: 0.011589707177244935 (0.0062473795665774065-0.02050473703489347)
cvx_truc_neldermead: 0.011955115325366996 (0.005858795534808449-0.025320373047068436)
direct_neldermead: 0.014058021883983345 (0.006836548917296422-0.024186330740267967)
L-BFGS-B
cvx_lbfgsb: 0.01155454706333826 (0.006202031936138126-0.02023855848244677)
cvx_truc_lbfgsb: 0.011955008766158851 (0.005858506281465259-0.02532005394425112)
direct_lbfgsb: 0.014155827792720916 (0.006959489836341849-0.024706374848310002)
TNC
cvx_tnc: 0.011554584082414839 (0.006202028238093751-0.020346636317331655)
cvx_truc_tnc: 0.011955024599833965 (0.005858533233261652-0.025319930237902574)
direct_tnc: 0.013965933313387848 (0.006806991247767034-0.024277962069386153)
SLSQP
cvx_slsqp: 0.011555497332841103 (0.00

In [4]:
alphas = np.load('../data/2arm/alphas.npy')

for tag in ['mc']:
    htalphas = np.load(f'../outputs/2arm/htalphas_{tag}.npy')
    alpha_res = np.sqrt(np.sum((alphas - htalphas) ** 2, axis=-1))
    print(f'{tag}: {np.median(alpha_res)} ({np.quantile(alpha_res, 0.25)}-{np.quantile(alpha_res, 0.75)})')
print('\n')

for solver in solvers:
    print(solver)
    s_tag = ''.join(solver.split('-')).lower()
    for tag in [f'cvx_{s_tag}', f'cvx_truc_{s_tag}', f'direct_{s_tag}']:
        htalphas = np.load(f'../outputs/2arm/htalphas_{tag}.npy')
        alpha_res = np.sqrt(np.sum((alphas - htalphas) ** 2, axis=-1))
        print(f'{tag}: {np.median(alpha_res)} ({np.quantile(alpha_res, 0.25)}-{np.quantile(alpha_res, 0.75)})')

mc: 0.22335357305490813 (0.1376816183439598-0.34097474511832626)


Nelder-Mead
cvx_neldermead: 0.2968979486838254 (0.16584467231761454-0.48329342244070594)
cvx_truc_neldermead: 0.28434200420641986 (0.16563803557543907-0.4273618612956912)
direct_neldermead: 0.3703750067980357 (0.2178892786227986-0.6087172830785499)
L-BFGS-B
cvx_lbfgsb: 0.2934422683155258 (0.16332991907118366-0.4682940473853815)
cvx_truc_lbfgsb: 0.284337071285686 (0.16494215220210945-0.4278312798891608)
direct_lbfgsb: 0.3805333839226913 (0.22053445109634617-0.6234053924129284)
TNC
cvx_tnc: 0.29275739899113107 (0.16274535742860397-0.46462964035670623)
cvx_truc_tnc: 0.28433706505205425 (0.16494211290722363-0.42783202313892216)
direct_tnc: 0.36801837295748463 (0.2113233075658593-0.6051569777697059)
SLSQP
cvx_slsqp: 0.29273549797392123 (0.16273261796063668-0.46761787210128125)
cvx_truc_slsqp: 0.2843373247623523 (0.16456551549459109-0.4273602067504178)
direct_slsqp: 0.38363752116412364 (0.22051477814505277-0.6162424234370794)

In [5]:
betas = np.load('../data/2arm/betas.npy')

for tag in ['mc']:
    htbetas = np.load(f'../outputs/2arm/htbetas_{tag}.npy')
    beta_res = np.sqrt(np.sum((betas - htbetas) ** 2, axis=-1))
    print(f'{tag}: {np.median(beta_res)} ({np.quantile(beta_res, 0.25)}-{np.quantile(beta_res, 0.75)})')
print('\n')

for solver in solvers:
    print(solver)
    s_tag = ''.join(solver.split('-')).lower()
    for tag in [f'cvx_{s_tag}', f'cvx_truc_{s_tag}', f'direct_{s_tag}']:
        htbetas = np.load(f'../outputs/2arm/htbetas_{tag}.npy')
        beta_res = np.sqrt(np.sum((betas - htbetas) ** 2, axis=-1))
        print(f'{tag}: {np.median(beta_res)} ({np.quantile(beta_res, 0.25)}-{np.quantile(beta_res, 0.75)})')

mc: 0.6371320935520897 (0.399232748277047-0.9969182535048107)


Nelder-Mead
cvx_neldermead: 1.0466363254764626 (0.5955283020827813-1.7891990767020647)
cvx_truc_neldermead: 1.1247940560042282 (0.5897242501065096-2.239556978082648)
direct_neldermead: 0.9815320824096152 (0.5786744664275265-1.735647166564007)
L-BFGS-B
cvx_lbfgsb: 1.0569774213292207 (0.6048593816441822-1.7953967956203274)
cvx_truc_lbfgsb: 1.019079694089032 (0.5455070451282276-1.8695327120258352)
direct_lbfgsb: 1.0514168169975786 (0.6183701880199575-1.8063492508613055)
TNC
cvx_tnc: 1.056978295057739 (0.6002373423859528-1.7953152875056158)
cvx_truc_tnc: 1.0655628529680965 (0.5810181327827494-1.8859552039969234)
direct_tnc: 1.0396153336152025 (0.5877881506123179-1.748320629407159)
SLSQP
cvx_slsqp: 1.061569861429752 (0.6005807550034029-1.7953992622272281)
cvx_truc_slsqp: 1.0435828209982372 (0.5572326433698126-1.8815718665487302)
direct_slsqp: 1.0479605715850666 (0.6059298332488928-1.8494149985974315)
Powell
cvx_powell: 1.061559

In [6]:
for logf in ['log_cvx', 'log_cvx_truc', 'log_mc']:
    df = pd.read_csv(f'../outputs/2arm/{logf}.csv')
    df['time'] *= 1000
    print(f'{logf}: {df['time'].describe()['50%']} ({df['time'].describe()['25%']}-{df['time'].describe()['75%']})')
print('\n')

for solver in solvers:
    print(solver)
    s_tag = ''.join(solver.split('-')).lower()
    for logf in [f'log_cvx_{s_tag}', f'log_cvx_truc_{s_tag}', f'log_direct_{s_tag}']:
        df = pd.read_csv(f'../outputs/2arm/{logf}.csv')
        if 'cvx' in logf:
            df['time'] = df['s1_time'] + df['s2_time']
        df['time'] *= 1000
        print(f'{logf}: {df['time'].describe()['50%']} ({df['time'].describe()['25%']}-{df['time'].describe()['75%']})')

log_cvx: 79.86472499999999 (64.25272625-96.57589425)
log_cvx_truc: 4.5175775 (3.87876275-5.302203499999999)
log_mc: 1811.1768960952759 (1740.8093810081482-1901.2468457221985)


Nelder-Mead
log_cvx_neldermead: 107.75412723864746 (90.46300977124022-124.72376582772826)
log_cvx_truc_neldermead: 10.40316427020255 (9.43296603955075-11.414138262756275)
log_direct_neldermead: 480.58295249938965 (409.71153974533075-559.2297911643982)
L-BFGS-B
log_cvx_lbfgsb: 157.193884507019 (137.57413367245474-175.93290631158442)
log_cvx_truc_lbfgsb: 74.40443370849606 (71.01581642031857-76.58539116122427)
log_direct_lbfgsb: 142.78721809387204 (118.2436347007751-178.5556077957153)
TNC
log_cvx_tnc: 117.03320077917473 (100.82175480834957-134.58015966531372)
log_cvx_truc_tnc: 15.286396059204051 (13.872232747619599-17.05042442868035)
log_direct_tnc: 926.6899824142456 (855.2671670913696-1154.2896628379822)
SLSQP
log_cvx_slsqp: 93.4010716881713 (76.89077953616331-110.26486112307735)
log_cvx_truc_slsqp: 10.30797966009

## 10-arm

In [7]:
values = np.load('../data/10arm/values.npy')
pis = np.exp(values) / np.sum(np.exp(values), axis=-1, keepdims=True)

for vtag in ['cvx', 'cvx_truc', 'mc']:
    htvalues = np.load(f'../outputs/10arm/htvalues_{vtag}.npy')
    htpis = np.exp(htvalues) / np.sum(np.exp(htvalues), axis=-1, keepdims=True)
    kl = np.mean(np.sum(sp.special.kl_div(pis, htpis), axis=-1), axis=-1)
    print(f'{vtag}: {np.median(kl)} ({np.quantile(kl, 0.25)}-{np.quantile(kl, 0.75)})')
print('\n')

for solver in solvers:
    print(solver)
    s_tag = ''.join(solver.split('-')).lower()
    for vtag in [f'cvx_{s_tag}', f'cvx_truc_{s_tag}', f'direct_{s_tag}']:
        htvalues = np.load(f'../outputs/10arm/htvalues_{vtag}.npy')
        htpis = np.exp(htvalues) / np.sum(np.exp(htvalues), axis=-1, keepdims=True)
        kl = np.mean(np.sum(sp.special.kl_div(pis, htpis), axis=-1), axis=-1)
        print(f'{vtag}: {np.median(kl)} ({np.quantile(kl, 0.25)}-{np.quantile(kl, 0.75)})')

cvx: 0.09411877885120183 (0.04266488891474052-0.17699980016115657)
cvx_truc: 0.08301200452920424 (0.03855678852447575-0.16801155762439646)
mc: 0.0113275282059967 (0.00554408211667462-0.018475336386460504)


Nelder-Mead
cvx_neldermead: 0.033954425452211 (0.0170796689704529-0.05911401930462007)
cvx_truc_neldermead: 0.036560199833521684 (0.017026697359701033-0.06790875916322475)
direct_neldermead: 0.03334294662597885 (0.015970262690265465-0.05395930266761721)
L-BFGS-B
cvx_lbfgsb: 0.03385203171768979 (0.01707986851623123-0.05887616047643358)
cvx_truc_lbfgsb: 0.036431226840524446 (0.01702684624173995-0.06790730889152853)
direct_lbfgsb: 0.03376573493147783 (0.0158153191407922-0.05442629269357478)
TNC
cvx_tnc: 0.033852029302570336 (0.017079867190948533-0.05887616160504534)
cvx_truc_tnc: 0.036431136255236994 (0.01702684600080932-0.0679073072895888)
direct_tnc: 0.033423133684103115 (0.015815282317235523-0.05519352549475217)
SLSQP
cvx_slsqp: 0.03385202437943062 (0.017081053364559548-0.0588760450

In [8]:
alphas = np.load('../data/10arm/alphas.npy')

for tag in ['mc']:
    htalphas = np.load(f'../outputs/10arm/htalphas_{tag}.npy')
    alpha_res = np.sqrt(np.sum((alphas - htalphas) ** 2, axis=-1))
    print(f'{tag}: {np.median(alpha_res)} ({np.quantile(alpha_res, 0.25)}-{np.quantile(alpha_res, 0.75)})')
print('\n')

for solver in solvers:
    print(solver)
    s_tag = ''.join(solver.split('-')).lower()
    for tag in [f'cvx_{s_tag}', f'cvx_truc_{s_tag}', f'direct_{s_tag}']:
        htalphas = np.load(f'../outputs/10arm/htalphas_{tag}.npy')
        alpha_res = np.sqrt(np.sum((alphas - htalphas) ** 2, axis=-1))
        print(f'{tag}: {np.median(alpha_res)} ({np.quantile(alpha_res, 0.25)}-{np.quantile(alpha_res, 0.75)})')

mc: 0.7124816772804101 (0.5849006520110306-0.8294858948900082)


Nelder-Mead
cvx_neldermead: 1.1587079879424609 (0.9060780508346171-1.4488581179738957)
cvx_truc_neldermead: 1.080480090952117 (0.8525242435962093-1.3725272503473338)
direct_neldermead: 1.3376057491803603 (1.133675327360376-1.5253100410634013)
L-BFGS-B
cvx_lbfgsb: 1.159298090165855 (0.9059388929233246-1.448858550488687)
cvx_truc_lbfgsb: 1.0804798734592125 (0.8525271402609037-1.3725386750796313)
direct_lbfgsb: 1.1905213155844705 (1.0170031435959124-1.3654268370569014)
TNC
cvx_tnc: 1.158704905608528 (0.9061720490493667-1.4487865520648615)
cvx_truc_tnc: 1.0804798835953677 (0.8525272593821642-1.3725386812003695)
direct_tnc: 1.2052858007355298 (1.0220665227636332-1.365174154002582)
SLSQP
cvx_slsqp: 1.158629094855157 (0.906052332259826-1.4482677727489803)
cvx_truc_slsqp: 1.0804778506272936 (0.8525278094678513-1.3725382089086098)
direct_slsqp: 1.1788212284320707 (1.0125543713448386-1.3663068208658582)
Powell
cvx_powell: 1.1586791

In [9]:
betas = np.load('../data/10arm/betas.npy')

for tag in ['mc']:
    htbetas = np.load(f'../outputs/10arm/htbetas_{tag}.npy')
    beta_res = np.sqrt(np.sum((betas - htbetas) ** 2, axis=-1))
    print(f'{tag}: {np.median(beta_res)} ({np.quantile(beta_res, 0.25)}-{np.quantile(beta_res, 0.75)})')
print('\n')

for solver in solvers:
    print(solver)
    s_tag = ''.join(solver.split('-')).lower()
    for tag in [f'cvx_{s_tag}', f'cvx_truc_{s_tag}', f'direct_{s_tag}']:
        htbetas = np.load(f'../outputs/10arm/htbetas_{tag}.npy')
        beta_res = np.sqrt(np.sum((betas - htbetas) ** 2, axis=-1))
        print(f'{tag}: {np.median(beta_res)} ({np.quantile(beta_res, 0.25)}-{np.quantile(beta_res, 0.75)})')

mc: 4.213815395457319 (3.693010702314796-4.667618650499081)


Nelder-Mead
cvx_neldermead: 8.173948620627904 (7.1810476552128115-8.999709628968457)
cvx_truc_neldermead: 7.917314318693062 (7.002407774493868-8.88390487870348)
direct_neldermead: 7.659895616286741 (6.679666687778061-8.518150458746256)
L-BFGS-B
cvx_lbfgsb: 8.175045823903636 (7.181944175108767-9.020387559359365)
cvx_truc_lbfgsb: 8.08445891132033 (7.138491397537776-8.975347059436217)
direct_lbfgsb: 6.653706396370258 (5.690833903051983-7.550870975462506)
TNC
cvx_tnc: 8.152551255697457 (7.13432547293712-8.979627462716294)
cvx_truc_tnc: 7.9554291236821335 (7.00175861497876-8.849093568893071)
direct_tnc: 6.416092285821627 (5.559535445729248-7.381353966717488)
SLSQP
cvx_slsqp: 8.165471913259191 (7.181052640411231-9.016392861343903)
cvx_truc_slsqp: 8.083626846499733 (7.13263784325836-8.964588069708057)
direct_slsqp: 6.641618986169201 (5.78496530862119-7.498177908180278)
Powell
cvx_powell: 8.174980059915578 (7.183984044542628-9.02027

In [10]:
for logf in ['log_cvx', 'log_cvx_truc', 'log_mc']:
    df = pd.read_csv(f'../outputs/10arm/{logf}.csv')
    df['time'] *= 1000
    print(f'{logf}: {df['time'].describe()['50%']} ({df['time'].describe()['25%']}-{df['time'].describe()['75%']})')
print('\n')

for solver in solvers:
    print(solver)
    s_tag = ''.join(solver.split('-')).lower()
    for logf in [f'log_cvx_{s_tag}', f'log_cvx_truc_{s_tag}', f'log_direct_{s_tag}']:
        df = pd.read_csv(f'../outputs/10arm/{logf}.csv')
        if 'cvx' in logf:
            df['time'] = df['s1_time'] + df['s2_time']
        df['time'] *= 1000
        print(f'{logf}: {df['time'].describe()['50%']} ({df['time'].describe()['25%']}-{df['time'].describe()['75%']})')

log_cvx: 171.424039 (119.70571675-408.1179305)
log_cvx_truc: 28.9950975 (25.858091499999997-31.8265425)
log_mc: 3827.5080919265743 (3572.5343227386475-4111.123204231262)


Nelder-Mead
log_cvx_neldermead: 200.587754329956 (142.46140349578852-449.8156818457642)
log_cvx_truc_neldermead: 35.32748828143305 (32.29413596093745-38.803422256042474)
log_direct_neldermead: 7594.64967250824 (7194.132566452026-7646.17383480072)
L-BFGS-B
log_cvx_lbfgsb: 343.03699729602044 (224.2196249633789-634.2948731660156)
log_cvx_truc_lbfgsb: 236.87006595758055 (146.6072930765991-328.57389763937374)
log_direct_lbfgsb: 1425.4329204559326 (761.5240812301636-2283.9733958244324)
TNC
log_cvx_tnc: 213.1529572905273 (152.43665258508298-458.65021087619016)
log_cvx_truc_tnc: 41.83787055804435 (38.431758891540454-45.48371371713255)
log_direct_tnc: 7625.915288925171 (5205.575466156006-7871.224641799927)
SLSQP
log_cvx_slsqp: 189.6738939714355 (134.09710971859738-443.192450720642)
log_cvx_truc_slsqp: 34.43520172045895 (31.60