In [None]:
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import ConcatDataset

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import confusion_matrix,f1_score
import numpy as np
import pandas as pd
import pickle

import matplotlib.pyplot as plt


import sys
sys.path.append('../')

from src import reload(Dataset, ISIDistributionDataset, EdgeDistributionDataset, FRDistributionDataset, ISIFRDistributionDataset)
from src.network import MLP

In [9]:
aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'i', 'i5Htr3a': 'i', 'e4Scnn1a': 'e', 'e4Rorb': 'e',
         'e4other': 'e', 'e4Nr5a1': 'e', 'i6Htr3a': 'i', 'i6Sst': 'i', 'e6Ntsr1': 'e',
         'i23Pvalb': 'i', 'i23Htr3a': 'i', 'i1Htr3a': 'i', 'i4Sst': 'i', 'e5Rbp4': 'e',
         'e5noRbp4': 'e', 'i23Sst': 'i', 'i4Htr3a': 'i', 'i6Pvalb': 'i', 'i5Pvalb': 'i',
         'i4Pvalb': 'i'}        
dataset = Dataset('../data', force_process=False)


Found processed pickle. Loading from '../data/processed/dataset.pkl'.


### Run Hyperparameter Grid

In [None]:
hp_df = pd.read_csv('hp_grid_p3.csv')
prev_dist_params = np.asarray([0])
for index,row in hp_df.iterrows():
    hpc = row['hp_idx']
    k,distribution,cell_split_seed,bin_size = row['k'],row['distribution'],row['cell_split_seed'],row['bin_size']
    isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step,fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step = [row['isi_dist_bins_start'],row['isi_dist_bins_stop'],row['isi_dist_bins_step'],row['fr_dist_bins_start'],row['fr_dist_bins_stop'],row['fr_dist_bins_step']]
    lr,n_hiddens = row['lr'],row['n_hiddens']
    n_hiddens = [int(nh) for nh in n_hiddens.rsplit(',')]
    dist_params = [k,distribution,cell_split_seed,bin_size,isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step,fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step,]
    print(dist_params)
    batch_size = 1
    n_class = int(''.join(filter(str.isdigit, k)))
    cell_sample_seed = 1
    test_size, val_size = 0.4,0.4
    isi_dist_bins = list(np.arange(isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step))
    fr_dist_bins = list(range(int(fr_dist_bins_start),int(fr_dist_bins_stop),int(fr_dist_bins_step)))
    if dist_params!=prev_dist_params:
        dataset = Dataset('../data', force_process=False)
        dataset.data_source = 'v1'
        dataset.labels_col = 'pop_name'
        sampler = 'R20'
        dataset.num_trials_in_window = 33
        print('Load')
        dataset = load_data()
        print('Scale')
        train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = get_scaler()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dropout_p = 0
    model = MLP(input_dims=train_dataset.num_bins, n_hiddens=n_hiddens, n_class=n_class, dropout_p=dropout_p).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_accs, val_accs, test_accs = [],[],[]
    epochs=50
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval=5)
        train_cm, train_acc = test(model, device, train_loader, 'Train')
        val_cm, val_acc = test(model, device, val_loader, 'Val')
        test_cm, test_acc = test(model, device, test_loader, 'Test')
        print('\n')
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        test_accs.append(test_acc)
    accs_df = pd.DataFrame(list(zip(train_accs, val_accs, test_accs)),columns=['train_acc','test_acc','val_acc'])
    accs_df.to_csv('{}_hp{}_f1s.csv'.format(distribution,str(hpc)),index=False)
    np.save('{}_hp{}_cms.npy'.format(distribution,str(hpc)),np.dstack([train_cm,val_cm,test_cm]))
    prev_dist_params = dist_params

['17celltypes', 'ISI', 1234, 0.2, 0, 0.401, 0.0005, 0, 51, 1]
Found processed pickle. Loading from '../data/processed/dataset.pkl'.


  if dist_params!=prev_dist_params:


Load
Scale
0.0026557706966627848 98.05970326641108
F1: 0.09506426236450011
F1: 0.08589927858726243
F1: 0.09872604154925733


F1: 0.23855750145939295
F1: 0.2193673301452732
F1: 0.2668098101286929


F1: 0.3084532387581299
F1: 0.26466467770024193
F1: 0.31961335853110295


F1: 0.3339077605949544
F1: 0.2827854564698146
F1: 0.32867773879772144


F1: 0.35795884657321553
F1: 0.27680077938623443
F1: 0.33237401270127775


F1: 0.3707129520979938
F1: 0.3038144440831973
F1: 0.32471853833792114


F1: 0.3998974229594419
F1: 0.32037912934484175
F1: 0.3285706037905996


F1: 0.42364913015373906
F1: 0.3354135635392691
F1: 0.3295887076158318


F1: 0.414429788235164
F1: 0.32316112032698113
F1: 0.3186865819051481


F1: 0.43982219809036116
F1: 0.3224979298163308
F1: 0.3412171531957788


F1: 0.44402920157871717
F1: 0.33385894712690933
F1: 0.32428436709570285


F1: 0.48159630982144536
F1: 0.34049061841912875
F1: 0.35108393425214596


F1: 0.4864783561409528
F1: 0.35421568426919003
F1: 0.34555097926965905


F1: 

F1: 0.14320896111037995


F1: 0.07916095648419844
F1: 0.06915718990232878
F1: 0.06591398771113972


F1: 0.121458482455326
F1: 0.09367514377822136
F1: 0.11060070600776817


F1: 0.1186988957035951
F1: 0.08344902680898868
F1: 0.10014240161696086


F1: 0.1123445494798759
F1: 0.07230130711789501
F1: 0.09433804008900323


F1: 0.12532450935254622
F1: 0.08515556942841274
F1: 0.10157279574146441


F1: 0.14824316682907646
F1: 0.09413404505498253
F1: 0.11545374476416115


F1: 0.12496895921100892
F1: 0.0829083372594024
F1: 0.10692418424645317


F1: 0.1380621159000554
F1: 0.09776812510624738
F1: 0.11998982061680703


F1: 0.16159366235427688
F1: 0.11751305401027913
F1: 0.14811288945532605


F1: 0.1262366102321346
F1: 0.1053445739526667
F1: 0.12392179577569422


F1: 0.15055271901272363
F1: 0.10723555439036016
F1: 0.11651268690368781


F1: 0.13901218090536052
F1: 0.11010603432387264
F1: 0.11015155629947704


F1: 0.12020922942542539
F1: 0.08877171770058036
F1: 0.097171258017528


F1: 0.1242279335886592

F1: 0.3447660524861507
F1: 0.3698795247811715


F1: 0.5931720894693094
F1: 0.3354904874741833
F1: 0.3511922159381375


F1: 0.5635264367904899
F1: 0.3020598264737299
F1: 0.3349228966490636


F1: 0.5796287712473612
F1: 0.31606754828172473
F1: 0.3516118236345993


F1: 0.6283257279963179
F1: 0.3444437604985498
F1: 0.3676711771194793


