In [1]:
import pandas as pd
import numpy as np
import statistics as s
from datamodel import TradingState, Listing, OrderDepth

In [2]:
data = pd.read_csv('results/zero-trades.csv', delimiter=';')

In [3]:
data.head(10)

Unnamed: 0,day,timestamp,product,bid_price_1,bid_volume_1,bid_price_2,bid_volume_2,bid_price_3,bid_volume_3,ask_price_1,ask_volume_1,ask_price_2,ask_volume_2,ask_price_3,ask_volume_3,mid_price,profit_and_loss
0,-1,0,PEARLS,9998,1,9995.0,30.0,,,10005,30,,,,,10001.5,0.0
1,-1,0,BANANAS,4946,1,4945.0,30.0,,,4952,30,,,,,4949.0,0.0
2,-1,100,PEARLS,9996,1,9995.0,30.0,,,10002,6,10004.0,1.0,10005.0,30.0,9999.0,0.0
3,-1,100,BANANAS,4945,31,,,,,4950,7,4952.0,31.0,,,4947.5,0.0
4,-1,200,BANANAS,4945,22,,,,,4951,1,4952.0,21.0,,,4948.0,0.0
5,-1,200,PEARLS,9996,1,9995.0,21.0,,,10004,1,10005.0,21.0,,,10000.0,0.0
6,-1,300,PEARLS,9996,2,9995.0,23.0,,,9998,3,10004.0,2.0,10005.0,23.0,9997.0,0.0
7,-1,300,BANANAS,4945,25,,,,,4952,25,,,,,4948.5,0.0
8,-1,400,PEARLS,9998,5,9996.0,2.0,9995.0,23.0,10004,2,10005.0,23.0,,,10001.0,0.0
9,-1,400,BANANAS,4946,5,4945.0,25.0,,,4952,25,,,,,4949.0,0.0


In [4]:
datad = data.to_dict("records")

In [5]:
trading_states_d = {}
for d in datad:
    timestamp = d['timestamp']
    product = d['product']
    buy_orders = {}
    sell_orders = {}
    # bids
    for i in range(1,4):
        price = d['bid_price_'+str(i)]
        volume = d['bid_volume_'+str(i)]
        
        buy_orders[price] = volume
    # asks
    for i in range(1,4):
        price = d['ask_price_'+str(i)]
        volume = d['ask_volume_'+str(i)]
        sell_orders[price] = volume
    
    listing = Listing(
        product,
        product,
        product
    )

    order_depth = OrderDepth(
        buy_orders,
        sell_orders,
        d['mid_price']
    )


    if timestamp not in trading_states_d:
        trading_states_d[timestamp] = TradingState(
            timestamp,
            {},
            {},
            {},
            {},
            {},
            {}
        )

    trading_states_d[timestamp].listings[product] = listing
    trading_states_d[timestamp].order_depths[product] = order_depth


states = list(trading_states_d.items())
states[0][1].order_depths['PEARLS'].buy_orders

{9998: 1, 9995.0: 30.0, nan: nan}

In [6]:
def wavg_bid_price(states, windows, product):
    ret = []
    mw = max(windows)
    ds = states[-mw:]
    wsum = 0
    vsum = 0
    avg = 0
    for wi, (_, d) in enumerate(ds[::-1]):
        for p, v in d.order_depths[product].buy_orders.items():
            if not np.isnan(p):
                wsum += p*v
                vsum += v
        if wi+1 in windows:
            ret.append(wsum/vsum)
            avg += wsum/vsum

    while len(ret) < len(windows):
        ret.append(avg/len(windows))
    return ret

print(wavg_bid_price(states, [1, 5, 10, 15], "PEARLS"))

def wavg_ask_price(states, windows, product):
    ret = []
    mw = max(windows)
    ds = states[-mw:]
    wsum = 0
    vsum = 0
    avg = 0
    for wi, (_, d) in enumerate(ds[::-1]):
        for p, v in d.order_depths[product].sell_orders.items():
            if not np.isnan(p):
                wsum += p*v
                vsum += v
        if wi+1 in windows:
            ret.append(wsum/vsum)
            avg += wsum/vsum
    while len(ret) < len(windows):
        ret.append(avg/len(windows))
    return ret

print(wavg_ask_price(states, [1, 5, 10, 15], "PEARLS"))

def avg_mid_price(states, windows, product):
    ret = []
    mw = max(windows)
    ds = states[-mw:]
    wsum = 0
    vsum = 0
    avg = 0
    for wi, (_, d) in enumerate(ds[::-1]):
        wsum += d.order_depths[product].mid_price
        vsum += 1
        if wi+1 in windows:
            ret.append(wsum/vsum)
            avg += wsum/vsum
    while len(ret) < len(windows):
        ret.append(avg/len(windows))
    return ret

print(avg_mid_price(states, [1, 5, 10, 15], "PEARLS"))

def volume_diff(states, windows, product):
    ret = []
    mw = max(windows)
    ds = states[-mw:]
    vsum = 0
    avg = 0
    for wi, (_, d) in enumerate(ds[::-1]):
        for p, v in d.order_depths[product].buy_orders.items():
            if not np.isnan(p):
                vsum -= v
        for p, v in d.order_depths[product].sell_orders.items():
            if not np.isnan(p):
                vsum += v
        if wi+1 in windows:
            ret.append(vsum)
            avg = vsum
    while len(ret) < len(windows):
        ret.append(0)
    return ret

print(volume_diff(states, [1, 5, 10, 15], "PEARLS"))

def best_prices(states, windows, product):
    ret = []
    mw = max(windows)
    ds = states[-mw:]
    asks = []
    bids = []
    for wi, (_, d) in enumerate(ds[::-1]):
        for p, _ in d.order_depths[product].buy_orders.items():
            if not np.isnan(p):
                bids.append(p)
        for p, _ in d.order_depths[product].sell_orders.items():
            if not np.isnan(p):
                asks.append(p)
        if wi+1 in windows:
            ret.append(max(bids))
            ret.append(min(asks))
    while len(ret) < len(windows)*2:
        ret.append(0)

    return ret

print(best_prices(states[:1], [1, 5], "PEARLS"))

[9995.79411764706, 9995.265734265735, 9995.253472222223, 9995.217494089835]
[10005.0, 10004.866666666667, 10004.909420289856, 10004.885922330097]
[10001.5, 10000.3, 10000.5, 10000.3]
[-9.0, -8.0, -12.0, -11.0]
[9998, 10005, 0, 0]


In [7]:
def compute_gt(states, product, margins, windows, volumes, shortsell=False):
    ret = {}
    mw = max(windows)
    ds = states[:mw+1]

    # Purchases
    infrom = ds[0][1].order_depths[product].sell_orders
    if shortsell:
        infrom = ds[0][1].order_depths[product].buy_orders
    purchases = {}
    psum = 0
    vsum = 0
    tbuy = 0
    for p, v in infrom.items():
        if tbuy > max(volumes):
            break
        if not np.isnan(p):
            for pv in range(1, int(v)+1):
                if tbuy > max(volumes):
                    break
                psum += p
                vsum += 1
                tbuy += 1
                if tbuy in volumes:
                    purchases[tbuy] = psum/vsum
                    psum = 0
                    vsum = 0

    # Sells  
    for margin in margins:
        ret[margin] = {window:{volume: 0 for volume in volumes} for window in windows}
        paidprice = purchases[1]
        tsold = 0
        
        for wi, (_, d) in enumerate(ds[1:]):
            outto = d.order_depths[product].buy_orders
            if shortsell:
                outto = d.order_depths[product].sell_orders
            
            tsum = 0
            vsum = 0
            avgsell = 0
            for p, v in outto.items():
                if not np.isnan(p):
                    for _ in range(int(v)):
                        tsum += p
                        vsum += 1
                        avgsell = tsum/vsum
                        if (
                            ((avgsell < (paidprice + margin)) and not shortsell) or
                            (((paidprice - margin) < avgsell) and shortsell)
                        ):
                            break
                        
                        tsold += 1
                        if tsold in volumes:
                            if tsold not in purchases:
                                break
                            paidprice = purchases[tsold]
            
                if (
                    ((avgsell < (paidprice + margin)) and not shortsell) or
                    (((paidprice - margin) < avgsell) and shortsell)
                ):
                    break

            # Check Window
            if wi+1 in windows:
                ret[margin][wi+1] = {
                    volume: 1 if tsold >= volume else 0 for volume in volumes 
                }
    return ret

