In [6]:
%load_ext autoreload
%autoreload 2

In [43]:
import polars as pl
from polars import col
import torch

from torch import optim
from torch.utils.data import DataLoader, Dataset

## Load and Process Data

In [8]:
articles = pl.read_csv("data/articles.csv", columns=["article_id", "garment_group_name", "index_group_name"])
customers = pl.read_csv("data/customers.csv", columns=["customer_id", "age", "club_member_status"]).filter(col("age").is_not_null())
transactions = pl.read_csv("data/transactions_train.csv", columns=["customer_id", "article_id", "t_dat", "price"])

In [9]:
articles.shape, customers.shape, transactions.shape

((105542, 3), (1356119, 3), (31788324, 4))

### Articles

In [10]:
articles.head(1)

article_id,index_group_name,garment_group_name
i64,str,str
108775015,"""Ladieswear""","""Jersey Basic"""


In [11]:
articles.select([
    col("article_id").n_unique(),
    col("garment_group_name").n_unique(),
    col("index_group_name").n_unique(),
])

article_id,garment_group_name,index_group_name
u32,u32,u32
105542,21,5


### Customers

In [12]:
customers.head(1)

customer_id,club_member_status,age
str,str,i64
"""00000dbacae5ab…","""ACTIVE""",49


In [13]:
customers.select([
    col("customer_id").n_unique(),
    col("club_member_status").n_unique(),
    col("age").n_unique(),
])

customer_id,club_member_status,age
u32,u32,u32
1356119,4,84


### Transactions

In [14]:
transactions.head(1)

t_dat,customer_id,article_id,price
str,str,i64,f64
"""2018-09-20""","""000058a12d5b43…",663713001,0.050831


In [15]:
transactions.select([
    col("t_dat").n_unique(),
    col("customer_id").n_unique(),
    col("article_id").n_unique(),
    col("price").n_unique()
])

t_dat,customer_id,article_id,price
u32,u32,u32,u32
734,1362281,104547,9857


### Create Queries

In [16]:
query = transactions.join(customers, on="customer_id", how="inner").join(articles, on="article_id", how="inner")
query.shape

(31648066, 8)

In [17]:
query.sample(5)

t_dat,customer_id,article_id,price,club_member_status,age,index_group_name,garment_group_name
str,str,i64,f64,str,i64,str,str
"""2018-09-26""","""2c72d85b7b7d5b…",572928001,0.06778,"""ACTIVE""",53,"""Sport""","""Jersey Fancy"""
"""2018-10-31""","""db5963e9b5c5b6…",530729006,0.025407,"""ACTIVE""",26,"""Ladieswear""","""Socks and Tigh…"
"""2019-01-31""","""21a77eae717280…",355072002,0.005068,"""ACTIVE""",21,"""Divided""","""Jersey Basic"""
"""2019-05-31""","""9e1d4b82072bca…",772794003,0.050831,"""ACTIVE""",46,"""Ladieswear""","""Dresses Ladies…"
"""2019-03-01""","""10fbd650290677…",729936002,0.050831,"""ACTIVE""",66,"""Ladieswear""","""Blouses"""


## Retrieval Model

In [None]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

In [72]:
query_sample = query.sample(100000)

query_customer_ids = query_sample["customer_id"]
query_article_ids = query_sample["article_id"]
query_ages = torch.tensor(query_sample["age"].to_list(), dtype=torch.float)
query_index_groups = query_sample["index_group_name"]
query_garment_groups = query_sample["garment_group_name"]

In [18]:
query_emb = query_model(query_customer_ids, query_ages)
query_emb.shape

torch.Size([100, 10])

In [19]:
item_emb = item_model(query_article_ids, query_index_groups, query_garment_groups)
item_emb.shape

torch.Size([100, 10])

In [47]:
scores.shape

torch.Size([100, 100])

In [164]:
all_customer_ids = customers["customer_id"].unique().to_list()
all_item_ids = articles["article_id"].unique().cast(pl.Utf8).to_list()
all_index_groups = articles["index_group_name"].unique().to_list()
all_garment_groups = articles["garment_group_name"].unique().to_list()

In [165]:
from retrieval import InteractionDataset, TwoTowerModel, QueryTower, ItemTower

In [166]:
query_model = QueryTower(all_customer_ids)
item_model = ItemTower(all_item_ids, all_index_groups, all_garment_groups)
model = TwoTowerModel(query_model, item_model)
dataset = InteractionDataset(query)
optimizer = optim.Adam(model.parameters())

In [167]:
dataloader = DataLoader(dataset, batch_size=512, shuffle=True)

In [169]:
num_epochs = 5

for epoch in range(num_epochs):
    model.train()

    running_loss = 0.0
    for batch in dataloader:
        optimizer.zero_grad()
        customer_ids, article_ids, ages, index_groups, garment_groups = batch
        loss = model(customer_ids, article_ids, ages, index_groups, garment_groups)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader)}")

TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.float32
TwoTower: torch.

KeyboardInterrupt: 