F1: 0.616154662695668
F1: 0.346070208914648
F1: 0.3743367942787544


F1: 0.587129723094351
F1: 0.30128066583293495
F1: 0.3401319636309592


F1: 0.5981106495791971
F1: 0.3499893201943922
F1: 0.37120536244355257


F1: 0.6331353788416125
F1: 0.3419767689100136
F1: 0.3534669933852181


F1: 0.634211010039683
F1: 0.34432979094722377
F1: 0.3544539502831643


F1: 0.6466274481059027
F1: 0.3368636304447218
F1: 0.3598009481923026


F1: 0.6470880881232978
F1: 0.3386105962513632
F1: 0.35820767420695754


F1: 0.6518494144910152
F1: 0.31503503875795635
F1: 0.3448327432988081


F1: 0.6336104118201897
F1: 0.34570745842090916
F1: 0.360846108194461


F1: 0.6588869941214398
F1: 

F1: 0.2896714259811686


F1: 0.7095277753206143
F1: 0.2714144806146934
F1: 0.2911972454061526


F1: 0.7199771206440151
F1: 0.2710649872221109
F1: 0.2819544858540344


F1: 0.7234386266345963
F1: 0.27004032529538924
F1: 0.27410345758510796


F1: 0.7156130204258456
F1: 0.2556869863705514
F1: 0.2759479943006763


F1: 0.7281198575981812
F1: 0.26728290051166953
F1: 0.2771216680072937


F1: 0.7331098045086075
F1: 0.2665482211191679
F1: 0.2808332576133368


F1: 0.733143886661524
F1: 0.2632726925232006
F1: 0.2822676476639413


F1: 0.7298197355022698
F1: 0.2639583318877008
F1: 0.26948338886009904


F1: 0.7429124596605318
F1: 0.2645685087760452
F1: 0.278872289864091


F1: 0.7350946472083545
F1: 0.25989977984180446
F1: 0.26427555754363136


F1: 0.7442127211887027
F1: 0.2638152055442473
F1: 0.273030899936208


F1: 0.7367653157701319
F1: 0.27440740285449516
F1: 0.2750747873470043


F1: 0.7337890902176081
F1: 0.26480639645058074
F1: 0.27458621154269397


F1: 0.7523637542612774
F1: 0.2750616266263249


In [79]:
dataset = Dataset('../data', force_process=False)
dataset.data_source = 'v1'
dataset.labels_col = 'pop_name'
k = '17celltypes'
distribution = 'ISIFR'
sampler = 'R20'
dataset.num_trials_in_window = 33
batch_size = 1
n_class = int(''.join(filter(str.isdigit, k)))
cell_sample_seed = 1
cell_split_seed = 1234
#cell_split_seed = 2345
#cell_split_seed = 3456
#cell_split_seed = 4567
#cell_split_seed = 5678

test_size, val_size = 0.2,0.2
bin_size = 0.2
#isi_dist_bins = 800
#isi_dist_bins = np.arange(-5,5.1,0.1).tolist()
isi_dist_bins = np.arange(0,0.401,0.0005).tolist()
fr_dist_bins = list(range(0,51,1))
#fr_dist_bins = 50

Found processed pickle. Loading from '../data/processed/dataset.pkl'.


### Loading dataset

