# Learning-to-Rank example

In [57]:
from sklearn.ensemble import RandomForestRegressor
import numpy as np

## A class for pointwise-based learning to rank model

In [63]:
class PointWiseLTRModel(object):
    def __init__(self, regressor):
        """
        :param classifier: an instance of scikit-learn regressor
        """
        self.regressor = regressor

    def _train(self, X, y):
        """
        Trains and LTR model.
        :param X: features of training instances
        :param y: relevance assessments of training instances
        :return:
        """
        assert self.regressor is not None
        self.model = self.regressor.fit(X, y)

    def rank(self, ft, doc_ids):
        """
        Predicts relevance labels and rank documents for a given query
        :param ft: a list of features for query-doc pairs
        :param ft: a list of document ids
        :return:
        """
        assert self.model is not None
        rel_labels = self.model.predict(ft)
        sort_indices = np.argsort(rel_labels)[::-1]

        results = []
        for i in sort_indices:
            results.append((doc_ids[i], rel_labels[i]))
        return results

## Read data from file

In [59]:
def read_data_from_file(path):
    """
    :param path: path of file
    :return: X features of data, y labels of data, group a list of numbers indicate how many instances for each query
    """
    X, y, qids, doc_ids = [], [], [], []
    with open(path, "r") as f:
        i, s_qid = 0, None
        for line in f:
            items = line.strip().split()
            label = int(items[0])
            qid = items[1]
            doc_id = items[2]
            features = np.array([float(i.split(":")[1]) for i in items[3:]])
            X.append(features)
            y.append(label)
            qids.append(qid)
            doc_ids.append(doc_id)

    return X, y, qids, doc_ids

## Main

#### Read input data

In [64]:
X, y, qids, doc_ids = read_data_from_file(path="sample.txt")
qids_unique= list(set(qids))

print("#queries: ", len(qids_unique))
print("#query-doc pairs: ", len(y))


#queries:  339
#query-doc pairs:  14013


#### Split data into train and test sets (80% and 20%, respectively)

In [61]:
train_qids = []
test_qids = []

for i in range(len(qids_unique)):
    qid = qids_unique[i]
    if i % 5 == 0:  # test query
        test_qids.append(qid)
    else:  # train query
        train_qids.append(qid)
    
train_X, train_y = [], []
test_X, test_y = [], []

for i in range(len(X)):
    if qids[i] in train_qids:
        train_X.append(X[i])
        train_y.append(y[i])
    else:
        test_X.append(X[i])
        test_y.append(y[i])

#### Create a regression model and an LTR instance based on that

In [65]:
clf = RandomForestRegressor(max_depth=3, random_state=0)
ltr = PointWiseLTRModel(clf)

#### Train LTR model

In [66]:
ltr._train(train_X, train_y)

#### Generate ranking for a test query

In [67]:
qid = test_qids[0]  # first test query
# get the doc_ids and feature vectors for the documents that are to be ranked for this query
# Note that this is a really inefficient way of doing this!
test_X = []
test_doc_ids = []

for i in range(len(X)):
    if qids[i] == qid:
        test_X.append(X[i])
        test_doc_ids.append(doc_ids[i])

r = ltr.rank(test_X, test_doc_ids)
print(r)

[('GX066-05-2211546', 0.55007501419168225), ('GX111-95-16124960', 0.55007501419168225), ('GX265-69-3114302', 0.5130774146669167), ('GX228-89-6124858', 0.47959854798001267), ('GX266-95-13746921', 0.43404533839445475), ('GX234-63-15348127', 0.29647123487437876), ('GX268-49-8192130', 0.29647123487437876), ('GX230-05-2463816', 0.29480247720759362), ('GX240-58-7360101', 0.29480247720759362), ('GX262-53-14945504', 0.29480247720759362), ('GX264-67-6899645', 0.29480247720759362), ('GX272-67-15456013', 0.29480247720759362), ('GX269-14-11574107', 0.29480247720759362), ('GX269-71-1099993', 0.29480247720759362), ('GX266-50-14611773', 0.29398510769686026), ('GX266-83-14646051', 0.28232082058795138), ('GX028-49-4353723', 0.28089334995282539), ('GX031-89-3936030', 0.28089334995282539), ('GX269-15-14813735', 0.27822426863760763), ('GX021-68-10634771', 0.26912890135142448), ('GX059-88-5222624', 0.26746212138462755), ('GX266-22-15189670', 0.2649759942071091), ('GX235-93-4773725', 0.2649759942071091), ('