In [1]:
import hyperopt
from hyperopt import fmin, tpe, hp, space_eval

import pandas as pd
import numpy as np
from pyids import IDS
from pyids.algorithms import mine_CARs
from pyids.algorithms.ids_multiclass import IDSOneVsAll
from pyids.data_structures import IDSRuleSet

from pyarc.qcba.data_structures import QuantitativeDataFrame
from pyarc.data_structures import TransactionDB
from pyarc import CBA

import random
import logging
import time

import matplotlib.pyplot as plt

%matplotlib inline

In [2]:
df = pd.read_csv("../data/iris0.csv")
quant_dataframe = QuantitativeDataFrame(df)


In [3]:

ids_multiclass = IDSOneVsAll(algorithm="RUSM")
ids_multiclass.fit(quant_dataframe, lambda_array=[1, 1, 0, 0, 100000, 10000, 1000000], rule_cutoff=200)
auc = ids_multiclass.score_auc(quant_dataframe)



In [4]:
def is_solution_interpretable(metrics):
    return (
        metrics["fraction_overlap"] <= 0.10 and
        metrics["fraction_classes"] == 1.0 and
        metrics["fraction_uncovered"] <= 0.15 and
        metrics["average_rule_width"] < 8 and
        metrics["ruleset_length"] <= 10
    )

In [22]:
def objective(args):
    lambda_array = list(args.values())
    print(lambda_array)
    
    ids_multiclass = IDSOneVsAll(algorithm="RUSM")
    ids_multiclass.fit(quant_dataframe, lambda_array=lambda_array, rule_cutoff=30)
    
    metrics = ids_multiclass.score_interpretability_metrics(quant_dataframe)
    
    if not is_solution_interpretable(metrics):
        return 0
    
    auc = ids_multiclass.score_auc(quant_dataframe)
    print("AUC", auc)

    return -auc

space = {
    "lambda1": hp.uniform("l1", 0, 1000),
    "lambda2": hp.uniform("l2", 0, 10000000),
    "lambda3": hp.uniform("l3", 0, 500),
    "lambda4": hp.uniform("l4", 0, 500),
    "lambda5": hp.uniform("l5", 0, 10000000),
    "lambda6": hp.uniform("l6", 0, 10000000),
    "lambda7": hp.uniform("l7", 0, 10000000)
}

best = fmin(objective, space, algo=tpe.suggest, max_evals=500)

print(best)