print(compute_gt(states[325:], 'BANANAS', [1, 2, 4], [1, 2, 4, 8, 16, 32, 64], [1, 2, 4, 8, 16, 20], shortsell=False))
print(compute_gt(states[325:], 'BANANAS', [1, 2], [1,100], [1, 2], shortsell=True))

{1: {1: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 2: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 4: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 8: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 16: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 32: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 64: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}}, 2: {1: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 2: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 4: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 8: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 16: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 32: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 64: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}}, 4: {1: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 2: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 4: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 8: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 16: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 32: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}, 64: {1: 0, 2: 0, 4: 0, 8: 0, 16: 0, 20: 0}}}
{1: {1: {1: 0, 2: 0}, 100: {1: 1, 2: 1}}, 2: {1: {1: 0, 2: 0}, 100: {1: 

In [8]:
print(states[50][1].order_depths['BANANAS'].buy_orders)
states[50][1].order_depths['BANANAS'].sell_orders

maxb = 0
mins = 10000000
for i,state in enumerate(states):
    maxb = max(maxb, max(state[1].order_depths['BANANAS'].buy_orders.keys()))
    mins = min(mins, min(state[1].order_depths['BANANAS'].sell_orders.keys()))
    # if min(state[1].order_depths['BANANAS'].sell_orders.keys()) < 10002:
    #     print(i)
    #     break
    if max(state[1].order_depths['BANANAS'].buy_orders.keys()) == 4955:
        print(i)
        break
print(maxb)
print(mins)

{4947: 4, 4946.0: 22.0, nan: nan}
325
4955
4943


In [9]:
products = ['PEARLS', 'BANANAS']
inds = [wavg_bid_price, wavg_ask_price, avg_mid_price, volume_diff, best_prices]
windows = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]

def compute_single(states, inds, windows, product):
    x = []
    for ind in inds:
        x.extend(ind(states, windows, product))
    return x

def compute_training(states, products, inds, windows):
    xss = {}
    yss = {}
    for product in products:
        ys = []
        xs = []
        for i in range(len(states)):
            xws = states[:i]
            yws = states[i:]
            xs.append(np.array(compute_single(xws, inds, windows, product)))
            ys.append(
                {
                    "buy":compute_gt(yws, product, [1, 2, 4, 8, 16, 32, 64], [1, 2, 4, 8, 16, 32, 64], [1, 2, 4, 8, 16, 20]),
                    "borrow":compute_gt(yws, product, [1, 2, 4, 8, 16, 32, 64], [1, 2, 4, 8, 16, 32, 64], [1, 2, 4, 8, 16, 20], shortsell=True)
                }
            )
        xss[product] = np.array(xs)
        yss[product] = ys
        
    for i, product_i in enumerate(products):
        for product_j in products[i+1:]:
            xss[product_i+product_j] = xss[product_i] - xss[product_j]
    return xss, yss

xss, yss = compute_training(states, products, inds, windows)

In [10]:
print(xss['BANANAS'][150])
print(yss['BANANAS'][150]['buy'])

[ 4.94206667e+03  4.94205769e+03  4.94209615e+03  4.94233491e+03
  4.94286247e+03  4.94310216e+03  4.94414937e+03  4.94528633e+03
  3.95439557e+03  3.95439557e+03  4.94857895e+03  4.94873770e+03
  4.94880180e+03  4.94913270e+03  4.94961468e+03  4.94986576e+03
  4.95092804e+03  4.95200944e+03  3.95976691e+03  3.95976691e+03
  4.94500000e+03  4.94550000e+03  4.94537500e+03  4.94562500e+03
  4.94625000e+03  4.94648438e+03  4.94736719e+03  4.94857812e+03
  3.95701797e+03  3.95701797e+03  8.00000000e+00  9.00000000e+00
  7.00000000e+00 -1.00000000e+00  7.00000000e+00 -2.00000000e+00
  3.00000000e+00 -2.00000000e+00  0.00000000e+00  0.00000000e+00
  4.94300000e+03  4.94700000e+03  4.94300000e+03  4.94700000e+03
  4.94300000e+03  4.94700000e+03  4.94300000e+03  4.94700000e+03
  4.94800000e+03  4.94700000e+03  4.94900000e+03  4.94700000e+03
  4.95000000e+03  4.94600000e+03  4.95300000e+03  4.94600000e+03
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
{1: {1: {1: 1, 2: 1, 4: 

In [11]:
xss['BANANAS'] = xss['BANANAS'] / np.linalg.norm(xss['BANANAS'])
xss['PEARLS'] = xss['PEARLS'] / np.linalg.norm(xss['PEARLS'])

In [12]:
print(xss['BANANAS'][150])
print(yss['BANANAS'][150]['buy'])

[ 3.22131560e-03  3.22130975e-03  3.22133482e-03  3.22149045e-03
  3.22183432e-03  3.22199055e-03  3.22267314e-03  3.22341423e-03
  2.57753629e-03  2.57753629e-03  3.22556041e-03  3.22566389e-03
  3.22570567e-03  3.22592135e-03  3.22623551e-03  3.22639917e-03
  3.22709158e-03  3.22779646e-03  2.58103740e-03  2.58103740e-03
  3.22322759e-03  3.22355350e-03  3.22347203e-03  3.22363498e-03
  3.22404236e-03  3.22419513e-03  3.22477056e-03  3.22555987e-03
  2.57924560e-03  2.57924560e-03  5.21452391e-06  5.86633940e-06
  4.56270843e-06 -6.51815489e-07  4.56270843e-06 -1.30363098e-06
  1.95544647e-06 -1.30363098e-06  0.00000000e+00  0.00000000e+00
  3.22192396e-03  3.22453123e-03  3.22192396e-03  3.22453123e-03
  3.22192396e-03  3.22453123e-03  3.22192396e-03  3.22453123e-03
  3.22518304e-03  3.22453123e-03  3.22583486e-03  3.22453123e-03
  3.22648667e-03  3.22387941e-03  3.22844212e-03  3.22387941e-03
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]
{1: {1: {1: 1, 2: 1, 4: 

In [13]:
# The GT is converted to classification with 7 classes - [MustBuy StrongBuy SoftBuy Neutral SoftSell StrongSell MustSell]
# Look at below setup for definitions
# Hold is everything else
# margins, windows, volumes
#[1, 2, 4, 8, 16, 32, 64], [1, 2, 4, 8, 16, 32, 64], [1, 2, 4, 8, 16, 20]

def compute_classes(yss, products, definitions):
    gts = {}
    names = []
    for d in definitions:
        names.append(d['name'])
    names.append('Neutral')
    for product in products:
        neutral = 0
        gts[product] = []

        for state in yss[product]:
            gt = [0] * (len(definitions)+1)
            for i, definition in enumerate(definitions):
                typehold = state['buy'] if definition['type'] == 'buy' else state['borrow'] 
                margins = definition['margins']
                mvolume = definition['minvol']
                wnstates = definition['withinstates']

                for margin in margins:
                    if typehold[margin][wnstates][mvolume] == 1:
                        gt[i] = 1
                
                if gt[i] == 1:
                    break
            if gt[i] == 0:
                gt[-1] = 1
                neutral += 1
            gts[product].append(np.array(gt))
        print(product, neutral)
    return names, gts

definitions = [
    {
        'name': 'FlipBuy',
        'type': 'buy',
        'margins': [64, 32, 16, 8, 4, 2, 1],
        'minvol':2,
        'withinstates':8
    },
    {
        'name': 'ShortBuy',
        'type': 'buy',
        'margins': [64, 32, 16, 8, 4, 2, 1],
        'minvol':4,
        'withinstates':16
    },
    {
        'name': 'MediumBuy',
        'type': 'buy',
        'margins': [64, 32, 16, 8, 4, 2, 1],
        'minvol':4,
        'withinstates':32
    },
    {
        'name': 'LongBuy',
        'type': 'buy',
        'margins': [64, 32, 16, 8, 4, 2, 1],
        'minvol':8,
        'withinstates':64
    },
    {
        'name': 'CrashShort',
        'type': 'borrow',
        'margins': [64, 32, 16, 8, 4, 2, 1],
        'minvol':2,
        'withinstates':8
    },
    {
        'name': 'ShortShort',
        'type': 'borrow',
        'margins': [64, 32, 16, 8, 4, 2, 1],
        'minvol':4,
        'withinstates':16
    },
    {
        'name': 'MediumShort',
        'type': 'borrow',
        'margins': [64, 32, 16, 8, 4, 2, 1],
        'minvol':4,
        'withinstates':32
    },
    {
        'name': 'LongShort',
        'type': 'borrow',
        'margins': [64, 32, 16, 8, 4, 2, 1],
        'minvol':8,
        'withinstates':64
    },
]

names, gts = compute_classes(yss, products, definitions)

print(names)
print(gts['BANANAS'])

PEARLS 1723
BANANAS 1176
['FlipBuy', 'ShortBuy', 'MediumBuy', 'LongBuy', 'CrashShort', 'ShortShort', 'MediumShort', 'LongShort', 'Neutral']
[array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 1, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 1, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 1, 0, 0, 0, 0, 0, 0]), array([0, 0, 1, 0, 0, 0, 0, 0, 0]), array([0, 0, 1, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([1, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 1, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([1, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([1, 0, 0, 0, 0, 0, 0, 0, 0]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 0, 0, 1]), array([0, 0, 0, 0, 0, 0, 0, 0, 

In [14]:
from src.logreg import LogisticRegression, serialize_logreg, deserialize_logreg
from imblearn.over_sampling import SMOTE

print(xss['BANANAS'].shape)
print(np.array(gts['BANANAS']).shape)

hardcode = {}
for product in products:
    hardcode[product] = []
    for i in range(9):
        logreg = LogisticRegression(learning_rate=0.01,num_iter=10000)
        sm = SMOTE(random_state=42)
        X_res, y_res = sm.fit_resample(xss[product], np.array(gts[product])[:,i])
        logreg.train(X_res, y_res)
        print(logreg.eval(X_res, y_res))
        hardcode[product].append(
            serialize_logreg(logreg)
        )
print(hardcode)

(2000, 60)
(2000, 9)
[[1420  483]
 [1339  564]]
0.5212821860220704
[[ 509 1479]
 [ 323 1665]]
0.5467806841046278
[[ 506 1467]
 [ 371 1602]]
0.5342118601115053
[[ 255 1734]
 [ 142 1847]]
0.5284062342885872
[[1423  480]
 [1269  634]]
0.5404624277456648
[[1740  253]
 [ 846 1147]]
0.7242849974912192
[[1731  250]
 [1329  652]]
0.6014639071176173
[[ 512 1481]
 [   0 1993]]
0.6284495735072755
[[ 497 1226]
 [ 427 1296]]
0.52031340684852
[[1651  239]
 [1585  305]]
0.5174603174603175
[[1594  372]
 [1251  715]]
0.5872329603255341
[[1450  460]
 [ 797 1113]]
0.6709424083769634
[[1439  467]
 [ 983  923]]
0.6196222455403987
[[1732  117]
 [1697  152]]
0.5094645754461872
[[ 506 1437]
 [ 200 1743]]
0.57874420998456
[[ 491 1382]
 [ 335 1538]]
0.5416444207154298
[[ 492 1347]
 [ 249 1590]]
0.566068515497553
[[1126   50]
 [1085   91]]
0.5174319727891157
{'PEARLS': [{'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [0.00012108514716673519, 4.270478513400809e-05, 8

In [15]:
hard = {'PEARLS': [{'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [0.00012108514716673519, 4.270478513400809e-05, 8.088298035840683e-05, -0.00031317241274268895, -0.0001948135309731971, 8.011828845798724e-06, 0.00034610567109686637, -0.0016209362270616084, -0.0017618692809446225, -0.0010705909471213338, -0.001297389013698461, 4.307600460111574e-05, 8.095678099710229e-05, -0.00031361058024811494, -0.00019497518689210372, 7.93675677557806e-06, 0.00034635104518326745, -0.0016225205303439755, -0.0017635675682644435, -0.001071682813936673, -0.001298689202738018, 4.206433317022106e-05, 7.942827302116973e-05, -0.0003141754621731288, -0.00019484569076575728, 8.430295143008256e-06, 0.00034659543818893235, -0.001621491061880873, -0.0017626447425563027, -0.0010710664332956672, -0.0012981013858160487, -4.2967398981632414e-06, -7.412281073912612e-06, -1.2081471570877044e-05, -2.2533002260467078e-05, -4.228784726078983e-05, -3.4475550537570106e-05, -1.8060725378010682e-05, -2.0639481692421055e-05, 1.9063585676564748e-05, 7.508928799217469e-05, 4.188979029698553e-05, 4.22388760434606e-05, 8.360457960841642e-05, 8.288759486376957e-05, -0.00040883865194031165, -0.00041105026399927447, -0.0002393865219111018, -0.0002402812807527513, 0.00010099533922825871, 0.00010022681994689624, 0.0007766001816196412, 0.0007768429884211554, -0.004144391529523069, -0.004142734104396281, -0.004614455923100261, -0.004612610509813698, -0.0011558578407659965, -0.0011553955900798316, -0.0034252772456762736, -0.003423907408745403]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-0.0002350702801734173, 3.948295739439951e-05, 7.638039789526808e-05, 0.00014123953252262637, 0.00025459590077989575, 0.00044831281621482996, 0.0007718055129237232, 0.0012898129913334444, 0.0020665562351121075, 0.0024158985397585006, 0.0025705840698031297, 4.0831746475328364e-05, 7.68886527392358e-05, 0.00014164041977441516, 0.00025458206112785446, 0.00044884609816993424, 0.0007727559179826786, 0.0012910540938936039, 0.002068634278220837, 0.002418314115568252, 0.0025732077332384956, 4.2541717775739064e-05, 7.78103052119027e-05, 0.0001432921172432725, 0.00025560766615707474, 0.00044854447863625085, 0.0007715400410869455, 0.001290054558807756, 0.0020673775278406367, 0.002416948066601522, 0.0025721003871066648, -9.802609485815335e-06, 3.5365275947735184e-06, -1.784742508582452e-05, 2.5670230727580355e-05, 2.7008230425257355e-05, 0.0001094491477120416, 9.736768383116222e-05, 0.0001208261911206194, 4.2213237087795973e-05, -0.0003510936468479941, 3.763644380188787e-05, 4.744699174958541e-05, 7.843639977599389e-05, 8.814385278577964e-05, 0.00016613807680286894, 0.00016728368871784448, 0.0003247101414320252, 0.0003285643495162854, 0.0006443513389221645, 0.0006515439680956704, 0.0012884217082100005, 0.0012959095782779891, 0.002590322307643235, 0.002589286385904518, 0.005181430765784127, 0.0051793586079093784, 0.006929256005191129, 0.006926484857018697, 0.00847667101813286, 0.00847328102772367]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-0.00020801859524815526, 4.0349593295385676e-05, 7.722661921625073e-05, 0.00014237148437395452, 0.00025629259843317495, 0.0004519220577764948, 0.0007778084208878992, 0.001299686952907222, 0.002082454180791184, 0.0023313925075755123, 0.0021355658026691532, 3.859937414052817e-05, 7.629902304086549e-05, 0.0001416485320555816, 0.0002561167735212321, 0.0004520988073464678, 0.0007785585160442415, 0.0013009753499187845, 0.002084498182571359, 0.002333677594095769, 0.002137489917206194, 3.728659589669659e-05, 7.629435814596953e-05, 0.00014131351805279574, 0.0002553130757253513, 0.00045201607483529434, 0.0007787488513289177, 0.0013007608841257617, 0.002083591836951402, 0.002332532836368801, 0.002136359570568306, -1.2344267806413814e-05, -1.4209895270890083e-05, -2.2468386929977017e-05, -3.888640192453001e-05, 1.5928363997213603e-05, 7.432144740315159e-05, 3.8762011367316796e-05, 7.430095425833452e-05, 0.00022351756267000717, 0.000637177003814771, 4.106227626467171e-05, 3.351091552872384e-05, 8.357802647416004e-05, 7.711880942072852e-05, 0.00016571493457595687, 0.00015806758885445732, 0.0003226903886912208, 0.00032451059477975737, 0.000652042211970808, 0.0006516223324849798, 0.0013057413930661298, 0.001305168344476802, 0.002610123860033771, 0.0026090800192579157, 0.00522094652814815, 0.005218858567129088, 0.006466484390511827, 0.006463898313970926, 0.004507086992273545, 0.004505284517971477]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-0.0001859044380700035, 3.8611001377488225e-05, 7.547755400406442e-05, 0.00014037562010193823, 0.0002539653054900584, 0.00044809978659194504, 0.0007714669589385605, 0.001289160814313736, 0.002065599806347213, 0.0023483353438243522, 0.001747825484213474, 4.068248467263889e-05, 7.661279901548091e-05, 0.00014138190675991522, 0.0002545374000645148, 0.00044862349732292063, 0.0007722266177455712, 0.0012904059191804185, 0.0020676774285276916, 0.0023506523098595143, 0.0017495327697986222, 4.068693235725406e-05, 7.692481137042032e-05, 0.00014256857503333844, 0.000254954190350412, 0.00044867776407010897, 0.0007708959682367635, 0.0012891132478213437, 0.002066229198781172, 0.002349746828328401, 0.0017484214595940622, 4.037414104513738e-06, 9.04917581317917e-06, 1.0555808319270253e-05, 7.534147509092673e-06, -1.2419689661714761e-05, 2.4997539634009066e-05, 0.0001632640071989032, 0.00016650004732090371, 0.0001742759642819403, 0.00022827535436113877, 3.722174821426997e-05, 4.415211650023842e-05, 7.732890482105071e-05, 8.50429413874579e-05, 0.0001674507017892729, 0.00016473267520099294, 0.000327586840455025, 0.00032055286072572907, 0.0006491393653943373, 0.0006491426226867785, 0.0012953098781813565, 0.001292947811374576, 0.002589205667962598, 0.002588170192790446, 0.00517903806738826, 0.005176966866401503, 0.0065929431885637995, 0.006590306538618346, 0.0005839074426457701, 0.0005836739263719618]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [0.00010004893092460815, 4.243868324766983e-05, 8.074006427346752e-05, 0.00014818687826829286, 0.0002666165128532574, 0.00046951792799857373, 0.0005191058647502989, -0.001447904854659705, -0.000971591103954069, -0.001069304691862621, -0.0014158413775722937, 4.293773718455025e-05, 8.07844375597944e-05, 0.00014839741010160346, 0.0002668448651297012, 0.0004699768594315219, 0.0005195950727546434, -0.0014493250377397562, -0.0009726170765442093, -0.001070420537262794, -0.0014172642211183796, 4.3860001974954377e-05, 8.124137232972772e-05, 0.0001474607861461866, 0.00026655915737802444, 0.0004697948952805912, 0.0005196399365552628, -0.00144848692296705, -0.0009721813440701562, -0.0010699013528017668, -0.001416303604495175, -7.89849399506407e-06, -7.053206746089782e-06, -9.770807494869907e-07, 7.196530101885496e-06, 4.918017831552928e-06, -4.266176703536481e-06, -1.0484212836981637e-05, -1.0621837277891634e-05, 5.235047611459717e-05, 1.92087899953521e-05, 4.434220569923848e-05, 4.337779825067513e-05, 8.608794872247272e-05, 8.531011633107925e-05, 0.00016737276763373522, 0.00016762410747828266, 0.00033555318861565173, 0.00034003515336547026, 0.000675868456333675, 0.000678118558120563, 0.0007762553881225593, 0.0007772547601951349, -0.004144441060593384, -0.004142783615658136, -0.0025556977693123217, -0.002554675694619532, -0.0030445949718607003, -0.003043377377390837, -0.006512745587906236, -0.006510141010586533]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-0.00014990032191053952, 3.834218174994118e-05, 7.50302589399594e-05, 0.00013958960888286884, 0.00025268601973628814, 0.00044723072348939997, 0.0007699256659734434, 0.0012866082544051455, 0.0020617102809874373, 0.003094906895717021, 0.0007646562320502917, 3.836848631871534e-05, 7.471334073017945e-05, 0.00013945966606439592, 0.00025336397385459245, 0.00044746866221823215, 0.0007707337283108282, 0.0012880227204665209, 0.002063693650909896, 0.0030978665346303984, 0.0007652176441952196, 3.2167068653257604e-05, 6.728104328187622e-05, 0.00013128274949650707, 0.0002476273973964695, 0.00044664042689959856, 0.0007696752391439992, 0.0012871230370042422, 0.0020624582729850674, 0.003096224049706152, 0.0007630874580993274, 2.4843425783783413e-05, 8.14053013861794e-06, 2.9717420971310193e-05, 5.365551668200633e-05, 3.207880065534162e-05, -1.2379095313457034e-05, 0.00014234860949986495, 0.00027815153391195195, 0.0003452578627041186, 0.00017773924359845353, 2.9536077375290165e-05, 3.479805993122457e-05, 6.860651808651504e-05, 5.9679577774939636e-05, 0.0001425110793840805, 0.00014441991836951101, 0.0003027156976795659, 0.00031199735104138314, 0.0006482851529785625, 0.0006443935214689988, 0.0012928355772150683, 0.0012916855702953495, 0.0025841450218016637, 0.002583111570483207, 0.005168800441400427, 0.005166733334645232, 0.010338166811113249, 0.010334032371276752, -0.012977715389330636, -0.012972525341184529]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [0.0002524850340500397, 4.008524401905281e-05, 7.677960061583568e-05, 0.00014242504295797407, 0.00025659906634827067, 0.0004516512108823461, 0.000776552801327396, 0.0012963485962288579, -0.00434257835872331, -0.0041761094322787864, -0.0037753033830143492, 4.072104812534623e-05, 7.769449026754834e-05, 0.00014286122443741352, 0.0002567520615563347, 0.00045188144733353144, 0.0007771358092418363, 0.0012973622368360497, -0.0043469524742065484, -0.004180246858653241, -0.003779011191002241, 3.741021260899161e-05, 7.588624209114853e-05, 0.00014240830236650007, 0.0002565610930496162, 0.0004523539203986642, 0.0007771922109126116, 0.0012973348889050817, -0.004344506077048099, -0.0041781825897713705, -0.0037769889756361576, 1.6635548569124824e-05, 2.340286370753749e-05, 2.1333981554155616e-05, 4.181217388154324e-05, 5.162401227873911e-05, -4.11363188942081e-05, -6.738028995635459e-05, -0.00010135437995495729, -0.0002210936712863456, -0.00040078088598871186, 3.587174223288266e-05, 3.8948682985097406e-05, 7.29254853680051e-05, 8.20145733181976e-05, 0.0001568017279619126, 0.00016237580309806621, 0.0003218030346210599, 0.0003283619593243029, 0.0006510745223314647, 0.0006547214633066599, 0.0012996021951346567, 0.0012998489772514886, 0.002601291624129303, 0.0026002513155413617, -0.016207453261500093, -0.01620097157653235, -0.015374096510793215, -0.015367948101870687, -0.011363102356156032, -0.011358558024079975]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-0.0003453320236914957, 3.79026019147202e-05, 7.513182730978656e-05, 0.00013987144271411137, 0.00025295718638508087, 0.000446693208581189, 0.0007697586343089202, 0.0012862584849223824, 0.0020610622020930647, 0.0030943253001908767, 0.0041274583898027895, 3.85705491168831e-05, 7.530685393908089e-05, 0.00014009765501019408, 0.00025382832986059467, 0.00044770170670849224, 0.0007706090623708022, 0.0012876978623504903, 0.002063244808696544, 0.00309741021233472, 0.00413153461615415, 3.19390221843995e-05, 7.146272203370107e-05, 0.00013827270698107377, 0.00025363087923554073, 0.0004479788108100761, 0.0007706036808398916, 0.0012873590379900207, 0.002062158426474741, 0.003096198331053754, 0.004129834606710335, 1.3122814497408266e-05, 6.511047637149191e-06, -2.188974647039185e-05, -8.972805315093274e-05, -9.433436027929878e-05, -3.283611353109092e-05, -0.0001409385096796766, -0.00036880143803829767, -0.0008095735528914655, -0.001161088652221248, 3.124590361729083e-05, 3.2632140751506914e-05, 7.815067876686532e-05, 7.90069600565254e-05, 0.000166446193840701, 0.00014905881779562165, 0.0003238728079629366, 0.00031472518015009603, 0.0006464142866516958, 0.0006445106571612404, 0.0012870192908054304, 0.0012916379895915127, 0.0025834051050411043, 0.0025823719496301726, 0.005167953979989802, 0.005165887211751454, 0.01033711145141943, 0.010332977433642414, 0.020675623991475104, 0.020667355395597713]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-9.627215824320833e-05, -4.67415304737409e-05, -8.895383735059803e-05, 0.0001135074631914607, -1.7230773844753355e-05, -0.00024132886667251418, -0.000470444962402992, 0.0010493889246082208, 0.0013434652449924986, 0.0011193696549421169, 0.0011660645933206798, -4.723571246237637e-05, -8.917020554125386e-05, 0.00011365016579866178, -1.7277577569808834e-05, -0.0002415292129231078, -0.000470876607120238, 0.0010504206949409369, 0.0013447779606264622, 0.001120482209700527, 0.0011672101111221557, -4.6717658630234634e-05, -8.841069900769584e-05, 0.00011425616465122782, -1.7033220508761337e-05, -0.0002417551426711906, -0.0004710010286277518, 0.0010496592771632042, 0.0013440564245290054, 0.0011198459757627986, 0.0011665899357198048, 4.386264942804899e-06, 3.455994872714523e-06, 5.402634056976346e-06, 4.98967833780485e-06, 8.947558889735357e-06, 8.06851843352919e-06, 1.699510078917947e-05, 4.291703081921476e-05, 1.974939583019996e-05, -1.0392881794480473e-05, -4.662797288164442e-05, -4.680734437882589e-05, -9.227542648344544e-05, -9.271998494979382e-05, 0.00016016016669163972, 0.00016234411539484, -2.525947297150948e-05, -2.7159628946683592e-05, -0.00040125050126775785, -0.0004023727330791856, -0.000859238579571355, -0.0008604916498173304, 0.0029425374814697623, 0.0029413607018331177, 0.003923455948019327, 0.003921886879453848, 0.0028022598387675997, 0.002801139158968048, 0.0032693936949300434, 0.0032680861989512508]}], 'BANANAS': [{'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [0.00013799587277872274, 4.0058847204610545e-05, 7.806080984987613e-05, 0.00014632447952682147, 0.0002655291422054703, 7.090108774687172e-05, -0.0008680269283080763, -0.0021871276106522107, -0.0013694960751220529, -0.0014477294309909851, -0.0014361448287137871, 3.8188127879924085e-05, 7.755974950173224e-05, 0.0001465564392854816, 0.0002661233450522742, 7.098055638836418e-05, -0.0008691794171706897, -0.0021900842435588386, -0.0013714161668148192, -0.0014496993097916457, -0.001438107985680801, 3.9578699796914636e-05, 7.836270329994781e-05, 0.00014629581457161313, 0.0002654609322100176, 7.139586411215254e-05, -0.0008683732404525326, -0.0021884525804909526, -0.0013703273797157772, -0.0014486217336104061, -0.0014373437148260722, -8.08642656097738e-06, -1.1382182102641808e-05, -1.0573430913393925e-05, -1.0824991358140573e-05, -7.23150376283629e-05, -4.827443547463047e-05, -8.232063154810088e-05, -0.00013715758152542647, -0.00015371441394992436, -0.00022521753331422925, 4.1991110259009285e-05, 3.716628933482045e-05, 8.401269193440076e-05, 8.383846146139367e-05, 0.00017099826846302772, 0.0001689207896940545, 0.0003395298934410034, 0.00033898067099384936, 1.4994468693705895e-05, 1.427637007796844e-05, -0.0018672551801989941, -0.0018592481158983837, -0.005170751881790631, -0.0051616909583054316, -0.0024394700117847416, -0.0024315208720918882, -0.002825388892427679, -0.002816222305132237, -0.0027333652654653027, -0.0027298625278373554]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-0.00012805584112260743, 5.982561678550817e-05, 9.545662165179044e-05, 0.00015944520832976384, 0.00027281304872997747, 0.00046552089513965964, 0.0007919300170462094, 0.0013193320759912672, 0.002106094258879178, 0.0020782277377468386, 0.0006821555504160617, 5.904592922842698e-05, 9.567481566927456e-05, 0.00015949519992210787, 0.0002728545770598869, 0.00046631731349662466, 0.0007930267890329117, 0.0013210145287903157, 0.0021089774286243783, 0.0020810987040980945, 0.0006826931495403977, 5.7396138408089925e-05, 9.412158619407141e-05, 0.00015884937511871858, 0.0002726255160086925, 0.00046574240349291844, 0.0007922947108749558, 0.00132052973948587, 0.0021073349168533304, 0.0020791982447270394, 0.0006819853673938048, 8.409626374049139e-06, 1.2825802455054653e-05, 1.1340371333657707e-05, -1.765885626236787e-06, 6.84244565651144e-05, 6.0256258686696576e-05, 5.328898452011753e-05, 2.852708035622106e-05, -2.152151465808788e-05, -0.00035335462759486107, 5.699216698692451e-05, 5.7800109829255556e-05, 9.703394624105798e-05, 9.975922185851674e-05, 0.0001758402148167197, 0.00017870426494073194, 0.0003336660886634651, 0.0003423763634204641, 0.000653935823061917, 0.0006708764690563243, 0.0013132043448453859, 0.0013297605927653294, 0.002640801110869429, 0.002638380582073029, 0.005268056766196346, 0.005255506035108699, 0.005119158406865119, 0.005111899405814131, -0.008824429558616136, -0.008793267114277463]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [0.00035581464303341827, 0.00010529348811269654, -0.00043154860701215427, -0.0003640927278347702, -0.0009039135541489292, -0.0021447577585889128, -0.0018053808378910665, -0.001260077263940595, -0.00044556338222421273, -0.003626977195443346, -0.0049447933709106445, 0.00010630103891537262, -0.00043146947226689863, -0.00036402109774155646, -0.0009046467641035592, -0.0021473600988181564, -0.001807884159233223, -0.0012617372458309892, -0.00044607663430443266, -0.003631718742165857, -0.004951504409075833, 0.00010563091497323009, -0.00043153414324462045, -0.0003640199629558542, -0.0009037414688910164, -0.002146133776793198, -0.001806449708224439, -0.001261486975962629, -0.0004464502445902711, -0.0036296240427836023, -0.004947972895049306, 3.7618006995670515e-06, 9.267980867955757e-06, 1.3759090134546909e-05, 1.4946435684940517e-05, 2.2832896256120024e-05, -1.862085642814698e-05, 5.216699480012382e-05, 8.671619733086267e-05, 7.314577769620154e-05, -0.00016661209437181514, 0.00010459831514572045, 0.00010666351480073887, -0.0004916443375831845, -0.0004894416174854866, -0.0004095281676355863, -0.0004051877552784038, -0.001178434566892141, -0.0011706695767063522, -0.003248481069957642, -0.0032417557298103023, -0.002567616715414053, -0.002564029890639079, -0.0012110617656422545, -0.0012124942840368196, 0.001502627535284072, 0.0014817900183288187, -0.014418621892930993, -0.014393390349203922, -0.027532099298135968, -0.0274606438694508]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [0.00019375588162360953, 3.929153249035551e-05, 7.789304635650217e-05, 0.00014723018027758245, 0.0002673765221024724, 0.0004744145807765925, 0.0008167820473674003, 0.001362344401289637, 3.000160687954533e-05, -0.004086723968029392, -0.003785854851894844, 3.934443064650022e-05, 7.805519831392084e-05, 0.00014689858275380017, 0.00026714948083686213, 0.0004745562893229534, 0.0008176198397345717, 0.0013641072136931045, 3.0113454787473456e-05, -0.004092353772049635, -0.0037911073534832584, 4.013228325378671e-05, 7.94147904296205e-05, 0.0001478333497560289, 0.0002681066980534187, 0.0004758326594853494, 0.000817536783663504, 0.0013628205854042247, 2.9927349689758656e-05, -0.004088709360442366, -0.0037872047171718885, -1.4405373486092979e-06, -9.163456648595455e-06, -1.338383863847979e-05, -6.9395294029333625e-06, 1.8220828604265495e-05, 6.976725447667537e-05, 0.00012023808662470731, -0.0002467791447223442, -0.000565684190287518, -0.000692762480215143, 3.791994238577296e-05, 4.2344624121803094e-05, 8.255809801415225e-05, 8.782987382454902e-05, 0.00016898068003049036, 0.0001733646619514829, 0.000340359378640191, 0.00034338507146070045, 0.0006913662607939873, 0.0006875477221779192, 0.0013681220659178955, 0.0013638963144779242, 0.0027327106611170797, 0.0027232591570176498, -0.0017194388992616645, -0.0017248125659942426, -0.02234089114233674, -0.02230196975183249, -0.01929875238013893, -0.01925277850695962]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [3.5791384857163247e-06, 4.2563095108128545e-05, 8.241494020908916e-05, 0.00015252639664979715, 0.00027484527470069444, 0.0002693288516215949, 0.0001989516970182213, -0.0003670153660141054, -0.00044466427319635047, -1.7068820700432916e-05, -8.75650591984557e-05, 4.3184290752379575e-05, 8.24456670059488e-05, 0.00015258725131055319, 0.00027507196533292335, 0.00026939887989912336, 0.00019910888428927983, -0.00036757654550509134, -0.0004453667550353266, -1.714474973133509e-05, -8.76626578455971e-05, 4.23833247511605e-05, 8.231644386217438e-05, 0.00015387921943723476, 0.0002763897545538797, 0.0002695126876910743, 0.0001989708100024013, -0.00036737666583804067, -0.00044508885662528563, -1.716301977123959e-05, -8.775414313273798e-05, 6.8430783770244144e-06, 1.59420586727426e-05, 1.9814758085671552e-05, -4.599631392524645e-06, 1.431551089712985e-05, 2.565779081781298e-05, -7.4029973913871886e-06, -3.6163677607375584e-05, 2.18967104562139e-05, 2.7317476254754573e-05, 4.23660715055856e-05, 4.240057799673534e-05, 8.681422076935954e-05, 8.691518345515988e-05, 0.00017781831332483794, 0.0001770712270362556, 0.0003533721249671589, 0.0003527659708136686, 0.0003403732526598514, 0.00034139565255295284, 0.0001972202248312423, 0.00020150797846271146, -0.001216681439981021, -0.001213530208867203, -0.0014691315347759263, -0.0014633681097165128, 0.0006872164916587408, 0.0006908122851815693, -4.93238154055421e-05, -4.7738867100789935e-05]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-0.0002655127057702568, 3.174656252951731e-05, 6.977496301371796e-05, 0.0001362638707136045, 0.00025202824101007036, 0.00045000222939101147, 0.0007824395821028743, 0.0013182091958924903, 0.0017638191093714076, 0.0025896089961454575, 0.003093292014807137, 3.294049633546918e-05, 7.1091684155787e-05, 0.00013715826572034, 0.0002527907696003097, 0.0004511050591004985, 0.0007838398203268838, 0.0013202997160156032, 0.0017663939150141523, 0.002593543827691595, 0.0030977436589989206, 3.500006310284429e-05, 7.071988508092905e-05, 0.00013767474671395877, 0.00025184521363326754, 0.0004494098032081528, 0.0007822849473711657, 0.0013186096011537027, 0.0017649248384377595, 0.002591674597398785, 0.0030951910456094166, -8.67953380069673e-06, -2.4347134647171216e-05, -2.541843575656679e-05, -1.6095558197761995e-06, -1.2279676364049685e-05, 1.813688583247394e-05, 2.9480917334561206e-05, -6.251207253152667e-06, 0.0001076271127895772, 0.0010463477200804729, 3.5478708417087756e-05, 3.452141778860399e-05, 7.58688183994626e-05, 7.376549414370064e-05, 0.0001595935388971832, 0.00015762251359053878, 0.00031999062436069447, 0.00032248862939480565, 0.0006485777758045591, 0.0006560042127501199, 0.001309590825068871, 0.0013168509071368511, 0.0026517033940880804, 0.0026430537384263624, 0.004129930759627547, 0.004118131616233089, 0.008274136122267933, 0.008249006978005636, 0.013186165430217977, 0.013135504021927914]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-0.00022007033276085357, 6.741188630456947e-05, 0.00010588107031254828, 0.00017388692108014935, 0.00029260973776713486, 0.0004977307983116063, 0.0008440061479526952, 0.0013940582826850044, 0.0016304253869118134, 0.002441138806966474, 0.0023627131774831456, 6.782689677344624e-05, 0.00010641367262110943, 0.00017452076488491173, 0.00029329382946806116, 0.0004986365045054533, 0.0008452215644445947, 0.0013959806328899334, 0.001632792912254347, 0.002444742426065879, 0.0023660268684306354, 6.666689705928692e-05, 0.00010471516736965635, 0.00017211161468314268, 0.00029096544497064303, 0.0004967226679170384, 0.0008437004400010646, 0.00139450982952561, 0.0016314467179382912, 0.002443119230197807, 0.0023643715418159746, 6.685411629224607e-07, 9.356113870831331e-06, 2.5958550511638873e-05, 6.511223616799924e-05, 7.750670195596114e-05, 3.271940842679326e-05, 8.421432976198557e-05, -1.4958110067083641e-05, 0.0001839854170987654, 0.0008027825141433052, 6.457096336882936e-05, 6.876283074974546e-05, 0.00010460248782976573, 0.00011170902324514566, 0.0001863749275076584, 0.00019716870128771703, 0.0003573048557286362, 0.00036900654687196553, 0.0006997202816953443, 0.0007101504238068832, 0.001399584631761089, 0.0013986744301856772, 0.0027757944273815367, 0.0027729861265240684, 0.0035661625281708863, 0.0035650741756760116, 0.0076421048212345155, 0.0076232702738757985, 0.0067533847341205155, 0.006727566975545083]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [-0.00029971279056151354, 7.296832106443126e-05, 0.00011238067514615004, 0.00018216261702515694, 0.00030460604540772355, 0.0005150295367975858, 0.00086388868497397, 0.0014214503800790786, 0.002258779399927897, 0.0033818616352523224, 0.0033410167206225565, 7.381778431445952e-05, 0.00011327825487367847, 0.00018326789811037874, 0.00030562242153622267, 0.0005162088158287414, 0.0008652058441931494, 0.0014236006682439468, 0.0022622037714059716, 0.003386739171761919, 0.003346003819809099, 7.460062305643008e-05, 0.00011331331512594478, 0.00018284094691661893, 0.0003047700993459223, 0.0005160235153570169, 0.0008653180663202729, 0.0014229523325269837, 0.002260737967698846, 0.0033845521579243146, 0.0033434638360396803, -1.0092516517687348e-06, 5.402644589684079e-08, -2.716037759057826e-06, -5.8476626057645765e-06, -4.270522433303899e-05, -0.000131058952533746, -0.00025011471693232266, -0.000266387266300188, 0.00012339013344451753, 0.0008821385042446264, 7.471091176152841e-05, 7.449033435133377e-05, 0.00011737863939144062, 0.00011900715032577418, 0.0002045952600756944, 0.0002083886678194486, 0.0003798260185086785, 0.00037610353269356577, 0.0007315864219880854, 0.0007240315666462013, 0.0014277111486954478, 0.0014205934384877528, 0.002827367372916897, 0.002830682528350565, 0.0056353740379553804, 0.005627806932521464, 0.011264406589294356, 0.011232409879024552, 0.010724018090486736, 0.010667434772186365]}, {'learning_rate': 0.01, 'num_iter': 10000, 'fit_intercept': True, 'verbose': False, 'weights': [5.340056718172212e-05, -8.531878326886056e-05, 2.8962663548776517e-06, -0.00010699599402258378, -0.00020288459592953217, -2.742676701166232e-05, -0.00022371912534133943, -0.0005540979690680574, -0.0010466067079170336, -0.0005791585560436769, -0.00038877290466474254, -8.567790597858214e-05, 2.596337239165766e-06, -0.00010742227165541007, -0.0002033442392482459, -2.753618650935982e-05, -0.00022402368191068995, -0.0005549417358419435, -0.0010482049829204114, -0.0005801534798126651, -0.00038949028848020754, -8.55256094290553e-05, 2.758578594776838e-06, -0.00010726798545815329, -0.00020311770803554293, -2.7489972198126905e-05, -0.00022400771939594825, -0.0005544037898261347, -0.0010473006714623084, -0.0005799150448927915, -0.0003892580300082992, -3.5948657871686584e-06, -5.691668137369121e-06, -8.66013189886557e-06, -9.491209585891701e-06, 8.131079911876542e-07, 2.7098955446654794e-05, 5.4045960392804795e-05, 0.00017662706797362817, 6.293000282265298e-05, -0.0003355224672563369, -8.565234652174287e-05, -8.53988723363683e-05, 1.2890416957543464e-05, 1.1282798887747147e-05, -0.00012418508376510858, -0.00012801083324956046, -0.00026013441997822096, -0.000265813783152717, 3.238035971841646e-05, 2.864147962791065e-05, -0.0003585642251397186, -0.0003618749263539387, -0.0011855399979763038, -0.0011852662300306461, -0.0028328909125426276, -0.0028252913111161096, -0.0005015390483106573, -0.0004962865520478136, 0.0014747003954061271, 0.0014854889377716209]}]}

In [16]:
for product in products:
    logregs = [
        deserialize_logreg(model_dict) for model_dict in hardcode[product]
    ]
    for i, logreg in enumerate(logregs):
        print(logreg.eval(xss[product], np.array(gts[product])[:,i]))

[[1420  483]
 [  68   29]]
0.7245
[[ 509 1479]
 [   3    9]]
0.259
[[ 506 1467]
 [   6   21]]
0.2635
[[ 255 1734]
 [   1   10]]
0.1325
[[1423  480]
 [  65   32]]
0.7275
[[1740  253]
 [   4    3]]
0.8715
[[1731  250]
 [  13    6]]
0.8685
[[ 512 1481]
 [   0    7]]
0.2595
[[  85  192]
 [ 427 1296]]
0.6905
[[1651  239]
 [  93   17]]
0.834
[[1594  372]
 [  22   12]]
0.803
[[1450  460]
 [  38   52]]
0.751
[[1439  467]
 [  49   45]]
0.742
[[1732  117]
 [ 140   11]]
0.8715
[[ 506 1437]
 [   6   51]]
0.2785
[[ 491 1382]
 [  21  106]]
0.2985
[[ 492 1347]
 [  20  141]]
0.3165
[[ 787   37]
 [1085   91]]
0.439


In [17]:
from src.dectree import DecisionTreeClassifier, serialize_decision_tree, deserialize_decision_tree

hardcode = {}
for product in products:
    hardcode[product] = []
    for i in range(9):
        clf = DecisionTreeClassifier(max_depth=6, min_samples_leaf=5)
        sm = SMOTE(random_state=42)
        X_res, y_res = sm.fit_resample(xss[product], np.array(gts[product])[:,i])
        clf.fit(X_res, y_res)
        print(clf.eval(X_res, y_res))
        hardcode[product].append(
            serialize_decision_tree(clf)
        )
print(hardcode)

KeyboardInterrupt: 

In [None]:
from numpy import array
hard = {'PEARLS': [{'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 17, 'threshold': 10004.859290438966, 'probas': array([0.9515, 0.0485]), 'is_terminal': False, 'depth': 1, 'left': {'column': 27, 'threshold': 10000.28515625, 'probas': array([0.95197599, 0.04802401]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.95245245, 0.04754755]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 2}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 8, 'threshold': 9995.249345740041, 'probas': array([0.977, 0.023]), 'is_terminal': False, 'depth': 1, 'left': {'column': 38, 'threshold': -175.0, 'probas': array([0.97747748, 0.02252252]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.8, 0.2]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.97837022, 0.02162978]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 0, 'threshold': 9995.0, 'probas': array([0.5, 0.5]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([1., 0.]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 3}}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 37, 'threshold': -145.0, 'probas': array([0.973, 0.027]), 'is_terminal': False, 'depth': 1, 'left': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 2}, 'right': {'column': 34, 'threshold': 60.0, 'probas': array([0.97348674, 0.02651326]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.97396094, 0.02603906]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.5, 0.5]), 'is_terminal': True, 'depth': 3}}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 15, 'threshold': 10004.92857142857, 'probas': array([0.9515, 0.0485]), 'is_terminal': False, 'depth': 1, 'left': {'column': 33, 'threshold': 44.0, 'probas': array([0.95197599, 0.04802401]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.95245245, 0.04754755]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 2}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 21, 'threshold': 10003.25, 'probas': array([0.978, 0.022]), 'is_terminal': False, 'depth': 1, 'left': {'column': 23, 'threshold': 9998.875, 'probas': array([0.97847848, 0.02152152]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.89361702, 0.10638298]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.98052281, 0.01947719]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 0, 'threshold': 9995.225806451614, 'probas': array([0.5, 0.5]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([1., 0.]), 'is_terminal': True, 'depth': 3}}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 15, 'threshold': 10004.915151515152, 'probas': array([0.974, 0.026]), 'is_terminal': False, 'depth': 1, 'left': {'column': 31, 'threshold': 26.0, 'probas': array([0.97574533, 0.02425467]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.97622661, 0.02377339]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.5, 0.5]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 11, 'threshold': 10004.98076923077, 'probas': array([0.80952381, 0.19047619]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.89473684, 0.10526316]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 3}}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 22, 'threshold': 9999.125, 'probas': array([0.195, 0.805]), 'is_terminal': False, 'depth': 1, 'left': {'column': 6, 'threshold': 9995.274643874644, 'probas': array([0.26153846, 0.73846154]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.24505929, 0.75494071]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.85714286, 0.14285714]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 6, 'threshold': 9995.237190558433, 'probas': array([0.18505747, 0.81494253]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.1984184, 0.8015816]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.13180516, 0.86819484]), 'is_terminal': True, 'depth': 3}}}}], 'BANANAS': [{'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 36, 'threshold': 85.0, 'probas': array([0.945, 0.055]), 'is_terminal': False, 'depth': 1, 'left': {'column': 52, 'threshold': 4931.0, 'probas': array([0.94594595, 0.05405405]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.83695652, 0.16304348]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.95120672, 0.04879328]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 2}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 33, 'threshold': -30.0, 'probas': array([0.959, 0.041]), 'is_terminal': False, 'depth': 1, 'left': {'column': 6, 'threshold': 4927.189895470383, 'probas': array([0.25, 0.75]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([1., 0.]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 9, 'threshold': 4945.953479933231, 'probas': array([0.96042084, 0.03957916]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.96258847, 0.03741153]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.72222222, 0.27777778]), 'is_terminal': True, 'depth': 3}}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 57, 'threshold': 4925.0, 'probas': array([0.8395, 0.1605]), 'is_terminal': False, 'depth': 1, 'left': {'column': 27, 'threshold': 4947.8046875, 'probas': array([0.63476562, 0.36523438]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.72636816, 0.27363184]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.3, 0.7]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 19, 'threshold': 4456.784879893033, 'probas': array([0.90994624, 0.09005376]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.48333333, 0.51666667]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.92787115, 0.07212885]), 'is_terminal': True, 'depth': 3}}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 30, 'threshold': 13.0, 'probas': array([0.9305, 0.0695]), 'is_terminal': False, 'depth': 1, 'left': {'column': 16, 'threshold': 4956.32618510158, 'probas': array([0.93176116, 0.06823884]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.93222892, 0.06777108]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 30, 'threshold': 15.0, 'probas': array([0.57142857, 0.42857143]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([1., 0.]), 'is_terminal': True, 'depth': 3}}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 39, 'threshold': 232.0, 'probas': array([0.9565, 0.0435]), 'is_terminal': False, 'depth': 1, 'left': {'column': 21, 'threshold': 4955.0, 'probas': array([0.96452328, 0.03547672]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.96505824, 0.03494176]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0., 1.]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 46, 'threshold': 4944.0, 'probas': array([0.88265306, 0.11734694]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.9197861, 0.0802139]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.11111111, 0.88888889]), 'is_terminal': True, 'depth': 3}}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 15, 'threshold': 4935.1658823529415, 'probas': array([0.7605, 0.2395]), 'is_terminal': False, 'depth': 1, 'left': {'column': 8, 'threshold': 4930.033496967947, 'probas': array([0.94954955, 0.05045045]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.97490347, 0.02509653]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.59459459, 0.40540541]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 29, 'threshold': 4455.3755859375, 'probas': array([0.68788927, 0.31211073]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.84660767, 0.15339233]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.63924051, 0.36075949]), 'is_terminal': True, 'depth': 3}}}}, {'max_depth': 3, 'min_samples_split': 2, 'min_samples_leaf': 1, 'classes': array([0, 1]), 'Tree': {'column': 17, 'threshold': 4935.9699583581205, 'probas': array([0.609, 0.391]), 'is_terminal': False, 'depth': 1, 'left': {'column': 9, 'threshold': 4928.902495479204, 'probas': array([0.49455865, 0.50544135]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.38693467, 0.61306533]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.77391304, 0.22608696]), 'is_terminal': True, 'depth': 3}}, 'right': {'column': 35, 'threshold': 32.0, 'probas': array([0.68968457, 0.31031543]), 'is_terminal': False, 'depth': 2, 'left': {'column': None, 'threshold': None, 'probas': array([0.72533849, 0.27466151]), 'is_terminal': True, 'depth': 3}, 'right': {'column': None, 'threshold': None, 'probas': array([0.42446043, 0.57553957]), 'is_terminal': True, 'depth': 3}}}}]}


