In [1]:
%load_ext autoreload
%autoreload 2

In [113]:
from tqdm import tqdm

import polars as pl
from polars import col

import torch
from torch import optim
from torch.utils.data import DataLoader, Dataset

from retrieval import InteractionDataset, TwoTowerModel, QueryTower, ItemTower

In [45]:
import jupyter_black

jupyter_black.load()

In [47]:
USE_MPS = False

In [48]:
if USE_MPS and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    print("MPS device not found or disabled.")
    device = torch.device("cpu")

MPS device not found or disabled.


## Load and Process Data

In [22]:
articles = pl.read_csv(
    "data/articles.csv",
    columns=["article_id", "garment_group_name", "index_group_name"],
)

customers = pl.read_csv(
    "data/customers.csv", columns=["customer_id", "age", "club_member_status"]
).filter(col("age").is_not_null())

transactions = pl.read_csv(
    "data/transactions_train.csv",
    columns=["customer_id", "article_id", "t_dat", "price"],
)

In [23]:
articles.shape, customers.shape, transactions.shape

((105542, 3), (1356119, 3), (31788324, 4))

### Articles

In [24]:
articles.head(1)

article_id,index_group_name,garment_group_name
i64,str,str
108775015,"""Ladieswear""","""Jersey Basic"""


In [25]:
articles.select(
    [
        col("article_id").n_unique(),
        col("garment_group_name").n_unique(),
        col("index_group_name").n_unique(),
    ]
)

article_id,garment_group_name,index_group_name
u32,u32,u32
105542,21,5


### Customers

In [26]:
customers.head(1)

customer_id,club_member_status,age
str,str,i64
"""00000dbacae5ab…","""ACTIVE""",49


In [27]:
customers.select(
    [
        col("customer_id").n_unique(),
        col("club_member_status").n_unique(),
        col("age").n_unique(),
    ]
)

customer_id,club_member_status,age
u32,u32,u32
1356119,4,84


### Transactions

In [28]:
transactions.head(1)

t_dat,customer_id,article_id,price
str,str,i64,f64
"""2018-09-20""","""000058a12d5b43…",663713001,0.050831


In [29]:
transactions.select(
    [
        col("t_dat").n_unique(),
        col("customer_id").n_unique(),
        col("article_id").n_unique(),
        col("price").n_unique(),
    ]
)

t_dat,customer_id,article_id,price
u32,u32,u32,u32
734,1362281,104547,9857


### Create Queries

In [49]:
customer_id_encoding = (
    customers[["customer_id"]].unique().with_row_index(name="encoded_customer_id")
)

article_id_encoding = (
    articles[["article_id"]].unique().with_row_index(name="encoded_article_id")
)

index_group_name_encoding = (
    articles[["index_group_name"]]
    .unique()
    .with_row_index(name="encoded_index_group_name")
)

garment_group_name_encoding = (
    articles[["garment_group_name"]]
    .unique()
    .with_row_index(name="encoded_garment_group_name")
)

In [114]:
num_customer_ids = customer_id_encoding.shape[0]
num_article_ids = article_id_encoding.shape[0]
num_index_group_names = index_group_name_encoding.shape[0]
num_garment_group_names = garment_group_name_encoding.shape[0]

In [115]:
query = (
    transactions.join(customers, on="customer_id", how="inner")
    .join(articles, on="article_id", how="inner")
    .join(customer_id_encoding, on="customer_id", how="left")
    .join(article_id_encoding, on="article_id", how="left")
    .join(index_group_name_encoding, on="index_group_name", how="left")
    .join(garment_group_name_encoding, on="garment_group_name", how="left")
    .sample(10000)
)

query.shape

(10000, 12)

In [116]:
query.sample(5)