[138.2082790871304, 8308729.050872044, 322.78171046291317, 242.86569454325618, 1909781.7414962582, 7002306.589150845, 2025818.6788223465]
  0%|                                                                          | 0/500 [00:00<?, ?trial/s, best loss=?]

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[849.6105512665187, 2035281.0542537093, 84.82396161933752, 125.8617170761922, 2620495.2809052574, 5802723.131644956, 7557280.737228815]
[959.5121597185232, 6412076.864095054, 398.3369299331096, 66.1946467289305, 4977263.40670847, 9053588.764637658, 5036704.472668461]
[121.20522969837599, 9974696.938841835, 477.7494507966899, 162.3598540347679, 5301581.716003613, 3120517.3466048143, 384683.00908104825]
[933.8154477297634, 2283511.6636913153, 241.6192586001711, 492.1421316546601, 7255363.914616135, 8202697.409336181, 5175229.512048409]
[815.6847746308061, 9540450.081579225, 469.86240365462504, 165.52430099569648, 7451129.723449913, 4768241.680237365, 1469455.2354153611]
[733.8456995143431, 2293664.959894444, 313.95224888939447, 104.08066588370745, 8400224.123791799, 5608234.875865357, 3854054.0664198184]
[895.6871395914758, 4186334.177397889, 313.66557417115854, 92.69921952220162, 2482811.4007606083, 6960885.700845002, 3982262.9433339974]
[102.02901883446857, 4950764.2545914175, 42.59481

[753.895909753707, 1677213.6978516947, 302.02814100593355, 206.89789364792705, 756762.3023023081, 7205671.523981301, 777511.1008618893]
[591.1844098961061, 6958955.098820581, 56.34555186105058, 474.1558972952337, 7493879.41596016, 7860059.20564609, 2248184.0322027868]
[195.52417379160858, 5208270.554788439, 59.321377607962916, 375.2034749763146, 9015920.109762479, 5036810.793470791, 3536741.801248366]
[546.5404876780538, 9128235.155163804, 475.99953079566365, 35.80046074481493, 2829316.5234112004, 1432875.2578873206, 5198255.1538335]
[767.2151393318458, 9699017.233044146, 465.12655151528884, 350.6795911337624, 6344483.909462382, 9118700.64898912, 5968488.606980775]
[848.049293401149, 1072165.4316899348, 373.2676348059137, 262.87789149908406, 7037092.13320354, 5945157.736633354, 4233121.392721029]
[379.70997657933464, 1973620.6488050572, 365.4209175384954, 241.07344786530578, 3667346.099783199, 9066670.54588892, 4677472.592318496]
[322.02803540534705, 2683096.980109918, 272.73368050604,

[607.8102108470382, 1811194.1522370572, 358.9768149098992, 132.05482305810034, 1472495.0577416094, 7818993.179369319, 702611.0989036931]
[570.212847754329, 7090832.395586479, 399.0912722834219, 476.3798759690699, 7743936.365672914, 7725936.145929555, 3420679.0481317705]
[486.1403423702328, 4999063.0526957065, 74.99329823446323, 387.2068865295539, 8208691.470817933, 5867772.394365091, 4087390.729668136]
[406.7703017964943, 9019956.209770601, 387.94431674166697, 12.083172398349006, 2874425.686241884, 2826182.8255917067, 3837420.8313919324]
[534.1372258625594, 8045255.270722391, 382.10691052837154, 438.87954119484743, 824.6531307306141, 427140.32394549437, 4257671.387031117]
[354.3236984170918, 9595656.359128252, 425.95863165218947, 363.8860235272002, 5224165.866221623, 9358003.865939831, 5508890.454985001]
[852.4603853851377, 8421184.56680335, 457.7187220682053, 267.28194443139955, 5541376.573704444, 9510603.707849292, 8089372.246429218]
[889.4092155317762, 685293.0255302443, 422.5633404

[617.8099367105701, 4744014.72597926, 491.2854186153743, 221.76008248602813, 3642842.347770659, 810790.0712058002, 5991768.107630711]
[635.8552505958995, 3756445.6993236295, 195.8632890249513, 401.4946992089804, 3053070.7679321594, 212537.55904078903, 9895121.816108558]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[686.590432232598, 4257570.826420295, 94.19310520964841, 410.0196355673875, 3010532.9935272825, 245402.4720936739, 9526389.547087213]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[724.2963932577429, 4283305.053965779, 16.608535200002848, 414.31748345252623, 3019678.0903095817, 170223.889813394

[998.4823530839741, 628902.1215246093, 114.68521131970839, 431.5888311362083, 5164531.6171203265, 415809.404233078, 6905055.497110632]
[911.4807223126078, 4070091.152726236, 49.54598846273657, 454.70905404613524, 3847444.0051480825, 1012482.6747973888, 7485594.593839022]
[779.439813105743, 2074490.2307384284, 53.75967172822025, 448.35768069614943, 3563837.966129378, 2857910.413427383, 8868757.183830345]
[979.2988339881342, 935839.8630444561, 38.86386093512747, 474.11302248294675, 6552563.142655648, 73286.45007951446, 6683164.998901712]
AUC                                                                                                                    
0.9592592592592593                                                                                                     
[865.7084995972333, 2431894.289155802, 99.25899024847791, 499.263053735379, 2836352.3507633237, 1490042.0200144271, 8064298.580313867]
[700.694208632041, 2213086.4426687863, 6.649079582617995, 359.13543313743463, 41681

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[899.7356087361849, 457508.0152595643, 126.20270496495858, 460.30649831227106, 4804753.126865287, 1131545.1949206027, 8573695.343927091]
[800.5716077051197, 38190.864546610974, 176.6549815207024, 329.78229369126177, 6202134.883200276, 360221.03397632344, 7608458.420730384]
[739.3815275090192, 4400346.729608926, 29.62149383596403, 413.2556226265299, 2294922.964203226, 716726.8970233587, 9557226.409232775]
[760.1365621592155, 465290.87569425424, 65.08484129290953, 376.3550592454344, 5113979.821116563, 2153679.321229218, 6396502.276251493]
[718.0250129363674, 3518168.386706298, 161.52457354769143, 386.9830269364028, 5737919.5099955555, 1338666.3279308702, 9261596.098119536]
[599.8360970237136, 3102043.758998212, 150.05940346559296, 481.54224345722645, 3277858.0147838714, 895610.3071454429, 9887436.232876053]
[833.9161598664869, 2678243.1271991404, 211.11724781443064, 394.11875757885076, 3818291.3806690383, 6362.477101476223, 9294715.0458436]
AUC                                            

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[790.5403308457869, 1822280.2994860187, 212.606211236294, 366.41067712292596, 5317852.514147614, 2596253.6380535043, 7833941.774077346]
[934.9599693892235, 5913376.747355937, 87.47065899576404, 468.1744182132276, 6530892.619725712, 1757840.4387570892, 9943781.497562788]
[851.2902452991167, 4872083.078046646, 70.89195370298583, 488.0893125088118, 7397697.528606812, 477094.3422968751, 8756644.623485861]
[905.5734925329042, 3358280.0659451694, 140.05402430208733, 439.9022518443, 6757922.450893939, 36432.153434700245, 8322759.027687539]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[538.9282377201016, 4243850.030231014, 139.5289454992296, 404.53022310004053, 7549185.035036599, 221014.36757919117, 9720959.677028786]
AUC                                                                                  

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[934.9186620416326, 2504478.3088662513, 173.0031710551074, 338.6104291816353, 7496446.476323623, 3012294.38191454, 7784489.1506836]
[992.3896318327224, 2737328.5139235836, 228.33848193793924, 470.3096836210594, 6246483.144122229, 6631.222257226123, 6572477.234920573]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[478.8020722684412, 4620690.643343081, 225.91807302315578, 420.48281231845317, 2466976.154850466, 754677.7595157628, 5730451.575052347]
[691.0855098121701, 5578956.388082052, 195.9860399240086, 319.3323616754459, 1812686.6428373086, 1657333.650802738, 9045448.832016034]
[814.5490591580311, 3027151.6632462544, 18.473499423561435, 391.7911848832732, 4956258.310979417, 972952.3788565171, 8470965.658904659]
[774.3347067742637, 128938.53187450208, 119.8438617471929, 497.45431826631096, 602371

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[603.7948726949821, 2189344.6222414686, 265.5130875656868, 317.54447135061025, 5592163.217001233, 1813868.2398453697, 6129947.847153554]
[866.1541450004873, 5252729.795853368, 102.01101627179703, 492.33660602471036, 3579157.6862012567, 4020860.6865543304, 9367464.518801894]
[793.5692305803988, 4529346.425542534, 59.73514733991652, 457.5147451417149, 2236868.9198255828, 577097.2524215109, 8297868.502611205]
[827.8049394055881, 3187501.088073838, 115.12006535931252, 435.91946758056747, 4953632.162609314, 68184.79931400585, 6371470.898879377]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[542.700991236681, 3618611.9194926433, 121.14054955051887, 395.6630933755015, 8048242.9501422215, 253731.5470761229, 5648807.475264994]
[550.9014342327596, 4196502.164502997, 104.8433506464832, 372.2730943425293, 4

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[567.9132618974496, 4254723.658897803, 143.12139319634122, 445.59319064722825, 7350840.485766269, 1407679.461733882, 7919916.06906918]
[991.3975301077787, 1434690.8838887487, 241.64455431034904, 362.5262279403621, 9374705.640835993, 450871.5315208201, 6811624.5559502775]
[946.2418625557647, 2833015.0065079667, 40.53983967312861, 476.5594115234105, 5917129.771682247, 56518.13333231208, 6450855.156463418]
AUC                                                                                                                    
0.9592592592592593                                                                                                     
[851.0778579518924, 3083970.080848404, 12.150447190808606, 498.56198162578204, 4079779.8121167673, 50052.87346700209, 6253809.208474738]
AUC                                                                                                                    
0.9611111111111112                                                                              

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[815.9814002434377, 1373230.3080121868, 180.56105075181202, 327.266008325826, 3789903.430006638, 2053673.4024553988, 8238003.759893374]
[730.725935661533, 6186365.463241382, 72.16735466181984, 310.79602919379533, 3626237.7368792244, 1689531.0180984244, 5551476.13114018]
[585.3488239552187, 3883466.056221437, 127.7782550575362, 480.6662625947082, 1727867.4339153373, 334358.3806417242, 9302883.988097403]
[701.6056396915759, 1964038.6636785134, 114.95761623771976, 280.4147921951303, 2719752.9602504857, 1088397.197120586, 7624533.579006979]
[622.8296055416884, 4422942.010330372, 159.70904288000662, 412.4728920208438, 5008196.7886095885, 2682996.22283724, 6382494.000557477]
[880.3849444062523, 1662203.6819662913, 226.93383523808473, 452.6513785231622, 6152410.041340118, 868410.2373435909, 7044527.399677883]
[480.7649028273754, 3449551.98423608, 201.53586546867672, 354.4134581817988, 1936342.0725409975, 2237101.050625538, 8849465.186303645]
[751.9674864506655, 5252875.067000117, 181.26294863

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[748.9413539836767, 2533598.562345085, 38.300773679972444, 463.02115315885914, 4287546.171533669, 1733337.6419748592, 4624842.566088577]
[921.4017301235219, 3499219.9936813116, 0.5066399398183421, 428.84265454848133, 3579212.3649704223, 2944198.8065032223, 4279952.547646346]
[978.6397089296445, 2723664.346663239, 166.20943411985252, 472.50518460208616, 5708552.962290479, 4965.240212552928, 6614187.529678049]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[849.3229937449273, 657384.036467989, 14.38985332559914, 400.68536086561994, 5731692.7018416505, 35701.64688246511, 6067933.286585345]
AUC                                                                                                                    
0.9611111111111112                                                                           

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[805.6463592161592, 2695910.136613712, 111.89010748121115, 344.26725176121283, 5153338.243030295, 1468418.5681959088, 7179351.053636318]
[998.9346526983222, 3301160.0242732065, 141.29286179440015, 269.3274877817677, 6740572.029859109, 794741.6679535167, 8158891.418527095]
[775.7802289675722, 1555504.0836004326, 153.89683107411892, 246.69548333294682, 5512911.326004745, 3565346.2889449093, 6306773.221407926]
[789.1275156109716, 2157452.1519835996, 0.9420945541345329, 450.25700094692127, 4146992.933079867, 195331.6033886787, 4983987.532393713]
[935.3940205873078, 2781217.399023709, 122.68499825360314, 362.76473620118327, 4766941.768229737, 410793.7009666179, 8694965.235040301]
[926.0385422680306, 3061318.6619179677, 173.97440200284558, 497.99579271461477, 4943647.514984414, 28953.638265912297, 5876513.251909201]
AUC                                                                                                                    
0.9611111111111112                                        

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[899.3076581443318, 1876013.2003495195, 268.3215434185269, 296.14390794614883, 6997526.148190539, 2131348.6757886587, 5700834.083067996]
[521.4084319578387, 7061708.560849633, 210.60461339185218, 373.6748527985661, 6708570.675681347, 252471.230192204, 7771477.833573922]
[866.842082447843, 2388836.01638465, 7.724384965653186, 499.6483432302749, 4425102.113767055, 4338750.967986946, 3943881.123887669]
[732.0111956710509, 5866385.33939393, 274.89206184926934, 314.82457819011046, 5327139.531688652, 1326017.174473404, 4413217.202002741]
[669.291466500466, 537681.9362986188, 198.56744193818366, 346.68670329393626, 5944601.705059775, 2743346.62266156, 3336416.555046045]
[397.93189113385256, 6352758.452952772, 235.18720955423166, 362.3669235204128, 4626665.200978842, 1931067.011288303, 9202771.981329512]
[363.157395318958, 5584293.967666792, 194.25523689796506, 371.93197364750677, 7616401.2903791955, 1522378.5140876628, 9713205.727691755]
[823.2586428554147, 2058621.289904566, 45.0731243570951

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[749.3362256727289, 5171345.075444216, 0.6860674282485864, 428.19627523049377, 3698128.5357268406, 2598985.6006919453, 4168131.8462346396]
[585.8614211480847, 6801945.340201354, 6.001250943705215, 491.66977037762484, 1439716.7534787548, 830723.7329605968, 5050856.723488212]
[677.951917184587, 5805265.749767642, 30.808878313044854, 453.3259563501524, 1939752.3228546844, 517359.61439016333, 3325001.8885933673]
[854.8336473860343, 4066005.076019764, 61.83226418266566, 462.8373191758155, 5664560.6200141255, 77551.96179602336, 5920601.618114213]
AUC                                                                                                                    
0.9592592592592593                                                                                                     
[815.3163725883494, 4698502.701847013, 54.629305019688225, 490.0878107327695, 2793418.574830746, 3144325.8838695907, 9421242.874808386]
[949.5270894691371, 2535844.887412153, 20.059581105053788, 467.21319595396443

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[972.9890429093388, 798007.1011350653, 44.12899150351785, 401.9598637900979, 6444641.096347874, 27974.140129355783, 6732259.80223835]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[996.6293431953333, 7339008.406177368, 293.9526076695837, 392.29791838388564, 3956095.3724811613, 4919786.64802671, 5494276.055938011]
 76%|██████████████████████████████████▍          | 382/500 [33:34<09:45,  4.96s/trial, best loss: -0.9611111111111112]

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[903.2177480841418, 6520859.6498292815, 2.077436415101835, 416.7212647832616, 2963794.914457595, 371008.94829277997, 2947523.619858572]
[837.0256679222439, 4114378.5785206244, 282.84166178739855, 446.2290394433059, 2565713.0616493747, 1692684.255051234, 4706632.202497765]
[929.2973976528497, 2700755.1654460533, 105.43986753215017, 496.03088134724356, 6147612.768043813, 613673.699547667, 6191633.291535203]
[898.4434164951596, 3559983.98696257, 329.5066860904486, 460.963689037678, 9174433.752760114, 463838.25153600913, 7191310.927126424]
[996.4765071843242, 1743057.2140855985, 305.59965320618846, 432.92345715126214, 8344643.234268356, 1030309.86010238, 6874267.997518104]
[808.1097355005676, 1556208.8840113364, 120.86911566426615, 256.4187568171213, 6877292.550008034, 3686749.57212481, 8139898.356094957]
[748.5229500307755, 2257766.2326355167, 20.836164489781186, 383.86523793505785, 4221335.342769414, 1943373.2456358627, 5119426.196354972]
[849.9973040001023, 1091084.1421384178, 211.22709

0.9611111111111112                                                                                                     
[936.5773500965084, 3264989.9085931433, 298.7024183722017, 189.7390383194079, 5646171.382348528, 1201852.1336942296, 5230084.851740396]
[874.3496231744842, 3591434.6700800885, 97.4659830343123, 473.693056042545, 7006967.592984016, 522540.0380349963, 8943692.402522698]
[906.9267991969629, 1618893.390263585, 48.68505731474057, 472.95120036436, 6428708.388770505, 4086.5980011006104, 7376130.822352876]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[927.2327198255389, 1583596.639143632, 269.3948710481455, 491.8078818145643, 6404537.091418238, 702244.4178612523, 5741759.376417241]
[886.0741806830488, 2182664.8906492693, 116.01405874668546, 144.28440619573905, 4487611.321423681, 18882

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[853.9608070234394, 2321570.5000002664, 29.377910043366704, 232.45324182833494, 6044915.483334363, 303375.86720401974, 7869083.198090753]
[913.1916558032543, 1072460.4997461399, 157.94584711766365, 463.139435385583, 5285700.026398128, 528.0479181476439, 6505943.935059153]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[891.3830328561073, 2621111.2336489134, 231.7465016690782, 262.02650865715117, 5907912.9107750645, 1044073.5066925295, 8104865.429485394]
[940.1708450917803, 4021061.494921075, 202.43148087913855, 435.3584747029109, 6751281.928507235, 1286190.1372913907, 8186146.57819544]
[668.596793993172, 5892083.34912776, 60.14087477417287, 277.10397357435403, 4129555.8594389763, 1782755.6583061381, 9091773.086635642]
[815.0220378139805, 2844639.6861428786, 82.94246918154829, 270.4848934581693, 3

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[768.127786067424, 2678066.5156914624, 379.46064699302394, 491.6219609183929, 4576713.915341146, 3695.6421059344934, 5607599.197804618]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[819.8905016708003, 1699073.3602749673, 33.02717793211268, 421.57003394531733, 4463460.349993203, 407603.1457929059, 5830001.334989432]
[817.6525343225549, 1972289.6120046224, 10.09515331990481, 408.217754500543, 4880222.443488992, 13075.389082428384, 6053981.160235081]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[785.7451638239176, 2101973.6303782803, 32.546383826764725, 418.82768232984415, 4841786.637206962, 711776.9695462

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[953.1614454786948, 872790.6819258733, 242.2286983178548, 153.99576647709222, 5562127.044761443, 2434293.0821779994, 7582475.949273564]
[724.3401565253182, 3990080.8567022565, 220.53549626702892, 112.78782170681197, 1773682.539448082, 2213764.8805645984, 8256377.752585012]
[510.7031460555134, 3778868.175960495, 324.8183105480733, 220.94215739924343, 6251447.001108172, 4447144.14813601, 3217441.393566303]
[895.3683859001123, 949601.0517985143, 128.43618209910358, 498.79168333034124, 4732831.673286068, 1113868.1944051122, 8684778.891437156]
[994.3964302947829, 1545900.9393330356, 266.5775000457997, 208.6668673275571, 6701729.405621368, 226604.5681924916, 7528072.121293519]
[917.3222575183236, 3352740.478760764, 175.23714048696382, 129.74924924608132, 7109801.494453706, 1546655.6533115944, 7485734.867477074]
[999.2354349898706, 1176236.6486229838, 146.73983037881862, 152.0664983630021, 6947373.388362186, 1317240.9092572317, 7800502.761041216]
[999.9256292767642, 1413952.7677837997, 298.84

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[999.8495757324309, 397403.5221152068, 294.0808721038868, 256.3157717077719, 6358538.003959675, 4019.6487609602336, 6402166.947806738]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[982.2186988490346, 231583.15833681682, 315.0740206865238, 245.62158474325645, 6536405.424954997, 705200.9124048713, 6487781.366017882]
[952.978944671603, 1109156.2680850457, 159.88790257890642, 295.49319564415697, 7167758.585514297, 13363.502127944032, 6892433.134120076]
AUC                                                                                                                    
0.9611111111111112                                                                                                     
[854.8016280925659, 6671917.469152666, 4.0191023512079385, 298.06589235462184, 4035876.221594482, 1679703.711180

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[738.4596013007131, 1263678.0643518595, 459.386899835318, 386.6627415608148, 8363220.279791051, 3847300.448958816, 8519328.320066787]
[531.2801467309989, 6250208.350482729, 207.17007006130362, 369.71182332833985, 9823049.989077289, 2032674.313371106, 9142867.20240085]
[448.7489630211658, 4277243.2440559445, 185.66544689655484, 199.98767000862279, 7947378.618047136, 235528.8559534872, 9639856.901150681]
[843.7968612928801, 1448782.8175016881, 105.34522198906558, 363.4225111496437, 9267849.862595556, 503567.0070845749, 4826913.835809221]
[873.0025007070765, 1724601.214605547, 131.75706939270145, 487.08775552804155, 8991671.841779724, 3295297.274689992, 9353808.022587106]
[883.582662627417, 1672869.9190037928, 438.24767412622435, 216.75900712443752, 5149983.332683923, 591126.145571944, 5514471.937578401]
[754.4173230558936, 3596741.284903556, 490.28949192082905, 310.8252623474201, 5135414.437724288, 5096836.926188763, 5884363.6329818005]
[658.0697127346446, 352663.8580365799, 152.39750613

In [28]:
print(space_eval(space, best))

{'lambda1': 635.8552505958995, 'lambda2': 3756445.6993236295, 'lambda3': 195.8632890249513, 'lambda4': 401.4946992089804, 'lambda5': 3053070.7679321594, 'lambda6': 212537.55904078903, 'lambda7': 9895121.816108558}


In [33]:
best_lambda_array = list(space_eval(space, best).values())

ids_multiclass = IDSOneVsAll(algorithm="RUSM")
ids_multiclass.fit(quant_dataframe, lambda_array=best_lambda_array, rule_cutoff=30)
auc = ids_multiclass.score_auc(quant_dataframe)

In [34]:
auc

0.9611111111111112

In [35]:
ids_multiclass.score_interpretability_metrics(quant_dataframe)

{'fraction_overlap': 0.029253380364491478,
 'fraction_classes': 1.0,
 'fraction_uncovered': 0.14320987654320985,
 'average_rule_width': 2.3333333333333335,
 'ruleset_length': 6.333333333333333}

In [36]:
cba = CBA(support=0, confidence=0)
txns = TransactionDB.from_DataFrame(df)

cba.fit(txns)
cba.rule_model_accuracy(txns)

0.9407407407407408

In [39]:
df = pd.read_csv("../data/segment0.csv")
quant_dataframe = QuantitativeDataFrame(df)

ids_multiclass = IDSOneVsAll(algorithm="RUSM")
ids_multiclass.fit(quant_dataframe, lambda_array=best_lambda_array, rule_cutoff=50)
auc = ids_multiclass.score_auc(quant_dataframe)

auc

0.7550104216770883

In [None]:
def objective(args):
    lambda_array = list(args.values())
    print(lambda_array)
    
    ids_multiclass = IDSOneVsAll(algorithm="RUSM")
    ids_multiclass.fit(quant_dataframe, lambda_array=lambda_array, rule_cutoff=30)
    
    metrics = ids_multiclass.score_interpretability_metrics(quant_dataframe)
    
    if not is_solution_interpretable(metrics):
        return 0
    
    auc = ids_multiclass.score_auc(quant_dataframe)
    print("AUC", auc)

    return -auc

space = {
    "lambda1": hp.uniform("l1", 0, 1000),
    "lambda2": hp.uniform("l2", 0, 10000000),
    "lambda3": hp.uniform("l3", 0, 500),
    "lambda4": hp.uniform("l4", 0, 500),
    "lambda5": hp.uniform("l5", 0, 10000000),
    "lambda6": hp.uniform("l6", 0, 10000000),
    "lambda7": hp.uniform("l7", 0, 10000000)
}

best = fmin(objective, space, algo=tpe.suggest, max_evals=500)

print(best)

[442.48862621002604, 8834404.471111432, 21.972009978520724, 320.6296234622158, 3015889.8991860896, 6406103.853795737, 4350889.335956591]
  0%|                                                                          | 0/500 [00:00<?, ?trial/s, best loss=?]

  out=out, **kwargs)

  ret = ret.dtype.type(ret / rcount)



[251.94683560311293, 8051666.39883145, 68.84048302506591, 446.2580024778189, 9281826.855039258, 3437583.7398656015, 6545617.731651106]
[396.55539422147444, 1180469.360109423, 31.351678579854713, 105.56515345640283, 279775.36734457687, 3503689.46222713, 1364672.437998905]
[536.7986651247595, 5897432.001282422, 332.1666582847025, 37.937498677792505, 9611973.527110798, 5999578.050770043, 2134256.9417648204]
[704.066696427469, 9905654.425629245, 241.87412047313887, 202.13821191757785, 6635623.652335165, 5572807.111875406, 3752267.803056445]
[733.5899356836424, 379742.7811253862, 143.84385875206974, 348.8082665777181, 4931230.077194873, 9988152.365659613, 5747594.908894304]
[991.1581394540857, 1532130.2418167505, 290.4829658445207, 377.51640851688416, 5945214.245429829, 4235033.037064519, 3814570.9896928594]
[205.29024358227622, 4635584.214882325, 63.14220620358574, 102.14724507109108, 340204.32462215645, 8362490.91015642, 1738963.9820665203]
[162.03123313213862, 2375161.9649860878, 62.5470

[151.57635234332076, 6705484.375381629, 103.03684862174333, 305.73411821507506, 4410491.463555221, 5853866.002599034, 362655.1649045069]
[694.5728801958473, 6456495.310714527, 178.51321270153082, 21.866635991729883, 7625198.515294598, 9776116.447957722, 5929817.4099199]
[667.3452566209378, 5600916.878431991, 41.245368097098236, 8.462409638091113, 6308786.206034159, 9538034.564548451, 4895330.954532336]
 13%|███████▌                                                    | 63/500 [27:53<4:06:26, 33.84s/trial, best loss: 0.0]