In [4]:
def load_data():
    if dataset.data_source == 'v1':
        if dataset.labels_col == 'pop_name':
            dataset.drop_dead_cells(cutoff=30)
            keepers = ['e5Rbp4', 'e23Cux2', 'i6Pvalb', 'e4Scnn1a', 'i23Pvalb', 'i23Htr3a',
             'e4Rorb', 'e4other', 'i5Pvalb', 'i4Pvalb', 'i23Sst', 'i4Sst', 'e4Nr5a1',
             'i1Htr3a', 'e5noRbp4', 'i6Sst', 'e6Ntsr1']
            dataset.drop_other_classes(classes_to_keep=keepers)
            if k == '17celltypes':
                pass #all filtering done above
            elif k == '13celltypes':
                aggr_dict = {'e23Cux2': 'e23', 'i5Sst': 'i5Sst', 'i5Htr3a': 'i5Htr3a', 'e4Scnn1a': 'e4', 'e4Rorb': 'e4',
                         'e4other': 'e4', 'e4Nr5a1': 'e4', 'i6Htr3a': 'i6Htr3a', 'i6Sst': 'i6Sst', 'e6Ntsr1': 'e6',
                         'i23Pvalb': 'i23Pvalb', 'i23Htr3a': 'i23Htr3a', 'i1Htr3a': 'i1Htr3a', 'i4Sst': 'i4Sst', 'e5Rbp4': 'e5',
                         'e5noRbp4': 'e5', 'i23Sst': 'i23Sst', 'i4Htr3a': 'i4Htr3a', 'i6Pvalb': 'i6Pvalb', 'i5Pvalb': 'i5Pvalb',
                         'i4Pvalb': 'i4Pvalb'}        
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '11celltypes':
                aggr_dict = {'e23Cux2': 'Cux2', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'Scnn1a', 'e4Rorb': 'Rorb',
                         'e4other': 'other', 'e4Nr5a1': 'Nr5a1', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'Ntsr1',
                         'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a', 'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'Rbp4',
                         'e5noRbp4': 'noRbp4', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a', 'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb',
                         'i4Pvalb': 'Pvalb'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '4celltypes':
                aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'e', 'e4Rorb': 'e', 'e4other': 'e', 
                         'e4Nr5a1': 'e', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'e', 'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a',
                         'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'e', 'e5noRbp4': 'e', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a',
                         'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb', 'i4Pvalb': 'Pvalb'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '5layers':
                aggr_dict = {'e23Cux2': '23', 'i5Sst': '5', 'i5Htr3a': '5', 'e4Scnn1a': '4', 'e4Rorb': '4', 'e4other': '4',
                         'e4Nr5a1': '4', 'i6Htr3a': '6', 'i6Sst': '6', 'e6Ntsr1': '6', 'i23Pvalb': '23', 'i23Htr3a': '23',
                         'i1Htr3a': '1', 'i4Sst': '4', 'e5Rbp4': '5', 'e5noRbp4': '5', 'i23Sst': '23', 'i4Htr3a': '4',
                         'i6Pvalb': '6', 'i5Pvalb': '5', 'i4Pvalb': '4'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '2celltypes':
                aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'i', 'i5Htr3a': 'i', 'e4Scnn1a': 'e', 'e4Rorb': 'e',
                         'e4other': 'e', 'e4Nr5a1': 'e', 'i6Htr3a': 'i', 'i6Sst': 'i', 'e6Ntsr1': 'e',
                         'i23Pvalb': 'i', 'i23Htr3a': 'i', 'i1Htr3a': 'i', 'i4Sst': 'i', 'e5Rbp4': 'e',
                         'e5noRbp4': 'e', 'i23Sst': 'i', 'i4Htr3a': 'i', 'i6Pvalb': 'i', 'i5Pvalb': 'i',
                         'i4Pvalb': 'i'}    
                dataset.aggregate_cell_classes(aggr_dict)


    # Split into train/val/test sets
    dataset.split_cell_train_val_test(test_size=test_size, val_size=val_size, seed=cell_split_seed)
    #dataset.split_trial_train_val_test(test_size=0.2, val_size=0.2, temp=True, seed=1234)

    # bining
    dataset.set_bining_parameters(bin_size=bin_size) # in seconds, so this is 200ms
    return dataset

In [5]:
def get_train_scaler(dataset,sampler='R20',transform='interspike_interval',cell_random_seed=1,bins=list(np.arange(0,0.402,0.002)),scaler_type='Z'):
    trials = dataset.get_trials('train')
    train_mask = dataset._cell_split['train']
    X_bank = []
    for trial in trials:
        X, y, m = dataset.sample(mode='train',trial_id=trial,sampler=sampler,transform=transform,preselected_mask=train_mask)
        for x in X:
            X_bank.append(x)

    print(min([np.min(x) for x in X_bank if len(x)>0]),max([np.max(x) for x in X_bank if len(x)>0]))
            
    if type(bins) == int:
        ser,adaptive_bins = pd.qcut(np.ndarray.flatten(np.hstack(X_bank)),bins,retbins=True)
        xi_hists = []
        for xi in X_bank:
            if len(xi) > 0:
                xi_hist = np.histogram(xi,bins=adaptive_bins)[0]
            else:
                xi_hist = np.zeros(bins)
            xi_hists.append(xi_hist)
        xi_hists_array = np.vstack(xi_hists)
        if scaler_type == 'Z':
            train_scaler = StandardScaler()
            train_scaler = train_scaler.fit(xi_hists_array)
        elif scaler_type == "MinMax":
            train_scaler = MinMaxScaler()
            train_scaler = train_scaler.fit(xi_hists_array)
        bins = adaptive_bins
        
    elif (type(bins) == list) | (type(bins) == tuple):
        xi_hists = []
        for xi in X_bank:
            xi_hist = np.histogram(xi,bins=bins)[0]
            xi_hists.append(xi_hist)
        xi_hists_array = np.vstack(xi_hists)        
        if scaler_type == 'Z':
            train_scaler = StandardScaler()
            train_scaler = train_scaler.fit(xi_hists_array)
        elif scaler_type == "MinMax":
            train_scaler = MinMaxScaler()
            train_scaler = train_scaler.fit(xi_hists_array)
        
    
    return train_scaler, bins


In [6]:
def get_scaler():
    if distribution == 'ISIFR':
        fr_train_scaler, fr_bins = get_train_scaler(dataset,sampler=sampler,transform='firing_rate',bins=fr_dist_bins)
        isi_train_scaler, isi_bins = get_train_scaler(dataset,sampler=sampler,transform='interspike_interval',bins=isi_dist_bins)
        train_scaler = [isi_train_scaler,fr_train_scaler]
        train_dataset = ISIFRDistributionDataset(dataset, isi_bins=isi_bins,fr_bins=fr_bins, isi_scaler=isi_train_scaler, fr_scaler=fr_train_scaler, mode='train', sampler=sampler)
        # fix population for validation set and test set (they will be different of course)
        val_dataset = ISIFRDistributionDataset(dataset, isi_bins=isi_bins,fr_bins=fr_bins, isi_scaler=isi_train_scaler, fr_scaler=fr_train_scaler, mode='val', sampler=sampler,cell_random_seed=cell_sample_seed)
        test_dataset = ISIFRDistributionDataset(dataset, isi_bins=isi_bins,fr_bins=fr_bins, isi_scaler=isi_train_scaler, fr_scaler=fr_train_scaler, mode='test', sampler=sampler,cell_random_seed=cell_sample_seed)    

    
    elif 'ISI' in distribution:
        if distribution == 'log_ISI':
            train_scaler, bins = get_train_scaler(dataset,sampler=sampler,transform='log_interspike_interval',bins=isi_dist_bins)        
        else:
            train_scaler, bins = get_train_scaler(dataset,sampler=sampler,transform='interspike_interval',bins=isi_dist_bins)
        # Create Pytorch datasets
        train_dataset = ISIDistributionDataset(dataset, min_isi=0, max_isi=0.4, bins=isi_dist_bins, scaler=train_scaler, mode='train', sampler=sampler)
        # fix population for validation set and test set (they will be different of course)
        val_dataset = ISIDistributionDataset(dataset, min_isi=0, max_isi=0.4, bins=isi_dist_bins, scaler=train_scaler, mode='val', sampler=sampler,cell_random_seed=cell_sample_seed)
        test_dataset = ISIDistributionDataset(dataset, min_isi=0, max_isi=0.4, bins=isi_dist_bins, scaler=train_scaler, mode='test', sampler=sampler,cell_random_seed=cell_sample_seed)

    elif distribution == 'FR':
        train_scaler, bins = get_train_scaler(dataset,sampler=sampler,transform='firing_rate',bins=fr_dist_bins)
        train_dataset = FRDistributionDataset(dataset, bins=fr_dist_bins, scaler=train_scaler, mode='train', sampler=sampler)
        # fix population for validation set and test set (they will be different of course)
        val_dataset = FRDistributionDataset(dataset, bins=fr_dist_bins, scaler=train_scaler, mode='val', sampler=sampler,cell_random_seed=cell_sample_seed)
        test_dataset = FRDistributionDataset(dataset, bins=fr_dist_bins, scaler=train_scaler, mode='test', sampler=sampler,cell_random_seed=cell_sample_seed)    



    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=train_dataset.collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=val_dataset.collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=test_dataset.collate_fn)
    return train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset

In [11]:
def train(model, device, train_loader, optimizer, epoch, log_interval=None):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        '''
        if log_interval and batch_idx % log_interval == 0:
            
            print('Train Epoch: {} [({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 100. * batch_idx / len(train_loader), loss.item()))
                '''

def test(model, device, loader, tag, labels=dataset.cell_type_labels):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    preds = []
    corrects = []
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            pred_ = np.ndarray.flatten(pred.cpu().numpy())
            targ_ = target.cpu().numpy()
            preds.append(pred_)
            corrects.append(targ_)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target)
    '''
    print('{} set: Accuracy: {}/{} ({:.0f}%)'.format(tag,
        correct, total,
        100. * correct / total))
        '''
    corrects = np.hstack(corrects)
    preds = np.hstack(preds)
    acc = f1_score(corrects,preds,average='macro')
    print('F1:',acc)
    cm = confusion_matrix(corrects,preds,normalize='true')
    return cm,acc

In [8]:
dataset = load_data()
#train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = get_scaler()

NameError: name 'dataset' is not defined

In [85]:
lr = 1e-3
n_hiddens=[150,75,38]
dropout_p = 0