In [None]:
for product in products:
    logregs = [
        deserialize_decision_tree(model_dict) for model_dict in hardcode[product]
    ]
    for i, dectree in enumerate(logregs):
        print(dectree.eval(xss[product], np.array(gts[product])[:,i]))

[[1720  183]
 [  68   29]]
0.8745
[[1952   36]
 [   5    7]]
0.9795
[[1721  252]
 [   9   18]]
0.8695
[[1929   60]
 [   0   11]]
0.97
[[1837   66]
 [  81   16]]
0.9265
[[1955   38]
 [   0    7]]
0.981
[[1804  177]
 [   5   14]]
0.909
[[1920   73]
 [   1    6]]
0.963
[[  63  214]
 [ 220 1503]]
0.783
[[1107  783]
 [  44   66]]
0.5865
[[1784  182]
 [  12   22]]
0.903
[[1664  246]
 [  12   78]]
0.871
[[1805  101]
 [   7   87]]
0.946
[[1406  443]
 [  75   76]]
0.741
[[1837  106]
 [  17   40]]
0.9385
[[1390  483]
 [   9  118]]
0.754
[[1571  268]
 [  17  144]]
0.8575
[[ 426  398]
 [  92 1084]]
0.755


In [None]:
from src.knn import KNN
import json
from json import JSONEncoder
import numpy

