# RankNet

In [2]:
import torch

In [3]:
%%latex
$$C_{ij}=C(o_{ij})=-\bar{P_{ij}}log(P_{ij})-(1-\bar{P_{ij}})log(1-P_{ij})$$

<IPython.core.display.Latex object>

In [4]:
%%latex
$$o_{ij}=f(x_i)-f(x_j)$$

<IPython.core.display.Latex object>

In [5]:
%%latex
$$P_{ij}=\frac{e^{o_{ij}}}{1+e^{o_{ij}}}$$

<IPython.core.display.Latex object>

In [6]:
%%latex
$$\text{out}_{i} = \frac{1}{1 + e^{-\text{input}_{i}}}$$

<IPython.core.display.Latex object>

In [7]:
class RankNet(torch.nn.Module):
    def __init__(self, num_input_features, hidden_dim=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.model = torch.nn.Sequential(
            torch.nn.Linear(num_input_features, self.hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_dim, 1),
        )
        
        self.out_activation = torch.nn.Sigmoid()

    def forward(self, input_1, input_2):
        logits_1 = self.predict(input_1)
        logits_2 = self.predict(input_2)
        
        logits_diff = logits_1 - logits_2
        out = self.out_activation(logits_diff)

        return out
    
    def predict(self, inp):
        logits = self.model(inp)
        return logits

In [8]:
ranknet_model = RankNet(num_input_features=10)

In [9]:
inp_1, inp_2 = torch.rand(4, 10), torch.rand(4, 10)
# batch_size x input_dim
inp_2

tensor([[0.0532, 0.4993, 0.7645, 0.0702, 0.3200, 0.3427, 0.3579, 0.7294, 0.3162,
         0.7681],
        [0.8687, 0.1679, 0.9260, 0.1604, 0.8834, 0.7331, 0.7024, 0.9545, 0.0521,
         0.5836],
        [0.8555, 0.6871, 0.0532, 0.2254, 0.8835, 0.4444, 0.8803, 0.2107, 0.2026,
         0.0426],
        [0.2593, 0.5056, 0.0537, 0.9563, 0.2461, 0.5078, 0.5954, 0.9523, 0.7879,
         0.8331]])

In [10]:
preds = ranknet_model(inp_1, inp_2)
preds

tensor([[0.5300],
        [0.5178],
        [0.4611],
        [0.5234]], grad_fn=<SigmoidBackward0>)

In [11]:
ranknet_model.model

Sequential(
  (0): Linear(in_features=10, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=1, bias=True)
)

In [12]:
first_linear_layer = ranknet_model.model[0]

In [13]:
first_linear_layer.weight.grad

In [14]:
criterion = torch.nn.BCELoss()
loss = criterion(preds, torch.ones_like(preds))
loss.backward()

In [15]:
first_linear_layer.weight.grad

tensor([[ 4.9871e-03, -2.0317e-02, -1.3578e-02, -2.0541e-02, -8.0495e-04,
         -5.8093e-03,  3.2482e-02,  1.9457e-02, -1.3675e-02, -7.3663e-04],
        [ 3.3058e-03,  6.4443e-03,  6.8474e-04,  1.2189e-02,  3.1375e-03,
          6.4726e-03,  7.5897e-03,  1.2139e-02,  1.0043e-02,  1.0619e-02],
        [ 3.0231e-03,  9.3784e-03,  1.4491e-02,  1.7893e-02,  7.3732e-03,
          9.1765e-03,  7.1943e-03,  2.6430e-02,  2.5375e-03,  1.7426e-02],
        [ 3.6658e-03, -1.4934e-02, -9.9808e-03, -1.5099e-02, -5.9169e-04,
         -4.2701e-03,  2.3876e-02,  1.4302e-02, -1.0052e-02, -5.4146e-04],
        [ 2.2958e-03, -9.3529e-03, -6.2507e-03, -9.4558e-03, -3.7055e-04,
         -2.6743e-03,  1.4953e-02,  8.9567e-03, -6.2953e-03, -3.3910e-04],
        [-1.1595e-02, -1.0985e-02, -4.4955e-03, -6.5031e-03, -2.4134e-06,
         -5.4944e-03, -1.5780e-04, -4.3601e-03, -1.0747e-02, -1.9958e-03],
        [ 5.4647e-04,  5.0785e-03,  2.0652e-03,  7.2159e-03,  1.2585e-03,
          3.1506e-03, -1.5551e-0

In [16]:
ranknet_model.model[0].weight.grad

tensor([[ 4.9871e-03, -2.0317e-02, -1.3578e-02, -2.0541e-02, -8.0495e-04,
         -5.8093e-03,  3.2482e-02,  1.9457e-02, -1.3675e-02, -7.3663e-04],
        [ 3.3058e-03,  6.4443e-03,  6.8474e-04,  1.2189e-02,  3.1375e-03,
          6.4726e-03,  7.5897e-03,  1.2139e-02,  1.0043e-02,  1.0619e-02],
        [ 3.0231e-03,  9.3784e-03,  1.4491e-02,  1.7893e-02,  7.3732e-03,
          9.1765e-03,  7.1943e-03,  2.6430e-02,  2.5375e-03,  1.7426e-02],
        [ 3.6658e-03, -1.4934e-02, -9.9808e-03, -1.5099e-02, -5.9169e-04,
         -4.2701e-03,  2.3876e-02,  1.4302e-02, -1.0052e-02, -5.4146e-04],
        [ 2.2958e-03, -9.3529e-03, -6.2507e-03, -9.4558e-03, -3.7055e-04,
         -2.6743e-03,  1.4953e-02,  8.9567e-03, -6.2953e-03, -3.3910e-04],
        [-1.1595e-02, -1.0985e-02, -4.4955e-03, -6.5031e-03, -2.4134e-06,
         -5.4944e-03, -1.5780e-04, -4.3601e-03, -1.0747e-02, -1.9958e-03],
        [ 5.4647e-04,  5.0785e-03,  2.0652e-03,  7.2159e-03,  1.2585e-03,
          3.1506e-03, -1.5551e-0

In [17]:
ranknet_model.zero_grad()

# ListNet

In [18]:
from itertools import combinations
import numpy as np

In [19]:
pwd

'/home/mic/Projects/github/python/ml-hard/1-part'

In [20]:
import sys
sys.path.append('.')
# from utils import ndcg, num_swapped_pairs
import utils

In [21]:
class ListNet(torch.nn.Module):
    def __init__(self, num_input_features, hidden_dim=10):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.model = torch.nn.Sequential(
            torch.nn.Linear(num_input_features, self.hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(self.hidden_dim, 1),
        )


    def forward(self, input_1):
        logits = self.model(input_1)
        return logits


In [22]:
%%latex
$$CE = -\sum ^{N}_{j=1} (P_y^i(j) * log(P_z^i(j)))$$

<IPython.core.display.Latex object>

In [23]:
%%latex
$$\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}$$

<IPython.core.display.Latex object>

In [24]:
def listnet_ce_loss(y_i, z_i):
    """
    y_i: (n_i, 1) GT
    z_i: (n_i, 1) preds
    """

    P_y_i = torch.softmax(y_i, dim=0)
    P_z_i = torch.softmax(z_i, dim=0)
    return -torch.sum(P_y_i * torch.log(P_z_i))

def listnet_kl_loss(y_i, z_i):
    """
    y_i: (n_i, 1) GT
    z_i: (n_i, 1) preds
    """
    P_y_i = torch.softmax(y_i, dim=0)
    P_z_i = torch.softmax(z_i, dim=0)
    return -torch.sum(P_y_i * torch.log(P_z_i/P_y_i))


def make_dataset(N_train, N_valid, vector_dim):
    fake_weights = torch.randn(vector_dim, 1)

    X_train = torch.randn(N_train, vector_dim)
    X_valid = torch.randn(N_valid, vector_dim)

    ys_train_score = torch.mm(X_train, fake_weights)
    ys_train_score += torch.randn_like(ys_train_score)

    ys_valid_score = torch.mm(X_valid, fake_weights)
    ys_valid_score += torch.randn_like(ys_valid_score)

#     bins = [-1, 1]  # 3 relevances
    bins = [-1, 0, 1, 2]  # 5 relevances
    ys_train_rel = torch.Tensor(
        np.digitize(ys_train_score.clone().detach().numpy(), bins=bins)
    )
    ys_valid_rel = torch.Tensor(
        np.digitize(ys_valid_score.clone().detach().numpy(), bins=bins)
    )

    return X_train, X_valid, ys_train_rel, ys_valid_rel

In [25]:
N_train = 1000
N_valid = 500

vector_dim = 100
epochs = 2

batch_size = 16

X_train, X_valid, ys_train, ys_valid = make_dataset(N_train, N_valid, vector_dim)

net = ListNet(num_input_features=vector_dim)
opt = torch.optim.Adam(net.parameters())


In [26]:
X_train.shape, ys_train.shape

(torch.Size([1000, 100]), torch.Size([1000, 1]))

In [27]:
torch.unique(ys_train)

tensor([0., 1., 2., 3., 4.])

In [28]:
net.model

Sequential(
  (0): Linear(in_features=100, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=1, bias=True)
)

In [None]:
for epoch in range(epochs):
    idx = torch.randperm(N_train)

    X_train = X_train[idx]
    ys_train = ys_train[idx]

    cur_batch = 0
    for it in range(N_train // batch_size):
        batch_X = X_train[cur_batch: cur_batch + batch_size]
        batch_ys = ys_train[cur_batch: cur_batch + batch_size]
        cur_batch += batch_size

        opt.zero_grad()
        if len(batch_X) > 0:
            batch_pred = net(batch_X)
            # import pdb; pdb.set_trace()
            batch_loss = listnet_kl_loss(batch_ys, batch_pred)
#             batch_loss = listnet_ce_loss(batch_ys, batch_pred)
            batch_loss.backward(retain_graph=True)
            opt.step()

        if it % 10 == 0:
            with torch.no_grad():
                valid_pred = net(X_valid)
                valid_swapped_pairs = utils.num_swapped_pairs(ys_valid, valid_pred)
                # import pdb; pdb.set_trace()
                ndcg_score = utils.ndcg(ys_valid, valid_pred)
            print(f"epoch: {epoch + 1}.\tNumber of swapped pairs: " 
                  f"{valid_swapped_pairs}/{N_valid * (N_valid - 1) // 2}\t"
                  f"nDCG: {ndcg_score:.4f}")