In [86]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MLP(input_dims=train_dataset.num_bins, n_hiddens=n_hiddens, n_class=n_class, dropout_p=dropout_p).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)
train_accs, val_accs, test_accs = [],[],[]
epochs=200
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval=5)
    train_cm, train_acc = test(model, device, train_loader, 'Train')
    val_cm, val_acc = test(model, device, val_loader, 'Val')
    test_cm, test_acc = test(model, device, test_loader, 'Test')
    print('\n')
    train_accs.append(train_acc)
    val_accs.append(val_acc)
    test_accs.append(test_acc)

Train set: Accuracy: 9531/34000 (28%)
0.27290448786732086
Val set: Accuracy: 9985/34000 (29%)
0.2825267604158765
Test set: Accuracy: 9336/34000 (27%)
0.26041305010600674


Train set: Accuracy: 10599/34000 (31%)
0.28864395385648217
Val set: Accuracy: 10992/34000 (32%)
0.29975322957577916
Test set: Accuracy: 9931/34000 (29%)
0.2716949101412439


Train set: Accuracy: 11479/34000 (34%)
0.3323395614747944
Val set: Accuracy: 12377/34000 (36%)
0.35273207423905356
Test set: Accuracy: 10455/34000 (31%)
0.2936477138781286


Train set: Accuracy: 11032/34000 (32%)
0.3259366045335737
Val set: Accuracy: 11886/34000 (35%)
0.3535055722757058
Test set: Accuracy: 10234/34000 (30%)
0.29600568332834803


Train set: Accuracy: 12409/34000 (36%)
0.3387952209079051
Val set: Accuracy: 12821/34000 (38%)
0.3456019192021726
Test set: Accuracy: 11410/34000 (34%)
0.30307644866287026


Train set: Accuracy: 12443/34000 (37%)
0.34405733980221
Val set: Accuracy: 13087/34000 (38%)
0.3590100511678045
Test set: Accuracy: 

Train set: Accuracy: 12709/34000 (37%)
0.35101204829409616
Val set: Accuracy: 13540/34000 (40%)
0.37376567782815817
Test set: Accuracy: 11583/34000 (34%)
0.31262776927422625


Train set: Accuracy: 12598/34000 (37%)
0.370741567266677
Val set: Accuracy: 13780/34000 (41%)
0.4003114659393385
Test set: Accuracy: 11497/34000 (34%)
0.33128717971972194


Train set: Accuracy: 12488/34000 (37%)
0.35083606400437517
Val set: Accuracy: 13309/34000 (39%)
0.37852391468565344
Test set: Accuracy: 11167/34000 (33%)
0.30914528792934903


Train set: Accuracy: 12859/34000 (38%)
0.37095510232287066
Val set: Accuracy: 13767/34000 (40%)
0.39490403528501616
Test set: Accuracy: 11783/34000 (35%)
0.33144071634670763


Train set: Accuracy: 13156/34000 (39%)
0.37077322442000604
Val set: Accuracy: 13618/34000 (40%)
0.38570633619144673
Test set: Accuracy: 12251/34000 (36%)
0.3419735138138966


Train set: Accuracy: 12825/34000 (38%)
0.38074301667784927
Val set: Accuracy: 13947/34000 (41%)
0.40719189130685907
Test set

Train set: Accuracy: 13301/34000 (39%)
0.38447101209012124
Val set: Accuracy: 14256/34000 (42%)
0.40942030038040406
Test set: Accuracy: 11904/34000 (35%)
0.3411690464081941


Train set: Accuracy: 13117/34000 (39%)
0.3757563812406784
Val set: Accuracy: 13822/34000 (41%)
0.3926697782382235
Test set: Accuracy: 11968/34000 (35%)
0.3358665374668157


Train set: Accuracy: 13561/34000 (40%)
0.3813378791333838
Val set: Accuracy: 13902/34000 (41%)
0.39021573943922794
Test set: Accuracy: 12623/34000 (37%)
0.35111911271591345


Train set: Accuracy: 13433/34000 (40%)
0.38000293255585227
Val set: Accuracy: 14277/34000 (42%)
0.40132634118674054
Test set: Accuracy: 12648/34000 (37%)
0.3517569426993295


Train set: Accuracy: 13662/34000 (40%)
0.39434866915788297
Val set: Accuracy: 14126/34000 (42%)
0.40802884267136064
Test set: Accuracy: 12516/34000 (37%)
0.35703317446848304


Train set: Accuracy: 13675/34000 (40%)
0.38800951898142627
Val set: Accuracy: 14195/34000 (42%)
0.40193377795719976
Test set: 

Train set: Accuracy: 13876/34000 (41%)
0.3898286016503253
Val set: Accuracy: 14254/34000 (42%)
0.40406260527235466
Test set: Accuracy: 12504/34000 (37%)
0.3489936935932534


Train set: Accuracy: 13908/34000 (41%)
0.4011561197733715
Val set: Accuracy: 14126/34000 (42%)
0.40864352197398535
Test set: Accuracy: 12569/34000 (37%)
0.3607269657391274


Train set: Accuracy: 13910/34000 (41%)
0.3956226950597656
Val set: Accuracy: 14359/34000 (42%)
0.4074750318194489
Test set: Accuracy: 12718/34000 (37%)
0.35563993735955785


Train set: Accuracy: 13824/34000 (41%)
0.38781078540499175
Val set: Accuracy: 14377/34000 (42%)
0.4065834587447972
Test set: Accuracy: 12654/34000 (37%)
0.3534329383801751


Train set: Accuracy: 13763/34000 (40%)
0.3872318563273937
Val set: Accuracy: 14414/34000 (42%)
0.4026809909645811
Test set: Accuracy: 12585/34000 (37%)
0.34868288593349195


Train set: Accuracy: 13769/34000 (40%)
0.3870149868141903
Val set: Accuracy: 14450/34000 (42%)
0.4037936088534463
Test set: Accura

Train set: Accuracy: 13973/34000 (41%)
0.39664258982319867
Val set: Accuracy: 14265/34000 (42%)
0.40739915095550255
Test set: Accuracy: 12738/34000 (37%)
0.3564988702226697


Train set: Accuracy: 13670/34000 (40%)
0.3973964104995746
Val set: Accuracy: 14302/34000 (42%)
0.4138739800907197
Test set: Accuracy: 11828/34000 (35%)
0.3409878643270814


Train set: Accuracy: 14028/34000 (41%)
0.3953709826131102
Val set: Accuracy: 14242/34000 (42%)
0.4015201367389712
Test set: Accuracy: 12711/34000 (37%)
0.35470789395222035


Train set: Accuracy: 13898/34000 (41%)
0.39249240683730224
Val set: Accuracy: 14256/34000 (42%)
0.402565845731484
Test set: Accuracy: 12656/34000 (37%)
0.3540537257561452


Train set: Accuracy: 14011/34000 (41%)
0.40291993308422425
Val set: Accuracy: 14268/34000 (42%)
0.4106421113408585
Test set: Accuracy: 12090/34000 (36%)
0.3480456893856529


Train set: Accuracy: 14033/34000 (41%)
0.4015186283064188
Val set: Accuracy: 14224/34000 (42%)
0.40813384936084024
Test set: Accura