class NumpyArrayEncoder(JSONEncoder):
    def default(self, obj):
        if isinstance(obj, numpy.ndarray):
            return obj.tolist()
        return JSONEncoder.default(self, obj)

clf = KNN(xss['BANANAS'], np.array(gts[product]), K=20)

print(clf.predict(xss['BANANAS'][150]))
print(np.array(gts[product])[150])

with open("xss.txt", "w") as fp:
    json.dump(xss, fp, cls=NumpyArrayEncoder)  # encode dict into JSON

with open("gts.txt", "w") as fp:
    json.dump(gts, fp, cls=NumpyArrayEncoder)  # encode dict into JSON


[1.80260998e-03 2.43956732e-04 7.27224769e-01 2.67531041e-01
 6.63143151e-04 2.43956732e-04 2.43956732e-04 2.43956732e-04
 1.80260998e-03]
[1 0 0 0 0 0 0 0 0]


In [25]:
import numpy as np
from statistics import NormalDist
from scipy.stats import norm
from sklearn.metrics import confusion_matrix

def serialize_gnb(model):
    serialized_model = {
        'classes':model.classes,
        'mean':model.mean.tolist(),
        'std':model.std.tolist(),
        'c_mean':model.c_mean.tolist(),
        'c_std':model.c_std.tolist(),
        'prior':model.prior.tolist()
    }

    return serialized_model

