In [1]:
import numpy as np

import data_io
import data_preprocessing
from implementations import *
import validation
import attribute_selection
import evaluators
import metrics

# Autoreload modules
%load_ext autoreload
%autoreload 2

In [2]:
from google.colab import drive
drive.mount('/content/drive')

DATA_FILE_PREFIX = '/content/drive/My Drive/mlproject1_higgs_data/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
y_train, x_train, _, cols = data_io.load_csv_data(f'{DATA_FILE_PREFIX}train.csv')
#_, x_test, ids_test, cols_train = data_io.load_csv_data(f'{DATA_FILE_PREFIX}test.csv')

In [4]:
col_to_index_mapping = {col_name: index - 2 for index, col_name in enumerate(cols) if index >= 2}
y_train = (y_train + 1) // 2

In [5]:
import operator

def tune_lambda(y, x, grid, seed=42, history=True):
    w_init = np.zeros(x.shape[1])
    res = {}
    for lambda_ in grid:
        np.random.seed(seed)
        train_model = lambda y_, x_: make_predictor(reg_logistic_regression_sgd(
            y_, x_, lambda_, w_init, 5, 1000, 0.5,
        )[0])
        res[lambda_] = validation.cross_validation(y, x, train_model, 5)[0].mean()
        if history:
          print(f"{lambda_}: {res[lambda_]:.4f}")
    return max(res.items(), key=operator.itemgetter(1))

In [6]:
def make_predictor(w):
  def foo(features):
    return (features @ w > 0).astype(int)
  return foo

In [7]:
def train_model(y, x):
  w_init = np.zeros(x.shape[1])
  lambda_ = 1e-5
  return make_predictor(reg_logistic_regression_sgd(
      y, x, lambda_, w_init, 5, 1000, 0.5)[0])

In [30]:
def build_pairwise_plus(x, column_idx):
    """build pairwise multiplyed features x"""
    if x.ndim == 1:
        x = x[:, np.newaxis]
        
    columns = np.copy(x[:, column_idx])
    pairwise = []
    for i in range(columns.shape[1] - 1):
        for j in range(i + 1, columns.shape[1] - 1):
            pairwise.append(columns[:, i] + columns[:, j])
    pairwise = np.array(pairwise).T
    return np.concatenate([np.copy(x), pairwise], 1)


def transformation_pipeline_median_selected_pairwise(x, col_to_index_mapping=col_to_index_mapping, transformation_memory=None):
    # Memory is required in order to apply same transformation on training and test data
    training = transformation_memory is None
    if training:
      transformation_memory = {}

    tx = np.copy(x) # Recommended to copy x so it doesn't change

    # Creating binary column indicating whether given column is missing for 
    #   each column that contains missing values
    if training:
      columns_with_missing_values = np.max((tx == -999), axis=0)
      transformation_memory['columns_with_missing_values'] = columns_with_missing_values
    missing_columns_binary = (tx[:, transformation_memory['columns_with_missing_values']] == -999)\
              .astype(int)
    
    # remove missing values with NANs
    tx[tx == -999.] = np.nan

    # Calculate mean and standard deviation in order to to later standardize data
    base_standardize_col_idx = [col_to_index_mapping[key] for key in col_to_index_mapping if 'PRI_jet_num' not in key]
    base_standardize_cols = tx[:, base_standardize_col_idx]
    if training:
      mean = np.nanmean(base_standardize_cols, axis=0)
      stddev = np.nanstd(base_standardize_cols, axis=0)
      transformation_memory['base_mean'] = mean
      transformation_memory['base_stddev'] = stddev

    # Standardize data
    tx[:, base_standardize_col_idx] = (base_standardize_cols - transformation_memory['base_mean']) \
          / transformation_memory['base_stddev']
    
    d = tx.shape[1]
    
    # Find which columns need their values clipped
    if training:
      CLIP = [None] * d
      for c in range(d):
        if c != col_to_index_mapping['PRI_jet_num']:
          v = tx[:, c]
          v = v[~np.isnan(v)]
          min_, max_ = np.min(v), np.max(v)
          do_clip = False
          q9 = np.quantile(v, 0.9)
          if (max_ - q9)/q9 > 1.:
            max_ = q9
            do_clip = True
              
          q1 = np.quantile(v, 0.1)
          if np.abs((q1 - min_)/q1) > 1.:
            min_ = q1  
            do_clip = True
          if do_clip:
            CLIP[c] = (min_, max_)
      transformation_memory['CLIP'] = CLIP

    # standardize and normalize may change value of fields from default missing values, so it uses matrix calculated before 
    #   applying transformations (0 = mean after standardization)
    tx[np.isnan(tx)] = 0

    # Clipping values for columns that need to be clipped and adding binary column indicating
    #   whether the given column for the sample was clipped
    #CLIP = transformation_memory['CLIP']
    #is_clipped_column = []
    #for c in range(d):
    #  if CLIP[c] is not None:
    #    col = tx[:, c]
    #    mask = np.logical_or(col < CLIP[c][0], col > CLIP[c][1])
    #    tx[mask, c] = 0
    #    is_clipped_column.append(mask[:, np.newaxis].astype(int))

    #is_clipped_binary = np.concatenate(is_clipped_column, axis=1)
     
    # onehot for categorical and drop one level
    tx, col_to_index_mapping_upd = data_preprocessing.one_hot_transformation(tx, 'PRI_jet_num', col_to_index_mapping)
    tx = tx[:, :-1]

    # Augment features using sin and cos
    sins = np.sin(tx)
    coses = np.cos(tx)
    tx = np.concatenate((tx, sins, coses), axis=1)
    
    # Select best features (determined using backwards attribute selection)
    first_selection_attr = [1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 14, 15, 16, 19, 21, 24, 29, 30, 32, 33, 34, 35, 39, 40, 41, 43, 44, 46, 48, 49, 51, 56, 57, 58, 61, 62, 64, 65, 66, 67, 68, 71, 73, 74, 75, 78, 80, 81, 86, 87, 90, 93, 94, 95]
    tx = tx[:, first_selection_attr]
    
    d = tx.shape[1]
    
    # Add polynomial degrees 2 and 3 for the selected features
    poly = data_preprocessing.build_poly(tx, list(range(d)), [2, 3])


    if training:
      poly_mean = np.nanmean(poly, axis=0)
      poly_stddev = np.nanstd(poly, axis=0)
      transformation_memory['poly_mean'] = poly_mean
      transformation_memory['poly_stddev'] = poly_stddev

    # Standardize value of polynomial degrees
    poly = (poly - transformation_memory['poly_mean']) / transformation_memory['poly_stddev']

    # Add features multiplied with each other, stratified polynomial degrees 2 and 3 and
    #   binary columns for missing and clipped features
    tx = np.c_[
               data_preprocessing.build_pairwise_alt(tx, list(range(d))), 
               poly, missing_columns_binary]#, is_clipped_binary]

    # Add bias
    tx = data_preprocessing.prepend_bias_column(tx)
    
    return tx, transformation_memory