Train set: Accuracy: 14005/34000 (41%)
0.394522459253553
Val set: Accuracy: 14059/34000 (41%)
0.39675386607402563
Test set: Accuracy: 12415/34000 (37%)
0.3465951358496613


Train set: Accuracy: 14031/34000 (41%)
0.39805815711430964
Val set: Accuracy: 14239/34000 (42%)
0.40406914363686
Test set: Accuracy: 12492/34000 (37%)
0.3511740623446818


Train set: Accuracy: 14223/34000 (42%)
0.4097332922677556
Val set: Accuracy: 14310/34000 (42%)
0.4107050162403696
Test set: Accuracy: 12467/34000 (37%)
0.3576700958904717


Train set: Accuracy: 14359/34000 (42%)
0.40889428975837605
Val set: Accuracy: 14340/34000 (42%)
0.4097549426478711
Test set: Accuracy: 12490/34000 (37%)
0.3551155866509842


Train set: Accuracy: 14268/34000 (42%)
0.4084695582157747
Val set: Accuracy: 14230/34000 (42%)
0.4065341644738173
Test set: Accuracy: 12693/34000 (37%)
0.3582423065756025


Train set: Accuracy: 13979/34000 (41%)
0.3886406321944693
Val set: Accuracy: 14109/34000 (41%)
0.3951663328549074
Test set: Accuracy: 1

Test set: Accuracy: 12481/34000 (37%)
0.35679527802788474


Train set: Accuracy: 14381/34000 (42%)
0.4154046317240015
Val set: Accuracy: 14433/34000 (42%)
0.41750279305903587
Test set: Accuracy: 12657/34000 (37%)
0.36300665164460694


Train set: Accuracy: 13970/34000 (41%)
0.38905117311535115
Val set: Accuracy: 14065/34000 (41%)
0.39469587746852547
Test set: Accuracy: 12705/34000 (37%)
0.34823458642083294


Train set: Accuracy: 14368/34000 (42%)
0.4047791832364056
Val set: Accuracy: 14317/34000 (42%)
0.40092137428225927
Test set: Accuracy: 12472/34000 (37%)
0.3446693557531842


Train set: Accuracy: 14167/34000 (42%)
0.39784848221621394
Val set: Accuracy: 14243/34000 (42%)
0.40148619188072443
Test set: Accuracy: 12377/34000 (36%)
0.3465508270885919


Train set: Accuracy: 14464/34000 (43%)
0.41722734187969435
Val set: Accuracy: 14274/34000 (42%)
0.4068099366138369
Test set: Accuracy: 12546/34000 (37%)
0.35457543574256034


Train set: Accuracy: 14523/34000 (43%)
0.4176440981297884
Val set

Train set: Accuracy: 14488/34000 (43%)
0.41627686375091794
Val set: Accuracy: 14211/34000 (42%)
0.40466894583473023
Test set: Accuracy: 12345/34000 (36%)
0.35011821500350565


Train set: Accuracy: 14527/34000 (43%)
0.41775022968231607
Val set: Accuracy: 14244/34000 (42%)
0.40712056598730717
Test set: Accuracy: 12541/34000 (37%)
0.35720343088641443


Train set: Accuracy: 14266/34000 (42%)
0.4048873347618577
Val set: Accuracy: 14009/34000 (41%)
0.39831676822394174
Test set: Accuracy: 12046/34000 (35%)
0.3397071710122392


Train set: Accuracy: 14305/34000 (42%)
0.4062925507325441
Val set: Accuracy: 14296/34000 (42%)
0.39980697985198405
Test set: Accuracy: 12344/34000 (36%)
0.3414985717739776


Train set: Accuracy: 14594/34000 (43%)
0.4178376354469687
Val set: Accuracy: 14241/34000 (42%)
0.4042025183875164
Test set: Accuracy: 12409/34000 (36%)
0.3491336144545622


Train set: Accuracy: 14611/34000 (43%)
0.41856242312175235
Val set: Accuracy: 14204/34000 (42%)
0.40562548644061214
Test set: A

Train set: Accuracy: 14801/34000 (44%)
0.42538355335136574
Val set: Accuracy: 14084/34000 (41%)
0.4027823818581985
Test set: Accuracy: 12464/34000 (37%)
0.35130238993742274


Train set: Accuracy: 14705/34000 (43%)
0.42348948357124394
Val set: Accuracy: 14101/34000 (41%)
0.40660582491110653
Test set: Accuracy: 12353/34000 (36%)
0.3532686975863456


Train set: Accuracy: 14645/34000 (43%)
0.42163535322991025
Val set: Accuracy: 14387/34000 (42%)
0.41190475605438176
Test set: Accuracy: 12434/34000 (37%)
0.3552118578679983


Train set: Accuracy: 14833/34000 (44%)
0.4283472303351042
Val set: Accuracy: 14131/34000 (42%)
0.4073943030578375
Test set: Accuracy: 12390/34000 (36%)
0.3569159614386592


Train set: Accuracy: 14725/34000 (43%)
0.4211693156808482
Val set: Accuracy: 14282/34000 (42%)
0.4074518325455361
Test set: Accuracy: 12188/34000 (36%)
0.3490939281909411


Train set: Accuracy: 14689/34000 (43%)
0.4202737245801098
Val set: Accuracy: 14247/34000 (42%)
0.4089643915038202
Test set: Accur

Train set: Accuracy: 14781/34000 (43%)
0.42141664204619694
Val set: Accuracy: 14011/34000 (41%)
0.39963482914729126
Test set: Accuracy: 12347/34000 (36%)
0.35100152425303977


Train set: Accuracy: 14668/34000 (43%)
0.4226202934147749
Val set: Accuracy: 14059/34000 (41%)
0.4021876950613564
Test set: Accuracy: 12448/34000 (37%)
0.3517568562908076


Train set: Accuracy: 14703/34000 (43%)
0.4260413126552281
Val set: Accuracy: 14060/34000 (41%)
0.408443604894868
Test set: Accuracy: 12370/34000 (36%)
0.3575261539046638


Train set: Accuracy: 14775/34000 (43%)
0.4250718167945729
Val set: Accuracy: 14198/34000 (42%)
0.40386825370091134
Test set: Accuracy: 12473/34000 (37%)
0.351375886218896


Train set: Accuracy: 14777/34000 (43%)
0.4193914283124464
Val set: Accuracy: 14064/34000 (41%)
0.3998611220438112
Test set: Accuracy: 12262/34000 (36%)
0.3463892255163062


Train set: Accuracy: 14585/34000 (43%)
0.4194264626338582
Val set: Accuracy: 13914/34000 (41%)
0.3958258441464261
Test set: Accuracy:

Train set: Accuracy: 14968/34000 (44%)
0.4347659710691349
Val set: Accuracy: 13772/34000 (41%)
0.395949841445067
Test set: Accuracy: 12257/34000 (36%)
0.34945657513939343


Train set: Accuracy: 14847/34000 (44%)
0.4260111456674584
Val set: Accuracy: 13823/34000 (41%)
0.3918484399956942
Test set: Accuracy: 12317/34000 (36%)
0.34396775665801954


