In [1]:
%load_ext autoreload
%autoreload 2

In [63]:
from tqdm import tqdm

import polars as pl
from polars import col

import torch
from torch import optim
from torch.utils.data import DataLoader, Dataset

from retrieval import InteractionDataset, TwoTowerModel, QueryTower, ItemTower

In [45]:
import jupyter_black

jupyter_black.load()

In [47]:
USE_MPS = False

In [48]:
if USE_MPS and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    print("MPS device not found or disabled.")
    device = torch.device("cpu")

MPS device not found or disabled.


## Load and Process Data

In [22]:
articles = pl.read_csv(
    "data/articles.csv",
    columns=["article_id", "garment_group_name", "index_group_name"],
)

customers = pl.read_csv(
    "data/customers.csv", columns=["customer_id", "age", "club_member_status"]
).filter(col("age").is_not_null())

transactions = pl.read_csv(
    "data/transactions_train.csv",
    columns=["customer_id", "article_id", "t_dat", "price"],
)

In [23]:
articles.shape, customers.shape, transactions.shape

((105542, 3), (1356119, 3), (31788324, 4))

### Articles

In [24]:
articles.head(1)

article_id,index_group_name,garment_group_name
i64,str,str
108775015,"""Ladieswear""","""Jersey Basic"""


In [25]:
articles.select(
    [
        col("article_id").n_unique(),
        col("garment_group_name").n_unique(),
        col("index_group_name").n_unique(),
    ]
)

article_id,garment_group_name,index_group_name
u32,u32,u32
105542,21,5


### Customers

In [26]:
customers.head(1)

customer_id,club_member_status,age
str,str,i64
"""00000dbacae5ab…","""ACTIVE""",49


In [27]:
customers.select(
    [
        col("customer_id").n_unique(),
        col("club_member_status").n_unique(),
        col("age").n_unique(),
    ]
)

customer_id,club_member_status,age
u32,u32,u32
1356119,4,84


### Transactions

In [28]:
transactions.head(1)

t_dat,customer_id,article_id,price
str,str,i64,f64
"""2018-09-20""","""000058a12d5b43…",663713001,0.050831


In [29]:
transactions.select(
    [
        col("t_dat").n_unique(),
        col("customer_id").n_unique(),
        col("article_id").n_unique(),
        col("price").n_unique(),
    ]
)

t_dat,customer_id,article_id,price
u32,u32,u32,u32
734,1362281,104547,9857


### Create Queries

In [49]:
customer_id_encoding = (
    customers[["customer_id"]].unique().with_row_index(name="encoded_customer_id")
)

article_id_encoding = (
    articles[["article_id"]].unique().with_row_index(name="encoded_article_id")
)

index_group_name_encoding = (
    articles[["index_group_name"]]
    .unique()
    .with_row_index(name="encoded_index_group_name")
)

garment_group_name_encoding = (
    articles[["garment_group_name"]]
    .unique()
    .with_row_index(name="encoded_garment_group_name")
)

In [54]:
num_customer_ids = customer_id_encoding.shape[0]
num_article_ids = article_id_encoding.shape[0]
num_index_group_names = index_group_name_encoding.shape[0]
num_garment_group_names = garment_group_name_encoding.shape[0]

In [51]:
query = (
    transactions.join(customers, on="customer_id", how="inner")
    .join(articles, on="article_id", how="inner")
    .join(customer_id_encoding, on="customer_id", how="left")
    .join(article_id_encoding, on="article_id", how="left")
    .join(index_group_name_encoding, on="index_group_name", how="left")
    .join(garment_group_name_encoding, on="garment_group_name", how="left")
    .sample(10000)
)

query.shape

(10000, 12)

In [52]:
query.sample(5)

