In [1]:
import polars as pl
paths = ['/root/suzhaopei/otto/submit/submission_v3.csv.gz',  # 0.578
         '/root/suzhaopei/otto/submit/submission_v4.csv.gz', # 0.576
         '/root/suzhaopei/otto/submit/submission_v5.csv.gz' # 0.577
        ]

In [4]:
def read_sub(path, weight=1): # by default let us assing the weight of 1 to predictions from each submission, this will be akin to a standard vote ensemble
    '''a helper function for loading and preprocessing submissions'''
    return (
        pl.read_csv(path)
            .with_column(pl.col('labels').str.split(by=' '))
            .with_column(pl.lit(weight).alias('vote'))
            .explode('labels')
            .rename({'labels': 'aid'})
            .with_column(pl.col('aid').cast(pl.Int32)) # we are casting the `aids` to `Int32`! memory management is super important to ensure we don't run out of resources
            .with_column(pl.col('vote').cast(pl.UInt8))
    )

In [6]:
subs = subs[0].join(subs[1], how='outer', on=['session_type', 'aid']).join(subs[2], how='outer', on=['session_type', 'aid'], suffix='_right2')
subs.head()

session_type,aid,vote,vote_right,vote_right2
str,i32,u8,u8,u8
"""12899779_click...",59625,1,1,1
"""12899779_click...",1253524,1,1,1
"""12899779_click...",737445,1,1,1
"""12899779_click...",731692,1,1,1
"""12899779_click...",1660529,1,1,1


In [7]:
subs = (subs
    .fill_null(0)
    .with_column((pl.col('vote') + pl.col('vote_right') + pl.col('vote_right2')).alias('vote_sum'))
    .drop(['vote', 'vote_right', 'vote_right2'])
    .sort(by='vote_sum')
    .reverse()
)

subs.head()

session_type,aid,vote_sum
str,i32,u8
"""14571581_order...",1547466,3
"""14571581_order...",1556465,3
"""14571581_order...",1464940,3
"""14571581_order...",984794,3
"""14571581_order...",555996,3


In [8]:
%%time
preds = subs.groupby('session_type').agg([
    pl.col('aid').head(20).alias('labels')
])

preds = preds.with_column(pl.col('labels').apply(lambda lst: ' '.join([str(aid) for aid in lst])))

CPU times: user 5min 55s, sys: 9.38 s, total: 6min 4s
Wall time: 5min 51s


In [10]:
preds

session_type,labels
str,str
"""12937866_order...","""495060 1400323..."
"""14510895_order...","""1023972 142105..."
"""12937920_order...","""1241431 147076..."
"""13056703_carts...","""1371633 537923..."
"""13470369_order...","""1370091 109499..."
"""14475991_click...","""976527 1435567..."
"""14084361_click...","""218523 1692937..."
"""14324893_carts...","""249809 1197666..."
"""14410757_carts...","""1578596 137164..."
"""13691791_order...","""258957 838580 ..."


In [11]:
preds.write_csv('submission.csv')