In [1]:
pip install catboost

Collecting catboost
  Downloading catboost-1.2-cp310-cp310-manylinux2014_x86_64.whl (98.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m98.6/98.6 MB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: catboost
Successfully installed catboost-1.2


In [2]:
import math

import numpy as np
import torch
from catboost.datasets import msrank
from sklearn.preprocessing import StandardScaler

from typing import List

In [25]:
from catboost.datasets import msrank_10k

In [3]:
from math import log2

In [4]:
from sklearn.metrics import ndcg_score

In [5]:
class ListNet(torch.nn.Module):
    def __init__(self, num_input_features: int, hidden_dim: int):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.model = torch.nn.Sequential(
            torch.nn.Linear(num_input_features, self.hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_dim, 1),
        )

    def forward(self, input_1: torch.Tensor) -> torch.Tensor:
        logits = self.model(input_1)
        return logits

In [29]:
class Solution:
    def __init__(self, n_epochs: int = 5, listnet_hidden_dim: int = 30,
                 lr: float = 0.001, ndcg_top_k: int = 10):
        self._prepare_data()
        self.num_input_features = self.X_train.shape[1]
        self.ndcg_top_k = ndcg_top_k
        self.n_epochs = n_epochs

        self.model = self._create_model(
            self.num_input_features, listnet_hidden_dim)
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)

    def _get_data(self) -> List[np.ndarray]:
        train_df, test_df = msrank()

        X_train = train_df.drop([0, 1], axis=1).values
        y_train = train_df[0].values
        query_ids_train = train_df[1].values.astype(int)

        X_test = test_df.drop([0, 1], axis=1).values
        y_test = test_df[0].values
        query_ids_test = test_df[1].values.astype(int)

        return [X_train, y_train, query_ids_train, X_test, y_test, query_ids_test]

    def _prepare_data(self) -> None:
        (X_train, y_train, self.query_ids_train,
            X_test, y_test, self.query_ids_test) = self._get_data()
        self.X_train = torch.FloatTensor(self._scale_features_in_query_groups(X_train, self.query_ids_train))
        self.y_train = torch.FloatTensor(y_train)

        self.X_test = torch.FloatTensor(self._scale_features_in_query_groups(X_test, self.query_ids_test))
        self.y_test = torch.FloatTensor(y_test)

    def _scale_features_in_query_groups(self, inp_feat_array: np.ndarray,
                                        inp_query_ids: np.ndarray) -> np.ndarray:
        query_idx = {query: [] for query in np.unique(inp_query_ids)}
        scaler = StandardScaler()
        for idx, query in enumerate(inp_query_ids, 0):
            query_idx[query].append(idx)
        for query in query_idx:
            inp_feat_array[query_idx[query]] = scaler.fit_transform(inp_feat_array[query_idx[query]])
        return inp_feat_array

    def _create_model(self, listnet_num_input_features: int,
                      listnet_hidden_dim: int) -> torch.nn.Module:
        torch.manual_seed(0)
        net = ListNet(listnet_num_input_features, listnet_hidden_dim)
        return net

    def fit(self) -> List[float]:
        for n in range(self.n_epochs):
            self._train_one_epoch()
            print(f'epoch_{n}: {self._eval_test_set()}')

    def _calc_loss(self, batch_ys: torch.FloatTensor,
                   batch_pred: torch.FloatTensor) -> torch.FloatTensor:
        P_y_i = torch.softmax(batch_ys, dim=0)
        P_z_i = torch.softmax(batch_pred, dim=0)
        return -torch.sum(P_y_i * torch.log(P_z_i))

    def _train_one_epoch(self) -> None:
        self.model.train(True)
        query_idx = {query: [] for query in np.unique(self.query_ids_train)}
        for idx, query in enumerate(self.query_ids_train, 0):
            query_idx[query].append(idx)
        for query, ids in query_idx.items():
            self.optimizer.zero_grad()

            ys_pred = self.model(self.X_train[ids]).reshape(-1,)
            ys_true = self.y_train[ids]
            loss = self._calc_loss(ys_true, ys_pred)
            loss.backward(retain_graph=True)
            self.optimizer.step()

    def _eval_test_set(self) -> float:
        with torch.no_grad():
            self.model.eval()
            query_idx = {query: [] for query in np.unique(sol.query_ids_test)}
            for idx, query in enumerate(sol.query_ids_test, 0):
                query_idx[query].append(idx)
            self.model.eval()
            ndcgs = []
            for query, ids in query_idx.items():
                ys_pred = self.model(self.X_test[ids])
                ndcg = self._ndcg_k(self.y_test[ids], ys_pred, self.ndcg_top_k)
                if np.isnan(ndcg):
                  ndcg = 0
                ndcgs.append(ndcg)
            return np.mean(ndcgs)

    def _ndcg_k(self, ys_true: torch.Tensor, ys_pred: torch.Tensor,
                ndcg_top_k: int) -> float:
        def dcg(ys_true, ys_pred):
            _, argsort = torch.sort(ys_pred, descending=True, dim=0)
            argsort = argsort[:ndcg_top_k]
            ys_true_sorted = ys_true[argsort]
            ret = 0
            for i, l in enumerate(ys_true_sorted, 1):
                ret += (2 ** l - 1) / math.log2(1 + i)
            return ret
        ideal_dcg = dcg(ys_true, ys_true)
        pred_dcg = dcg(ys_true, ys_pred)
        return (pred_dcg / ideal_dcg).item()

In [30]:
sol = Solution(n_epochs = 20, lr = 0.0005)

In [31]:
sol.fit()

epoch_0: 0.40996675501950086
epoch_1: 0.41451518119871616
epoch_2: 0.4179455491635017
epoch_3: 0.41965697512403133
epoch_4: 0.41931505596265195
epoch_5: 0.4212415689053014
epoch_6: 0.4209454887192696
epoch_7: 0.4206902553280816
epoch_8: 0.4206857800697908
epoch_9: 0.42220401181466877
epoch_10: 0.4225088845090941
epoch_11: 0.42127809999603777
epoch_12: 0.4223796743499115
epoch_13: 0.42314490593690424
epoch_14: 0.4225236871326342
epoch_15: 0.42267579738236966
epoch_16: 0.42296321080904453
epoch_17: 0.4231152542619966
epoch_18: 0.42317845740402116
epoch_19: 0.4241229712856002
