In [1]:
import argparse
import datetime
import glob
import hashlib
import json
import multiprocessing
import pickle
import os
import shutil
import subprocess
import uuid
import random

import numpy as np
import pandas as pd
import lightgbm as lgb
from collections import defaultdict
from lightgbm.sklearn import LGBMRanker
from tqdm import tqdm

from pyserini.analysis import Analyzer, get_lucene_analyzer
from pyserini.ltr import *
from pyserini.search import get_topics_with_reader

def train_data_loader(task='triple', neg_sample=10, random_seed=12345):
    if os.path.exists(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle'):
        sampled_train = pd.read_pickle(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle')
        print(sampled_train.shape)
        print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
        print(sampled_train.groupby('qid').count().mean())
        print(sampled_train.head(10))
        print(sampled_train.info())
        return sampled_train
    else:
        if task == 'triple':
            train = pd.read_csv('../collections/msmarco-passage/qidpidtriples.train.full.2.tsv', sep="\t",
                                names=['qid', 'pos_pid', 'neg_pid'], dtype=np.int32)
            pos_half = train[['qid', 'pos_pid']].rename(columns={"pos_pid": "pid"}).drop_duplicates()
            pos_half['rel'] = np.int32(1)
            neg_half = train[['qid', 'neg_pid']].rename(columns={"neg_pid": "pid"}).drop_duplicates()
            neg_half['rel'] = np.int32(0)
            del train
            sampled_neg_half = []
            for qid, group in tqdm(neg_half.groupby('qid')):
                sampled_neg_half.append(group.sample(n=min(neg_sample, len(group)), random_state=random_seed))
            sampled_train = pd.concat([pos_half] + sampled_neg_half, axis=0, ignore_index=True)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle')
        elif task == 'rank':
            qrel = defaultdict(list)
            with open("../collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("../runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle')
        else:
            raise Exception('unknown parameters')
        return sampled_train
sampled_train = train_data_loader(task='triple', neg_sample=10)

(4416413, 1)
(400782,)
rel    11.019489
dtype: float64
             rel
qid pid         
3   970816     0
    1142680    1
    2019206    0
    2605131    0
    2963098    0
    2971685    0
    3783924    0
    5067083    0
    5904778    0
    6176208    0
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 4416413 entries, (3, 970816) to (1185869, 7770561)
Data columns (total 1 columns):
 #   Column  Dtype
---  ------  -----
 0   rel     int32
dtypes: int32(1)
memory usage: 166.9 MB
None


In [2]:
def dev_data_loader(task='pygaggle'):
    if os.path.exists(f'dev_{task}.pickle'):
        dev = pd.read_pickle(f'dev_{task}.pickle')
        print(dev.shape)
        print(dev.index.get_level_values('qid').drop_duplicates().shape)
        print(dev.groupby('qid').count().mean())
        print(dev.head(10))
        print(dev.info())
        dev_qrel = pd.read_pickle(f'dev_qrel.pickle')
        return dev, dev_qrel
    else:
        if task == 'rerank':
            dev = pd.read_csv('../collections/msmarco-passage/top1000.dev', sep="\t",
                              names=['qid', 'pid', 'query', 'doc'], usecols=['qid', 'pid'], dtype=np.int32)
        elif task == 'anserini':
            dev = pd.read_csv('../runs/msmarco-passage-expanded/run.msmarco-passage-expanded.dev.small.txt',sep="\t",
                            names=['qid','pid','rank'], dtype=np.int32)
        else:
            raise Exception('unknown parameters')
        dev_qrel = pd.read_csv('../collections/msmarco-passage/qrels.dev.small.tsv', sep="\t",
                               names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'], dtype=np.int32)
        dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left')
        dev['rel'] = dev['rel'].fillna(0).astype(np.int32)
        dev = dev.sort_values(['qid','pid']).set_index(['qid','pid'])
        
        print(dev.shape)
        print(dev.index.get_level_values('qid').drop_duplicates().shape)
        print(dev.groupby('qid').count().mean())
        print(dev.head(10))
        print(dev.info())

        dev.to_pickle(f'dev_{task}.pickle')
        dev_qrel.to_pickle(f'dev_qrel.pickle')
        return dev, dev_qrel
dev, dev_qrel = dev_data_loader(task='anserini')

(6974999, 2)
(6980,)
rank    999.283524
rel     999.283524
dtype: float64
           rank  rel
qid pid             
2   35571   745    0
    55860   657    0
    62497   354    0
    63138   979    0
    72210   961    0
    98587   316    0
    98589   256    0
    98590   425    0
    98592   637    0
    98595   340    0
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 6974999 entries, (2, 35571) to (1102400, 8841784)
Data columns (total 2 columns):
 #   Column  Dtype
---  ------  -----
 0   rank    int32
 1   rel     int32
dtypes: int32(2)
memory usage: 286.8 MB
None


In [3]:
def query_loader(choice='default'):
    if os.path.exists(f'query_{choice}_tokenized.pickle'):
        return pickle.load(open(f'query_{choice}_tokenized.pickle','rb'))
    else:
        if choice == 'default':
            analyzer = Analyzer(get_lucene_analyzer())
            nonStopAnalyzer = Analyzer(get_lucene_analyzer(stopwords=False))
            queries = get_topics_with_reader('io.anserini.search.topicreader.TsvIntTopicReader', \
                                             '../collections/msmarco-passage/queries.train.tsv')
            queries.update(get_topics_with_reader('io.anserini.search.topicreader.TsvIntTopicReader', \
                                                  '../collections/msmarco-passage/queries.dev.tsv'))
            for qid,value in queries.items():
                assert 'tokenized' not in value
                value['tokenized'] = analyzer.analyze(value['title'])
                assert 'nonSW' not in value
                value['nonSW'] = nonStopAnalyzer.analyze(value['title'])
        else:
            raise Exception('unknown parameters')

        pickle.dump(queries,open(f'query_{choice}_tokenized.pickle','wb'))

        return queries
queries = query_loader()

In [4]:
fe = FeatureExtractor('../indexes/msmarco-passage-expanded/lucene-index-msmarco-passage-expanded_combined//',max(multiprocessing.cpu_count()//2,1))
fe.add(BM25(k1=0.9,b=0.4))
fe.add(BM25(k1=1.2,b=0.75))
fe.add(BM25(k1=2.0,b=0.75))

fe.add(LMDir(mu=1000))
fe.add(LMDir(mu=1500))
fe.add(LMDir(mu=2500))

fe.add(LMJM(0.1))
fe.add(LMJM(0.4))
fe.add(LMJM(0.7))

fe.add(NTFIDF())
fe.add(ProbalitySum())

fe.add(DFR_GL2())
fe.add(DFR_In_expB2())
fe.add(DPH())

# fe.add(ContextDFR_GL2(AvgPooler()))
# fe.add(ContextDFR_GL2(VarPooler()))
# fe.add(ContextDFR_In_expB2(AvgPooler()))
# fe.add(ContextDFR_In_expB2(VarPooler()))
# fe.add(ContextDPH(AvgPooler()))
# fe.add(ContextDPH(VarPooler()))

fe.add(Proximity())
fe.add(TPscore())
fe.add(tpDist())
# fe.add(SDM())

fe.add(DocSize())
fe.add(Entropy())
fe.add(StopCover())
fe.add(StopRatio())

fe.add(QueryLength())
fe.add(QueryLengthNonStopWords())
fe.add(QueryCoverageRatio())
fe.add(UniqueTermCount())
fe.add(MatchingTermCount())
fe.add(SCS())

fe.add(tfStat(AvgPooler()))
fe.add(tfStat(SumPooler()))
fe.add(tfStat(MinPooler()))
fe.add(tfStat(MaxPooler()))
fe.add(tfStat(VarPooler()))
fe.add(tfIdfStat(AvgPooler()))
fe.add(tfIdfStat(SumPooler()))
fe.add(tfIdfStat(MinPooler()))
fe.add(tfIdfStat(MaxPooler()))
fe.add(tfIdfStat(VarPooler()))
fe.add(scqStat(AvgPooler()))
fe.add(scqStat(SumPooler()))
fe.add(scqStat(MinPooler()))
fe.add(scqStat(MaxPooler()))
fe.add(scqStat(VarPooler()))
fe.add(normalizedTfStat(AvgPooler()))
fe.add(normalizedTfStat(SumPooler()))
fe.add(normalizedTfStat(MinPooler()))
fe.add(normalizedTfStat(MaxPooler()))
fe.add(normalizedTfStat(VarPooler()))
# fe.add(normalizedDocSizeStat(AvgPooler()))
# fe.add(normalizedDocSizeStat(SumPooler()))
# fe.add(normalizedDocSizeStat(MinPooler()))
# fe.add(normalizedDocSizeStat(MaxPooler()))
# fe.add(normalizedDocSizeStat(VarPooler()))

fe.add(idfStat(AvgPooler()))
fe.add(idfStat(SumPooler()))
fe.add(idfStat(MinPooler()))
fe.add(idfStat(MaxPooler()))
fe.add(idfStat(VarPooler()))
fe.add(idfStat(MaxMinRatioPooler()))
fe.add(idfStat(ConfidencePooler()))
fe.add(ictfStat(AvgPooler()))
fe.add(ictfStat(SumPooler()))
fe.add(ictfStat(MinPooler()))
fe.add(ictfStat(MaxPooler()))
fe.add(ictfStat(VarPooler()))
fe.add(ictfStat(MaxMinRatioPooler()))
fe.add(ictfStat(ConfidencePooler()))

fe.add(UnorderedSequentialPairs(3))
fe.add(UnorderedSequentialPairs(8))
fe.add(UnorderedSequentialPairs(15))
fe.add(OrderedSequentialPairs(3))
fe.add(OrderedSequentialPairs(8))
fe.add(OrderedSequentialPairs(15))
fe.add(UnorderedQueryPairs(3))
fe.add(UnorderedQueryPairs(8))
fe.add(UnorderedQueryPairs(15))
fe.add(OrderedQueryPairs(3))
fe.add(OrderedQueryPairs(8))
fe.add(OrderedQueryPairs(15))

In [5]:
def extract(df, queries, fe):
    df_pieces = []
    fetch_later = []
    qidpid2rel = defaultdict(dict)
    need_rows = 0
    for qid,group in tqdm(df.groupby('qid')):
        for t in group.reset_index().itertuples():
            assert t.pid not in qidpid2rel[t.qid]
            qidpid2rel[t.qid][t.pid] = t.rel
            need_rows += 1
        #test.py has bug here, it does not convert pid to str, not sure why it does not cause problem in java
        fe.lazy_extract(str(qid),
                        queries[qid]['nonSW'], 
                        queries[qid]['tokenized'],
                        [str(pid) for pid in qidpid2rel[t.qid].keys()])
        fetch_later.append(str(qid))
        if len(fetch_later) == 10000:
            info = np.zeros(shape=(need_rows,3), dtype=np.int32)
            feature = np.zeros(shape=(need_rows,len(fe.feature_names())), dtype=np.float32)
            idx = 0
            for qid in fetch_later:
                for doc in fe.get_result(qid):
                    info[idx,0] = int(qid)
                    info[idx,1] = int(doc['pid'])
                    info[idx,2] = qidpid2rel[int(qid)][int(doc['pid'])]
                    feature[idx,:] = doc['features']
                    idx += 1
            info = pd.DataFrame(info, columns=['qid','pid','rel'])
            feature = pd.DataFrame(feature, columns=fe.feature_names())
            df_pieces.append(pd.concat([info,feature], axis=1))
            fetch_later = []
            need_rows = 0
    #deal with rest
    if len(fetch_later) > 0:
        info = np.zeros(shape=(need_rows,3), dtype=np.int32)
        feature = np.zeros(shape=(need_rows,len(fe.feature_names())), dtype=np.float32)
        idx = 0
        for qid in fetch_later:
            for doc in fe.get_result(qid):
                info[idx,0] = int(qid)
                info[idx,1] = int(doc['pid'])
                info[idx,2] = qidpid2rel[int(qid)][int(doc['pid'])]
                feature[idx,:] = doc['features']
                idx += 1
        info = pd.DataFrame(info, columns=['qid','pid','rel'])
        feature = pd.DataFrame(feature, columns=fe.feature_names())
        df_pieces.append(pd.concat([info,feature], axis=1))
    data = pd.concat(df_pieces, axis=0, ignore_index=True)
    data = data.sort_values(by='qid', kind='mergesort')
    group = data.groupby('qid').agg(count=('pid', 'count'))['count']
    return data,group

In [6]:
def hash_df(df):
    h = pd.util.hash_pandas_object(df)
    return hex(h.sum().astype(np.uint64))


def hash_anserini_jar():
    find = glob.glob(os.environ['ANSERINI_CLASSPATH'] + "/*fatjar.jar")
    assert len(find) == 1
    md5Hash = hashlib.md5(open(find[0], 'rb').read())
    return md5Hash.hexdigest()


def hash_fe(fe):
    return hashlib.md5(','.join(sorted(fe.feature_names())).encode()).hexdigest()


def data_loader(task, df, queries, fe):
    df_hash = hash_df(df)
    jar_hash = hash_anserini_jar()
    fe_hash = hash_fe(fe)
    print(f'check {task}_{df_hash}_{jar_hash}_{fe_hash}.pickle')
    if os.path.exists(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle'):
        res = pickle.load(open(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle','rb'))
        print(res['data'].shape)
        print(res['data'].qid.drop_duplicates().shape)
        print(res['group'].mean())
        print(res['data'].head(10))
        print(res['data'].info())
        return res
    else:
        if task == 'train' or task == 'dev': 
            data,group = extract(df, queries, fe)
            obj = {'data':data,'group':group,'df_hash':df_hash,'jar_hash':jar_hash,'fe_hash':fe_hash}
            print(data.shape)
            print(data.qid.drop_duplicates().shape)
            print(group.mean())
            print(data.head(10))
            print(data.info())
            pickle.dump(obj,open(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle','wb'))
            return obj
        else:
            raise Exception('unknown parameters')

In [7]:
import json
def export(df, fn):
    with open(fn,'w') as f:
        for qid, group in tqdm(df.groupby('qid')):
            line = {}
            line['qid'] = qid
            line['queryTokens'] = queries[qid]['tokenized']
            line['queryText'] = queries[qid]['nonSW']
            line['docIds'] = [str(did) for did in group.reset_index().pid.drop_duplicates().tolist()]
            f.write(json.dumps(line)+'\n')

In [8]:
train_extracted = data_loader('train', sampled_train, queries, fe)
# export(sampled_train, 'sampled_train_export.json')
dev_extracted = data_loader('dev', dev, queries, fe)
# export(dev, 'sampled_dev_export.json')
del sampled_train, dev

check train_0xa0471fd494eedb15_ff7f12482f8f0fe1a91997fdb7ce8f59_b89200e852f02ebe37e423d4c4eed8ff.pickle
(4416413, 76)
(400782,)
11.019489398226467
   qid      pid  rel  contents_BM25_k1_0.90_b_0.40  \
0    3   970816    0                     28.065395   
1    3  1142680    1                     30.463688   
2    3  2019206    0                     26.527094   
3    3  2605131    0                     20.248068   
4    3  2963098    0                     27.040712   
5    3  2971685    0                     21.828558   
6    3  3783924    0                     22.618692   
7    3  5067083    0                     13.735486   
8    3  5904778    0                     11.439550   
9    3  6176208    0                     21.785690   

   contents_BM25_k1_1.20_b_0.75  contents_BM25_k1_2.00_b_0.75  \
0                     31.157480                     38.675735   
1                     34.685135                     45.659901   
2                     29.210737                     35.150719  

check dev_0xcd314f29dd04512c_ff7f12482f8f0fe1a91997fdb7ce8f59_b89200e852f02ebe37e423d4c4eed8ff.pickle
(6974999, 76)
(6980,)
999.2835243553009
   qid    pid  rel  contents_BM25_k1_0.90_b_0.40  \
0    2  35571    0                     14.072955   
1    2  55860    0                     14.509807   
2    2  62497    0                     15.519955   
3    2  63138    0                     13.816697   
4    2  72210    0                     13.710751   
5    2  98587    0                     15.653041   
6    2  98589    0                     15.776281   
7    2  98590    0                     15.307396   
8    2  98592    0                     14.562378   
9    2  98595    0                     15.583447   

   contents_BM25_k1_1.20_b_0.75  contents_BM25_k1_2.00_b_0.75  \
0                     16.006453                     19.918795   
1                     16.314312                     20.537521   
2                     17.792782                     23.548132   
3                     15.

In [9]:
def eval_mrr(dev_data):
    score_tie_counter = 0
    score_tie_query = set()

    MRR = []
    for qid, group in tqdm(dev_data.groupby('qid')):
        group = group.reset_index()
        rank = 0
        prev_score = None
        assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
        # stable sort is also used in LightGBM

        for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
            if prev_score is not None and abs(t.score - prev_score) < 1e-8:
                score_tie_counter += 1
                score_tie_query.add(qid)
            prev_score = t.score
            prev_pid = t.pid
            rank += 1
            if t.rel>0:
                MRR.append(1.0/rank)
                break
            elif rank == 10 or rank == len(group):
                MRR.append(0.)
                break

    score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
    print(score_tie)
    print(np.mean(MRR))


In [10]:
def eval_recall(dev_qrel, dev_data):
    dev_rel_num = dev_qrel[dev_qrel['rel']>0].groupby('qid').count()['rel']

    score_tie_counter = 0
    score_tie_query = set()
    
    recall_point = [10,20,50,100,200,500,1000]
    recall_curve = {k:[] for k in recall_point}
    for qid, group in tqdm(dev_data.groupby('qid')):
        group = group.reset_index()
        rank = 0
        prev_score = None
        assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
        # stable sort is also used in LightGBM
        total_rel = dev_rel_num.loc[qid]
        query_recall = [0 for k in recall_point]
        for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
            if prev_score is not None and abs(t.score - prev_score) < 1e-8:
                score_tie_counter += 1
                score_tie_query.add(qid)
            prev_score = t.score
            rank += 1
            if t.rel>0:
                for i,p in enumerate(recall_point):
                    if rank <= p:
                        query_recall[i] += 1
        for i,p in enumerate(recall_point):
            if total_rel>0:
                recall_curve[p].append(query_recall[i]/total_rel)
            else:
                recall_curve[p].append(0.)

    score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
    print(score_tie)
    
    for k,v in recall_curve.items():
        avg = np.mean(v)
        print(f'recall@{k}:{avg}')


In [11]:
feature_name = fe.feature_names()
train_X = train_extracted['data'].loc[:, feature_name]
train_Y = train_extracted['data']['rel']
dev_X = dev_extracted['data'].loc[:, feature_name]
dev_Y = dev_extracted['data']['rel']
lgb_train = lgb.Dataset(train_X,label=train_Y,group=train_extracted['group'])
lgb_valid = lgb.Dataset(dev_X,label=dev_Y,group=dev_extracted['group'])

In [12]:
for i,n in enumerate(feature_name):
    if np.isnan(train_X.iloc[:,i]).any():
        print(n)

contents_DPH
Entropy
contents_IDF_confidence
contents_ICTF_confidence


In [13]:
dev_rel_num = dev_qrel[dev_qrel['rel']>0].groupby('qid').count()['rel']
gid = 0
sample_id = 0
dev_group_rel_num = []
for qid,group in dev_extracted['data'].groupby('qid'):
    group = group.sort_values(['pid'])
    assert len(group) == dev_extracted['group'].iloc[gid]
    assert np.isclose(group.iloc[0,:].loc[feature_name],
                      dev_X.iloc[sample_id,:], equal_nan=True).all()
    dev_group_rel_num.append(dev_rel_num.loc[qid])
    gid += 1
    sample_id += len(group)
dev_group=dev_extracted['group']

In [14]:
def recall_at_200(preds, dataset):
    global dev_group_rel_num
    global dev_group
    labels = dataset.get_label()
    groups = dataset.get_group()
    assert np.equal(groups, dev_group).all()
    idx = 0
    recall = 0
    for g,gnum in zip(groups, dev_group_rel_num):
        top_preds = labels[idx:idx+g][np.argsort(preds[idx:idx+g])]
        recall += np.sum(top_preds[-200:])/gnum
        idx += g
    assert idx == len(preds)
    return 'recall@200', recall/len(groups), True

In [12]:
params = {
    'boosting_type': 'gbdt',
    'objective': 'lambdarank',
    'max_bin':255,
    'num_leaves':63,
    'max_depth':-1,
    'min_data_in_leaf':50,
    'min_sum_hessian_in_leaf':0,
    'bagging_fraction':0.8,
    'bagging_freq':50,
    'feature_fraction':1,
    'learning_rate':0.1,
    'num_boost_round':1000,
    'metric':['map'],
    'eval_at':[10],
    'label_gain':[0,1],
    'lambdarank_truncation_level':20,
    'force_col_wise':True,
    'seed':12345,
    'num_threads':max(multiprocessing.cpu_count()//2,1)
}

num_boost_round = params.pop('num_boost_round')
eval_results={}
cv_gbm = lgb.cv(params, lgb_train, nfold=5, 
                num_boost_round=num_boost_round,
                feature_name=feature_name,
                verbose_eval=False,
                return_cvbooster=True)
dev_extracted['data']['score'] = 0.
for gbm in cv_gbm['cvbooster'].boosters:
    dev_extracted['data']['score']+=gbm.predict(dev_X)
feature_importances = sorted(list(zip(feature_name,gbm.feature_importance().tolist())),key=lambda x:x[1],reverse=True)
print(feature_importances)
eval_recall(dev_qrel, dev_extracted['data'])
eval_mrr(dev_extracted['data'])

[LightGBM] [Info] Total groups: 320625, total data: 3533129
[LightGBM] [Info] Total Bins 14680
[LightGBM] [Info] Number of data points in the train set: 3533129, number of used features: 65
[LightGBM] [Info] Total groups: 80157, total data: 883284
[LightGBM] [Info] Total groups: 320626, total data: 3533131
[LightGBM] [Info] Total Bins 14680
[LightGBM] [Info] Number of data points in the train set: 3533131, number of used features: 65
[LightGBM] [Info] Total groups: 80156, total data: 883282
[LightGBM] [Info] Total groups: 320626, total data: 3533131
[LightGBM] [Info] Total Bins 14680
[LightGBM] [Info] Number of data points in the train set: 3533131, number of used features: 65
[LightGBM] [Info] Total groups: 80156, total data: 883282
[LightGBM] [Info] Total groups: 320626, total data: 3533131
[LightGBM] [Info] Total Bins 14680
[LightGBM] [Info] Number of data points in the train set: 3533131, number of used features: 65
[LightGBM] [Info] Total groups: 80156, total data: 883282
[LightGB

100%|██████████| 6980/6980 [02:41<00:00, 43.13it/s] 


score_tie occurs 270979 times in 2013 queries
recall@10:0.5749880611270296
recall@20:0.6790472779369627
recall@50:0.7868194842406877
recall@100:0.8449140401146131
recall@200:0.8901981852913085
recall@500:0.9321513849092646
recall@1000:0.9470869149952243


100%|██████████| 6980/6980 [01:16<00:00, 91.79it/s] 

score_tie occurs 24 times in 9 queries 0.2974463887751853





In [15]:
dev_extracted['data']['score'] = 0.
for seed in [12345,25523,23543,12352,64654,23123]:
    params = {
            'boosting_type': 'gbdt',
            'objective': 'lambdarank',
            'max_bin':255,
            'num_leaves':63,
            'max_depth':-1,
            'min_data_in_leaf':30,
            'min_sum_hessian_in_leaf':0,
            'bagging_fraction':0.8,
            'bagging_freq':50,
            'feature_fraction':1,
            'learning_rate':0.1,
            'num_boost_round':1000,
            'early_stopping_round':200,
            'metric':'custom',
            'label_gain':[0,1],
            'lambdarank_truncation_level':20,
            'seed':seed,
            'num_threads':max(multiprocessing.cpu_count()//2,1)
    }
    num_boost_round = params.pop('num_boost_round')
    early_stopping_round = params.pop('early_stopping_round')
    gbm = lgb.train(params, lgb_train,
                    valid_sets=lgb_valid,
                    num_boost_round=num_boost_round,
                    early_stopping_rounds=early_stopping_round,
                    feval=recall_at_200,
                    feature_name=feature_name,
                    verbose_eval=True)
    dev_extracted['data']['score'] += gbm.predict(dev_X)
    best_score = gbm.best_score['valid_0']['recall@200']
    print(best_score)
    best_iteration = gbm.best_iteration
    print(best_iteration)
    feature_importances = sorted(list(zip(feature_name,gbm.feature_importance().tolist())),
                                 key=lambda x:x[1],reverse=True)
    print(feature_importances)
eval_recall(dev_qrel, dev_extracted['data'])
eval_mrr(dev_extracted['data'])

You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 14680
[LightGBM] [Info] Number of data points in the train set: 4416413, number of used features: 65
[1]	valid_0's recall@200: 0.85394
Training until validation scores don't improve for 200 rounds
[2]	valid_0's recall@200: 0.867467
[3]	valid_0's recall@200: 0.874809
[4]	valid_0's recall@200: 0.878415
[5]	valid_0's recall@200: 0.8782
[6]	valid_0's recall@200: 0.878629
[7]	valid_0's recall@200: 0.878665
[8]	valid_0's recall@200: 0.879
[9]	valid_0's recall@200: 0.880253
[10]	valid_0's recall@200: 0.881865
[11]	valid_0's recall@200: 0.882653
[12]	valid_0's recall@200: 0.883226
[13]	valid_0's recall@200: 0.88356
[14]	valid_0's recall@200: 0.883274
[15]	valid_0's recall@200: 0.883405
[16]	valid_0's recall@200: 0.883739
[17]	valid_0's recall@200: 0.883775
[18]	valid_0's recall@200: 0.88405
[19]	valid_0's recall@200: 0.884229
[20]	valid_0's reca

[215]	valid_0's recall@200: 0.889076
[216]	valid_0's recall@200: 0.889219
[217]	valid_0's recall@200: 0.889291
[218]	valid_0's recall@200: 0.889291
[219]	valid_0's recall@200: 0.889291
[220]	valid_0's recall@200: 0.889291
[221]	valid_0's recall@200: 0.889434
[222]	valid_0's recall@200: 0.889434
[223]	valid_0's recall@200: 0.889577
[224]	valid_0's recall@200: 0.889649
[225]	valid_0's recall@200: 0.889362
[226]	valid_0's recall@200: 0.889219
[227]	valid_0's recall@200: 0.889362
[228]	valid_0's recall@200: 0.889506
[229]	valid_0's recall@200: 0.889434
[230]	valid_0's recall@200: 0.889219
[231]	valid_0's recall@200: 0.889219
[232]	valid_0's recall@200: 0.889219
[233]	valid_0's recall@200: 0.889362
[234]	valid_0's recall@200: 0.889362
[235]	valid_0's recall@200: 0.889291
[236]	valid_0's recall@200: 0.889506
[237]	valid_0's recall@200: 0.889649
[238]	valid_0's recall@200: 0.889577
[239]	valid_0's recall@200: 0.889434
[240]	valid_0's recall@200: 0.889434
[241]	valid_0's recall@200: 0.889434
[

[438]	valid_0's recall@200: 0.888754
[439]	valid_0's recall@200: 0.888754
[440]	valid_0's recall@200: 0.888754
[441]	valid_0's recall@200: 0.888754
[442]	valid_0's recall@200: 0.888754
[443]	valid_0's recall@200: 0.888754
[444]	valid_0's recall@200: 0.888754
[445]	valid_0's recall@200: 0.888754
[446]	valid_0's recall@200: 0.888968
[447]	valid_0's recall@200: 0.888968
[448]	valid_0's recall@200: 0.888968
[449]	valid_0's recall@200: 0.888968
[450]	valid_0's recall@200: 0.888968
[451]	valid_0's recall@200: 0.888968
[452]	valid_0's recall@200: 0.889112
[453]	valid_0's recall@200: 0.889112
[454]	valid_0's recall@200: 0.889112
[455]	valid_0's recall@200: 0.889112
[456]	valid_0's recall@200: 0.889255
[457]	valid_0's recall@200: 0.889255
[458]	valid_0's recall@200: 0.889255
[459]	valid_0's recall@200: 0.889255
[460]	valid_0's recall@200: 0.889112
[461]	valid_0's recall@200: 0.889255
[462]	valid_0's recall@200: 0.889398
[463]	valid_0's recall@200: 0.889613
[464]	valid_0's recall@200: 0.889685
[

[99]	valid_0's recall@200: 0.887572
[100]	valid_0's recall@200: 0.887572
[101]	valid_0's recall@200: 0.887858
[102]	valid_0's recall@200: 0.888216
[103]	valid_0's recall@200: 0.887787
[104]	valid_0's recall@200: 0.887428
[105]	valid_0's recall@200: 0.887572
[106]	valid_0's recall@200: 0.888145
[107]	valid_0's recall@200: 0.888145
[108]	valid_0's recall@200: 0.888073
[109]	valid_0's recall@200: 0.888288
[110]	valid_0's recall@200: 0.888252
[111]	valid_0's recall@200: 0.888181
[112]	valid_0's recall@200: 0.888037
[113]	valid_0's recall@200: 0.888037
[114]	valid_0's recall@200: 0.888037
[115]	valid_0's recall@200: 0.888181
[116]	valid_0's recall@200: 0.888324
[117]	valid_0's recall@200: 0.888324
[118]	valid_0's recall@200: 0.888395
[119]	valid_0's recall@200: 0.887751
[120]	valid_0's recall@200: 0.887751
[121]	valid_0's recall@200: 0.888037
[122]	valid_0's recall@200: 0.887894
[123]	valid_0's recall@200: 0.888181
[124]	valid_0's recall@200: 0.888109
[125]	valid_0's recall@200: 0.888467
[1

[321]	valid_0's recall@200: 0.889577
[322]	valid_0's recall@200: 0.889434
[323]	valid_0's recall@200: 0.889434
[324]	valid_0's recall@200: 0.889434
[325]	valid_0's recall@200: 0.889649
[326]	valid_0's recall@200: 0.889649
[327]	valid_0's recall@200: 0.889649
[328]	valid_0's recall@200: 0.889649
[329]	valid_0's recall@200: 0.889792
[330]	valid_0's recall@200: 0.889792
[331]	valid_0's recall@200: 0.889506
[332]	valid_0's recall@200: 0.889936
[333]	valid_0's recall@200: 0.889936
[334]	valid_0's recall@200: 0.889936
[335]	valid_0's recall@200: 0.889936
[336]	valid_0's recall@200: 0.889792
[337]	valid_0's recall@200: 0.889792
[338]	valid_0's recall@200: 0.889506
[339]	valid_0's recall@200: 0.889362
[340]	valid_0's recall@200: 0.889362
[341]	valid_0's recall@200: 0.889076
[342]	valid_0's recall@200: 0.889076
[343]	valid_0's recall@200: 0.889076
[344]	valid_0's recall@200: 0.889076
[345]	valid_0's recall@200: 0.888933
[346]	valid_0's recall@200: 0.888933
[347]	valid_0's recall@200: 0.888933
[

You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 14680
[LightGBM] [Info] Number of data points in the train set: 4416413, number of used features: 65
[1]	valid_0's recall@200: 0.852352
Training until validation scores don't improve for 200 rounds
[2]	valid_0's recall@200: 0.868469
[3]	valid_0's recall@200: 0.872302
[4]	valid_0's recall@200: 0.875167
[5]	valid_0's recall@200: 0.877674
[6]	valid_0's recall@200: 0.879083
[7]	valid_0's recall@200: 0.878701
[8]	valid_0's recall@200: 0.880062
[9]	valid_0's recall@200: 0.879799
[10]	valid_0's recall@200: 0.880492
[11]	valid_0's recall@200: 0.881925
[12]	valid_0's recall@200: 0.882223
[13]	valid_0's recall@200: 0.883369
[14]	valid_0's recall@200: 0.883966
[15]	valid_0's recall@200: 0.883465
[16]	valid_0's recall@200: 0.884133
[17]	valid_0's recall@200: 0.884396
[18]	valid_0's recall@200: 0.884348
[19]	valid_0's recall@200: 0.883775
[20]	valid_

[215]	valid_0's recall@200: 0.889613
[216]	valid_0's recall@200: 0.889613
[217]	valid_0's recall@200: 0.889685
[218]	valid_0's recall@200: 0.889542
[219]	valid_0's recall@200: 0.889542
[220]	valid_0's recall@200: 0.889398
[221]	valid_0's recall@200: 0.889255
[222]	valid_0's recall@200: 0.889255
[223]	valid_0's recall@200: 0.889255
[224]	valid_0's recall@200: 0.889255
[225]	valid_0's recall@200: 0.889398
[226]	valid_0's recall@200: 0.889398
[227]	valid_0's recall@200: 0.889542
[228]	valid_0's recall@200: 0.889542
[229]	valid_0's recall@200: 0.889542
[230]	valid_0's recall@200: 0.889542
[231]	valid_0's recall@200: 0.889542
[232]	valid_0's recall@200: 0.889685
[233]	valid_0's recall@200: 0.889828
[234]	valid_0's recall@200: 0.889828
[235]	valid_0's recall@200: 0.889971
[236]	valid_0's recall@200: 0.889828
[237]	valid_0's recall@200: 0.889828
[238]	valid_0's recall@200: 0.889828
[239]	valid_0's recall@200: 0.889828
[240]	valid_0's recall@200: 0.889828
[241]	valid_0's recall@200: 0.8899
[24

[438]	valid_0's recall@200: 0.89095
[439]	valid_0's recall@200: 0.89095
[440]	valid_0's recall@200: 0.89095
[441]	valid_0's recall@200: 0.890807
[442]	valid_0's recall@200: 0.890807
[443]	valid_0's recall@200: 0.890915
[444]	valid_0's recall@200: 0.890915
[445]	valid_0's recall@200: 0.890915
[446]	valid_0's recall@200: 0.890915
[447]	valid_0's recall@200: 0.890915
[448]	valid_0's recall@200: 0.890771
[449]	valid_0's recall@200: 0.890485
[450]	valid_0's recall@200: 0.890485
[451]	valid_0's recall@200: 0.890485
[452]	valid_0's recall@200: 0.890485
[453]	valid_0's recall@200: 0.890485
[454]	valid_0's recall@200: 0.890485
[455]	valid_0's recall@200: 0.890485
[456]	valid_0's recall@200: 0.890485
[457]	valid_0's recall@200: 0.890485
[458]	valid_0's recall@200: 0.890485
[459]	valid_0's recall@200: 0.890485
[460]	valid_0's recall@200: 0.890485
[461]	valid_0's recall@200: 0.890485
[462]	valid_0's recall@200: 0.890485
[463]	valid_0's recall@200: 0.890485
[464]	valid_0's recall@200: 0.890485
[465

[662]	valid_0's recall@200: 0.89101
[663]	valid_0's recall@200: 0.89101
[664]	valid_0's recall@200: 0.89101
[665]	valid_0's recall@200: 0.89101
[666]	valid_0's recall@200: 0.89101
[667]	valid_0's recall@200: 0.890867
[668]	valid_0's recall@200: 0.890867
[669]	valid_0's recall@200: 0.890867
[670]	valid_0's recall@200: 0.890867
[671]	valid_0's recall@200: 0.890867
[672]	valid_0's recall@200: 0.890867
[673]	valid_0's recall@200: 0.890938
[674]	valid_0's recall@200: 0.89101
[675]	valid_0's recall@200: 0.890938
[676]	valid_0's recall@200: 0.891082
[677]	valid_0's recall@200: 0.891082
[678]	valid_0's recall@200: 0.891225
[679]	valid_0's recall@200: 0.891297
[680]	valid_0's recall@200: 0.891297
[681]	valid_0's recall@200: 0.891153
[682]	valid_0's recall@200: 0.891153
[683]	valid_0's recall@200: 0.89101
[684]	valid_0's recall@200: 0.89101
[685]	valid_0's recall@200: 0.89101
[686]	valid_0's recall@200: 0.890723
[687]	valid_0's recall@200: 0.890723
[688]	valid_0's recall@200: 0.890867
[689]	vali

You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 14680
[LightGBM] [Info] Number of data points in the train set: 4416413, number of used features: 65
[1]	valid_0's recall@200: 0.8532
Training until validation scores don't improve for 200 rounds
[2]	valid_0's recall@200: 0.870308
[3]	valid_0's recall@200: 0.875036
[4]	valid_0's recall@200: 0.877531
[5]	valid_0's recall@200: 0.877627
[6]	valid_0's recall@200: 0.878152
[7]	valid_0's recall@200: 0.879513
[8]	valid_0's recall@200: 0.880396
[9]	valid_0's recall@200: 0.881913
[10]	valid_0's recall@200: 0.881996
[11]	valid_0's recall@200: 0.881901
[12]	valid_0's recall@200: 0.88202
[13]	valid_0's recall@200: 0.882593
[14]	valid_0's recall@200: 0.881948
[15]	valid_0's recall@200: 0.882736
[16]	valid_0's recall@200: 0.882593
[17]	valid_0's recall@200: 0.883883
[18]	valid_0's recall@200: 0.88473
[19]	valid_0's recall@200: 0.885351
[20]	valid_0's 

[215]	valid_0's recall@200: 0.888825
[216]	valid_0's recall@200: 0.888825
[217]	valid_0's recall@200: 0.88904
[218]	valid_0's recall@200: 0.88861
[219]	valid_0's recall@200: 0.888467
[220]	valid_0's recall@200: 0.888395
[221]	valid_0's recall@200: 0.888395
[222]	valid_0's recall@200: 0.888539
[223]	valid_0's recall@200: 0.888539
[224]	valid_0's recall@200: 0.888539
[225]	valid_0's recall@200: 0.888682
[226]	valid_0's recall@200: 0.888825
[227]	valid_0's recall@200: 0.888825
[228]	valid_0's recall@200: 0.88861
[229]	valid_0's recall@200: 0.888754
[230]	valid_0's recall@200: 0.888467
[231]	valid_0's recall@200: 0.888467
[232]	valid_0's recall@200: 0.888467
[233]	valid_0's recall@200: 0.888467
[234]	valid_0's recall@200: 0.888467
[235]	valid_0's recall@200: 0.888467
[236]	valid_0's recall@200: 0.888682
[237]	valid_0's recall@200: 0.888754
[238]	valid_0's recall@200: 0.888754
[239]	valid_0's recall@200: 0.888754
[240]	valid_0's recall@200: 0.888754
[241]	valid_0's recall@200: 0.88861
[242]

[26]	valid_0's recall@200: 0.885136
[27]	valid_0's recall@200: 0.884921
[28]	valid_0's recall@200: 0.884706
[29]	valid_0's recall@200: 0.884563
[30]	valid_0's recall@200: 0.884277
[31]	valid_0's recall@200: 0.884491
[32]	valid_0's recall@200: 0.885064
[33]	valid_0's recall@200: 0.885709
[34]	valid_0's recall@200: 0.885852
[35]	valid_0's recall@200: 0.886211
[36]	valid_0's recall@200: 0.886497
[37]	valid_0's recall@200: 0.885817
[38]	valid_0's recall@200: 0.886139
[39]	valid_0's recall@200: 0.885996
[40]	valid_0's recall@200: 0.88596
[41]	valid_0's recall@200: 0.886318
[42]	valid_0's recall@200: 0.885888
[43]	valid_0's recall@200: 0.88639
[44]	valid_0's recall@200: 0.886819
[45]	valid_0's recall@200: 0.886784
[46]	valid_0's recall@200: 0.886211
[47]	valid_0's recall@200: 0.886712
[48]	valid_0's recall@200: 0.886999
[49]	valid_0's recall@200: 0.887213
[50]	valid_0's recall@200: 0.887213
[51]	valid_0's recall@200: 0.886784
[52]	valid_0's recall@200: 0.886282
[53]	valid_0's recall@200: 0.8

[250]	valid_0's recall@200: 0.888336
[251]	valid_0's recall@200: 0.888407
[252]	valid_0's recall@200: 0.888407
[253]	valid_0's recall@200: 0.888372
[254]	valid_0's recall@200: 0.888372
[255]	valid_0's recall@200: 0.888336
[256]	valid_0's recall@200: 0.888372
[257]	valid_0's recall@200: 0.888192
[258]	valid_0's recall@200: 0.888479
[259]	valid_0's recall@200: 0.888479
[260]	valid_0's recall@200: 0.888479
[261]	valid_0's recall@200: 0.888479
[262]	valid_0's recall@200: 0.888622
[263]	valid_0's recall@200: 0.888622
[264]	valid_0's recall@200: 0.888622
[265]	valid_0's recall@200: 0.888479
[266]	valid_0's recall@200: 0.888479
[267]	valid_0's recall@200: 0.888479
[268]	valid_0's recall@200: 0.888479
[269]	valid_0's recall@200: 0.888192
[270]	valid_0's recall@200: 0.888336
[271]	valid_0's recall@200: 0.888479
[272]	valid_0's recall@200: 0.888479
[273]	valid_0's recall@200: 0.888551
[274]	valid_0's recall@200: 0.888551
[275]	valid_0's recall@200: 0.888837
[276]	valid_0's recall@200: 0.888837
[

[36]	valid_0's recall@200: 0.886127
[37]	valid_0's recall@200: 0.886485
[38]	valid_0's recall@200: 0.88627
[39]	valid_0's recall@200: 0.886557
[40]	valid_0's recall@200: 0.887202
[41]	valid_0's recall@200: 0.887058
[42]	valid_0's recall@200: 0.886915
[43]	valid_0's recall@200: 0.886987
[44]	valid_0's recall@200: 0.887416
[45]	valid_0's recall@200: 0.88756
[46]	valid_0's recall@200: 0.887775
[47]	valid_0's recall@200: 0.887345
[48]	valid_0's recall@200: 0.887488
[49]	valid_0's recall@200: 0.887202
[50]	valid_0's recall@200: 0.887918
[51]	valid_0's recall@200: 0.88756
[52]	valid_0's recall@200: 0.88756
[53]	valid_0's recall@200: 0.888276
[54]	valid_0's recall@200: 0.887989
[55]	valid_0's recall@200: 0.887989
[56]	valid_0's recall@200: 0.887918
[57]	valid_0's recall@200: 0.887775
[58]	valid_0's recall@200: 0.888061
[59]	valid_0's recall@200: 0.888061
[60]	valid_0's recall@200: 0.888204
[61]	valid_0's recall@200: 0.888061
[62]	valid_0's recall@200: 0.888348
[63]	valid_0's recall@200: 0.887

[260]	valid_0's recall@200: 0.88941
[261]	valid_0's recall@200: 0.889195
[262]	valid_0's recall@200: 0.889195
[263]	valid_0's recall@200: 0.889195
[264]	valid_0's recall@200: 0.888909
[265]	valid_0's recall@200: 0.88898
[266]	valid_0's recall@200: 0.888837
[267]	valid_0's recall@200: 0.889267
[268]	valid_0's recall@200: 0.889267
[269]	valid_0's recall@200: 0.889267
[270]	valid_0's recall@200: 0.889267
[271]	valid_0's recall@200: 0.889267
[272]	valid_0's recall@200: 0.889267
[273]	valid_0's recall@200: 0.88898
[274]	valid_0's recall@200: 0.889124
[275]	valid_0's recall@200: 0.889124
[276]	valid_0's recall@200: 0.88898
[277]	valid_0's recall@200: 0.888837
[278]	valid_0's recall@200: 0.888694
[279]	valid_0's recall@200: 0.888837
[280]	valid_0's recall@200: 0.888694
[281]	valid_0's recall@200: 0.888837
[282]	valid_0's recall@200: 0.888694
[283]	valid_0's recall@200: 0.888694
[284]	valid_0's recall@200: 0.888694
[285]	valid_0's recall@200: 0.888694
[286]	valid_0's recall@200: 0.888694
[287]

100%|██████████| 6980/6980 [01:51<00:00, 62.66it/s]


score_tie occurs 303652 times in 2244 queries
recall@10:0.5726957975167144
recall@20:0.6808142311365807
recall@50:0.783691499522445
recall@100:0.8446036294173831
recall@200:0.8894340974212034
recall@500:0.9333452722063037
recall@1000:0.9470869149952243


100%|██████████| 6980/6980 [00:45<00:00, 153.01it/s]

score_tie occurs 9 times in 7 queries 0.29788221449038066



