In [1]:
cd ../

/Users/hoangle/Projects/recsys


In [2]:
from typing import Literal
from pathlib import Path
from datetime import datetime
import sys

import polars as pl
import torch
import torch.nn as nn
# import numpy as np
import torch.nn.functional as F
import lightning as L
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.nn import Module
from torch import Tensor
from polars import DataFrame
from loguru import logger
from lightning.pytorch.callbacks import RichProgressBar, ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
# from sklearn import metrics

In [3]:
logger.remove()
logger.add(sys.stdout, level="INFO")

1

# Read splits

In [4]:
path = Path("data/processed/ml-1m/train_temporal-loo.parquet")
train = pl.read_parquet(path)

path = Path("data/processed/ml-1m/test_temporal-loo.parquet")
test = pl.read_parquet(path)

path = Path("data/processed/ml-1m/ml-1m_items.parquet")
items = pl.read_parquet(path)

In [5]:
test.head()

user_id,item_id,is_positive,genre_id
i64,i64,bool,list[u32]
1,48,True,"[0, 1, … 12]"
2,1687,True,"[15, 7]"
2,434,True,"[15, 9, 6]"
2,1544,True,"[15, 9, … 7]"
2,1917,True,"[15, 9, … 7]"


In [6]:
items.head()

item_id,genre_id
i64,list[u32]
545,[2]
2388,"[2, 12]"
3695,"[5, 8]"
1837,[5]
950,[14]


# Define data loader

In [7]:
MAX_N_GENRES = 9
DEVICE = torch.device("mps")
BSZ = 50
N_EPOCHS = 20

N_USERS = 6041
N_ITEMS = 3953
N_GENRES = 19
MAX_N_GENRES = 6

In [8]:
class DatasetRS(Dataset):
    def __init__(
        self,
        inters: DataFrame,
        genre_pad_id: int = 18,
        is_training: bool = True
    ) -> None:
        super().__init__()

        self.genre_pad_id = genre_pad_id
        self.inters_pos = inters.filter(pl.col('is_positive'))
        self.inters_neg = inters.filter(~pl.col('is_positive'))
        self.n_sample = 1 if is_training else 100

    def __len__(self):
        return len(self.inters_pos)
    
    def _pad(self, t: list) -> Tensor:
        num_pad = MAX_N_GENRES - len(t)
        return F.pad(torch.tensor(t), (0, num_pad), value=self.genre_pad_id)

    def __getitem__(self, index):
        row_pos = self.inters_pos.row(index, named=True)
        rows_neg = (
            self.inters_neg
            .filter(pl.col('user_id') == row_pos['user_id'])
            .sample(self.n_sample)
            .to_dicts()
        )

        item_id_pos = torch.tensor(row_pos['item_id']).unsqueeze(0)
        item_id_neg = torch.tensor([r['item_id'] for r in rows_neg])
        genre_id_pos = self._pad(row_pos['genre_id']).unsqueeze(0)
        pad_genre_pos = (genre_id_pos != self.genre_pad_id).int()
        genre_id_neg = torch.vstack([self._pad(r['genre_id']) for r in rows_neg])
        pad_genre_neg = (genre_id_neg != self.genre_pad_id).int()
    

        return {
            'user_id': row_pos['user_id'],
            'item_id_pos': item_id_pos,
            'item_id_neg': item_id_neg,
            'genre_id_pos': genre_id_pos,
            'pad_genre_pos': pad_genre_pos,
            'genre_id_neg': genre_id_neg,
            'pad_genre_neg': pad_genre_neg,
        }

# ds_test = DatasetRS(test, is_training=True)
# ds_test[10]
# loader_test = DataLoader(ds_test, batch_size=BSZ, shuffle=True)
# for r in loader_test:
#     break

# for k, v in r.items():
#     print(f"{k} --> {v.shape}")

# Define model

In [9]:
class Block(Module):
    def __init__(self, n: int, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)

        self.seq1 = nn.Sequential(
            nn.Linear(n, n),
            nn.LayerNorm(n),
            nn.GELU(),
        )
        self.seq2 = nn.Sequential(
            nn.LayerNorm(n),
            nn.GELU(),
        )

    def forward(self, X: Tensor) -> Tensor:
        out = self.seq2(self.seq1(X) + X)

        return out

class TwoTower(Module):
    def __init__(
        self,
        n_users: int,
        n_items: int,
        n_genres: int,
        d_hid: int = 128,
        n_blocks: int = 2,
    ) -> None:
        super().__init__()

        self.embd_users = nn.Embedding(n_users, d_hid)
        self.embd_items = nn.Embedding(n_items, d_hid)
        self.embd_genres = nn.Embedding(n_genres, d_hid)

        self.block_user = nn.ModuleList([Block(d_hid)] * n_blocks)
        self.block_item = nn.ModuleList([Block(d_hid)] * n_blocks)
        self.lin1 = nn.Linear(2 * d_hid, d_hid)

        # self.ff1 = nn.Sequential(
        #     nn.Linear(4 * d_hid, 4 * d_hid),
        #     nn.LayerNorm(4 * d_hid),
        #     nn.GELU(),
        # )
        self.lin2 = nn.Linear(4 * d_hid, 1)

    def forward(self,
        users: Tensor,
        items: Tensor,
        genres: Tensor,
        pads: Tensor
    ) -> Tensor:
        # TODO: HoangLe [Jan-22]: Continue modifying this to match with data
        # users: [bz]
        # items: [bz, l]
        # genres, pads: [bz, l, MAX_N_GENRES]

        ##################################
        # Step 1: Embed
        ##################################
        users = self.embd_users(users)
        # [bz, d_hid]

        items = self.embd_items(items)
        # [bz, l, d_hid]
        genres = self.embd_genres(genres)
        # [bz,, l, MAX_N_GENRES, d_hid]
        genres = (genres * pads.unsqueeze(-1)).sum(dim=2)
        # [bz, l, d_hid]

        items = torch.concat([items, genres], dim=-1)
        # [bz, l,  2 * d_hid]

        ##################################
        # Step 2: Block
        ##################################
        for block in self.block_user:
            users = block(users)
            # [bz, d_hid]

        items = self.lin1(items)
        # [bz, l, d_hid]
        for block in self.block_item:
            items = block(items)
            # [bz, l, d_hid]

        ##################################
        # Step 3: Alignment
        ##################################
        # aligned = torch.concat([users, items, users - items, users * items], dim=-1)
        # [bz, 4 * d_hid]
        # aligned = self.ff1(aligned)
        # logit = F.sigmoid(self.lin2(aligned))

        users = users.unsqueeze(1)
        # [bz, 1, d_hid]
        items = items.transpose(1, 2)
        # [bz, d_hid, l]

        logger.debug(f"users: {users.shape}")
        logger.debug(f"items: {items.shape}")
  
        aligned = (users @ items).squeeze(1)
        # [bz, l]

        return aligned

