In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir('..')

In [3]:
from typing import Callable

import numpy as np
import polars as pl
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

In [4]:
relations = pl.read_csv(".data/transactions_train.csv", try_parse_dates=True)
users = pl.read_csv(".data/customers.csv")
items = pl.read_csv(".data/articles.csv")
sample_submission = pl.read_csv(".data/sample_submission.csv")

In [5]:
relations = relations.sort("t_dat")

In [6]:
relations = relations.with_columns(
    customer_id_fact=pl.col("customer_id").cast(pl.Categorical).to_physical(),
    article_id_fact=pl.col("article_id").cast(pl.String).cast(pl.Categorical).to_physical(),
)

assert relations.select("customer_id_fact").n_unique() == (relations.get_column("customer_id_fact").max() + 1)
assert relations.select("article_id_fact").n_unique() == (relations.get_column("article_id_fact").max() + 1)

In [7]:
relations = relations.with_columns(
    pl.col("t_dat").dt.day().alias("day"),
    pl.col("t_dat").dt.month().alias("month"),
    pl.col("t_dat").dt.year().alias("year"),    
)
lasts = relations.select("month", "year").max()

In [8]:
relations.group_by("year", "month", maintain_order=True).agg(pl.len())

year,month,len
i32,i8,u32
2018,9,594776
2018,10,1397040
2018,11,1270619
2018,12,1148827
2019,1,1263471
…,…,…
2020,5,1361815
2020,6,1764507
2020,7,1351502
2020,8,1237192


In [9]:
train_df = relations.filter(pl.col("t_dat") < pl.date(2020, 7, 1))
valid_df = relations.filter(pl.col("t_dat") >= pl.date(2020, 7, 1))

In [10]:
train_df

t_dat,customer_id,article_id,price,sales_channel_id,customer_id_fact,article_id_fact,day,month,year
date,str,i64,f64,i64,u32,u32,i8,i8,i32
2018-09-20,"""000058a12d5b43e67d225668fa1f8d…",663713001,0.050831,2,0,0,20,9,2018
2018-09-20,"""000058a12d5b43e67d225668fa1f8d…",541518023,0.030492,2,0,1,20,9,2018
2018-09-20,"""00007d2de826758b65a93dd24ce629…",505221004,0.015237,2,1,2,20,9,2018
2018-09-20,"""00007d2de826758b65a93dd24ce629…",685687003,0.016932,2,1,3,20,9,2018
2018-09-20,"""00007d2de826758b65a93dd24ce629…",685687004,0.016932,2,1,4,20,9,2018
…,…,…,…,…,…,…,…,…,…
2020-06-30,"""fffb2ba21d4a2f5938d5b955662d81…",851010006,0.016932,1,1128213,92319,30,6,2020
2020-06-30,"""fffb2ba21d4a2f5938d5b955662d81…",880238002,0.016932,1,1128213,95295,30,6,2020
2020-06-30,"""fffb2ba21d4a2f5938d5b955662d81…",780297002,0.025407,1,1128213,63616,30,6,2020
2020-06-30,"""fffb2ba21d4a2f5938d5b955662d81…",878794001,0.025407,1,1128213,90777,30,6,2020


In [15]:
valid_df.group_by("customer_id_fact", maintain_order=True).agg(pl.col("article_id_fact"), pl.len())

customer_id_fact,article_id_fact,len
u32,list[u32],u32
140340,"[87006, 88658, … 63952]",80
3,"[89468, 86624, … 100494]",10
656209,"[51711, 51711, … 51711]",5
13990,"[91741, 798, … 102717]",61
27154,"[21883, 59769, … 103192]",13
…,…,…
1362280,"[102215, 102661, 68956]",3
347946,"[99220, 99220]",2
1147080,"[85573, 79711, … 96030]",6
957020,"[64373, 101668, … 97478]",5