Train set: Accuracy: 14909/34000 (44%)
0.42688200749859795
Val set: Accuracy: 14274/34000 (42%)
0.40684456097502664
Test set: Accuracy: 12453/34000 (37%)
0.3553571551065616


Train set: Accuracy: 14873/34000 (44%)
0.4312111046228529
Val set: Accuracy: 13828/34000 (41%)
0.395237653886439
Test set: Accuracy: 12296/34000 (36%)
0.34661905025171336


Train set: Accuracy: 14699/34000 (43%)
0.42873355647851225
Val set: Accuracy: 13907/34000 (41%)
0.3991885168756652
Test set: Accuracy: 12305/34000 (36%)
0.35688081389122384


Train set: Accuracy: 14815/34000 (44%)
0.4269924452435194
Val set: Accuracy: 14118/34000 (42%)
0.4043314825990923
Test set: Accura

Train set: Accuracy: 14862/34000 (44%)
0.4216762034079622
Val set: Accuracy: 13956/34000 (41%)
0.3952802567309893
Test set: Accuracy: 12213/34000 (36%)
0.3419829530537443


Train set: Accuracy: 15072/34000 (44%)
0.43982982396368353
Val set: Accuracy: 14112/34000 (42%)
0.405729247489244
Test set: Accuracy: 12378/34000 (36%)
0.35717093346449075


Train set: Accuracy: 14995/34000 (44%)
0.42856809210111546
Val set: Accuracy: 14071/34000 (41%)
0.4007621461464077
Test set: Accuracy: 12631/34000 (37%)
0.3573883905266274


Train set: Accuracy: 14793/34000 (44%)
0.4330518260066095
Val set: Accuracy: 13645/34000 (40%)
0.3993686161148913
Test set: Accuracy: 12120/34000 (36%)
0.35640449885368214


Train set: Accuracy: 15083/34000 (44%)
0.4381179575197328
Val set: Accuracy: 14017/34000 (41%)
0.40276852221445086
Test set: Accuracy: 12360/34000 (36%)
0.3569857344763916


Train set: Accuracy: 15014/34000 (44%)
0.4328077047402161
Val set: Accuracy: 13773/34000 (41%)
0.3938003503872074
Test set: Accurac

Train set: Accuracy: 15043/34000 (44%)
0.42799046305600447
Val set: Accuracy: 13933/34000 (41%)
0.391582189332897
Test set: Accuracy: 12237/34000 (36%)
0.34099349941649093


Train set: Accuracy: 15060/34000 (44%)
0.4331385658450202
Val set: Accuracy: 13917/34000 (41%)
0.39923662624951395
Test set: Accuracy: 12131/34000 (36%)
0.3460152886829183


Train set: Accuracy: 15086/34000 (44%)
0.4293953748196878
Val set: Accuracy: 13654/34000 (40%)
0.38826394110432233
Test set: Accuracy: 12136/34000 (36%)
0.34578054347588144


Train set: Accuracy: 15067/34000 (44%)
0.43540778524819185
Val set: Accuracy: 13916/34000 (41%)
0.40151496004782916
Test set: Accuracy: 12092/34000 (36%)
0.3460016305955897


Train set: Accuracy: 15076/34000 (44%)
0.4266011857297418
Val set: Accuracy: 13782/34000 (41%)
0.38926506190140975
Test set: Accuracy: 12104/34000 (36%)
0.34124004320652573


Train set: Accuracy: 15241/34000 (45%)
0.44033571904276736
Val set: Accuracy: 13880/34000 (41%)
0.3987310982128153
Test set: Ac

Train set: Accuracy: 14888/34000 (44%)
0.4221802981271864
Val set: Accuracy: 13508/34000 (40%)
0.377293957496199
Test set: Accuracy: 12109/34000 (36%)
0.33699411291492265


Train set: Accuracy: 15071/34000 (44%)
0.4309238117267906
Val set: Accuracy: 13903/34000 (41%)
0.3923420200382475
Test set: Accuracy: 12287/34000 (36%)
0.34622722475668194


Train set: Accuracy: 15318/34000 (45%)
0.4445242599588463
Val set: Accuracy: 14031/34000 (41%)
0.40247463419248536
Test set: Accuracy: 12103/34000 (36%)
0.347423771937273


Train set: Accuracy: 15275/34000 (45%)
0.44134445915456866
Val set: Accuracy: 13907/34000 (41%)
0.4006489535231265
Test set: Accuracy: 11994/34000 (35%)
0.34557384779201833


Train set: Accuracy: 15236/34000 (45%)
0.4341184493371083
Val set: Accuracy: 13750/34000 (40%)
0.390526530915081
Test set: Accuracy: 12225/34000 (36%)
0.34469299589986363


Train set: Accuracy: 15208/34000 (45%)
0.44245560753731106
Val set: Accuracy: 13733/34000 (40%)
0.39403123819870245
Test set: Accura

Train set: Accuracy: 15195/34000 (45%)
0.43586152016642465
Val set: Accuracy: 13662/34000 (40%)
0.39124427530620226
Test set: Accuracy: 12288/34000 (36%)
0.3473822439769562


Train set: Accuracy: 15068/34000 (44%)
0.4416730822527841
Val set: Accuracy: 13707/34000 (40%)
0.39477070525439684
Test set: Accuracy: 11923/34000 (35%)
0.34057944118806566


Train set: Accuracy: 15018/34000 (44%)
0.42946914822885257
Val set: Accuracy: 13496/34000 (40%)
0.3850745904833783
Test set: Accuracy: 11820/34000 (35%)
0.3325074767026101


Train set: Accuracy: 15337/34000 (45%)
0.44715168881150247
Val set: Accuracy: 13593/34000 (40%)
0.39434751459729517
Test set: Accuracy: 11950/34000 (35%)
0.3447561506569407


Train set: Accuracy: 15353/34000 (45%)
0.4468466087365768
Val set: Accuracy: 13676/34000 (40%)
0.3940662578919066
Test set: Accuracy: 12042/34000 (35%)
0.34634902220042524


Train set: Accuracy: 15361/34000 (45%)
0.44179624001891377
Val set: Accuracy: 13696/34000 (40%)
0.38732509373630547
Test set: A

Test set: Accuracy: 12023/34000 (35%)
0.3480622131078493


Train set: Accuracy: 15190/34000 (45%)
0.43619710106208665
Val set: Accuracy: 13679/34000 (40%)
0.38924036221081804
Test set: Accuracy: 12143/34000 (36%)
0.34352730994233055


Train set: Accuracy: 15353/34000 (45%)
0.44514287102115263
Val set: Accuracy: 13862/34000 (41%)
0.39733556261734265
Test set: Accuracy: 12086/34000 (36%)
0.34244735380570607


Train set: Accuracy: 15378/34000 (45%)
0.4455868324307931
Val set: Accuracy: 13693/34000 (40%)
0.3926854482369305
Test set: Accuracy: 12149/34000 (36%)
0.3469324228645469


