In [1]:
import pandas as pd, numpy as np
from tqdm.notebook import tqdm
import gc

In [2]:
# ! pip install polars
import polars as pl

In [10]:
submissionPath = '../submissions/'
submissionNotes = 'final_1_ensemble_9models'

In [4]:
subNames = [
    'Xgb_top_100covisit_20_20_20_newSuggester2_drop_12_add_last3CovScore', 
    'Xgb_top_100_final_3', 
    'Xgb_top_100covisit_20_20_20_newSuggester2_addLast3_cfW2vSim_drop_12', 
    'Xgb_top_100_suggester_addLast_8',
    'Xgb_top_100covisit_20_20_20_newSuggester2_covWgt_3x_addLast3_cfW2vSim', 
    'Xgb_top_100_final_4', 
    'Xgb_top_100_suggester_addLast_3',  
    'Xgb_top_100_suggester_addLast', 
    'final_2_ensemble'
    ]

In [5]:
def read_sub(path, weight=1): # by default let us assing the weight of 1 to predictions from each submission, this will be akin to a standard vote ensemble
    '''a helper function for loading and preprocessing submissions'''
    return (
        pl.read_csv(path)
            .with_column(pl.col('labels').str.split(by=' '))
            .with_column(pl.lit(weight).alias('vote'))
            .explode('labels')
            .rename({'labels': 'aid'})
            .with_column(pl.col('aid').cast(pl.UInt32)) # we are casting the `aids` to `Int32`! memory management is super important to ensure we don't run out of resources
            .with_column(pl.col('vote').cast(pl.UInt8))
    )

In [6]:
subs = [read_sub(submissionPath + subName + '.csv') for subName in subNames]
subs[0].head()

session_type,aid,vote
str,u32,u8
"""12899779_click...",59625,1
"""12899779_click...",737445,1
"""12899779_click...",731692,1
"""12899779_click...",1790770,1
"""12899779_click...",1660529,1


In [7]:
subs = subs[0].join(subs[1], how='outer', on=['session_type', 'aid']).join(subs[2], how='outer', \
    on=['session_type', 'aid'], suffix='_right2').join(subs[3], how='outer', on=['session_type', 'aid'], suffix='_right3').join(subs[4], how='outer', \
        on=['session_type', 'aid'], suffix='_right4').join(subs[5], how='outer', on=['session_type', 'aid'], suffix='_right5').join(subs[6], how='outer', \
        on=['session_type', 'aid'], suffix='_right6').join(subs[7], how='outer', on=['session_type', 'aid'], suffix='_right7').join(subs[8], how='outer', on=['session_type', 'aid'], suffix='_right8')
subs.head()

session_type,aid,vote,vote_right,vote_right2,vote_right3,vote_right4,vote_right5,vote_right6,vote_right7,vote_right8
str,u32,u8,u8,u8,u8,u8,u8,u8,u8,u8
"""12899779_click...",59625,1,1,1,1,1,1,1,1,1
"""12899779_click...",737445,1,1,1,1,1,1,1,1,1
"""12899779_click...",731692,1,1,1,1,1,1,1,1,1
"""12899779_click...",1790770,1,1,1,1,1,1,1,1,1
"""12899779_click...",1253524,1,1,1,1,1,1,1,1,1


In [8]:
subs = (subs
    .fill_null(0)
    .with_column((pl.col('vote') + pl.col('vote_right') + pl.col('vote_right2') + pl.col('vote_right3') + pl.col('vote_right4') + pl.col('vote_right5') + pl.col('vote_right6') + pl.col('vote_right7') + pl.col('vote_right8')).alias('vote_sum'))
    .drop(['vote', 'vote_right', 'vote_right2', 'vote_right3', 'vote_right4', 'vote_right5', 'vote_right6', 'vote_right7', 'vote_right8'])
    .sort(by='vote_sum')
    .reverse()
)
subs.head()

session_type,aid,vote_sum
str,u32,u8
"""14571581_order...",940217,9
"""14571581_order...",984794,9
"""14571581_order...",1236674,9
"""14571581_order...",1497245,9
"""14571581_order...",1791780,9


In [9]:
%%time
preds = subs.groupby('session_type').agg([
    pl.col('aid').head(20).alias('labels')
])

preds = preds.with_column(pl.col('labels').apply(lambda lst: ' '.join([str(aid) for aid in lst])))

CPU times: user 2min 27s, sys: 4.22 s, total: 2min 31s
Wall time: 1min 50s


In [11]:
preds.write_csv(submissionPath + f'{submissionNotes}.csv')
preds

session_type,labels
str,str
"""14281146_click...","""1142000 150212..."
"""13254183_carts...","""1476166 243042..."
"""14326586_click...","""171112 1685420..."
"""13964885_click...","""488046 268509 ..."
"""13755465_order...","""596931 1385870..."
"""13183497_carts...","""607652 536592 ..."
"""13636355_click...","""329209 408241 ..."
"""14160771_order...","""1783610 109941..."
"""14240010_carts...","""124383 1532105..."
"""13114197_order...","""897426 1490438..."
