In [1]:
import hyperopt
from hyperopt import fmin, tpe, hp, space_eval

import pandas as pd
import numpy as np
from pyids import IDS
from pyids.algorithms import mine_CARs
from pyids.algorithms.ids_multiclass import IDSOneVsAll
from pyids.data_structures import IDSRuleSet

from pyarc.qcba.data_structures import QuantitativeDataFrame
from pyarc.data_structures import TransactionDB
from pyarc import CBA

import random
import logging
import time

import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
df = pd.read_csv("../data/iris0.csv")
quant_dataframe = QuantitativeDataFrame(df)


In [3]:

ids_multiclass = IDSOneVsAll(algorithm="RUSM")
ids_multiclass.fit(quant_dataframe, lambda_array=[1, 1, 0, 0, 100000, 10000, 1000000], rule_cutoff=200)
auc = ids_multiclass.score_auc(quant_dataframe)



In [4]:
def is_solution_interpretable(metrics):
    return (
        metrics["fraction_overlap"] <= 0.10 and
        metrics["fraction_classes"] == 1.0 and
        metrics["fraction_uncovered"] <= 0.15 and
        metrics["average_rule_width"] < 8 and
        metrics["ruleset_length"] <= 10
    )

In [5]:
def objective(args):
    lambda_array = list(args.values())
    print(lambda_array)
    
    ids_multiclass = IDSOneVsAll(algorithm="RUSM")
    ids_multiclass.fit(quant_dataframe, lambda_array=lambda_array, rule_cutoff=30)
    
    metrics = ids_multiclass.score_interpretability_metrics(quant_dataframe)
    
    if not is_solution_interpretable(metrics):
        return 0
    
    auc = ids_multiclass.score_auc(quant_dataframe)
    print("AUC", auc)

    return -auc

space = {
    "lambda1": hp.uniform("l1", 0, 1000),
    "lambda2": hp.uniform("l2", 0, 10000000),
    "lambda3": hp.uniform("l3", 0, 500),
    "lambda4": hp.uniform("l4", 0, 500),
    "lambda5": hp.uniform("l5", 0, 10000000),
    "lambda6": hp.uniform("l6", 0, 10000000),
    "lambda7": hp.uniform("l7", 0, 10000000)
}

best = fmin(objective, space, algo=tpe.suggest, max_evals=500)

print(best)