class LitTwoTower(L.LightningModule):
    def __init__(
        self,
        d_hid: int = 128,
        n_blocks: int = 2,
        lr: float = 1e-4,
        criterion: Literal["BPR", "LogLoss"] = "BPR",
        k: int = 5,
    ) -> None:
        super().__init__()

        self.lr = lr

        self.model = TwoTower(N_USERS, N_ITEMS, N_GENRES, d_hid=d_hid, n_blocks=n_blocks)
        self.criterion = criterion

        self.ndcg = []
        self.k = k

    def _calc_loss(self, scores_pos: Tensor, scores_neg: Tensor, THETA: float = 1e-9) -> Tensor:
        match (self.criterion):
            case "BPR":
                distance = torch.clamp(scores_pos - scores_neg, THETA)
                loss = -torch.sum(torch.log(F.sigmoid(distance)))
            case "LogLoss":
                pass
            case _:
                raise NotImplementedError()

        return loss

    def training_step(self, batch, batch_idx) -> Tensor:
        scores_pos = self.model(batch['user_id'], batch['item_id_pos'], batch['genre_id_pos'], batch['pad_genre_pos'])
        scores_neg = self.model(batch['user_id'], batch['item_id_neg'], batch['genre_id_neg'], batch['pad_genre_neg'])

        loss = self._calc_loss(scores_pos, scores_neg)

        self.log("train_loss", loss, on_epoch=True, prog_bar=True, on_step=True)

        return loss
    
    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=self.lr)

        return optimizer
    
    def validation_step(self, batch, batch_idx):
        scores_pos = self.model(batch['user_id'], batch['item_id_pos'], batch['genre_id_pos'], batch['pad_genre_pos'])
        scores_neg = self.model(batch['user_id'], batch['item_id_neg'], batch['genre_id_neg'], batch['pad_genre_neg'])

        scores = torch.concat([scores_pos, scores_neg], dim=-1)
        relevances = torch.zeros_like(scores, device=self.device)
        relevances[:, 0] = 1

        relevances_sorted = relevances[:, torch.argsort(scores, dim=-1, descending=True)][0, :, : self.k]
        indices = 1 / torch.log2(torch.arange(2, relevances_sorted.shape[-1] + 2, device=self.device)).unsqueeze(0)

        ndcg = torch.mean(relevances_sorted @ indices.T)
        self.log("val_ndcg", ndcg, on_epoch=True, prog_bar=False)

        # self.scores_pos.append(scores_pos.detach().cpu())
        # self.scores_neg.append(scores_neg.detach().cpu())


# ds = DatasetRS(test, is_training=False)
# loader = DataLoader(ds, batch_size=BSZ, shuffle=True)
# for x in loader:
#     break
# model = TwoTower(N_USERS, N_ITEMS, N_GENRES)
# scores_pos = model(x['user_id'], x['item_id_pos'], x['genre_id_pos'], x['pad_genre_pos'])
# scores_neg = model(x['user_id'], x['item_id_neg'], x['genre_id_neg'], x['pad_genre_neg'])

In [10]:
loader_train = DataLoader(DatasetRS(train), batch_size=BSZ, shuffle=True,)
loader_val = DataLoader(DatasetRS(test, is_training=False), batch_size=BSZ)

# criterion = nn.BCEWithLogitsLoss()
# optimizer = AdamW(model.parameters())
litmodel = LitTwoTower()

trainer = L.Trainer(
    accelerator="mps",
    devices=1,
    max_epochs=N_EPOCHS,
    callbacks=[RichProgressBar(leave=True)],
    logger=TensorBoardLogger("tb_logs", name="two_tower", version=datetime.now().strftime("%m%d-%H%M%S")),
    check_val_every_n_epoch=1
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Start training

In [11]:
# model.train() 

# for epch in range(1, N_EPOCHS + 1):
#     for x in loader_train:
#         optimizer.zero_grad()

#         out = model(x['user'], x['product'], x['genres'], x['pad'], x['price'])
#         tgt = torch.ones_like(out, device=DEVICE, dtype=torch.float32)

#         loss = criterion(out, tgt)

#         loss.backward()
#         optimizer.step()
#         logger.info(f"loss: {loss.item():.4f}")

trainer.fit(litmodel, train_dataloaders=loader_train, val_dataloaders=loader_val)

Output()

ShapeError: cannot take a larger sample than the total population when `with_replacement=false`