In [31]:
tx_train_2, memory = transformation_pipeline_median_selected_pairwise(x_train)


In [33]:
tx_train_2.shape

(250000, 1605)

In [32]:
del x_train

In [None]:
del tx_train_2

In [None]:
print(tx_train_2.shape)

(250000, 1654)


In [None]:
import operator

def tune_lambda(y, x, grid, seed=42, history=True):
    w_init = np.zeros(x.shape[1])
    res = {}
    for lambda_ in grid:
        np.random.seed(seed)
        train_model = lambda y_, x_: make_predictor(reg_logistic_regression_sgd(
            y_, x_, lambda_, w_init, 100, 2000, 0.1,
        )[0])
        res[lambda_] = validation.cross_validation(y, x, train_model, 5)[0].mean()
        if history:
          print(f"{lambda_}: {res[lambda_]:.4f}")
    return max(res.items(), key=operator.itemgetter(1))

tx_train_to_tune = tx_train_2
ty_train_to_tune = y_train

_ = tune_lambda(ty_train_to_tune, tx_train_to_tune, [0, 1e-5, 1e-7, 1e-8, 1e-9])

  exp_x = np.exp(x)
  exp_x / (1 + exp_x))


0: 0.8409
1e-05: 0.8412
1e-07: 0.8414
1e-08: 0.8415
1e-09: 0.8416


In [35]:
w = None
def train_model(y_, x_):
  global w
  w_init = np.zeros(x_.shape[1])
  lambda_ = 1e-9
  weights, loss = reg_logistic_regression_sgd(
    y_, x_, lambda_, w_init, 400, 2000, 0.04,
  )
  w = weights
  return make_predictor(weights)

predict = train_model(y_train, tx_train_2)
metrics.accuracy(y_train, predict(tx_train_2))

  exp_x = np.exp(x)
  exp_x / (1 + exp_x))


0.841888

In [36]:
tx_train_2.shape

(250000, 1605)

In [37]:
import numpy as np
np.set_printoptions(threshold=np.inf)

print('84.1888% weights:', list(w))