def deserialize_gnb(model_dict):
    deserialized_gnb = NaiveBayes()
    deserialized_gnb.classes = model_dict['classes']
    deserialized_gnb.mean = np.array(model_dict['mean'])
    deserialized_gnb.std = np.array(model_dict['std'])
    deserialized_gnb.c_mean = np.array(model_dict['c_mean'])
    deserialized_gnb.c_std = np.array(model_dict['c_std'])
    deserialized_gnb.prior = np.array(model_dict['prior'])
    
    return deserialized_gnb

class NaiveBayes(object):

    def train (self, X, y):

        """
            Calculates population and class-wise mean and standard deviation
        """

        # Population mean and standard deviation

        self.classes = set(y)
        self.mean = np.mean(X, axis=0)
        self.std = np.std(X, axis=0)

        # Class mean and standard deviation

        self.c_mean = np.zeros((len(self.classes), X.shape[1]))
        self.c_std = np.zeros((len(self.classes), X.shape[1]))
        self.prior = np.zeros((len(self.classes),))

        for c in self.classes:
            indices = np.where(y == c)
            self.prior[c] = indices[0].shape[0] / float(y.shape[0])
            self.c_mean[c] = np.mean(X[indices], axis=0)
            self.c_std[c] = np.std(X[indices], axis=0)

        return

    def predict (self, X):

        """
            Calculates observations' posteriors and returns class with 
            maximum posterior.
        """

        p = []

        for obs in X:
            tiled = np.repeat([obs], len(self.classes), axis=0)

            # Probability of observation in population

            evidence = norm.pdf((self.mean - obs) / self.std)
            evidence = np.prod(evidence)

            # Probability of observation in each class

            likelihood = norm.pdf((tiled - self.c_mean) / self.c_std)
            likelihood = np.prod(likelihood, axis=1)

            # Probability of each class given observation

            posterior = self.prior * likelihood / evidence
            p.append(np.argmax(posterior))

        return p
    
    def eval(self, X, y):
        """"Evaluate accuracy on dataset."""
        p = self.predict(X)
        print(confusion_matrix(y, p))
        return np.sum(p == y) / X.shape[0]

