# Learning-to-Rank example

In [1]:
from sklearn.ensemble import RandomForestRegressor
import numpy as np

## A class for pointwise-based learning to rank model

In [2]:
class PointWiseLTRModel(object):
    def __init__(self, regressor):
        """
        :param classifier: an instance of scikit-learn regressor
        """
        self.regressor = regressor

    def _train(self, X, y):
        """
        Trains and LTR model.
        :param X: features of training instances
        :param y: relevance assessments of training instances
        :return:
        """
        assert self.regressor is not None
        self.model = self.regressor.fit(X, y)

    def rank(self, ft, doc_ids):
        """
        Predicts relevance labels and rank documents for a given query
        :param ft: a list of features for query-doc pairs
        :param ft: a list of document ids
        :return:
        """
        assert self.model is not None
        rel_labels = self.model.predict(ft)
        sort_indices = np.argsort(rel_labels)[::-1]

        results = []
        for i in sort_indices:
            results.append((doc_ids[i], rel_labels[i]))
        return results

## Read data from file

In [3]:
def read_data_from_file(path):
    """
    :param path: path of file
    :return: X features of data, y labels of data, group a list of numbers indicate how many instances for each query
    """
    X, y, qids, doc_ids = [], [], [], []
    with open(path, "r") as f:
        i, s_qid = 0, None
        for line in f:
            items = line.strip().split()
            label = int(items[0])
            qid = items[1]
            doc_id = items[2]
            features = np.array([float(i.split(":")[1]) for i in items[3:]])
            X.append(features)
            y.append(label)
            qids.append(qid)
            doc_ids.append(doc_id)

    return X, y, qids, doc_ids

## Main

#### Read input data

In [4]:
X, y, qids, doc_ids = read_data_from_file(path="data/features_sample.txt")
qids_unique= list(set(qids))

print("#queries: ", len(qids_unique))
print("#query-doc pairs: ", len(y))


#queries:  339
#query-doc pairs:  14013


#### Split data into train and test sets (80% and 20%, respectively)

In [5]:
train_qids = []
test_qids = []

for i in range(len(qids_unique)):
    qid = qids_unique[i]
    if i % 5 == 0:  # test query
        test_qids.append(qid)
    else:  # train query
        train_qids.append(qid)
    
train_X, train_y = [], []
test_X, test_y = [], []

for i in range(len(X)):
    if qids[i] in train_qids:
        train_X.append(X[i])
        train_y.append(y[i])
    else:
        test_X.append(X[i])
        test_y.append(y[i])

#### Create a regression model and an LTR instance based on that

In [6]:
clf = RandomForestRegressor(max_depth=3, random_state=0)
ltr = PointWiseLTRModel(clf)

#### Train LTR model

In [7]:
ltr._train(train_X, train_y)

#### Generate ranking for a test query

In [8]:
qid = test_qids[0]  # first test query
# get the doc_ids and feature vectors for the documents that are to be ranked for this query
# Note that this is a really inefficient way of doing this!
test_X = []
test_doc_ids = []

for i in range(len(X)):
    if qids[i] == qid:
        test_X.append(X[i])
        test_doc_ids.append(doc_ids[i])

r = ltr.rank(test_X, test_doc_ids)
print(r)

[('GX271-93-3327909', 0.33622395312903264), ('GX236-37-11249339', 0.27539716807585607), ('GX264-92-1269066', 0.27539716807585607), ('GX233-88-9264368', 0.27539716807585607), ('GX132-25-6127667', 0.24858251233029405), ('GX104-75-14667053', 0.24652013303482176), ('GX261-94-0633628', 0.24652013303482176), ('GX051-59-12400045', 0.24652013303482176), ('GX036-88-3105280', 0.24652013303482176), ('GX230-97-12877355', 0.24652013303482176), ('GX232-43-7086097', 0.24652013303482176), ('GX230-96-2438447', 0.24652013303482176), ('GX252-20-5351758', 0.24652013303482176), ('GX254-22-4180355', 0.24652013303482176), ('GX260-93-1403917', 0.23757130553338385), ('GX012-85-8988825', 0.23757130553338385), ('GX025-11-11819187', 0.23757130553338385), ('GX054-19-10953443', 0.23757130553338385), ('GX252-62-4793191', 0.23757130553338385), ('GX270-90-5058390', 0.23757130553338385), ('GX256-61-13960422', 0.23757130553338385), ('GX097-96-1976846', 0.23757130553338385), ('GX084-00-1975997', 0.23757130553338385), ('G