Train set: Accuracy: 15544/34000 (46%)
0.45280455068511777
Val set: Accuracy: 13804/34000 (41%)
0.39629510612960056
Test set: Accuracy: 12041/34000 (35%)
0.3447131942820224


Train set: Accuracy: 15318/34000 (45%)
0.4416414342456171
Val set: Accuracy: 13769/34000 (40%)
0.3931113524333243
Test set: Accuracy: 12105/34000 (36%)
0.3427459321069524


Train set: Accuracy: 15337/34000 (45%)
0.44454312461332124
Val set: 

Train set: Accuracy: 15376/34000 (45%)
0.4437273818457264
Val set: Accuracy: 13621/34000 (40%)
0.3901963243084409
Test set: Accuracy: 12052/34000 (35%)
0.34538447944541195


Train set: Accuracy: 15386/34000 (45%)
0.44838675549213397
Val set: Accuracy: 13698/34000 (40%)
0.3971040943591516
Test set: Accuracy: 11762/34000 (35%)
0.34256707317729135


Train set: Accuracy: 15576/34000 (46%)
0.45314114319309184
Val set: Accuracy: 13642/34000 (40%)
0.3942485800610734
Test set: Accuracy: 11898/34000 (35%)
0.3439542028593831


Train set: Accuracy: 15595/34000 (46%)
0.45371844067781986
Val set: Accuracy: 13769/34000 (40%)
0.39767673863927133
Test set: Accuracy: 12039/34000 (35%)
0.348435574264231


Train set: Accuracy: 15352/34000 (45%)
0.44302996270923584
Val set: Accuracy: 13635/34000 (40%)
0.38551962905976767
Test set: Accuracy: 12149/34000 (36%)
0.3427159595943947


Train set: Accuracy: 15440/34000 (45%)
0.448543832161568
Val set: Accuracy: 13438/34000 (40%)
0.3882022548967662
Test set: Accur

Train set: Accuracy: 15440/34000 (45%)
0.44295343018716515
Val set: Accuracy: 13912/34000 (41%)
0.3967607504693377
Test set: Accuracy: 11888/34000 (35%)
0.3427162363520383


Train set: Accuracy: 15670/34000 (46%)
0.46118147336876786
Val set: Accuracy: 13586/34000 (40%)
0.3930946824685413
Test set: Accuracy: 11783/34000 (35%)
0.3406527184724596


Train set: Accuracy: 15531/34000 (46%)
0.44857891007709927
Val set: Accuracy: 13575/34000 (40%)
0.3883220320817417
Test set: Accuracy: 11916/34000 (35%)
0.3416569081743481


Train set: Accuracy: 15457/34000 (45%)
0.4481662739877739
Val set: Accuracy: 13536/34000 (40%)
0.3886680591361795
Test set: Accuracy: 11740/34000 (35%)
0.33578171825083425


Train set: Accuracy: 15571/34000 (46%)
0.45504460271291913
Val set: Accuracy: 13856/34000 (41%)
0.3992230137382823
Test set: Accuracy: 11954/34000 (35%)
0.345842828046412


Train set: Accuracy: 15408/34000 (45%)
0.44652708427950394
Val set: Accuracy: 13634/34000 (40%)
0.38610090858490315
Test set: Accur

Train set: Accuracy: 15771/34000 (46%)
0.4534623814546266
Val set: Accuracy: 13853/34000 (41%)
0.3974055267495552
Test set: Accuracy: 11983/34000 (35%)
0.3417327547067772


Train set: Accuracy: 15619/34000 (46%)
0.4496447888266002
Val set: Accuracy: 13752/34000 (40%)
0.3922901581059012
Test set: Accuracy: 11908/34000 (35%)
0.33803757936807777


Train set: Accuracy: 15398/34000 (45%)
0.4402002121710584
Val set: Accuracy: 13457/34000 (40%)
0.3835286630879351
Test set: Accuracy: 11628/34000 (34%)
0.3284659433175666


Train set: Accuracy: 15721/34000 (46%)
0.45978013999342054
Val set: Accuracy: 13486/34000 (40%)
0.39090420448093866
Test set: Accuracy: 11709/34000 (34%)
0.33550074492026344


Train set: Accuracy: 15508/34000 (46%)
0.4552569399703719
Val set: Accuracy: 13557/34000 (40%)
0.39129049538079314
Test set: Accuracy: 11836/34000 (35%)
0.3425277777527349


Train set: Accuracy: 15475/34000 (46%)
0.4447156922978885
Val set: Accuracy: 13652/34000 (40%)
0.3896743254123133
Test set: Accura

Train set: Accuracy: 15796/34000 (46%)
0.4577692823384884
Val set: Accuracy: 13662/34000 (40%)
0.39242827537889236
Test set: Accuracy: 11511/34000 (34%)
0.33443370301828823


Train set: Accuracy: 15813/34000 (47%)
0.46031211712872633
Val set: Accuracy: 13805/34000 (41%)
0.39790227883380186
Test set: Accuracy: 11999/34000 (35%)
0.34760528453814643


Train set: Accuracy: 15663/34000 (46%)
0.4556898803097503
Val set: Accuracy: 13540/34000 (40%)
0.39084695080529197
Test set: Accuracy: 11825/34000 (35%)
0.3427601911830467


Train set: Accuracy: 15632/34000 (46%)
0.4512244067811751
Val set: Accuracy: 13616/34000 (40%)
0.3896376070205289
Test set: Accuracy: 11759/34000 (35%)
0.33835846500757294


Train set: Accuracy: 15721/34000 (46%)
0.4601037729844991
Val set: Accuracy: 13682/34000 (40%)
0.395620070133457
Test set: Accuracy: 11855/34000 (35%)
0.34030537853504594




KeyboardInterrupt: 

plot_cm(test_cm,cell_type_labels=dataset.cell_type_labels,outfilename='4bestfr20nh_adaptive_cm.png')
plot_accs(train_accs,val_accs,test_accs,outfilename='4bestfr20nh_adaptive_acc.png')

In [None]:
def plot_cm(cm,cell_type_labels=dataset.cell_type_labels,outfilename='ex_cm.png'):
    plt.figure(figsize = (10,7))
    plt.set_cmap('Reds')
    plt.xticks(range(len(cell_type_labels)),dataset.cell_type_labels,rotation='vertical')
    plt.yticks(range(len(cell_type_labels)),dataset.cell_type_labels)
    plt.imshow(cm,vmin=0,vmax=1)
    plt.colorbar()
    for (j,i),label in np.ndenumerate(cm):
        if i == j:
            plt.text(i,j,np.round(label,2),ha='center',va='center',fontsize='small',weight='bold')        
        else:
            plt.text(i,j,np.round(label,2),ha='center',va='center',fontsize='small')
    if outfilename == None:
        plt.show()
    else:
        plt.savefig(outfilename,dpi=300)    
    plt.clf()