In [26]:
from src.gnb import serialize_gnb, deserialize_gnb

hardcode = {}
for product in products:
    hardcode[product] = []
    for i in range(9):
        clf = NaiveBayes()
        sm = SMOTE(random_state=42)
        X_res, y_res = sm.fit_resample(xss[product], np.array(gts[product])[:,i])
        clf.train(X_res, y_res)
        print(clf.eval(X_res, y_res))
        hardcode[product].append(
            serialize_gnb(clf)
        )
print(hardcode)

  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1772  131]
 [1659  244]]
0.5296899632159747


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1980    8]
 [1988    0]]
0.49798792756539234


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1965    8]
 [1973    0]]
0.4979726305119108


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1981    8]
 [1989    0]]
0.49798893916540976


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1841   62]
 [1755  148]]
0.522595901208618


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1985    8]
 [1993    0]]
0.4979929754139488


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1946   35]
 [1520  461]]
0.6075214538112065


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1985    8]
 [1993    0]]
0.4979929754139488


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[ 238 1485]
 [ 147 1576]]
0.526407428903076


  posterior = self.prior * likelihood / evidence


[[1797   93]
 [1717  173]]
0.5211640211640212


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1958    8]
 [1964    2]]
0.49847405900305186


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1660  250]
 [1148  762]]
