In [60]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [70]:
import polars as pl
from polars import col

## Load and Process Data

In [79]:
articles = pl.read_csv("data/articles.csv", columns=["article_id", "garment_group_name", "index_group_name"])
customers = pl.read_csv("data/customers.csv", columns=["customer_id", "age", "club_member_status"]).filter(col("age").is_not_null())
transactions = pl.read_csv("data/transactions_train.csv", columns=["customer_id", "article_id", "t_dat", "price"])

In [80]:
articles.shape, customers.shape, transactions.shape

((105542, 3), (1356119, 3), (31788324, 4))

### Articles

In [81]:
articles.head(1)

article_id,index_group_name,garment_group_name
i64,str,str
108775015,"""Ladieswear""","""Jersey Basic"""


In [82]:
articles.select([
    col("article_id").n_unique(),
    col("garment_group_name").n_unique(),
    col("index_group_name").n_unique(),
])

article_id,garment_group_name,index_group_name
u32,u32,u32
105542,21,5


### Customers

In [83]:
customers.head(1)

customer_id,club_member_status,age
str,str,i64
"""00000dbacae5ab…","""ACTIVE""",49


In [84]:
customers.select([
    col("customer_id").n_unique(),
    col("club_member_status").n_unique(),
    col("age").n_unique(),
])

customer_id,club_member_status,age
u32,u32,u32
1356119,4,84


### Transactions

In [85]:
transactions.head(1)

t_dat,customer_id,article_id,price
str,str,i64,f64
"""2018-09-20""","""000058a12d5b43…",663713001,0.050831


In [86]:
transactions.select([
    col("t_dat").n_unique(),
    col("customer_id").n_unique(),
    col("article_id").n_unique(),
    col("price").n_unique()
])

t_dat,customer_id,article_id,price
u32,u32,u32,u32
734,1362281,104547,9857


### Create Queries

In [169]:
query = transactions.join(customers, on="customer_id", how="inner").join(articles, on="article_id", how="inner")
query.shape

(31648066, 8)

In [170]:
query.sample(5)

t_dat,customer_id,article_id,price,club_member_status,age,index_group_name,garment_group_name
str,str,i64,f64,str,i64,str,str
"""2019-10-23""","""a575ce13ded7fb…",759847005,0.027441,"""ACTIVE""",29,"""Ladieswear""","""Jersey Fancy"""
"""2020-04-22""","""4e6d535d689fcd…",750424009,0.042356,"""ACTIVE""",54,"""Ladieswear""","""Trousers"""
"""2019-03-26""","""7c698781597861…",732842002,0.06778,"""ACTIVE""",28,"""Divided""","""Trousers Denim…"
"""2020-02-01""","""9ee4797472de94…",788328003,0.035576,"""ACTIVE""",23,"""Ladieswear""","""Knitwear"""
"""2018-12-19""","""e22e3b44689348…",641008016,0.011847,"""ACTIVE""",34,"""Baby/Children""","""Dresses/Skirts…"


## Retrieval Model

In [173]:
all_customer_ids = customers["customer_id"].unique().to_list()

In [174]:
from retrieval import QueryTower, ItemTower

query_model = QueryTower(all_customer_ids)

In [175]:
query_sample = query.sample(100)

query_customer_ids = query_sample["customer_id"]
query_ages = torch.tensor(query_sample["age"].to_list(), dtype=torch.float)

In [176]:
query_model(customer_ids, ages)

tensor([[-0.5613,  0.2989,  0.6827,  ..., -1.5160,  0.6059, -0.0585],
        [ 1.3982,  0.3761, -0.6977,  ..., -0.8727,  0.5480, -0.3728],
        [-0.2185, -0.7900, -0.0162,  ..., -0.6052,  0.0153,  0.5145],
        ...,
        [ 0.4666, -0.0780,  0.1500,  ..., -0.1500, -0.1304, -0.1255],
        [-0.9520, -0.4970,  0.3099,  ..., -0.0451,  0.2399,  0.1738],
        [-0.5519, -0.2644,  0.3391,  ...,  0.5079, -0.8182, -1.2151]],
       grad_fn=<AddmmBackward0>)