In [1]:
%load_ext autoreload
%autoreload 2

In [16]:
import polars as pl
from polars import col
import torch

## Load and Process Data

In [3]:
articles = pl.read_csv("data/articles.csv", columns=["article_id", "garment_group_name", "index_group_name"])
customers = pl.read_csv("data/customers.csv", columns=["customer_id", "age", "club_member_status"]).filter(col("age").is_not_null())
transactions = pl.read_csv("data/transactions_train.csv", columns=["customer_id", "article_id", "t_dat", "price"])

In [4]:
articles.shape, customers.shape, transactions.shape

((105542, 3), (1356119, 3), (31788324, 4))

### Articles

In [5]:
articles.head(1)

article_id,index_group_name,garment_group_name
i64,str,str
108775015,"""Ladieswear""","""Jersey Basic"""


In [6]:
articles.select([
    col("article_id").n_unique(),
    col("garment_group_name").n_unique(),
    col("index_group_name").n_unique(),
])

article_id,garment_group_name,index_group_name
u32,u32,u32
105542,21,5


### Customers

In [7]:
customers.head(1)

customer_id,club_member_status,age
str,str,i64
"""00000dbacae5ab…","""ACTIVE""",49


In [8]:
customers.select([
    col("customer_id").n_unique(),
    col("club_member_status").n_unique(),
    col("age").n_unique(),
])

customer_id,club_member_status,age
u32,u32,u32
1356119,4,84


### Transactions

In [9]:
transactions.head(1)

t_dat,customer_id,article_id,price
str,str,i64,f64
"""2018-09-20""","""000058a12d5b43…",663713001,0.050831


In [10]:
transactions.select([
    col("t_dat").n_unique(),
    col("customer_id").n_unique(),
    col("article_id").n_unique(),
    col("price").n_unique()
])

t_dat,customer_id,article_id,price
u32,u32,u32,u32
734,1362281,104547,9857


### Create Queries

In [11]:
query = transactions.join(customers, on="customer_id", how="inner").join(articles, on="article_id", how="inner")
query.shape

(31648066, 8)

In [12]:
query.sample(5)

t_dat,customer_id,article_id,price,club_member_status,age,index_group_name,garment_group_name
str,str,i64,f64,str,i64,str,str
"""2020-04-18""","""229010662b2cf3…",612800018,0.033017,"""ACTIVE""",56,"""Ladieswear""","""Blouses"""
"""2019-08-08""","""430ab7dff6da7a…",694119001,0.06778,"""ACTIVE""",27,"""Menswear""","""Outdoor"""
"""2020-07-27""","""1585f1b9e407e6…",717490010,0.008458,"""PRE-CREATE""",33,"""Divided""","""Jersey Basic"""
"""2018-11-07""","""299a10a292c3a6…",399223025,0.028797,"""ACTIVE""",28,"""Divided""","""Trousers Denim…"
"""2020-04-03""","""60c70dc4d2d0ab…",887770002,0.016932,"""ACTIVE""",47,"""Ladieswear""","""Blouses"""


## Retrieval Model

In [13]:
all_customer_ids = customers["customer_id"].unique().to_list()
all_item_ids = articles["article_id"].unique().to_list()
all_index_groups = articles["index_group_name"].unique().to_list()
all_garment_groups = articles["garment_group_name"].unique().to_list()

In [14]:
from retrieval import QueryTower, ItemTower

query_model = QueryTower(all_customer_ids)
item_model = ItemTower(all_item_ids, all_index_groups, all_garment_groups)

In [17]:
query_sample = query.sample(100)

query_customer_ids = query_sample["customer_id"]
query_article_ids = query_sample["article_id"]
query_ages = torch.tensor(query_sample["age"].to_list(), dtype=torch.float)
query_index_groups = query_sample["index_group_name"]
query_garment_groups = query_sample["garment_group_name"]

In [18]:
query_emb = query_model(query_customer_ids, query_ages)
query_emb.shape

torch.Size([100, 10])

In [19]:
item_emb = item_model(query_article_ids, query_index_groups, query_garment_groups)
item_emb.shape

torch.Size([100, 10])

In [30]:
item_emb.t().shape

torch.Size([10, 100])

In [36]:
num_queries = query_emb.shape[0]
num_items = item_emb.shape[0]

In [32]:
scores = torch.matmul(query_emb, item_emb.t())
scores

tensor([[-0.0370, -0.8167, -0.0748,  ..., -0.3836,  0.5784,  0.4594],
        [-0.5915, -1.7951, -0.1887,  ..., -0.7878,  0.2614,  0.3080],
        [-0.2042, -0.2387, -0.4825,  ..., -0.0380,  0.6027,  0.3154],
        ...,
        [-0.0509, -0.1221, -0.0986,  ..., -0.1526, -0.0404, -0.0020],
        [-0.2119, -0.8352, -0.1478,  ..., -0.1535,  0.2651,  0.4234],
        [-0.3182, -0.9747, -0.2806,  ..., -0.2679,  0.3345,  0.4815]],
       grad_fn=<MmBackward0>)

In [47]:
scores.shape

torch.Size([100, 100])

In [38]:
labels = torch.eye(num_queries, num_items)
labels

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 0.,  ..., 0., 0., 0.],
        [0., 0., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 0.],
        [0., 0., 0.,  ..., 0., 0., 1.]])

In [46]:
labels.shape

torch.Size([100, 100])

In [42]:
import torch.nn as nn

t = torch.ones(1,1024,15)
o = torch.randn(1,1024,15)

In [48]:
loss_fn = nn.CrossEntropyLoss()

In [49]:
loss_fn(scores, labels)

tensor(4.6964, grad_fn=<DivBackward1>)

In [44]:
t.shape

torch.Size([1, 1024, 15])

In [45]:
o.shape

torch.Size([1, 1024, 15])

In [41]:
torch.argmax(labels, dim=1)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89,
        90, 91, 92, 93, 94, 95, 96, 97, 98, 99])

In [53]:
nn.LogSoftmax(dim=1)(scores)

tensor([[-4.6404, -5.4200, -4.6782,  ..., -4.9870, -4.0250, -4.1440],
        [-4.8047, -6.0083, -4.4020,  ..., -5.0010, -3.9519, -3.9053],
        [-4.7517, -4.7862, -5.0300,  ..., -4.5854, -3.9448, -4.2321],
        ...,
        [-4.6033, -4.6745, -4.6510,  ..., -4.7050, -4.5928, -4.5544],
        [-4.7564, -5.3798, -4.6924,  ..., -4.6981, -4.2794, -4.1212],
        [-4.9530, -5.6095, -4.9154,  ..., -4.9027, -4.3003, -4.1533]],
       grad_fn=<LogSoftmaxBackward0>)

In [58]:
from retrieval import TwoTowerModel

In [59]:
model = TwoTowerModel(query_model, item_model)

In [60]:
model(query_customer_ids, query_ages, query_article_ids, query_index_groups, query_garment_groups)

tensor(4.6964, grad_fn=<NllLossBackward0>)