0.6340314136125654


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1898    8]
 [1906    0]]
0.49790136411332636


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1842    7]
 [1847    2]]
0.49864791779340184


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1901   42]
 [1877   66]]
0.5061760164693773


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1812   61]
 [1732  141]]
0.5213561131873999


  posterior = self.prior * likelihood / evidence
  posterior = self.prior * likelihood / evidence


[[1817   22]
 [1834    5]]
0.4953779227841218


  posterior = self.prior * likelihood / evidence


[[  69 1107]
 [  44 1132]]
0.5106292517006803
{'PEARLS': [{'classes': {0, 1}, 'mean': [0.003216514353444957, 0.003215752573647912, 0.003206519213292203, 0.003204152986496387, 0.0032000944506728707, 0.003193331002204896, 0.0031370823321421803, 0.0031038309663591513, 0.003075389486142913, 0.003029938886201451, 0.003219605356972329, 0.0032188432096581765, 0.003209597816716491, 0.003207231481816508, 0.003203166610002467, 0.003196396420708358, 0.0031400945647799285, 0.0031068119442319363, 0.003078342290530324, 0.0030328474951299574, 0.0032180416077796335, 0.0032172680316867497, 0.003208041296417588, 0.003205690283029769, 0.0032016358522961243, 0.003194867815700389, 0.003138590017392849, 0.003105320497013142, 0.0030768657169977015, 0.0030313904227205965, -6.462152846679185e-08, -1.1373842810560479e-07, -1.7938178536489597e-07, -3.349602171720904e-07, -6.047187869111561e-07, -2.989497849419073e-07, 5.408606697758823e-07, 1.4553005953647904e-06, 3.7554643924260896e-06, 2.8130111915097004e-06, 

In [None]:
# from src.dectree import Forest
# for product in products:
#     for i in range(9):
#         clf = Forest(max_depth=3, no_trees=2)
#         sm = SMOTE(random_state=42)
#         X_res, y_res = sm.fit_resample(xss[product], np.array(gts[product])[:,i])
#         clf.train(X_res, y_res)
#         print(clf.eval(X_res, y_res))

Training Forest...


Training Decision Tree no 1...


Training Decision Tree no 2...



IndexError: invalid index to scalar variable.

In [39]:
# from src.lstm import LSTM
# seqlen = 10
# for product in products:
#     for i in range(9):
#         clf = LSTM(60, 100, 1)
#         sm = SMOTE(random_state=42)
#         samxss = np.array([xss[product][j-10:j] for j in range(seqlen, len(xss[product]))]).reshape((-1, 600))
#         ugts = np.array([np.array(gts[product])[seqlen:][:,i]])
#         print(samxss.shape)
#         X_res, y_res = sm.fit_resample(samxss, ugts)
#         print(y_res)
#         clf.train(X_res.reshape((-1, 10, 60)), y_res, 100, 0.1)
#         print(clf.eval(X_res, y_res))

(1990, 600)


ValueError: Imbalanced-learn currently supports binary, multiclass and binarized encoded multiclasss targets. Multilabel and multioutput targets are not supported.