t_dat,customer_id,article_id,price,club_member_status,age,index_group_name,garment_group_name,encoded_customer_id,encoded_article_id,encoded_index_group_name,encoded_garment_group_name
str,str,i64,f64,str,i64,str,str,u32,u32,u32,u32
"""2020-09-07""","""51176ae62fe2e2…",828991003,0.033203,"""ACTIVE""",28,"""Ladieswear""","""Dresses Ladies…",991468,6659,4,13
"""2020-07-12""","""a5c972f559812a…",818768001,0.033881,"""ACTIVE""",24,"""Divided""","""Jersey Basic""",983166,96842,1,16
"""2018-12-23""","""d9d2b8aba219f7…",692586001,0.042356,"""ACTIVE""",49,"""Divided""","""Outdoor""",8521,23564,1,5
"""2019-05-26""","""29e76dcc1f2d48…",675270001,0.030492,"""ACTIVE""",28,"""Divided""","""Unknown""",444312,102269,1,4
"""2019-02-22""","""73dffe80e941a9…",611146010,0.005746,"""ACTIVE""",30,"""Baby/Children""","""Knitwear""",672413,46226,3,10


## Retrieval Model

In [117]:
# fmt: off

query_model = QueryTower(num_customer_ids, device).to(device)
item_model = ItemTower(num_article_ids, num_index_group_names, num_garment_group_names, device).to(device)
model = TwoTowerModel(query_model, item_model, device).to(device)
dataset = InteractionDataset(query, device)
optimizer = optim.Adam(model.parameters())

# fmt: on

In [118]:
dataloader = DataLoader(dataset, batch_size=512, shuffle=True)

In [119]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()

    running_loss = 0.0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        customer_ids, article_ids, ages, index_groups, garment_groups = batch        
        loss = model(customer_ids, article_ids, ages, index_groups, garment_groups)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")

100%|████████████████████████████████████| 20/20 [00:00<00:00, 24.53it/s]


Epoch [1/100], Loss: 6.245253396034241


100%|████████████████████████████████████| 20/20 [00:00<00:00, 28.55it/s]


Epoch [2/100], Loss: 6.2249671697616575


100%|████████████████████████████████████| 20/20 [00:00<00:00, 28.72it/s]


Epoch [3/100], Loss: 6.215025877952575


100%|████████████████████████████████████| 20/20 [00:00<00:00, 28.81it/s]


Epoch [4/100], Loss: 6.2080933332443236


100%|████████████████████████████████████| 20/20 [00:00<00:00, 28.94it/s]


Epoch [5/100], Loss: 6.202776575088501


100%|████████████████████████████████████| 20/20 [00:00<00:00, 28.91it/s]


Epoch [6/100], Loss: 6.198720550537109


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.80it/s]


Epoch [7/100], Loss: 6.194621324539185


100%|████████████████████████████████████| 20/20 [00:00<00:00, 28.90it/s]


Epoch [8/100], Loss: 6.19010009765625


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.42it/s]


Epoch [9/100], Loss: 6.185738158226013


100%|████████████████████████████████████| 20/20 [00:00<00:00, 28.66it/s]


Epoch [10/100], Loss: 6.180938386917115


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.55it/s]


Epoch [11/100], Loss: 6.175008034706115


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.79it/s]


Epoch [12/100], Loss: 6.168134808540344


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.62it/s]


Epoch [13/100], Loss: 6.160735607147217


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.19it/s]


Epoch [14/100], Loss: 6.151558685302734


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.22it/s]


Epoch [15/100], Loss: 6.141666531562805


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.68it/s]


Epoch [16/100], Loss: 6.130298495292664


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.49it/s]


Epoch [17/100], Loss: 6.1165199995040895


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.52it/s]


Epoch [18/100], Loss: 6.102178049087525


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.15it/s]


Epoch [19/100], Loss: 6.085036396980286


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.54it/s]


Epoch [20/100], Loss: 6.06866819858551


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.75it/s]


Epoch [21/100], Loss: 6.047522163391113


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.42it/s]


Epoch [22/100], Loss: 6.027482438087463


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.91it/s]


Epoch [23/100], Loss: 6.005876755714416


100%|████████████████████████████████████| 20/20 [00:00<00:00, 27.92it/s]


Epoch [24/100], Loss: 5.983120465278626


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.87it/s]


Epoch [25/100], Loss: 5.960227704048156


100%|████████████████████████████████████| 20/20 [00:00<00:00, 26.96it/s]


Epoch [26/100], Loss: 5.933237385749817


 20%|███████▍                             | 4/20 [00:00<00:00, 24.97it/s]


KeyboardInterrupt: 