84.1888% weights: [-0.030417508408354384, 0.10698475436687996, 0.04853731440605264, 0.09403148293082435, 0.001772579046417737, -0.0007317566482108598, 0.0018049564709619367, 0.14316584941318072, 0.06757262855573323, -0.12918320817008014, -0.010013475649925318, 0.007470658485405547, -0.02326616626672536, 0.03450872493233192, 0.045698003057938495, -0.13581897962644812, 0.0012143915409809963, 0.022873801697797354, 0.00921668491550147, 0.304611600100228, 0.030757886433071417, 0.20663165259010188, -0.00590828174903174, 0.07039962701947476, -0.11418978294462259, 0.10574828134492362, -0.0335966205712955, 0.14783672716120957, 0.005437426897650698, -0.0027443129993442467, -0.038874495375593385, 0.03549554475331292, -0.0033716859372763543, 0.08619737351507845, -0.01195324180253799, 0.019247640440945825, 0.007755572932511067, -0.08390828589631631, -0.044579804286668634, 0.09444513412086476, -0.05922948758780334, 0.00765293068877501, 0.02759101287901558, -0.023119982339465182, -0.02382021441942353

In [42]:
w8418 = [-0.030417508408354384, 0.10698475436687996, 0.04853731440605264, 0.09403148293082435, 0.001772579046417737, -0.0007317566482108598, 0.0018049564709619367, 0.14316584941318072, 0.06757262855573323, -0.12918320817008014, -0.010013475649925318, 0.007470658485405547, -0.02326616626672536, 0.03450872493233192, 0.045698003057938495, -0.13581897962644812, 0.0012143915409809963, 0.022873801697797354, 0.00921668491550147, 0.304611600100228, 0.030757886433071417, 0.20663165259010188, -0.00590828174903174, 0.07039962701947476, -0.11418978294462259, 0.10574828134492362, -0.0335966205712955, 0.14783672716120957, 0.005437426897650698, -0.0027443129993442467, -0.038874495375593385, 0.03549554475331292, -0.0033716859372763543, 0.08619737351507845, -0.01195324180253799, 0.019247640440945825, 0.007755572932511067, -0.08390828589631631, -0.044579804286668634, 0.09444513412086476, -0.05922948758780334, 0.00765293068877501, 0.02759101287901558, -0.023119982339465182, -0.02382021441942353, 0.0004272444584879405, 0.044218510520661294, 0.021535040453953942, 0.07490714246655718, 0.07529879077299574, -0.024628128924143344, -0.09730903114271042, -0.04093254230486105, -0.03465439721154954, -0.04258132698267704, -0.053708621389371894, -0.18635870823155612, 0.1430409811408772, -0.07378478344803052, 0.0714323034729632, -0.19997137474137971, -0.023454605953292282, 0.11415819964246339, -0.0063465991471776005, -0.005919690714240191, -0.02350339627655959, 0.024175067460868185, 0.10332451362745668, -0.03532256368630053, -0.05303241029877695, 0.18161662373024068, 0.090723907387561, -0.02784990054040276, 0.0435159805808886, -0.10038299214866504, 0.045887364027421135, 0.15668722080090922, -0.17350464129757942, -0.14356984823422947, 0.0325671185996022, -0.03890773690921194, 0.022440511560682233, -0.10148020994414804, -0.051242685957674824, -0.07182817774071239, 0.017470638932649357, 0.05253866122293118, 0.0441976221758716, 0.15282511922777176, 0.07634153569503137, -0.20876711163378184, 0.02030976571697162, 0.03138758955840536, 0.02922229257607449, -0.0446815241723879, -0.0007392668518153056, -0.10086655847088548, 0.11891006923778114, -0.04185144158770778, 0.0476663033461591, 0.050636430527424135, -0.07793351713323295, 0.020546899273687132, 0.2010677861874382, 0.11300147875906646, 0.02349601122207645, 0.06527918333818686, -0.07581676300635416, -0.019287480976983525, 0.06662055332232418, -0.055675810347769104, 0.04102297153129655, -0.23761538140629543, -0.005039478693532261, -0.18302776239116478, 0.03499333096624771, -0.028518761392508085, 0.033819425080504666, 0.050126615547243784, 0.1069561686585426, -0.0375065893594325, 0.06880238499826542, -0.11982541525513259, 0.03819953271434217, -0.12107351489401487, -0.1696993436392965, 0.09511253179571874, -0.19181014528699095, -0.3270081370357077, 0.013874830032718334, 0.18506518555090648, 0.032362475868833875, 0.022053839099662145, 0.021649509271411065, 0.06376732205861689, -0.04625456173662549, 0.20185904740415334, -0.06561857459595283, 0.031111274595353754, -0.06271375437775034, -0.10082961017975077, 0.03214379841233955, -0.11222163041461736, -0.12324065872085284, -0.051022840255700655, -0.02265236305484843, -0.0653566307181537, -0.2412307588543227, 0.08738113734576328, -0.02977445797094535, -0.20172578434764785, -0.24720034044335298, -0.13923263645400064, -0.00801607355781126, -0.05018072212720818, -0.08847166518153919, 0.016348666862595458, 0.10362078149723111, 0.030977077300355527, -0.1274210264992648, -0.008530149785439268, 0.023953533046537318, 0.006367805006892738, 0.2265493421939355, 0.015535408481595436, 0.1695627677534966, -0.0376645091295444, -0.018283712098606165, -0.06355396763505017, 0.041145969219857255, -0.024901532725283403, 0.001399747140198019, -0.03542734494491079, -0.01207154245290296, -0.07482658227478223, 0.03543207986469278, -0.1578920458939345, -0.270873085053754, 0.11639514597108512, -0.21703934129898395, -0.1300249222583874, 0.09518196587543698, -0.022372876703231733, 0.08264796537345305, 0.11516004480942735, -0.07134268030533165, -0.09795994413512708, 0.006256919670894581, -0.03157902608368529, 0.018557762485209284, 0.049436183964085896, -0.010157852715994786, -0.06296439787656866, 0.4200208867412803, 0.13031431626691733, 0.18589469751200774, 0.05654938795427327, 0.07398370516089334, -0.1718306309748734, 0.02519339181857983, 0.09763892401844176, -0.03276284052120803, 0.0339382818661237, 0.05594560071168605, 0.025174271930648605, -0.026304638264456153, -0.11112627024616081, -0.025268042501487068, 0.09958074316103832, 0.12842909026230923, 0.06117541643501024, 0.28696485868744087, 0.09003985269918643, -0.03049766729756727, 0.014398476962289724, 0.03211408396064861, -0.052785311407016325, 0.0028397755399605126, -0.0016709040589089328, -0.021878068337641138, -0.0064104540455587064, 0.014221560513617088, 0.01453874309706127, 0.0, 0.0, 0.12649495126974003, 0.15028685981044473, 0.16657043842803082, -0.008245721102541077, 0.06531038246842048, 0.018659064832656586, 0.13878547084544965, -0.0673031968856314, 0.41956764648270867, -0.02802436936329725, 0.03800375678863625, -0.0166443528893352, -0.003350054219630499, 0.04601168945287874, 0.11609324148637527, 0.018101073502174568, 0.0, 0.0, 0.09257522610518741, -0.05459423196413813, 0.01660286172458275, 0.04300643213810379, 0.11318858808405377, 0.12347129861453443, -0.025884367301037478, -0.04389514381560708, 0.0633725495169307, -0.02648994064366663, 0.031395933549596876, -0.10780522347467836, -0.1319169949550658, -0.0006870576048508036, 0.04295357817935623, 0.001772579046417737, 0.001772579046417737, -0.2945362123861617, -0.10194548481339855, -0.0019718247942958067, 0.013183594356956232, -0.02885426386408212, 0.033508491145503395, 0.03169250423341391, 0.05427314702581292, 0.026557596715991394, -0.022690239918779085, -0.023852192895922924, -0.013488395103115478, 0.0, 0.0, 0.09925564858375255, -0.022940552969185044, 0.05507772306711934, -0.05448381741977441, 0.061659572113770864, 0.07678954134422532, 0.06220051316017565, -0.015391847615903467, 0.3896619661251733, -0.005693282773845754, -0.02701193626139759, -0.020088759717855745, 0.016156712492904096, 0.015848856488888792, -0.15492353805082607, -0.032870565194043404, 0.0, 0.0, 0.08372348706510989, 0.031572017743860585, 0.06787554775907625, 0.028236577182781445, 0.14716281636593834, -0.04599015699240966, 0.04576027396562107, -0.02865196768724156, -0.04748614629873007, 0.031677423842812906, -0.014291189237678465, -0.11683462715020168, 0.017373630028131876, 0.0007710012435746558, -0.12193932780366175, -0.0007317566482108598, -0.0007317566482108598, -0.27279206469004585, 0.05198041368040978, 0.049621153877430754, -0.04666790934545857, -0.07190524323155731, 0.04404815581672309, 0.03363831693115084, 0.0543214788309677, -0.005788034646473869, 0.029906185501278694, -8.354651978539164e-05, 0.0, 0.0, -0.004597260880318397, 0.006531179553334406, 0.016306646244812966, -0.031415598694072754, -0.007414551469891326, 0.07491911219692057, 0.06807120787900872, -0.042768498251672074, -0.03895513056470884, -0.00663641253326814, 0.005249831746810455, -0.017365208154179003, 0.004677139529893586, 0.027566498314228558, -0.04031220418215219, 0.005474800493468779, 0.0, 0.0, -0.18617120584184224, 0.010890216545229429, -0.00871332465642246, 0.0030511766824557517, -0.15617878911526967, 0.02559846354763585, -0.0442491926310965, -0.042161586101797884, 0.015778326828466357, -0.2663815418357105, 0.07726319300192057, 0.05763431044217493, -0.048457299594908604, 0.0484751914941105, -0.08486352078987845, 0.0018049564709619367, 0.0018049564709619367, 0.1694628743811785, -0.08892645467765649, -0.05125023362631209, 0.09786039711184806, 0.015342257192634913, -0.025265538568197927, 0.14601677834692295, -0.09558723948286246, 0.051836539273841506, 0.0927457711008788, 0.10247402850987596, -0.017380875759085843, -0.3281144736457683, -0.1828344129548803, -0.16420150121427518, 0.2910167310918835, -0.04184483345272783, -0.11109227532022668, -0.08263813503605236, 0.10403544743998366, 0.09604231337667125, 0.06788955766194837, -0.0742724104197062, 0.006926801990509334, 0.02536003619766391, -0.006241388522294266, -0.004703576698174736, 0.01175733175521789, 0.08622892168743672, -0.014625502641821865, -0.09631022018034244, 0.09774892992304196, -0.042880069162954795, -0.19047229285807188, 0.04150287391791661, -0.03049229308103073, -0.19915324209786187, 0.08624489507202368, 0.23515697177740724, 0.029865808021721767, 0.02882160186531683, 0.008612001614909433, 0.03683318656809599, -0.09720947183200487, 0.014112043192680213, 0.09605877479878867, 0.15115579792162437, -0.0032873231859547342, 0.04672793984356739, -0.08309788687022693, 0.004465304607873394, 0.0031491986267496636, -0.05596239417626765, 0.003467594563292936, -0.042196601933865485, -0.024400473410586664, 0.005530395505778111, -0.06837291663667452, -0.2630200876970097, 0.14485497304829723, 0.269442266558456, -0.04781635321947454, 0.028325397812362187, -0.00627783404721028, -0.03857259567449627, -0.010427168123460015, 0.028982082059992698, -0.03444895973803705, 0.015994919106109822, 0.011882213329434045, 0.027752560176571028, 0.002610912555196132, 0.07995691064304279, -0.023448216983752827, 0.004653667352624224, -0.05753382549645075, -0.372291057529727, 0.018790565300045886, 0.24280612291420126, 0.002816418680272112, -0.07910069655078361, -0.010785079233451253, 0.09031219345839912, 0.030621977251016265, 0.04473900248225063, 0.025783727629015516, 0.020681294205735896, 0.01018205761155228, -0.08226674933877172, -0.023826084217342735, 0.12491302570984845, 0.06503031849408848, 0.09900350067468143, -0.015556753167709509, 0.17428978825554584, -0.018090331708451788, 0.02158105284451238, -0.020500154725312046, -0.14589192827859448, 0.019272556025286252, -0.017784014164734952, -0.2539201142084656, 0.0439404324185693, -0.35280095825747176, -0.3525694703885614, -0.1584526005374421, -0.03101454050622611, -0.019651209332567626, -0.0049100906885678796, 0.1687022085993064, 0.06248425192278774, -0.0010862513512989904, 0.038566976654225385, 0.08507008565606462, 0.0016379309473001597, -0.04175999355767108, -0.007087041453250672, 0.05167958956610045, -0.01449372762859737, -0.21366640856553187, 0.036974598940137424, -0.23897910643553968, -0.03315034541807198, -0.27878097662040185, 0.12086554798320179, -0.0814817620175191, -0.015966735546973636, 0.0047621803996514105, 0.005808543037531779, -0.07487945276171457, -0.01089328409516517, -0.014283504949605444, 0.1102993466491066, -0.019134417072808885, -0.05755170644897098, -0.1293321498987785, -0.012456717174749565, -0.14938252363205604, -0.1061832126286139, -0.003794904998939635, -0.008422303046686971, -0.06273035140184398, 0.05711593656512154, -0.033340886061208656, -0.016358040577051196, -0.003841192205254939, -0.05231099000314698, 0.05983362073525047, 0.01802653480125416, -0.02599408009172148, -0.00025914237260979075, -0.3549042648911852, -0.0038004792492729532, -0.0564687543732893, -0.044363824350370167, 0.0647651582859597, -0.03279730460171117, -0.03541235963611442, -0.008041717062062533, -0.04898600078569023, 0.008789066291565704, 0.09621686305685785, 0.0036343975289951456, -0.003232251787792154, -0.04401818027422381, 0.0021570354643178245, 0.06771674194513107, -0.01693734593929081, 0.054489675438127216, 0.04682385825506152, 0.1108001509848728, 0.09020896910765001, 0.10939588629895276, -0.07377449506062689, 0.006768333960869007, -0.0105788939551037, -0.0037686792417556915, 0.010456654131193168, -0.024331872508850827, -0.03568310007369948, -0.008247688450452487, 0.014033765832275539, 0.03841251525477445, 0.008124893808791372, -0.0534165323064568, -0.025243410276357472, -0.03880097312858317, 0.05872921900467513, 0.028138131892671964, 0.054595099112115346, -0.019776365069047543, -0.033248900444320405, 0.013166800573216871, -0.049910694859458746, -0.054684900866970705, -0.04011025492231782, 0.032593463679041416, 0.024148037750699264, -0.015980002356301353, -0.08753392556549461, -0.022550131277835308, 0.21804837835392224, 0.029718803380437493, -0.0427291050919677, -0.010829184924951583, 0.004593029946613685, 0.023677421554381118, 0.04594019181555617, 0.025491535234769814, 0.03650460705274911, -0.048851451376920166, 0.06713674957354547, -0.0003194745798423904, 0.034956808009451945, -0.06617586387600945, -0.014405627814340351, -0.047347774584970026, -0.00576624216032557, 0.031223608342123656, 0.017601056438877893, -0.0007715484952153823, -0.03789245317703534, 0.014139339275988403, -0.0054643758628339775, -0.017626582687334344, -0.04443201543092565, 0.001511454441135266, 0.041340871386529365, -0.014173345057696471, 0.03531204374231429, -0.0012888160600987703, -0.004955780398549631, 0.025329661424380923, 0.006480158631655059, -0.025374797048894566, -0.013729221804398534, 0.009343369894050128, -0.001144821590552023, 0.054407541311424196, 0.036935606825669806, -7.454705693553496e-05, -0.028506407332699688, -0.019410448919897773, 0.016628043076668892, -0.021919345885447767, -0.025777495557084463, -0.07109474005250589, 0.008558118245639078, -0.001084501319327623, -0.004170145412459193, 0.002597723904515014, 0.02588082882692707, 0.00036054089135900344, -0.0029924780721019037, -0.005667068928366136, -0.00865578577410618, -0.00755134665223284, 0.04360270982905648, 0.043180239181769164, 0.010802067061227731, -0.04546942586276935, -0.009077287328542237, -0.034327970338441406, 0.025447825181654542, -0.013500821490744896, -0.022673700495738254, -0.02098800544488838, 0.043098080676868794, -0.01577944844002139, 0.0013863783557353, 0.01263877344044284, 0.0798132810296951, 0.003584825630611725, 0.24755812301614816, 0.28496117941763704, 0.12293523645815227, -0.027820199692610296, 0.03172239007282255, 0.043231308335762116, -0.15969267408727617, -0.11460984923384322, 0.06666869434226143, 0.05813980160781537, -0.08656617373286594, 0.017567076671012465, 0.015586588942499488, 0.0126707739049997, -0.08201458907706444, 0.026152737215532247, 0.06716056018880577, 0.003016526753755206, 0.2571802117164192, 0.15334981464034347, 0.13007630620622296, -0.03482169425353779, 0.0052928752839903675, 0.07684480347923602, -0.05293106136869285, -0.13089813091333863, -0.014548338345269135, -0.016852133384837695, 0.04657224889710973, -0.10714080453299224, 0.02954468681804204, 0.02803804791103171, 0.07512630836742842, -0.002181256318116523, 0.03286078885607531, 0.029776017124485868, -0.036985232140503393, 0.018192377583704317, -0.009238175888963121, 0.037667975121221664, -0.18019637000391522, -0.12105585560557851, 0.03306335766362759, -0.0978819552743932, -0.31368001879304347, 0.046453044021907836, -0.040640549298319235, 0.028632434916810157, -0.032165417373525254, 0.00013977636549288556, 0.0677981722972943, 0.027586349919469733, -0.027070290326466736, 0.0064484852544300164, -0.051055371705650436, -0.029505826996059036, -0.007773656963114285, 0.031696508120973584, 0.044097756443453116, -0.12331173675768826, -0.062449780245263695, -0.03386956914585024, 0.00325850879749234, -0.1957861955555303, 0.03850514213921512, -0.07387028402952978, -0.30547179220115267, -0.09789544553647864, -0.08295376897239157, -0.09485482244864588, -0.00829454241459497, -0.05419594630549541, -0.02608124601507583, 0.04994477121208008, 0.028382121752096704, 0.036060649201983694, 0.04502312679519704, 0.11331463735379678, 0.013308226077693244, 0.1412808977179485, 0.036130914852276354, -0.04317640108448147, 0.1317888249295688, 0.005011945833827478, -0.0674258888899364, -0.25240944863426207, 0.07680636506326745, -0.1457788576799806, 0.025298011400325997, 0.03179463047760283, 0.03701545875170609, -0.030771327280294313, -0.006753042489316663, -0.05655872703910011, 0.03168849485983838, 0.09535097948724941, 0.01119848610364255, 0.043933785763484605, 0.053066549680989034, -0.1022183908042909, 0.03073123773801467, 0.05683161844302711, 0.02324893279951825, -0.20949137701310214, 0.08753703025197193, 0.08502743796545402, 0.11993612576020438, -0.018322313780350912, 0.0792647674680625, 0.03987771944500674, 0.08115102696190335, -0.041409623271153384, -0.18790945712937496, -0.14193674046735336, 0.09493598839162093, 0.0, 0.002733535114013111, 0.043502014749810844, 0.011116188404419138, -0.009007010107651897, -0.024246645598899427, -0.06690966832292707, 0.05072337293272978, 0.008360154526921747, 0.014031491603085668, 0.0034916421134264385, 0.007527076044686314, 0.00546787359976859, 0.02469969357154198, 0.04262598885900899, -0.0025645880651662557, -0.08326558383047983, 0.04106142659506542, 0.0, 0.0023001904843955225, 0.02690009807144462, -0.01660750726793947, -0.0442788939390051, -0.004262088314474908, -0.018398858412461336, 0.05077412885934176, -0.01727569359262269, -0.02116856446361714, -0.02583083718921103, 0.006918470600378068, -0.03657024721835553, 0.05214159679524781, 0.021307134624789243, 0.03461387646479233, 0.0015255374184626241, 0.0012143915409809963, -4.220824775923613e-05, -0.06205021440935549, 0.0, -0.2005164076035529, -0.14942674636090192, -0.0998264121877149, 0.01708923884234134, 0.19470294358454968, -0.20078391981615126, 0.05050581876153967, -0.0781430448304508, 0.0, 0.030742833225408148, -0.01660442318730058, -0.03482739462093398, 0.0012178472759798492, 0.0, 0.0, 0.0, 0.019247640440945825, 0.0, 0.05391300357919464, -0.16868346109446888, 0.1437262994490441, -0.043907016330975455, 0.022873801697797354, -0.14073163721459853, 0.04848675603601539, -0.044926802023342025, -0.04722591670042532, 0.08143743305910595, 0.04105565635828828, 0.16864029390808613, 0.022873801697797354, 0.022873801697797354, 0.022873801697797354, 0.012358767801290634, 0.022873801697797354, 0.022873801697797354, 0.0747934358899426, -0.027459490227564526, -0.009385197473198615, -0.09910148044769543, 0.010145365253768433, 0.14228781148051206, 0.06285582276504374, -0.02581000594818955, 0.0, -0.019041126004519404, -0.04354675351504377, -0.0021482416015433385, -0.030276321676117685, 0.0360131865174545, 0.0, 0.0, 0.0, 0.007755572932511067, -0.14778915768366419, 0.027008590901417762, 0.1130487155667161, -0.014314851086261541, 0.00921668491550147, 0.07684828337471519, -0.06162766622005462, 0.07101462976811, 0.021963777690990442, 0.01835114588015032, 0.035766599754608046, 0.1438113302006597, 0.02944468449199043, -0.15155567315096075, 0.00921668491550147, 0.00921668491550147, 0.004979796112305482, 0.00921668491550147, -0.24679461938850197, -0.003644005877404959, -0.08997702590501673, -0.09561748452959498, -0.010759836160285361, 0.29083584991272227, 0.10243384616814558, -0.033233058206710256, -0.03524827732112034, 0.093493164930961, 0.08888425763146804, -0.09874764000887767, -0.012236832608212416, 0.009731078532636736, -0.08758072367116448, -0.16872873897630083, 0.06293650615547572, 0.5481291718981843, -0.4031196692576848, 0.3634152922871679, 0.31066355720875155, 0.017213262269086153, -0.33374862692288154, 0.1925356849179622, -0.12376953394558599, 0.18150387499420184, 0.10649484024214091, 0.05481784959173895, 0.09836915879837864, 0.0005820053732746545, 0.19610946707109653, -0.07449789862894246, 0.3967885303111857, 0.27022923008542327, 0.058613734608145104, -0.39855315338038255, -0.025058125892537104, 0.10019129031310162, 0.2795666884387779, 0.30808538264741564, 0.06372330349746459, -0.006654809286403457, -0.013452827315862784, 0.0918060172661081, 0.06594754652000737, -0.10823375598692749, 0.031153112294660962, -0.06430680544631162, 0.010274857972767665, -0.12573827141694804, -0.023106364284111544, -0.22368683536784564, -0.23515889813310845, -0.08060772770242307, 0.13368515146463222, -0.03297797661224561, -0.09537101769356657, 0.03994245935053071, -0.24924276964561043, -0.15454366428444885, -0.004988403463861493, -0.13086364116948832, -0.009628970385086316, -0.10645864333807141, -0.024414437296709376, 0.009813226873127441, 0.09944901717680242, 0.043380950772719244, -0.28548875724966943, -0.11925837498495268, -0.03556378840314326, 0.17115113658109227, 0.1845430708029888, -0.013397312597287023, -0.0723490158286244, 0.027077963895942473, -0.07994419087111562, -0.02399300233202436, -0.10415457830524188, -0.04893685269038831, 0.023708548476586235, 0.11319372082980665, -0.08400102937343469, -0.007897371360388879, 0.4884270925582387, -0.18386574649028428, 0.17435916470930699, 0.30418683592087264, -0.03819047017231294, -0.289375499362056, 0.06417339750515609, -0.07403863775598332, -0.03182083284236713, -0.05039612866753495, 0.07314464798756509, -0.02326718672293286, -0.04403638112890262, 0.06810615804085256, -0.025414609662796385, 0.2525216240862521, 0.21094600622750523, -0.07097964402700534, 0.14344841306972067, 0.08037220701614273, -0.014258182628453067, 0.027180227646816376, -0.08064322425116202, 0.011071640940029005, -0.07422640136654333, -0.08187809224763325, -0.0696751946049071, -0.049566410600461984, -0.015256771277733375, 0.041372082161507735, 0.014380098638282454, -0.08339102034824385, 0.004678342936141683, 0.06034274369534799, -0.20398572773315485, 0.03346789078575346, 0.08448970931370499, 0.10360102751035744, -0.17327617270709925, 0.04250960159315527, -0.07871822720125946, 0.03168733889348665, -0.1098972162364184, -0.055922029313491606, -0.016213394143880407, 0.0008549082254816074, -0.014612232110780634, -0.0137641654393248, 0.039648440297828366, 0.10967557693349332, 0.053924422889471316, -0.24825428730348062, -0.25083841403812834, 0.023762013342990922, 0.0022743389774641442, -0.09895554908704991, -0.07656552597466369, -0.20515694285521419, -0.07885621324349082, -0.04475261173099617, 0.020455298666018477, 0.16383687768308702, 0.008537030491324375, -0.3263170456759125, -0.03221211883776766, -0.19606626504288147, 0.14661827963102833, 0.06970887806075653, 0.04978025041362037, -0.015540695573607351, 0.025041763102846493, 0.36181736406462656, 0.03794054976889716, 0.009207690708588538, -0.051774088856603735, 0.12954835696425687, -0.002025246188377725, 0.03065587105678505, -0.01910486718702892, 0.06573582600619128, -0.006595075136553358, 0.0085895435708305, 0.06920135172697703, -0.07077048060887235, 0.07404261517254992, 0.00213458818994969, -0.026622064329858725, 0.028156689422497286, -0.01634201048206037, -0.021440780819765002, 0.0019213589609921673, -0.1689538427412878, 0.11973106485266785, -0.10074019798021928, 0.07363790713265034, 0.005654149165750255, 0.1555212498570646, -0.07631656828848457, -0.09957934532165916, -0.06445570441838624, -0.015028627238972537, -0.011255919093238914, -0.02178150579571202, 0.014157193069523486, -0.07889041958578788, -0.10583429548339936, 0.22652236104595674, -0.024223753024859462, -0.021889877986380372, -0.17959916178528182, 0.1397713646529108, -0.009449044849304441, 0.10551325626351296, 0.03354345243347617, 0.12334192881706613, 0.024555755640642357, 0.014886508311545817, -0.0369501045034401, 0.17632221127432, -0.04185104649663768, 0.04249918105180196, 0.05289135108301231, 0.397417044733905, 0.12016065842260486, 0.024212756011152567, 0.039871118163771285, 0.1306391643221717, 0.14972364698915752, -0.017024832163643904, 0.05290388460302188, 0.10578094847424983, 0.15434817341413662, -0.02185746561466995, -0.02960481608665671, -0.024968677310444765, -0.19479297921168215, -0.0006616184235034366, 0.08253087292000301, 0.07685360455707245, 0.01671438730826793, -0.004949949372194439, 0.010120528867834908, -0.06298995748135543, 0.0017318319337784184, 0.06135748393951149, 0.018978604161162885, 0.04605400480303644, -0.006359637644156889, -0.06575510488936771, -0.021718371123120726, 0.00757751128268976, -0.008833793810991652, -0.01723161694062917, 0.07117245652606327, 0.009668949646528159, 0.09547734497827848, 0.04833891165491316, 0.07821589565237123, -0.06687110333815122, 0.01133963829203416, -0.025900381534867095, 0.014236239838760621, -0.004775350591759419, -0.06742741312521551, -0.07132893901368503, 0.0023255569497046356, -0.021731820351383848, 0.01069930481458036, 0.04792602643803743, -0.04889178935123811, 0.009363513504453885, 0.009199423525106292, -0.039307425515044694, 0.13869513252815668, 0.005307146307832181, 0.0, 0.0, 0.09491552349612761, 0.11128718654611927, 0.011758740676323568, 0.014423671625220002, 0.006820327755804305, -0.009912023049155732, 0.025644554484963902, 0.02067915571103017, 0.03432102907260143, -0.14257104669358067, 0.040349345249580455, 0.0696051642607851, 0.012103353713688881, -0.16047460036480082, -0.02887766814551416, 0.14783672716120957, 0.14783672716120957, -0.08816483413707722, 0.021575421382941787, 0.2079519416540272, 0.051758348071577004, -0.06674113170463378, -0.05125594549772451, 0.06562249397219584, 0.025869202149968747, -0.016022555050874264, -0.023697070769941795, -0.029262565041622113, -0.013482133408567003, 0.06878684371016794, -0.01336586922303224, 0.029657458030545505, -0.09487121142054954, 0.04279901197649842, 0.051760536010428064, -0.016282032847518904, -0.01624390002680394, 0.027088013949487785, -0.008583502364763036, 0.040673227293763735, 0.015581436276797797, -0.008694982647149465, 0.01419058861560246, -0.05129024525405726, 0.005847898283847335, -0.027670648053899575, 0.02169878116869464, -0.041521305283077625, -0.04069115963070253, -0.013972140331585082, -0.03664332956549042, -0.01668795715477744, 0.12473207713398804, -0.11341097977708085, 0.04578968064295831, 0.0036251836047938527, 0.016726556432396372, 0.06982405735883698, -0.04375767308893363, -0.045396997734330286, -0.010507243836129776, -0.0010684630884352907, 0.0029421552186152697, -0.09241878090359193, 0.029465045159794482, 0.1022928864951515, 0.0048887020522475405, 0.017274029178449786, -0.1294934097903273, -0.019211054682330145, -0.030339058738509745, -0.06322650057569365, 0.0074916111410408956, -0.029306242049970243, -0.0018076829760559593, 0.011613161984500745, -0.016447378724397384, -0.011138911406053902, -0.10750243131664694, 0.0232972313364968, -0.026215986901422173, 0.004732380611276889, 0.018977226370426845, -0.026632797254629137, 0.014853359670983286, -0.006458902588012046, -0.023618677400370023, 0.13088510070707618, 0.02860674038410909, 0.04100387008676423, -0.02286442237572992, -0.03788695366492616, 0.052561929847997346, -0.019017139794251518, 0.0894472172549519, -0.0004344025407158048, 0.0010247831466642695, -0.025476646217163456, 0.31646142589887516, -0.25464751980451633, -0.08515523700145396, -0.026241626758538594, 0.04412268773555134, 0.20257707708036735, 0.044258829081575954, -0.12589970678456203, -0.07932048129298883, -0.08999903654235446, -0.044826284394636845, 0.06573531045003773, -0.0144534583574492, -0.027322583437023314, -0.03130544526369554, 0.03493570316874078, 0.049413500014619575, 0.026742453254821573, 0.02043163179770762, -0.011447337838467481, 0.0, 0.03030405152491291, 0.07510778792960654, 0.01919147618774618, -0.006621020712517427, -0.01899613658395026, -0.011666061221967715, -0.00474956134348689, -0.049083269911932674, 4.657268086176317e-05, -0.0014268291070105008, 0.010528162968799907, 0.011244743034335173, -0.03786559892380961, -0.0031646634255997463, -0.005856524603728285, 0.007290043266007071, -0.0033716859372763543, -0.01992686473769063, -0.03178032603435096, -0.0065567546708905754, 0.0, 0.0, 0.018639278793212827, 0.03588052134616851, 0.02596166235581439, -0.08523165790251558, -0.009379571072350277, -0.13952188682937972, -0.06509362813873604, 0.023756195419512526, 0.015795337293188903, -0.0071082460508350275, -0.03527507038784272, -0.005184514909385065, -0.01960300712197206, 0.022432837212423548, -0.047293045709524424, 0.08619737351507845, 0.08619737351507845, -0.19204334838701292, 0.0, 0.0, -0.026362294378987502, -0.054294707016437786, 0.0048184609051364746, -0.042288460411455425, -0.04012120076438216, -0.022726980314200295, -0.02371353262015413, -0.0014767914709746257, 0.026308285654373344, -0.02896852574983715, 0.06626617564396646, 0.06854879956770522, 0.12778368038050397, -0.03166080489971032, -0.02978057900864301, -0.01195324180253799, -0.01195324180253799, 0.018626108570973892, 0.0, 0.0453662282157371, -0.14194223812796625, 0.12094151074018226, -0.03694648027200212, 0.019247640440945825, -0.11842158936059376, 0.04080019835176612, -0.037804600342851886, -0.03973923863436229, 0.06852723699647291, 0.03454714358774346, 0.14190591419313106, 0.019247640440945825, 0.019247640440945825, 0.019247640440945825, 0.01039954451276403, 0.019247640440945825, 0.019247640440945825, -0.12436028806000256, 0.022726945584089706, 0.09512721401919268, -0.012045531840935075, 0.007755572932511067, 0.06466560069211798, -0.051857892985601936, 0.05975675044673969, 0.018481881643739508, 0.015441956796123346, 0.030096555918739623, 0.12101306165048371, 0.024776847656832676, -0.12752970153956233, 0.007755572932511067, 0.007755572932511067, 0.004190353938764285, 0.007755572932511067, -0.0260175085808357, -0.05461718355893401, -0.3044593927311357, -0.16362453625770015, -0.002739235981540115, -0.05637406639356998, -0.004502518535956583, -0.14295507882526984, 0.036656841842191895, -0.09907636873151626, -0.01941986162161012, -0.02241433851680251, -0.10290995260475452, -0.17860900187624984, -0.10869196932539349, -0.015969950891444436, -0.11560851982578027, -0.05615489853948229, 0.08723160402700803, -0.05015302379366111, 0.040648457536661124, 0.15559685898425113, 0.010748338979581029, -0.011548009539377199, -0.02054597274902739, 0.08102287398832242, 0.06221067416118408, -0.07403648421844508, -0.02093350459641128, -0.0946864446550252, 0.03296359381663998, -0.0569955912458018, -0.15958920821061845, -0.078979648991116, 0.07778760427679861, -0.13826369688439727, 0.16684497835769913, 0.013475254469988665, 0.08716206175437495, 0.030360254050106564, -0.09469686973019444, 0.08932671535056191, 0.09911322857606174, -0.046933708925787014, 0.02772005897769178, 0.02837448567803362, 0.042476900250275965, 0.05400132875274487, 0.04570340830854614, -0.0687718463203423, 0.07631188388807757, -0.043345653653779126, 0.026053040474198964, -0.07567722323590645, 0.03213130926487396, -0.12134702677491793, -0.02385329877264947, -0.09538762158612794, 0.006548146059150682, -0.0390455334242438, -0.052648983551607464, -0.07872220350809091, 0.06225359088769512, -0.014214510042144142, -0.048186508421908074, -0.01151734843593039, 0.03408072901823199, 0.09829730607049575, 0.1122508741785472, -0.002284052274346486, -0.03657692210676226, -0.009640613867486617, -0.002862103207732163, 0.0034160418855789195, -0.0016958277800898103, 0.13744374225737835, 0.12956464509408483, -0.09033990423517552, 0.0017776419743839242, 0.10709703728553202, 0.05587200943477678, 0.04427782201011278, -0.10588320577643205, -0.06693145044872077, 0.09228502199796781, -0.007735965786332623, -0.1202537465918832, -0.07983130058089125, -0.09935521558678541, 0.05681300372036922, 0.07351438372136035, -0.009227325600247657, 0.045728381257034385, 0.0030176171866757883, -0.04575148657198219, -0.04540923228515526, 0.005210113716622661, 0.11811018163544827, -0.03359426394774805, -0.0042531568202319025, -0.06166999441943469, 0.05599739527847818, -0.03219313618086765, 0.0053081841174363814, -0.1158248271061419, -0.003167467124574752, -0.0564654759734514, -0.0013084278270323983, 0.06476435091870113, 0.029577341774073042, 0.1263900798772922, 0.03285119786604272, 0.1342091042343228, -0.1387806035759304, 0.022136889468936678, -0.009669453500485417, -0.017183254929171866, 0.021126455682816132, 0.20901941842634383, -0.01398847189745519, -0.1629606691056826, -0.008070661027421783, 0.006781910327373049, 0.03578253107487891, -0.028249458957329717, 0.059195735029026274, -0.06314572785812621, 0.1025828975337944, -0.0035041087388625863, 0.002661849894978878, 0.0050932170198235895, -0.06319224594116173, -0.00628687474876506, -0.19068159087287015, 0.06925777580549931, -0.002616411780708895, 0.008797405583277983, 0.07915849907213142, -0.12706071579457956, 0.10661518352501373, 0.0647837568764885, 0.061763137207587145, 0.03834649348425925, -0.12868688526280891, -0.03514316282065014, 0.04504166455595559, -0.0373047817083413, -0.10782406503921686, -0.10154591994590548, 0.014691376665599862, -0.04516943110805597, -0.05309636087918374, -0.0468182157858726, -0.011241763812914975, 0.19797526921817302, 0.06601519332216674, -0.04832943164739122, 0.10321600019336877, -0.08647821487728075, -0.13774620492142797, 0.07766277358191734, 0.13379772984986826, -0.10297980438087927, -0.17598633089807328, -0.0005048977729174961, 0.01337664040630627, -0.013757306317115536, 0.06517560209411426, -0.0028357785812421037, 0.07151552522387951, 0.040321782943748656, -0.07811001201219386, 0.014970789446854491, -0.1010343473298984, 0.018125248593217323, 0.06723576554017571, -0.03569281306421115, -0.07331560533310898, 0.008527529711385748, -0.07407455168471634, -0.013788459189112878, -0.024140579647612986, -0.015058085095765034, -0.007660948128662173, 0.015625772125446816, -0.0044182018921559395, 0.025388936482329018, 0.07151552522388108, 0.040321782943849915, 0.07811001201219722, -0.014970789446855069, 0.10103434732989774, -0.01812524859321612, 0.19878549273185392, -0.06723576554017446, 0.0733156053331123, -0.023965601388981003, -0.008527529711387753, 0.01378845918911312, 0.02414057964761201, 0.01505808509576557, -0.0031071820755258894, -0.011217847955922264, -0.025388936482328425, -0.07151552522354009, -0.040321782943611634, -0.0812451895789809, -0.05335824523826379, -0.08549834592579833, 0.0029710231817100838, -0.0375925937509511, -0.13790384590857527, 0.08620873965249008, 0.15669897643904818, 0.04307231383743708, -0.12954377313379287, 0.13246217086725653, -0.0062470552733099715, 0.010989860492063139, -0.11255249657730061, 0.021263661868520533, -0.0024594600047082274, 0.023820779714444795, 0.07151552522387951, 0.040321782943748656, -0.3428586467070401, -0.06335269458216841, -0.490738267324018, -0.04404317119236691, 0.054256852565768295, 0.13948914939184826, 0.09990681860335321, 0.004277484117364523, 0.1287883771147364, -0.0143057438390212, 0.058935045441565755, 0.0017772265283297608, -0.0035467709131289014, 0.018017432210953616, 0.030494011320987082, -0.0011948697158381649, 0.07151552522376337, 0.04032178294374156, 0.690674352211244, 0.12252396100111633, 0.036259043511949196, 0.09422298753300104, -0.2427441046686153, 0.013576658657567733, 0.06212935835231884, 0.010267575684985398, -0.1115088068319613, -0.09736155227758334, -0.04558575031484687, -0.03996659342074583, -0.0014692907752585245, 0.02901351758336788, 0.006352298256041709, -0.07151552522372549, -0.0403217829437984, -0.08124518957895789, -0.8939229818804559, 0.03209048661329847, 0.03209048661329847, 0.03209048661329847, 0.03209048661329847, 0.022873801697797354, 0.022873801697797354, 0.022873801697797354, 0.03209048661329847, 0.03209048661329847, 0.03209048661329847]

len(w8418)

1605

In [38]:
metrics.accuracy(y_train, predict(tx_train_2))

0.841888

In [39]:
del tx_train_2
del y_train

In [40]:
_, x_test, ids_test, cols_train = data_io.load_csv_data(f'{DATA_FILE_PREFIX}test.csv')

In [45]:
import time

batch_size = 100000
current_ind = 0
pred = []
while current_ind <= x_test.shape[0]:
  x_test_batch = x_test[current_ind: min(current_ind + batch_size, x_test.shape[0])]
  current_ind += batch_size
  tx_test_batch, _ = transformation_pipeline_median_selected_pairwise(x_test_batch, transformation_memory=memory)
  predictions = predict(tx_test_batch)
  pred.append(predictions)
  del tx_test_batch
  time.sleep(1)

In [46]:
prediction_test = np.concatenate(pred)

In [47]:
prediction_test = prediction_test * 2 - 1

In [48]:
prediction_test[:5]

array([-1, -1,  1,  1, -1])

In [49]:
x_test.shape

(568238, 30)

In [50]:
prediction_test.shape

(568238,)

In [51]:
data_io.create_csv_submission(ids_test, prediction_test, '8418acc_memory.csv')