def plot_accs(train_accs,val_accs,test_accs,outfilename='ex_acc.png'):
    plt.plot(train_accs,label='Train')
    plt.plot(val_accs,label='Validate')
    plt.plot(test_accs,label='Test')
    plt.xlabel('Iterations')
    plt.ylabel('F-Measure')
    plt.legend()
    if outfilename == None:
        plt.show()
    else:
        plt.savefig(outfilename,dpi=300)
    plt.clf()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MLP(input_dims=train_dataset.num_bins, n_hiddens=[20, 20, 20], n_class=n_class, dropout_p=0.2).to(device)
print(model)

lr = 1e-2
optimizer = optim.Adam(model.parameters(), lr=lr)

epochs=100
for epoch in range(1, epochs + 1):
    train(model, device, train_loader2, optimizer, epoch, log_interval=5)
    test(model, device, train_loader2, 'Train')
    test(model, device, val_loader2, 'Val')
    test(model, device, test_loader2, 'Test')
    print('\n')

In [None]:
print(isi_dist_bins)

In [None]:
isi_dist_bins_list = [50,100,200,400]
for isi_dist_bins in isi_dist_bins_list:
    dataset = load_data()
    train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = get_scaler()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = MLP(input_dims=train_dataset.num_bins, n_hiddens=n_hiddens, n_class=n_class, dropout_p=dropout_p).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_accs, val_accs, test_accs = [],[],[]
    epochs=20
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval=5)
        train_cm, train_acc = test(model, device, train_loader, 'Train')
        val_cm, val_acc = test(model, device, val_loader, 'Val')
        test_cm, test_acc = test(model, device, test_loader, 'Test')
        print('\n')
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        test_accs.append(test_acc)
    plot_cm(test_cm,outfilename='testcm_adaptive{}.png'.format(str(isi_dist_bins)))
    plot_cm(train_cm,outfilename='traincm_adaptive{}.png'.format(str(isi_dist_bins)))
    plot_cm(val_cm,outfilename='valcm_adaptive{}.png'.format(str(isi_dist_bins)))
    plot_accs(train_accs,val_accs,test_accs,outfilename='acc_adaptive{}.png'.format(str(isi_dist_bins)))

In [None]:
isi_dist_bins_list = [np.arange(0,0.402,0.002).tolist(),np.arange(0,0.402,0.004).tolist(),np.arange(0,0.402,0.008).tolist(),np.arange(0,0.401,0.001).tolist()]
for isi_dist_bins in isi_dist_bins_list:
    dataset = load_data()
    train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = get_scaler()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = MLP(input_dims=train_dataset.num_bins, n_hiddens=n_hiddens, n_class=n_class, dropout_p=dropout_p).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_accs, val_accs, test_accs = [],[],[]
    epochs=20
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval=5)
        train_cm, train_acc = test(model, device, train_loader, 'Train')
        val_cm, val_acc = test(model, device, val_loader, 'Val')
        test_cm, test_acc = test(model, device, test_loader, 'Test')
        print('\n')
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        test_accs.append(test_acc)  
    plot_cm(test_cm,outfilename='testcm_even{}.png'.format(str(len(isi_dist_bins)-1)))
    plot_cm(train_cm,outfilename='traincm_even{}.png'.format(str(len(isi_dist_bins)-1)))
    plot_cm(val_cm,outfilename='valcm_even{}.png'.format(str(len(isi_dist_bins)-1)))
    plot_accs(train_accs,val_accs,test_accs,outfilename='acc_even{}.png'.format(str(len(isi_dist_bins)-1)))

In [None]:
sampler_list = ['R10','R20','R40','R60','R80','R100']
isi_dist_bins = np.arange(0,0.402,0.002).tolist()
for sampler in sampler_list:
    dataset = load_data()
    train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = get_scaler()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = MLP(input_dims=train_dataset.num_bins, n_hiddens=n_hiddens, n_class=n_class, dropout_p=dropout_p).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_accs, val_accs, test_accs = [],[],[]
    epochs=20
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval=5)
        train_cm, train_acc = test(model, device, train_loader, 'Train')
        val_cm, val_acc = test(model, device, val_loader, 'Val')
        test_cm, test_acc = test(model, device, test_loader, 'Test')
        print('\n')
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        test_accs.append(test_acc)
    plot_cm(test_cm,outfilename='testcm_{}.png'.format(str(sampler)))
    plot_cm(train_cm,outfilename='traincm_{}.png'.format(str(sampler)))
    plot_cm(val_cm,outfilename='valcm_{}.png'.format(str(sampler)))
    plot_accs(train_accs,val_accs,test_accs,outfilename='acc_{}.png'.format(str(sampler)))

In [None]:
isi_dist_bins = np.arange(0,0.402,0.002).tolist()
sampler = 'R40'
n_hiddens_list = [[10,10,10],[20,20,20],[30,30,30],[40,40,40],[50,50,50]]
for n_hiddens in n_hiddens_list:
    dataset = load_data()
    train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = get_scaler()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = MLP(input_dims=train_dataset.num_bins, n_hiddens=n_hiddens, n_class=n_class, dropout_p=dropout_p).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_accs, val_accs, test_accs = [],[],[]
    epochs=20
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval=5)
        train_cm, train_acc = test(model, device, train_loader, 'Train')
        val_cm, val_acc = test(model, device, val_loader, 'Val')
        test_cm, test_acc = test(model, device, test_loader, 'Test')
        print('\n')
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        test_accs.append(test_acc)
    plot_cm(test_cm,outfilename='testcm_nh{}.png'.format(str(n_hiddens[0])))
    plot_cm(train_cm,outfilename='traincm_nh{}.png'.format(str(n_hiddens[0])))
    plot_cm(val_cm,outfilename='valcm_nh{}.png'.format(str(n_hiddens[0])))
    plot_accs(train_accs,val_accs,test_accs,outfilename='acc_nh{}.png'.format(str(n_hiddens[0])))

In [None]:
sampler_list = ['R10','R20','R60','R80','R100','R40']
isi_dist_bins = np.arange(0,0.402,0.002).tolist()
sampler = 'R40'
n_hiddens_list = [[20],[20,20],[20,20,20],[20,20,20,20]]
for n_hiddens in n_hiddens_list:
    dataset = load_data()
    train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = get_scaler()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = MLP(input_dims=train_dataset.num_bins, n_hiddens=n_hiddens, n_class=n_class, dropout_p=dropout_p).to(device)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_accs, val_accs, test_accs = [],[],[]
    epochs=20
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval=5)
        train_cm, train_acc = test(model, device, train_loader, 'Train')
        val_cm, val_acc = test(model, device, val_loader, 'Val')
        test_cm, test_acc = test(model, device, test_loader, 'Test')
        print('\n')
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        test_accs.append(test_acc)
    plot_cm(test_cm,outfilename='testcm_nl{}.png'.format(str(len(n_hiddens))))
    plot_cm(train_cm,outfilename='traincm_nl{}.png'.format(str(len(n_hiddens))))
    plot_cm(val_cm,outfilename='valcm_nl{}.png'.format(str(len(n_hiddens))))
    plot_accs(train_accs,val_accs,test_accs,outfilename='acc_nl{}.png'.format(str(len(n_hiddens))))