t_dat,customer_id,article_id,price,club_member_status,age,index_group_name,garment_group_name,encoded_customer_id,encoded_article_id,encoded_index_group_name,encoded_garment_group_name
str,str,i64,f64,str,i64,str,str,u32,u32,u32,u32
"""2019-07-28""","""1a9280dddd1da2…",766439002,0.050831,"""ACTIVE""",56,"""Ladieswear""","""Dresses Ladies…",1259892,78417,4,13
"""2019-02-04""","""0465527224bb14…",621467001,0.035576,"""ACTIVE""",50,"""Ladieswear""","""Trousers""",191546,80835,4,11
"""2018-12-30""","""af12fc9321af42…",179950001,0.025407,"""ACTIVE""",59,"""Ladieswear""","""Accessories""",784371,65801,4,12
"""2020-06-04""","""3d9c0869c9b8a5…",852672001,0.059305,"""ACTIVE""",48,"""Ladieswear""","""Blouses""",1338757,31545,4,1
"""2020-03-22""","""4eb095226a3a5c…",851400008,0.059305,"""ACTIVE""",35,"""Ladieswear""","""Skirts""",359152,12296,4,18


## Retrieval Model

In [64]:
# fmt: off

query_model = QueryTower(num_customer_ids, device).to(device)
item_model = ItemTower(num_article_ids, num_index_group_names, num_garment_group_names, device).to(device)
model = TwoTowerModel(query_model, item_model, device).to(device)
dataset = InteractionDataset(query, device)
optimizer = optim.Adam(model.parameters())

# fmt: on

In [65]:
dataloader = DataLoader(dataset, batch_size=512, shuffle=True)

In [66]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()

    running_loss = 0.0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        customer_ids, article_ids, ages, index_groups, garment_groups = batch
        loss = model(customer_ids, article_ids, ages, index_groups, garment_groups)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.21it/s]


Epoch [1/100], Loss: 6.244333624839783


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.78it/s]


Epoch [2/100], Loss: 6.222913551330566


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.83it/s]


Epoch [3/100], Loss: 6.212193608283997


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.80it/s]


Epoch [4/100], Loss: 6.205841612815857


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.75it/s]


Epoch [5/100], Loss: 6.1996814727783205


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.80it/s]


Epoch [6/100], Loss: 6.194971656799316


100%|████████████████████████████████████| 20/20 [00:00<00:00, 25.70it/s]


Epoch [7/100], Loss: 6.190672469139099


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.32it/s]


Epoch [8/100], Loss: 6.185712099075317


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.44it/s]


Epoch [9/100], Loss: 6.180507349967956


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.28it/s]


Epoch [10/100], Loss: 6.174689292907715


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.76it/s]


Epoch [11/100], Loss: 6.168560147285461


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.81it/s]


Epoch [12/100], Loss: 6.160663723945618


100%|████████████████████████████████████| 20/20 [00:00<00:00, 25.48it/s]


Epoch [13/100], Loss: 6.152734708786011


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.67it/s]


Epoch [14/100], Loss: 6.142882490158081


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.12it/s]


Epoch [15/100], Loss: 6.132381463050843


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.14it/s]


Epoch [16/100], Loss: 6.119173288345337


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.51it/s]


Epoch [17/100], Loss: 6.105819416046143


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.77it/s]


Epoch [18/100], Loss: 6.090452194213867


100%|████████████████████████████████████| 20/20 [00:00<00:00, 24.98it/s]


Epoch [19/100], Loss: 6.07329363822937


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.97it/s]


Epoch [20/100], Loss: 6.054527974128723


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.42it/s]


Epoch [21/100], Loss: 6.035891532897949


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.48it/s]


Epoch [22/100], Loss: 6.014799189567566


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.03it/s]


Epoch [23/100], Loss: 5.99369068145752


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.38it/s]


Epoch [24/100], Loss: 5.968440818786621


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.16it/s]


Epoch [25/100], Loss: 5.9432446479797365


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.46it/s]


Epoch [26/100], Loss: 5.918284058570862


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.08it/s]


Epoch [27/100], Loss: 5.892315244674682


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.78it/s]


Epoch [28/100], Loss: 5.865921831130981


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.68it/s]


Epoch [29/100], Loss: 5.838477444648743


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.62it/s]


Epoch [30/100], Loss: 5.808508586883545


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.25it/s]


Epoch [31/100], Loss: 5.781992173194885


 70%|█████████████████████████▏          | 14/20 [00:00<00:00, 23.95it/s]


KeyboardInterrupt: 