In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from tqdm import tqdm

import polars as pl
from polars import col

import torch
from torch import optim
from torch.utils.data import DataLoader, Dataset

from retrieval import InteractionDataset, TwoTowerModel, QueryTower, ItemTower

In [3]:
USE_MPS = True

In [4]:
if USE_MPS and torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    print ("MPS device not found or disabled.")
    device = torch.device("cpu")

## Load and Process Data

In [5]:
articles = pl.read_csv("data/articles.csv", columns=["article_id", "garment_group_name", "index_group_name"])
customers = pl.read_csv("data/customers.csv", columns=["customer_id", "age", "club_member_status"]).filter(col("age").is_not_null())
transactions = pl.read_csv("data/transactions_train.csv", columns=["customer_id", "article_id", "t_dat", "price"])

In [6]:
articles.shape, customers.shape, transactions.shape

((105542, 3), (1356119, 3), (31788324, 4))

### Articles

In [7]:
articles.head(1)

article_id,index_group_name,garment_group_name
i64,str,str
108775015,"""Ladieswear""","""Jersey Basic"""


In [8]:
articles.select([
    col("article_id").n_unique(),
    col("garment_group_name").n_unique(),
    col("index_group_name").n_unique(),
])

article_id,garment_group_name,index_group_name
u32,u32,u32
105542,21,5


### Customers

In [9]:
customers.head(1)

customer_id,club_member_status,age
str,str,i64
"""00000dbacae5ab…","""ACTIVE""",49


In [10]:
customers.select([
    col("customer_id").n_unique(),
    col("club_member_status").n_unique(),
    col("age").n_unique(),
])

customer_id,club_member_status,age
u32,u32,u32
1356119,4,84


### Transactions

In [11]:
transactions.head(1)

t_dat,customer_id,article_id,price
str,str,i64,f64
"""2018-09-20""","""000058a12d5b43…",663713001,0.050831


In [12]:
transactions.select([
    col("t_dat").n_unique(),
    col("customer_id").n_unique(),
    col("article_id").n_unique(),
    col("price").n_unique()
])

t_dat,customer_id,article_id,price
u32,u32,u32,u32
734,1362281,104547,9857


### Create Queries

In [None]:
query = transactions.join(customers, on="customer_id", how="inner").join(articles, on="article_id", how="inner").with_columns([
    col("article_id").cast(pl.Utf8)
]).sample(10000)

query.shape

In [14]:
query.sample(5)

t_dat,customer_id,article_id,price,club_member_status,age,index_group_name,garment_group_name
str,str,i64,f64,str,i64,str,str
"""2020-06-20""","""40b2b9014d7423…",866837004,0.022017,"""ACTIVE""",34,"""Ladieswear""","""Jersey Fancy"""
"""2020-01-01""","""ccdea851a6bddf…",717773001,0.013542,"""ACTIVE""",29,"""Ladieswear""","""Accessories"""
"""2020-06-17""","""885af1de025c20…",745219002,0.010831,"""ACTIVE""",20,"""Divided""","""Blouses"""
"""2020-01-08""","""47685766345f2d…",568601007,0.050831,"""ACTIVE""",28,"""Ladieswear""","""Dressed"""
"""2018-10-10""","""cca6a80fc7a233…",657395002,0.027102,"""ACTIVE""",25,"""Ladieswear""","""Blouses"""


## Retrieval Model

In [15]:
all_customer_ids = customers["customer_id"].unique().to_list()
all_item_ids = articles["article_id"].unique().cast(pl.Utf8).to_list()
all_index_groups = articles["index_group_name"].unique().to_list()
all_garment_groups = articles["garment_group_name"].unique().to_list()

In [16]:
query_model = QueryTower(all_customer_ids, device).to(device)
item_model = ItemTower(all_item_ids, all_index_groups, all_garment_groups, device).to(device)
model = TwoTowerModel(query_model, item_model, mps_device).to(device)
dataset = InteractionDataset(query, device)
optimizer = optim.Adam(model.parameters())

NameError: name 'mps_device' is not defined

In [None]:
dataloader = DataLoader(dataset, batch_size=512, shuffle=True)

In [None]:
num_epochs = 100

for epoch in range(num_epochs):
    model.train()

    running_loss = 0.0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        customer_ids, article_ids, ages, index_groups, garment_groups = batch
        loss = model(customer_ids, article_ids, ages, index_groups, garment_groups)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader)}")