[292.40542823258585, 8558250.548718462, 8.533490940484, 135.69820408402106, 7301306.304354309, 585465.6252552149, 741571.0589130553]
  0%|                                                                          | 0/500 [00:00<?, ?trial/s, best loss=?]

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[210.35152142869327, 5605702.827383794, 372.48857826673265, 231.26660692287084, 4582144.260252676, 2283808.7438924527, 6482859.685696906]
[901.9783083466112, 1933613.1905032129, 477.28740142398846, 254.57561389473915, 6962877.353924441, 3041733.5613727192, 8759535.63126139]
[498.3469601781455, 5985058.206348775, 274.2015076689583, 428.19220363165016, 8416011.710204463, 3991191.2380608083, 2389703.8917295076]
[443.3342569825347, 8231990.9365935065, 487.7841341692449, 301.2300307058133, 4393479.11484199, 6633727.2242137715, 4613333.465376459]
[323.6479631255702, 1654043.078122418, 51.8356636254021, 413.4509111222296, 7876217.435816454, 6702244.502650456, 3553247.0251272386]
[418.7853535248185, 9589903.410622295, 180.31495973574025, 75.2259785488702, 8974218.021353954, 5989009.763281103, 1688089.616587809]
[994.2517901493234, 1533624.3575885245, 429.98165116767274, 177.194602876156, 8710352.984735278, 9196078.459551906, 2936197.9184338106]
[335.9942855751555, 3511276.0145960567, 351.08455

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[42.134716606060266, 4748439.877859656, 320.2218500793018, 130.5558872788932, 332499.0932336361, 113349.45767442742, 6133949.151933845]
[265.2726198799679, 8991080.185983514, 225.2620019912568, 136.06899153570876, 3621199.1389952167, 1014061.5266011065, 7232749.8071969515]
[869.5903267939296, 89716.81458992604, 97.1731142480796, 12.384636813567639, 5666065.2915170565, 12403.955646035261, 8434903.385405613]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[837.8011256646976, 476016.2147081923, 122.84567476358258, 5.24789119572651, 1388795.9577555675, 1999202.2689391675, 9947369.04830456]
[986.9698764101963, 2384579.148182185, 84.39365699213295, 29.36160864232567, 5662756.644828149, 108413.42552242824, 8229273.623690988]
AUC                                                                             

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[999.5474969557994, 2273998.3564536553, 2.8220563915134846, 323.22541066250363, 7201179.276652841, 2676587.376875956, 7876584.336163379]
[895.2358837165114, 1255030.4751896912, 112.14193888233912, 271.5357357273806, 7763501.4773793, 3300253.551327821, 9055823.559690477]
[753.2913009020072, 2818254.02990701, 53.25900540221648, 19.504973621934802, 5190583.475620553, 28967.446586156173, 9963899.215214]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[765.3953131162507, 6106972.034999791, 413.1756495404933, 162.36296196454506, 1227271.685247663, 1096921.3340338857, 6429656.110250409]
[631.717543805137, 3847728.714412966, 336.16360733243073, 104.48902731260304, 3691698.472438813, 2181596.255123038, 9903613.446705762]
[556.0146167202226, 2805123.338912218, 47.97988506551816, 67.41896114024718, 4395174.4

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[756.9421149397647, 5336344.107737496, 204.77080029497034, 203.280702641827, 2826476.7476326595, 7616927.260469876, 5061188.4984955415]
[395.8998812403119, 4697718.685893874, 498.7440180226644, 256.95666439991396, 1292087.8902201175, 4267256.3746517375, 5643884.263936335]
[869.6955917196491, 1733185.0609736373, 153.6755997263612, 498.775127066305, 7967370.901711881, 512846.797315496, 7782453.356859719]
[525.9495951417978, 854716.9636826054, 31.370471990739333, 360.1779421552841, 5098780.611269627, 3275856.7813415835, 9970585.810492467]
[652.245327863582, 7571299.122824111, 271.67159959090634, 171.0719572189219, 1925454.7129758864, 1653795.139971062, 6956638.324252634]
[814.4288661137897, 5857716.396784759, 415.8106589925587, 78.26370037535189, 3182316.188901849, 6093916.731640964, 5830972.420690481]
[945.5661176742285, 8967227.1976991, 362.02262582452937, 283.8988619976765, 6333946.711792178, 2517588.8744666735, 4474512.475895173]
[383.8000843032383, 6960529.88248969, 465.8612013471756

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[922.0463998988727, 9979079.445105165, 284.3462035244799, 68.68767148123945, 5964853.848078274, 1628856.6380372453, 9104605.650534427]
[845.5974150947485, 2981645.4104432883, 227.97820320490678, 86.37835607505068, 5457963.856738486, 413221.14145405096, 5355344.430575407]
[668.4961760520446, 4514953.330443063, 168.61995710975367, 37.97709423970113, 7057453.285944384, 1008428.4491785598, 4052212.1891776705]
[936.3963042828949, 389784.50699265767, 93.70750341103802, 14.529543513722313, 4850428.662314666, 2385162.9386483207, 7846787.778942297]
[623.4371894345824, 6256375.528812705, 204.3971759403174, 0.20606766653242659, 4305496.587259686, 3485255.8828182807, 6933441.328379568]
[809.9783225536792, 1344264.209861428, 261.40934313546927, 109.26078740305073, 5861451.153940906, 4036921.5968966447, 6562198.210319573]
[901.3487740785502, 3378404.9030344086, 112.62921303612704, 60.47619307594448, 2999859.7239900916, 2033167.512659269, 7595167.067187395]
[738.2905809865322, 8234581.8232525345, 435

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[227.0067396287564, 5089902.386825148, 480.35732766148385, 159.9949179716681, 2262005.637853187, 3024122.5731986775, 4926548.3658916615]
[953.8366731269242, 277210.42733667744, 215.0291411667867, 95.07207873835488, 6305132.561842864, 6478724.5790520925, 9728247.44551172]
[999.1546823903627, 2457531.139176952, 137.3084058983553, 210.07318146952056, 4881167.902210843, 2559964.08483411, 3312954.96834005]
[908.7524562769493, 1508891.326177332, 41.835393272245106, 79.692348468799, 3211342.834217693, 9564164.497893255, 6433389.400467234]
[826.3971344999883, 3588065.906375035, 399.703020730465, 107.88473524966714, 8633054.129519193, 7843606.640948081, 8231201.413735321]
[717.9981935511154, 3202645.563364691, 74.08136879977539, 234.28475679199582, 5757413.525027227, 4860077.999020215, 7036352.32999102]
[583.2682049585483, 7869517.416911727, 312.7643000762688, 123.36188583311993, 268848.38590667024, 4399434.334496716, 4222305.21748961]
[658.3162897463943, 8376908.930326967, 261.14934277466114, 

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[298.7727018036711, 6978795.360091776, 336.66248337145174, 175.9068089103719, 2929737.2807526365, 7128.941290125733, 1846453.597397102]
[894.8424142111735, 4926360.479429819, 418.5652799882608, 100.98732931548062, 3201826.7020573355, 1249395.190081839, 1405311.18370834]
[231.90289703793667, 1482729.6880825013, 28.659969218328254, 162.61810085214913, 8096205.3477276275, 2289471.9103753413, 9346891.180204278]
[918.4871358677814, 683619.3462361465, 64.24167730368674, 49.55948571909775, 7614028.908428158, 8205337.640262209, 7129743.592291132]
[795.7243784906559, 7233025.080627622, 185.03149370288992, 114.55564074308498, 4033384.203322005, 287017.2329532737, 6012491.008625494]
[818.6191417110354, 8828879.529896725, 215.8282960246566, 132.0182460452581, 2509828.6754332404, 4115425.9022708163, 4482056.577373692]
[532.8882971406939, 3976441.6290025357, 292.5680411242692, 149.0714677022409, 4787239.235759202, 4576748.656995322, 2349035.8006503317]
[706.8837293047534, 4341446.150228813, 461.8973

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[795.1942481047927, 4585292.49400687, 272.7335459519732, 126.50327510928759, 1845843.3804017857, 1298939.4908638124, 4969747.412742743]
[573.445987911369, 3295286.9734781054, 387.630172731859, 335.1833400224207, 1119228.5632244726, 1986780.9221675599, 3981695.524697561]
[993.1568373427239, 146135.35769762524, 175.77108429249085, 95.63164638058068, 6596135.680825268, 9693123.13264654, 8928001.113997186]
[867.1121127138564, 265376.8826339546, 114.63490368758642, 41.67310868964762, 6415594.679291558, 6928096.781750567, 9987091.811207455]
[553.0326182641791, 5475229.742118704, 328.03019869952925, 203.911799752084, 2961863.350222553, 3625545.5693568396, 5898627.410933905]
[363.12812554799564, 5785688.67845852, 344.65635704218033, 282.917539813829, 7298606.311548775, 675678.0571622517, 6634701.478879758]
[771.3517961201516, 4853991.263455927, 494.9268499320931, 143.5784171034615, 2045564.0432039164, 1081203.7482904287, 7655251.480390474]
[694.863813177809, 4184920.083669519, 257.244171710225

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[881.8320033425415, 157235.94660509843, 198.9605153944707, 28.985837988455017, 4609011.74151933, 7599793.048082445, 4337917.099451236]
[465.48548666976666, 2313117.6806087214, 37.82363964193979, 87.89514875119545, 6243657.286386119, 39673.82679946534, 9659181.652318118]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[397.86866590783006, 1427852.6379772918, 248.4118752791213, 155.01480278853998, 6788428.410194888, 1344891.1524884375, 8759861.416359467]
[983.4092382245994, 729451.7641284065, 105.67801167935599, 11.664137983325837, 5961832.510309883, 488758.06664725335, 9957748.381285645]
[923.5299821126666, 1997723.9055204466, 141.55475345743127, 39.10140905035263, 5400438.801358786, 5919947.120782058, 8930042.569172908]
 34%|███████████████▍                             | 172/500 [12:02<24:52,  4.5

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[959.302896285405, 1721595.045740845, 78.26516649899632, 1.2801943006451837, 5685792.531694844, 36019.158308639584, 8372793.329614877]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[733.7095270175053, 2826475.411342703, 50.88835767081627, 19.038400475381277, 2384876.84250341, 2452176.653182574, 9527970.679178586]
[836.0961406138406, 5599028.412416184, 220.74686199457835, 76.04845697556561, 4437085.695791672, 9279979.985677933, 7504014.23827447]
 35%|███████████████▋                             | 175/500 [12:13<21:53,  4.04s/trial, best loss: -0.9611111111111112]

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[992.2898072428159, 1006948.7913882136, 207.11685019691208, 34.51777580908477, 3501853.2297370783, 871278.1961558748, 4559128.039763006]
[1.2387804769019226, 671791.6856416216, 235.74301398987075, 96.11758962613175, 4967866.726183722, 1090026.3673832016, 148162.03143241908]
[254.06903983677523, 2292547.4353494016, 0.4016669866085678, 172.42204018133089, 7853398.750070582, 1899090.222473431, 9810326.981934385]
[705.0850210227984, 3751700.424889843, 152.75534413590984, 209.50305083941714, 4211291.285747017, 8415270.126662368, 7828391.387829848]
[600.675950456939, 1146523.098639297, 301.73385671990144, 72.22970208362804, 7505688.708783865, 634338.418706702, 2691104.268416788]
[278.94090649735745, 2452464.9557782845, 22.643423776277714, 54.798528646779445, 6534661.42974052, 3363657.595322665, 8177608.90641857]
[907.1549063297492, 1847315.562549892, 114.19283663727455, 361.26785662303723, 5491865.015462037, 4331424.3923269035, 7996855.085819598]
[206.81210062527128, 3062824.839485877, 75.89

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[946.0162409398175, 383148.59931123664, 105.57890198313571, 3.7921477948222257, 5900466.033541579, 8967754.366132468, 7386333.521085244]
[474.9300393499738, 8075660.811132845, 32.320944493729044, 139.49492517208395, 6964008.408636611, 40956.867738407076, 939439.3485366339]
[538.2771918310142, 4720778.585479915, 184.57647090482007, 86.6979597215515, 6719463.862080762, 1564547.630844075, 5980008.919288889]
[579.3454904865437, 4046674.066442179, 434.75946898046993, 126.11716506343906, 1581817.425908572, 2641210.215889739, 6707616.959135943]
[424.47009101784477, 42125.08349069973, 328.8391769468673, 111.03736098074923, 3947847.4139212323, 1720013.039319099, 1406799.8111587875]
[300.5360969647831, 1354979.4524189115, 60.168606095521845, 67.3385163863859, 9148468.659792949, 4806610.464387637, 8326859.690144523]
[337.79080292578766, 9966930.775417982, 378.4411979379122, 77.40148790890595, 7628.179403944872, 316559.26446631766, 9380139.854433635]
[605.1441583235289, 1258690.4410586446, 426.086

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[243.89236769144986, 9007665.052850887, 317.0616971806604, 124.65419796388352, 489257.61436018685, 5354609.259991799, 3521420.4363155714]
[444.07728149921894, 2247826.2545310743, 162.06734492140288, 38.202624625315984, 4373176.591794068, 482706.3595979505, 9769044.314563556]
[92.62880872018997, 3166968.165401112, 47.423301345890806, 446.9447711494628, 7390604.953359092, 7605904.813930548, 9932520.45276984]
[676.4957719792357, 3642617.5132720917, 59.364735128237015, 324.6520985016499, 6184485.632850899, 1000692.0707957211, 9729656.41768583]
[641.7186428905163, 2088521.0277225068, 12.267014177574097, 61.46453783009053, 7731616.39071677, 284455.34709392756, 9361810.841478493]
[530.2833104051517, 1309425.364679368, 26.275781116095356, 116.05002837350365, 1907133.8115217488, 6524848.954079392, 9199379.48746329]
[42.08466460707484, 4185939.9823411554, 98.84795871527322, 290.8149331625756, 7269666.756509247, 3615548.7439947454, 7826072.459784375]
[191.26108952624134, 3342490.1568834474, 15.54

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[164.85509939402618, 6926559.376344545, 339.01898501112385, 138.37399986801168, 8246314.641248516, 6854391.757225336, 4107381.595924929]
[947.574939458476, 5104180.216196846, 228.15701554144465, 406.918982088348, 3121641.7843446354, 1782407.0199697153, 5470001.754501461]
[570.5671051226857, 1933033.9357547155, 105.32167134704862, 84.83245562997885, 5849648.594201812, 7197361.641995359, 9446591.390771475]
[15.198508106568568, 5406885.27083295, 82.92246878839474, 487.43928647009295, 8499487.570320435, 9606.891509790628, 4747703.5068701105]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[320.6639808478013, 928544.0003989318, 50.80674593126634, 168.29734374098481, 4658819.290790814, 2050736.6705782106, 7563950.442441113]
[350.8776252235347, 6309945.810125727, 37.01379526386086, 192.48357359981173, 75

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[463.8001502453537, 4336533.627188437, 191.26460385733077, 267.43621996283747, 6726736.089102671, 5628845.851732643, 3229412.579953117]
[418.0876713858494, 6068669.91630268, 248.67601307596516, 199.2284046971086, 7265510.481199657, 1017085.0723905805, 407395.1299402071]
[713.2281641351763, 5260869.5065424545, 289.1408904404317, 149.44377713097703, 6460536.961887346, 2732589.744928702, 5797034.747605033]
[386.43341278437066, 3221677.8313336596, 74.82338901571076, 332.96602905978875, 5629619.86472974, 350529.706262179, 7437161.920679174]
[912.0990490106074, 41598.415762383986, 407.0381361150347, 72.42717431359132, 1464514.8938877778, 564310.4270304962, 9284381.747904785]
[493.9626311040766, 7851576.120233848, 238.17168618155398, 113.23921505221205, 208199.20377921034, 1588925.0550749036, 3418295.8860544083]
[630.2557021976678, 7650064.323729895, 300.4147663322663, 73.18547300347379, 2522539.4884090065, 857006.4113498013, 3876764.992392529]
[797.1504309393733, 6406623.938419611, 280.35827

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[436.77643214544, 3505473.8443629225, 150.94740395580013, 45.11476324954568, 4521559.184179968, 626056.4594805171, 8912765.188161977]
[2.5890213645607396, 3290314.963193855, 0.23223111906459337, 433.48296629845856, 9062164.405135648, 3829.6523970208546, 7003382.391514527]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[867.2499028378969, 3282277.5222235844, 0.3386506191500497, 33.6432488020783, 5164768.319215043, 1617873.2963247714, 7024007.870171331]
[80.58304899687337, 388612.85412557435, 90.30576676169186, 417.68825673330434, 8782940.352661958, 448034.6410629449, 6687012.523582763]
[554.0991582752796, 2064432.3652247954, 104.69835139885055, 138.36566344803066, 6450667.644819001, 1943617.580084869, 7615378.489496044]
[357.9581165961221, 2634693.2307809833, 17.4820447222303, 211.72454450718465, 

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[491.82011461518175, 1819097.0645305188, 428.8114428727226, 7.94708441318588, 9288304.432565706, 1140965.7299044405, 9979001.898262724]
[69.62563726938366, 881087.1665272752, 142.92982770138516, 350.7187123450845, 7802807.106042649, 722304.960111735, 8220487.6759375725]
[264.1001975595966, 3068290.5059365397, 117.51840604682201, 292.4904928286629, 6899769.35213475, 1353922.7908212245, 5982554.371772846]
[830.4356651883331, 4492951.052590484, 11.940557157828682, 198.23946441872332, 5613398.479398605, 228228.88012062176, 6399012.312492722]
[806.2939290173442, 6883512.129529918, 36.48478798980451, 23.886158797694364, 4264613.275953526, 1798224.9316336215, 5032020.2379955]
[998.9042809217182, 34440.44477478927, 90.42413430417602, 48.97453871101344, 4290855.421634678, 7402168.763421798, 2743668.9196684025]
[476.50578687844086, 1164996.2849120025, 401.79927271834106, 305.6722468743934, 8190166.728406252, 2097757.9312575324, 7795405.402990134]
[910.1800138407127, 237115.104538491, 134.8899672

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[783.4880502432244, 2077227.1385554038, 10.981546446424284, 182.95586586973513, 7391849.158197949, 1225584.2811617109, 6102143.830910934]
[766.741915325173, 3653748.8604043797, 339.41910698954104, 210.2448211554475, 6706557.012610147, 319848.52054278884, 7156413.726353303]
[4.0030611090440384, 6415565.807922811, 69.01045112207495, 385.51391389335663, 7988762.97741895, 12499.692002547827, 7453159.155473828]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[193.78388246390855, 2798135.829981519, 24.608931864506083, 374.429584725222, 8017414.051070486, 42478.979591239535, 7471149.219246611]
AUC                                                                                                                    
0.9611111111111112                                                                            

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[327.8228882897962, 2580844.8942819135, 47.203605203520716, 323.23101807490775, 7698682.351630816, 19935.325863446044, 3723990.9097610395]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[281.7028197729254, 943534.1768206242, 256.200541943143, 147.9298875376594, 6223453.641654169, 1795048.2730515106, 9319895.621743143]
[174.3221621704766, 4174324.521183081, 378.50923505994723, 434.65500481105096, 9605375.654725458, 1067087.089976113, 1746399.5765828441]
 63%|████████████████████████████▏                | 313/500 [22:00<11:25,  3.67s/trial, best loss: -0.9611111111111112]

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[533.9530773844157, 3874485.876610702, 78.6298771435187, 19.618498291342306, 4931634.417137712, 574641.3775586109, 8565844.745224744]
[310.80324356404816, 1517773.4752930594, 309.70168743018735, 232.2226068840552, 2148959.2399887573, 4428478.249437366, 7712905.156020782]
[841.809359255534, 1932775.854066436, 109.63742564820618, 169.06914027664376, 5778135.967069828, 6678954.360605422, 6656894.81813102]
[809.9191606462191, 5777244.632118517, 98.07108510485942, 110.19738425730773, 2847224.776265994, 8200184.0600451, 5757215.203855877]
[942.9980078216099, 5520696.649023288, 115.59833611470299, 36.985080724781184, 4749890.538779651, 9424380.47986433, 2887768.321871899]
[586.194772919918, 6576665.283751216, 241.99461387701365, 259.38818652151326, 8696408.972005172, 399020.31048305536, 4896275.084200165]
[406.65815309625947, 2379522.9006782332, 0.9954194888710005, 226.4811837628993, 1831773.546965828, 1381320.1371912826, 8966030.74700797]
[147.00179753518466, 2189140.3948748745, 225.27793054

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[601.08717683646, 6704942.9525501, 89.98173158580887, 392.05017763278306, 7893157.979518362, 2072266.7771004862, 9253849.11113936]
[389.88648649660996, 4047992.702558782, 136.52277160363423, 64.30566493042991, 6764204.117999544, 790901.2558969768, 9277995.073056832]
[192.6196122377104, 3520490.824597721, 24.13590446908493, 236.25584090554776, 7420402.340724593, 6411723.454572632, 5919498.732775476]
[537.5169439152735, 4182832.9362525158, 97.42752711581714, 89.60441159389873, 8322479.676013347, 206037.75275293874, 8890687.290900392]
[613.1433131685911, 1258981.0904920697, 365.09434767282187, 280.59314868917215, 5930665.673392737, 1504111.1194226062, 8519900.25349213]
[712.3798160415076, 4598726.8231075145, 337.0697116335574, 69.98327532464174, 2733838.4902193528, 657217.0437043621, 6579270.954834865]
[286.94345225602103, 2082589.9250554023, 18.918272552005202, 188.9860385727149, 7192030.261078501, 6062986.617357823, 8121397.027752201]
[334.5475733682623, 553668.7026487022, 77.7616570616

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[57.72206080138518, 5889064.261148452, 107.41451536081554, 425.57175771986283, 8576232.56004017, 5478886.074565172, 4511548.132906945]
[732.1439345814711, 1890132.0364748095, 297.51373995093337, 41.59372531120068, 571667.3727340654, 2663715.754739949, 9969634.364137528]
[29.318124651600083, 4865807.762348308, 318.5623726284923, 439.17723426057006, 3028684.845156554, 4183194.844961575, 5437434.425882416]
[653.3478498917885, 6879215.416014432, 390.40551013029693, 301.8559935919076, 7809753.317055678, 1255583.1535169445, 6487462.587065581]
[161.5252454233359, 2681343.9976399126, 9.105433158413035, 452.4289033143989, 8974613.83121605, 804255.4445841184, 6970911.718035897]
[262.01505148091724, 3416341.437127077, 2.093628823814626, 420.8254449365287, 7539006.035521614, 2167389.2992828162, 7601305.75919828]
[248.81945084692606, 3226609.928564564, 31.708185021015566, 135.81682816441398, 3885686.7410973785, 2945.251105881631, 8343484.76961102]
AUC                                                

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[551.5874965158569, 730782.8425677086, 273.1556437062319, 176.05000074756728, 2631344.6435658652, 7994584.4327833615, 1204195.8699036539]
[688.9165841198834, 5115804.4200891275, 286.3481987435252, 95.50027335509715, 2358312.188187879, 1900166.0606985372, 9780151.197127694]
[185.9464227235864, 2544522.0543149626, 85.04251433153584, 310.24533091400474, 8264794.295744314, 1676605.686181723, 3605890.587229098]
[7.86133015565396, 8715120.344856646, 121.50387412570718, 499.08868475398145, 9162767.612902332, 481967.5677402537, 6393709.301760913]
[223.66439955267703, 6823365.1222850345, 66.74540897536959, 376.6066765646226, 7823735.748747427, 759393.0293415603, 4080296.1992772045]
[639.4176715708038, 2959813.8584164083, 0.24495418437021724, 74.94416434223106, 9681029.912268637, 618.7376143741785, 9454854.615845839]
AUC                                                                                                                    
0.9611111111111112                                           

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[919.6383362911888, 352816.0598725462, 152.2159109028919, 56.89649115077363, 5269804.893204296, 424409.7919436007, 9092471.178282782]
[944.2362543329388, 590198.6732145594, 203.48568089196965, 28.468500506255793, 4777663.463142999, 2442560.89057976, 9596494.928737542]
[39.93105558512718, 4431121.028897151, 351.1437442722509, 240.8619332996002, 8453733.902537325, 192119.99082363799, 5138942.905164563]
[99.78329258958125, 4154594.5166477785, 374.3629698409661, 200.41703472057822, 7636259.887543933, 1163137.5037424718, 5322222.257830656]
[440.3758657212789, 3929471.5236686454, 420.33034057880286, 251.1240785162656, 7361592.6101541035, 3459511.8801587373, 4608276.039771268]
[568.0414891557309, 8489546.32030787, 261.53385444607625, 280.55864781564617, 3622015.8080565296, 3028.7810124384123, 5703692.110642493]
AUC                                                                                                                    
0.9611111111111112                                              

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[957.8116291956924, 1924110.3130025617, 55.9366361779322, 415.7520062442071, 5797083.820571987, 1968989.7440808397, 7139954.27851894]
[906.8733124177156, 3663176.790348753, 42.9386262804578, 15.848516566087856, 9861571.715836832, 1557282.637274194, 8842768.80421067]
[594.1525877572582, 1203491.8371928725, 361.23019887065107, 474.00863617750224, 8799398.328681331, 901115.8678811474, 9072201.10882108]
[487.75958173012083, 2427290.0427860385, 432.19474456377253, 463.86131201171486, 8867360.91094564, 36386.49253930731, 8572444.203243403]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[380.33541611591846, 9296198.150555028, 428.16160986779244, 321.7287005357404, 9288301.967262799, 15348.164645488218, 3242524.1504805665]
[507.7959628395889, 2773767.938548894, 77.02938010781313, 10.804903436681695, 7050

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[152.92062699579785, 8169342.373430837, 27.509746266120473, 361.69220684851916, 6517384.337171574, 7373.615523257093, 5002436.991001138]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[165.15301193629145, 7989741.331110171, 488.23500549336785, 380.58425494008577, 8138862.507604183, 4707697.304434439, 7515224.294891603]
 78%|███████████████████████████████████▏         | 391/500 [27:21<06:37,  3.65s/trial, best loss: -0.9611111111111112]

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[200.2126884233471, 1602871.3889115413, 102.4826504297524, 226.4038550953523, 5616942.573244777, 1689585.1538520418, 7963861.416841675]
[581.0574551701086, 3005074.624586132, 10.219543395423063, 257.8045298438338, 9548381.2548319, 4938397.770080197, 9429588.836641397]
[126.76607855030166, 2184365.6441511554, 118.49929897309977, 445.66825096665167, 7790348.29986527, 533017.6438953064, 2464586.10761823]
[204.6516352198846, 969652.5899012389, 142.59621938406372, 352.0528104522814, 8712277.52273911, 1165762.6257976203, 2816566.099792095]
[700.9098983081421, 3048964.7939749686, 279.09382226502, 338.46315174493094, 7217498.150453182, 5760314.241782723, 5994183.779570963]
[244.00946429734245, 5309650.331729924, 82.72922717307115, 493.3572915779479, 8561063.747586733, 261429.80298066325, 3364525.7317444803]
[790.569001771191, 330530.5115072648, 341.67751929601843, 463.8421346056686, 4312822.136721268, 852148.9492010777, 6788662.581819391]
[76.96053804257544, 766234.4747691494, 414.585840845119

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[268.62807396013545, 3063833.3716267883, 33.8028574016309, 136.51185208449914, 6478628.797572218, 31657.09783420921, 3696993.916378486]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[260.8938178117885, 1023243.8313288607, 34.442725881778486, 83.53839101220794, 24786.875109139597, 1051830.2705331566, 4207203.877341269]
[772.4206907550819, 1141159.6098026615, 94.05953936501527, 50.14142453836796, 4592412.609941633, 8429959.424841959, 8604027.969708169]
 83%|█████████████████████████████████████▎       | 414/500 [28:42<05:14,  3.65s/trial, best loss: -0.9611111111111112]

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[837.264703737525, 2378502.4408186297, 497.3375828615472, 62.47205606256555, 1922994.2657663254, 733522.3276972612, 8533224.728628274]
[813.8937440754491, 1600283.2966330198, 51.76541282024447, 36.414954087972475, 1611714.6841542255, 426412.57173926954, 9632735.397194084]
[922.3475659984056, 1944339.7790201085, 25.654353414415365, 52.616755158608285, 5276593.09393852, 1199878.3579772064, 9668150.244067779]
[431.7886682867038, 2455498.702608619, 473.2314291429819, 239.18981953945405, 6699351.068188698, 283268.09432456875, 9931000.451843772]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[537.973786008085, 9337152.843844857, 482.27639072733155, 456.3913361490605, 8139881.3424924, 14945.527847749347, 9983180.552404005]
AUC                                                                              

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[358.56709945748514, 2601056.3400996113, 401.99086987614214, 272.65595375430655, 7980406.70333063, 798681.930442414, 3067555.4975875225]
[614.7919315984534, 1222482.8412892127, 307.99333923808064, 113.9098651949677, 1176703.4241764296, 1440734.2763898137, 9022112.762963856]
[485.8023543621257, 3418684.79261754, 150.10305756899808, 463.87646777630005, 9042897.161694089, 9818532.541290652, 7959998.580597619]
[569.8428934012835, 2066813.3572993067, 133.72363541813087, 499.6261296632334, 9410046.24076554, 604897.2603761433, 8144296.634268055]
[852.9441939546825, 3590905.4773639888, 66.07743111057077, 316.8607126333307, 8949877.542157205, 931208.6622369154, 6693777.401741801]
[136.0179668838169, 8461467.016664712, 440.3956600214923, 441.36180930641706, 2979882.336119162, 141754.1674046893, 8808940.981611375]
AUC                                                                                                                    
0.9555555555555556                                               

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[619.859230086586, 4849005.40890322, 60.10706595598748, 401.87630148270205, 8664777.505261052, 195363.7461838602, 7946035.218307029]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[296.79476308550846, 1671843.709839299, 44.07693033470934, 364.06418255640006, 356349.85715939803, 665779.101874653, 7539439.859379895]
[474.3390253354752, 7941331.978708191, 223.94117829489446, 167.67921304729154, 6855779.485633865, 4704.904352834576, 5438515.431774726]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[380.73827413116874, 2260963.5387660433, 51.483228706622235, 159.7497466700827, 6924186.539541762, 394729.510086382

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[86.19151303109005, 6323990.6580815995, 19.008418903055308, 400.27947282465556, 7967599.253581689, 119719.72938411884, 6992436.387462063]
[180.04797113829454, 5980345.504448844, 28.852157418295228, 370.8621069373319, 7687229.990815896, 1663847.7503662696, 6742072.720690006]
[159.73120324749567, 982721.4154973417, 330.3305021786088, 358.32487069729393, 140041.35531592323, 2596736.4563583788, 7636117.419713736]
[307.64957789644234, 3700801.241582075, 162.35465570405603, 468.69490476950006, 9973130.476424122, 1486053.142422681, 1673237.005376563]
[378.66250565373457, 2457692.199502694, 178.87230272846767, 458.93718519567454, 9363359.412690025, 951665.4030108644, 4009074.2481055395]
[433.10140576648485, 1278022.0869340273, 476.77822300876187, 296.59345843954327, 7195982.76311802, 320926.34029373375, 9938106.372284098]
[524.6944726005988, 1930.6246098706033, 466.0175537877155, 68.95469911011473, 1433634.8345711292, 1136033.5142243446, 9822149.062299648]
[118.23504280728231, 8632915.55466683

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[215.41806931935005, 7181353.634305693, 246.84196520092485, 381.9314100714499, 945581.3677353048, 431499.24516918376, 6044237.926064593]
[140.8336380878522, 598887.4142310019, 71.7845022591362, 391.5642794056921, 6596356.81164488, 5250093.000889993, 8647525.00153234]
[680.8899169553301, 5172125.584712764, 62.58355169526027, 344.3214841156057, 6933181.981727301, 802217.5523107147, 4724801.092956771]
[595.1517478423988, 6149642.702415193, 87.80249805407676, 438.50188177389805, 7075455.491841009, 237879.68543940145, 4851681.648836884]
[244.0420970409582, 4774115.116984793, 303.4805502966738, 142.05039717628475, 3521081.40626883, 1022571.4048682876, 5490875.784297299]
[457.3033371483875, 7106640.492407519, 216.015261879539, 166.57926398164193, 5483822.460175575, 661444.114812434, 4527865.566028584]
[544.3397386989349, 5659831.172196671, 257.7886138628386, 115.68299657240668, 8674472.123918634, 8842143.87668995, 7836407.553518532]
[547.4608046948139, 7832525.547013944, 179.5341045988824, 40

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[493.3829152925646, 5908256.861352971, 208.88635952666863, 189.42265023502884, 8112202.835424853, 1450695.773190093, 9300062.385576963]
[614.5438443629673, 4120800.9200376105, 54.26812092055967, 173.63799417171458, 9613736.880847018, 14075.743166531494, 7688404.73783168]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[95.71747398262566, 6612443.874432099, 70.09388327303708, 485.80366754886484, 9532667.391018378, 11724.742464369163, 7695377.4920183625]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[632.7397222405715, 7768987.901437537, 4.311859275402941, 160.45649683900277, 9667779.018406456, 9087921.628429

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[424.4521496168387, 8319195.490534725, 7.939414591722823, 148.69174858577225, 7321635.510188403, 5826.651144454414, 2866656.4206394595]
[318.9300657535001, 2078634.0057047647, 422.7190861703999, 76.13186421571493, 6389257.99757855, 559388.0382012734, 967531.3790065167]
[398.05578616764933, 965127.4946628932, 284.83994849811944, 92.63414410731252, 7421793.018497104, 4394228.693544245, 6249075.782806657]
[352.80465532871415, 2669088.3046555487, 42.52599669769479, 222.69950287855306, 6758396.193200496, 888306.182026738, 3261231.7365908884]
[295.3575985888935, 5295192.075203415, 364.3075762189358, 444.5205632754592, 9880338.214844422, 4018323.313686598, 5680463.632933]
[49.30849834230414, 8963641.375906676, 0.6451704593086731, 354.0943782755237, 8001822.993259977, 1254744.8552508678, 6409868.869000269]
[669.1529376954484, 3054817.942439114, 23.461037481870815, 451.8991506843601, 8371790.117940014, 474341.71191193594, 8038965.018815139]
[442.4521319412638, 3173151.7839172515, 33.71282559352

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[335.10578181280255, 368827.43696876935, 143.54269945252898, 267.6691181713047, 9134708.242237754, 416845.6912860306, 3138279.589552654]
[74.03971084600309, 33846.10192195425, 343.8524217139242, 289.15143025313125, 7581951.491238579, 1064818.6872564212, 280703.817294647]
[266.0305546498891, 721554.8723516909, 483.35619080266565, 428.73589109623157, 7904724.756135503, 611833.6576990313, 2784482.7936718697]
[315.7868967053068, 383999.7810339472, 397.75972557399916, 252.0881088680385, 7819896.153291337, 359338.7678455263, 3520215.9365658644]
[213.55651507620843, 5851857.346556622, 99.37449984144811, 187.51325207224656, 5141315.613443314, 2457414.191087633, 8332263.814859004]
[117.31831585898948, 1858347.8267978814, 435.9700492665606, 474.6713251719964, 9244797.883700829, 1474876.2665041643, 8735130.408271315]
[489.2635992880279, 2316984.883079314, 499.91970723568477, 495.7442222234139, 8948567.176321635, 846073.7294534562, 8886948.66338098]
[115.6479573770763, 5513167.384720858, 78.801798

In [6]:
print(space_eval(space, best))

{'lambda1': 779.5700009689626, 'lambda2': 4830131.9980255505, 'lambda3': 312.1580702974475, 'lambda4': 134.63105864272865, 'lambda5': 3329905.077922123, 'lambda6': 59003.45150548895, 'lambda7': 6090330.444138394}


In [7]:
best_lambda_array = list(space_eval(space, best).values())

ids_multiclass = IDSOneVsAll(algorithm="RUSM")
ids_multiclass.fit(quant_dataframe, lambda_array=best_lambda_array, rule_cutoff=30)
auc = ids_multiclass.score_auc(quant_dataframe)

In [8]:
auc

0.9611111111111112

In [9]:
ids_multiclass.score_interpretability_metrics(quant_dataframe)

{'fraction_overlap': 0.03753086419753086,
 'fraction_classes': 1.0,
 'fraction_uncovered': 0.14320987654320985,
 'average_rule_width': 2.277777777777778,
 'ruleset_length': 6.0}

In [10]:
cba = CBA(support=0, confidence=0)
txns = TransactionDB.from_DataFrame(df)

cba.fit(txns)
cba.rule_model_accuracy(txns)

0.9407407407407408

In [11]:
df = pd.read_csv("../data/segment0.csv")
quant_dataframe = QuantitativeDataFrame(df)

ids_multiclass = IDSOneVsAll(algorithm="RUSM")
ids_multiclass.fit(quant_dataframe, lambda_array=best_lambda_array, rule_cutoff=50)
auc = ids_multiclass.score_auc(quant_dataframe)

auc-';°'

0.7567740901074235

In [None]:
def objective(args):
    lambda_array = list(args.values())
    print(lambda_array)
    
    ids_multiclass = IDSOneVsAll(algorithm="RUSM")
    ids_multiclass.fit(quant_dataframe, lambda_array=lambda_array, rule_cutoff=30)
    
    metrics = ids_multiclass.score_interpretability_metrics(quant_dataframe)
    
    if not is_solution_interpretable(metrics):
        return 0
    
    auc = ids_multiclass.score_auc(quant_dataframe)
    print("AUC", auc)

    return -auc

space = {
    "lambda1": hp.uniform("l1", 0, 1000),
    "lambda2": hp.uniform("l2", 0, 10000000),
    "lambda3": hp.uniform("l3", 0, 500),
    "lambda4": hp.uniform("l4", 0, 500),
    "lambda5": hp.uniform("l5", 0, 10000000),
    "lambda6": hp.uniform("l6", 0, 10000000),
    "lambda7": hp.uniform("l7", 0, 10000000)
}

best = fmin(objective, space, algo=tpe.suggest, max_evals=500)

print(best)

[652.1517688857958, 9203529.891177068, 362.1448861373701, 334.85981287483975, 2729536.943221206, 8560488.263530202, 6236394.897953363]
  0%|                                                                          | 0/500 [00:00<?, ?trial/s, best loss=?]

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[532.3604827502271, 4599084.024118777, 324.7828639335177, 266.771735265762, 4881386.842201881, 5482306.634920353, 5376633.515823349]
[639.5488138255362, 9916540.353570009, 233.8197599706261, 449.23264403906086, 2227958.130621458, 5352683.450235693, 7640859.492372396]
[125.16016240387262, 9610788.497466125, 435.05508792775464, 35.16318721376172, 6168504.520121843, 1721679.860977925, 5964110.784077803]
[856.1363653226862, 5208693.96235933, 419.65586054654005, 230.25993119772082, 3376231.042693162, 4310241.967044583, 4179037.2885578196]
[899.7262913936803, 7596074.342321355, 111.16895023511353, 186.02706472757185, 3990787.0834798063, 9531594.04674724, 1456184.3794973728]
[358.108780951569, 5511303.396257774, 73.2014723158907, 109.6980449223891, 8362511.728903069, 7723060.041963535, 4164331.3566381023]
[909.512194490104, 3389636.9233002667, 313.14512228528554, 322.2510868447779, 6913858.5481052585, 1175410.5269593108, 2243957.0370007334]
[0.4324988936867946, 894006.5430888921, 111.11554261

[175.0326819817857, 7501325.160242845, 413.9394005579171, 445.3448345787462, 2708003.355252241, 185790.94363875967, 3363648.82735882]
[954.3118178135837, 6898999.774727123, 233.97772043305318, 307.5733697329649, 1361822.4214115764, 8660523.030750658, 1610987.7692419067]
[667.7481215396022, 5427641.288058732, 232.94436055697346, 345.03771021662504, 2456345.6483083796, 9372462.587027367, 1809504.0754237236]
[917.4242759223309, 2602153.254585025, 454.2755343266711, 290.57692235458535, 5919067.961369919, 7984586.783630852, 7002641.931713537]
[38.287901878321975, 4788830.987450316, 476.1210210369132, 323.7375838500887, 6647661.995219993, 4495570.252779115, 7234430.954266065]
[563.0801956677927, 6391127.880230173, 298.5938683681425, 345.7350557441001, 7833474.3255964145, 4144485.642313118, 6208846.2440676065]
[524.5114118035353, 5366765.846427033, 284.88890339341293, 386.6623120045391, 7043529.177017278, 2342701.489946878, 5058665.485734254]
[383.00963530165404, 6108732.855754371, 121.255322