In [1]:
import argparse
import datetime
import glob
import hashlib
import json
import multiprocessing
import pickle
import os
import shutil
import subprocess
import uuid
import random

import numpy as np
import pandas as pd
import lightgbm as lgb
from collections import defaultdict
from lightgbm.sklearn import LGBMRanker
from tqdm import tqdm
import sys
sys.path.append('..')

from pyserini.analysis import Analyzer, get_lucene_analyzer
from pyserini.ltr import *
from pyserini.search import get_topics_with_reader

def train_data_loader(task='triple', neg_sample=10, cutoff=100, random_seed=12345):
    if task == 'triple' or task == 'rank':
        fn = f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle'
    elif task == 'rank_recall' or task == 'rank_hard':
        fn = f'train_{task}_sampled_with_{neg_sample}_{cutoff}_{random_seed}.pickle'
    else:
        raise Exception('unknown parameters')
    if os.path.exists(fn):
        sampled_train = pd.read_pickle(fn)
        print(sampled_train.shape)
        print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
        print(sampled_train.groupby('qid').count().mean())
        print(sampled_train.head(10))
        print(sampled_train.info())
        return sampled_train
    else:
        if task == 'triple':
            train = pd.read_csv('collections/msmarco-passage/qidpidtriples.train.full.2.tsv', sep="\t",
                                names=['qid', 'pos_pid', 'neg_pid'], dtype=np.int32)
            pos_half = train[['qid', 'pos_pid']].rename(columns={"pos_pid": "pid"}).drop_duplicates()
            pos_half['rel'] = np.int32(1)
            neg_half = train[['qid', 'neg_pid']].rename(columns={"neg_pid": "pid"}).drop_duplicates()
            neg_half['rel'] = np.int32(0)
            del train
            sampled_neg_half = []
            for qid, group in tqdm(neg_half.groupby('qid')):
                sampled_neg_half.append(group.sample(n=min(neg_sample, len(group)), random_state=random_seed))
            sampled_train = pd.concat([pos_half] + sampled_neg_half, axis=0, ignore_index=True)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle')
        elif task == 'rank':
            qrel = defaultdict(list)
            with open("collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle')
        elif task == 'rank_recall':
            qrel = defaultdict(list)
            with open("../collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("../runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        if int(rank) > cutoff:
                            qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{cutoff}_{random_seed}.pickle')
        elif task == 'rank_hard':
            qrel = defaultdict(list)
            with open("../collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("../runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        if int(rank) < cutoff:
                            qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{cutoff}_{random_seed}.pickle')
        else:
            raise Exception('unknown parameters')
        return sampled_train
sampled_train = train_data_loader(task='triple',neg_sample=20)

(8385888, 1)
(400782,)
rel    20.923814
dtype: float64
             rel
qid pid         
3   689698     0
    970816     0
    1142680    1
    1519440    0
    1887112    0
    2006462    0
    2019206    0
    2605131    0
    2679073    0
    2963098    0
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 8385888 entries, (3, 689698) to (1185869, 8626607)
Data columns (total 1 columns):
 #   Column  Dtype
---  ------  -----
 0   rel     int32
dtypes: int32(1)
memory usage: 303.9 MB
None


In [2]:
def dev_data_loader(task='pygaggle'):
    if os.path.exists(f'dev_{task}.pickle'):
        dev = pd.read_pickle(f'dev_{task}.pickle')
        print(dev.shape)
        print(dev.index.get_level_values('qid').drop_duplicates().shape)
        print(dev.groupby('qid').count().mean())
        print(dev.head(10))
        print(dev.info())
        dev_qrel = pd.read_pickle(f'dev_qrel.pickle')
        return dev, dev_qrel
    else:
        if task == 'rerank':
            dev = pd.read_csv('collections/msmarco-passage/top1000.dev', sep="\t",
                              names=['qid', 'pid', 'query', 'doc'], usecols=['qid', 'pid'], dtype=np.int32)
        elif task == 'anserini':
            dev = pd.read_csv('runs/msmarco-passage/run.msmarco-passage.dev.small.tsv',sep="\t",
                            names=['qid','pid','rank'], dtype=np.int32)
        elif task == 'pygaggle':
            dev = pd.read_csv('../pygaggle/data/msmarco_ans_entire/run.dev.small.tsv',sep="\t",
                            names=['qid','pid','rank'], dtype=np.int32)
        else:
            raise Exception('unknown parameters')
        dev_qrel = pd.read_csv('collections/msmarco-passage/qrels.dev.small.tsv', sep="\t",
                               names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'], dtype=np.int32)
        dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left')
        dev['rel'] = dev['rel'].fillna(0).astype(np.int32)
        dev = dev.sort_values(['qid','pid']).set_index(['qid','pid'])
        
        print(dev.shape)
        print(dev.index.get_level_values('qid').drop_duplicates().shape)
        print(dev.groupby('qid').count().mean())
        print(dev.head(10))
        print(dev.info())

        dev.to_pickle(f'dev_{task}.pickle')
        dev_qrel.to_pickle(f'dev_qrel.pickle')
        return dev, dev_qrel
dev, dev_qrel = dev_data_loader(task='pygaggle')

(6974598, 2)
(6980,)
rank    999.226074
rel     999.226074
dtype: float64
            rank  rel
qid pid              
2   55860    345    0
    72202    557    0
    72210    213    0
    98589    278    0
    98590    323    0
    98593    580    0
    98595    553    0
    112123   108    0
    112126   469    0
    112127    21    0
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 6974598 entries, (2, 55860) to (1102400, 8830447)
Data columns (total 2 columns):
 #   Column  Dtype
---  ------  -----
 0   rank    int32
 1   rel     int32
dtypes: int32(2)
memory usage: 282.7 MB
None


In [3]:
def query_loader():
    queries = {}
    with open('queries.train.small.entity.json') as f:
        for line in f:
            query = json.loads(line)
            qid = query.pop('id')
            query['analyzed'] = query['analyzed'].split(" ")
            query['text'] = query['text_unlemm'].split(" ")
            query['text_unlemm'] = query['text_unlemm'].split(" ")
            query['text_bert_tok'] = query['text_bert_tok'].split(" ")
            assert 'raw' in query
            assert 'entity' in query
            queries[qid] = query
    with open('queries.dev.small.entity.json') as f:
        for line in f:
            query = json.loads(line)
            qid = query.pop('id')
            query['analyzed'] = query['analyzed'].split(" ")
            query['text'] = query['text_unlemm'].split(" ")
            query['text_unlemm'] = query['text_unlemm'].split(" ")
            query['text_bert_tok'] = query['text_bert_tok'].split(" ")
            assert 'raw' in query
            assert 'entity' in query
            queries[qid] = query
    with open('queries.eval.small.entity.json') as f:
        for line in f:
            query = json.loads(line)
            qid = query.pop('id')
            query['analyzed'] = query['analyzed'].split(" ")
            query['text'] = query['text_unlemm'].split(" ")
            query['text_unlemm'] = query['text_unlemm'].split(" ")
            query['text_bert_tok'] = query['text_bert_tok'].split(" ")
            assert 'raw' in query
            assert 'entity' in query
            queries[qid] = query
    return queries
queries = query_loader()

In [4]:
fe = FeatureExtractor('../../anserini/indexes/msmarco-passage/lucene-index-msmarco-flex_ent',max(multiprocessing.cpu_count()//2,1))

for qfield, ifield in [('analyzed','contents'),
                       ('text','text'),
                       ('text_unlemm','text_unlemm'),
                       ('text_bert_tok','text_bert_tok')]:
    print(qfield, ifield)
    fe.add(BM25(k1=0.9,b=0.4, field=ifield, qfield=qfield))
    fe.add(BM25(k1=1.2,b=0.75, field=ifield, qfield=qfield))
    fe.add(BM25(k1=2.0,b=0.75, field=ifield, qfield=qfield))

    fe.add(LMDir(mu=1000, field=ifield, qfield=qfield))
    fe.add(LMDir(mu=1500, field=ifield, qfield=qfield))
    fe.add(LMDir(mu=2500, field=ifield, qfield=qfield))

    fe.add(LMJM(0.1, field=ifield, qfield=qfield))
    fe.add(LMJM(0.4, field=ifield, qfield=qfield))
    fe.add(LMJM(0.7, field=ifield, qfield=qfield))

    fe.add(NTFIDF(field=ifield, qfield=qfield))
    fe.add(ProbalitySum(field=ifield, qfield=qfield))

    fe.add(DFR_GL2(field=ifield, qfield=qfield))
    fe.add(DFR_In_expB2(field=ifield, qfield=qfield))
    fe.add(DPH(field=ifield, qfield=qfield))

    fe.add(Proximity(field=ifield, qfield=qfield))
    fe.add(TPscore(field=ifield, qfield=qfield))
    fe.add(tpDist(field=ifield, qfield=qfield))

    fe.add(DocSize(field=ifield))

    fe.add(QueryLength(qfield=qfield))
    fe.add(QueryCoverageRatio(qfield=qfield))
    fe.add(UniqueTermCount(qfield=qfield))
    fe.add(MatchingTermCount(field=ifield, qfield=qfield))
    fe.add(SCS(field=ifield, qfield=qfield))

    fe.add(tfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(tfIdfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(scqStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(normalizedTfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(idfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(ictfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(UnorderedSequentialPairs(3, field=ifield, qfield=qfield))
    fe.add(UnorderedSequentialPairs(8, field=ifield, qfield=qfield))
    fe.add(UnorderedSequentialPairs(15, field=ifield, qfield=qfield))
    fe.add(OrderedSequentialPairs(3, field=ifield, qfield=qfield))
    fe.add(OrderedSequentialPairs(8, field=ifield, qfield=qfield))
    fe.add(OrderedSequentialPairs(15, field=ifield, qfield=qfield))
    fe.add(UnorderedQueryPairs(3, field=ifield, qfield=qfield))
    fe.add(UnorderedQueryPairs(8, field=ifield, qfield=qfield))
    fe.add(UnorderedQueryPairs(15, field=ifield, qfield=qfield))
    fe.add(OrderedQueryPairs(3, field=ifield, qfield=qfield))
    fe.add(OrderedQueryPairs(8, field=ifield, qfield=qfield))
    fe.add(OrderedQueryPairs(15, field=ifield, qfield=qfield))

# fe.add(EntityHowLong())
# fe.add(EntityHowMany())
# fe.add(EntityHowMuch())
# fe.add(EntityWhen())
# fe.add(EntityWhere())
# fe.add(EntityWho())
# fe.add(EntityWhereMatch())
# fe.add(EntityWhoMatch())

# for ent_type in ['PERSON','NORP','FAC','ORG','GPE','LOC','PRODUCT','EVENT','WORK_OF_ART','LAW',
#                 'LANGUAGE','DATE','TIME','PERCENT','MONEY','QUANTITY','ORDINAL','CARDINAL']:
#     fe.add(EntityDocCount(ent_type))

# fe.add(QueryRegex("^[0-9.+_ ]*what.*$"))

fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/title_unlemm","text_unlemm","title_unlemm","text_unlemm"))
print('IBM model Loaded')
fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/url_unlemm","text_unlemm","url_unlemm","text_unlemm"))
print('IBM model Loaded')
fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/body","text_unlemm","body","text_unlemm"))
print('IBM model Loaded')
fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/text_bert_tok","text_bert_tok","text_bert_tok","text_bert_tok"))
print('IBM model Loaded')

analyzed contents
text text
text_unlemm text_unlemm
text_bert_tok text_bert_tok
IBM model Loaded
IBM model Loaded
IBM model Loaded
IBM model Loaded


In [5]:
def batch_extract(df, queries, fe):
    tasks = []
    task_infos = []
    group_lst = []

    info_dfs = []
    feature_dfs = []
    group_dfs = []

    for qid, group in tqdm(df.groupby('qid')):
        task = {
            "qid": str(qid),
            "docIds": [],
            "rels": [],
            "query_dict": queries[str(qid)]
        }
        for t in group.reset_index().itertuples():
            task["docIds"].append(str(t.pid))
            task_infos.append((qid, t.pid, t.rel))
        tasks.append(task)
        group_lst.append((qid, len(task['docIds'])))
        if len(tasks) == 10000:
            features = fe.batch_extract(tasks)
            task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel'])
            group = pd.DataFrame(group_lst, columns=['qid', 'count'])
            print(features.shape)
            print(task_infos.qid.drop_duplicates().shape)
            print(group.mean())
            print(features.head(10))
            print(features.info())
            info_dfs.append(task_infos)
            feature_dfs.append(features)
            group_dfs.append(group)
            tasks = []
            task_infos = []
            group_lst = []
    # deal with rest
    if len(tasks) > 0:
        features = fe.batch_extract(tasks)
        task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel'])
        group = pd.DataFrame(group_lst, columns=['qid', 'count'])
        print(features.shape)
        print(task_infos.qid.drop_duplicates().shape)
        print(group.mean())
        print(features.head(10))
        print(features.info())
        info_dfs.append(task_infos)
        feature_dfs.append(features)
        group_dfs.append(group)
    info_dfs = pd.concat(info_dfs, axis=0, ignore_index=True)
    feature_dfs = pd.concat(feature_dfs, axis=0, ignore_index=True, copy=False)
    group_dfs = pd.concat(group_dfs, axis=0, ignore_index=True)
    return info_dfs, feature_dfs, group_dfs

In [6]:
def hash_df(df):
    h = pd.util.hash_pandas_object(df)
    return hex(h.sum().astype(np.uint64))


def hash_anserini_jar():
    find = glob.glob(os.environ['ANSERINI_CLASSPATH'] + "/*fatjar.jar")
    assert len(find) == 1
    md5Hash = hashlib.md5(open(find[0], 'rb').read())
    return md5Hash.hexdigest()


def hash_fe(fe):
    return hashlib.md5(','.join(sorted(fe.feature_names())).encode()).hexdigest()


def data_loader(task, df, queries, fe):
    df_hash = hash_df(df)
    jar_hash = hash_anserini_jar()
    fe_hash = hash_fe(fe)
    if os.path.exists(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle'):
        res = pickle.load(open(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle', 'rb'))
        print(res['info'].shape)
        print(res['info'].qid.drop_duplicates().shape)
        print(res['group'].mean())
        return res
    else:
        if task == 'train' or task == 'dev':
            info, data, group = batch_extract(df, queries, fe)
            obj = {'info':info, 'data': data, 'group': group,
                   'df_hash': df_hash, 'jar_hash': jar_hash, 'fe_hash': fe_hash}
            print(info.shape)
            print(info.qid.drop_duplicates().shape)
            print(group.mean())
            pickle.dump(obj, open(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle', 'wb'))
            return obj
        else:
            raise Exception('unknown parameters')

In [7]:
import json
def export(df, fn, queries):
    with open(fn,'w') as f:
        line_num = 0
        for qid, group in tqdm(df.groupby('qid')):
            line = {}
            line['qid'] = qid
            line['docIds'] = [str(did) for did in group.reset_index().pid.drop_duplicates().tolist()]
            assert 'qid' not in queries[str(qid)]
            assert 'docIds' not in queries[str(qid)]
            line.update(queries[str(qid)])
            f.write(json.dumps(line)+'\n')
            line_num += 1

In [8]:
train_extracted = data_loader('train', sampled_train, queries, fe)
dev_extracted = data_loader('dev', dev, queries, fe)
del sampled_train, dev
feature_name = fe.feature_names()
del queries, fe

  3%|▎         | 10068/400782 [04:02<239:31:07,  2.21s/it]

(210283, 336)
(10000,)
qid      16070.1173
count       21.0283
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              16.065151   
1                              22.893564   
2                              26.467836   
3                              16.092194   
4                              17.071817   
5                              26.280588   
6                              19.257418   
7                              13.279017   
8                              13.294892   
9                              19.064749   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              16.743813   
1                              23.464577   
2                              27.822981   
3                              17.323364   
4                              18.976582   
5                              28.163824   
6                              21.266153   
7                              15.103254   
8                              13.668623 

  5%|▌         | 20081/400782 [04:38<27:27:55,  3.85it/s] 

(210507, 336)
(10000,)
qid      43586.6803
count       21.0507
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              12.922753   
1                              12.513888   
2                              12.434628   
3                              13.131429   
4                              14.483192   
5                              12.844487   
6                              20.265697   
7                              18.062141   
8                              17.035084   
9                              13.037727   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.139851   
1                              13.831271   
2                              13.694565   
3                              14.182406   
4                              15.425590   
5                              13.905419   
6                              21.310802   
7                              19.409094   
8                              17.027304 

  8%|▊         | 30206/400782 [05:04<2:32:47, 40.42it/s] 

(210463, 336)
(10000,)
qid      71878.3534
count       21.0463
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.622095   
1                              16.603477   
2                              11.843141   
3                              15.527351   
4                              27.521973   
5                              16.122030   
6                              11.660427   
7                              13.312406   
8                              13.046857   
9                              10.598083   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              15.136847   
1                              16.832077   
2                              13.368475   
3                              17.201197   
4                              30.654858   
5                              16.723196   
6                              12.876721   
7                              13.845608   
8                              14.057477 

 10%|█         | 40150/400782 [05:28<2:15:27, 44.37it/s]

(210376, 336)
(10000,)
qid      102505.6404
count        21.0376
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              11.739278   
1                              14.907814   
2                               9.844326   
3                              15.783961   
4                              10.652197   
5                              11.436790   
6                              12.812359   
7                              11.428194   
8                               8.527626   
9                              11.162281   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              13.057232   
1                              16.272871   
2                               9.578762   
3                              17.434904   
4                              11.800447   
5                              13.539322   
6                              12.816894   
7                              12.319444   
8                               9.55459

 13%|█▎        | 50162/400782 [05:50<1:46:21, 54.95it/s]

(210279, 336)
(10000,)
qid      131703.5651
count        21.0279
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.653768   
1                              10.598051   
2                              13.082002   
3                               8.600472   
4                               8.558200   
5                               9.122086   
6                               9.170490   
7                              10.252264   
8                              10.222827   
9                              10.830508   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              11.094537   
1                              10.902937   
2                              13.128061   
3                               9.005003   
4                               8.905561   
5                              10.455590   
6                              10.593972   
7                              10.143054   
8                              11.59557

 15%|█▌        | 60160/400782 [06:13<2:47:24, 33.91it/s]

(210697, 336)
(10000,)
qid      159810.2550
count        21.0697
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.145366   
1                              11.561275   
2                              10.872355   
3                               9.024568   
4                              10.203894   
5                              12.585082   
6                              10.808022   
7                              10.420988   
8                               9.939776   
9                              12.968530   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              11.109741   
1                              11.977177   
2                              12.508373   
3                               9.539717   
4                              10.570965   
5                              13.619138   
6                              11.593347   
7                              11.293334   
8                               9.97793

 17%|█▋        | 70047/400782 [06:35<3:26:12, 26.73it/s]

(210576, 336)
(10000,)
qid      188667.0197
count        21.0576
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              19.041761   
1                              15.286544   
2                              11.359656   
3                              14.190989   
4                              17.077909   
5                              16.837852   
6                              13.992632   
7                              21.122465   
8                              11.497820   
9                              13.849800   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              18.860228   
1                              15.182791   
2                               9.848079   
3                              15.245861   
4                              15.056050   
5                              18.394512   
6                              14.743046   
7                              20.633486   
8                              11.82902

 20%|█▉        | 80130/400782 [06:59<2:44:58, 32.39it/s]

(210526, 336)
(10000,)
qid      221530.6843
count        21.0526
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              18.649384   
1                              13.232711   
2                              11.592625   
3                              10.385224   
4                              11.001211   
5                              13.167721   
6                              13.436407   
7                              13.072525   
8                              12.666722   
9                              14.405597   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              19.288710   
1                              12.959960   
2                              12.440346   
3                              11.309126   
4                              13.023666   
5                              12.938331   
6                              14.274032   
7                              14.227904   
8                              13.23176

 23%|██▎       | 90192/400782 [07:23<2:02:21, 42.31it/s]

(210194, 336)
(10000,)
qid      255886.0298
count        21.0194
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              24.202383   
1                              24.236105   
2                              17.378906   
3                              14.533926   
4                              15.652447   
5                              21.532755   
6                              14.985152   
7                              13.223838   
8                              19.116377   
9                              22.196043   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              24.836401   
1                              24.911392   
2                              17.762716   
3                              14.197733   
4                              16.986910   
5                              21.040609   
6                              16.910011   
7                              14.577914   
8                              19.64689

 25%|██▍       | 100172/400782 [07:46<2:21:36, 35.38it/s]

(210245, 336)
(10000,)
qid      285440.4913
count        21.0245
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.001701   
1                              15.987255   
2                              10.649349   
3                              13.886163   
4                              18.710411   
5                              10.756891   
6                              14.259351   
7                              10.500708   
8                              14.140326   
9                              13.498837   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              11.193398   
1                              15.324629   
2                              11.677783   
3                              14.515985   
4                              19.792707   
5                              11.960898   
6                              14.281878   
7                              11.514482   
8                              14.62507

 27%|██▋       | 110203/400782 [08:09<1:37:02, 49.91it/s]

(210174, 336)
(10000,)
qid      316130.7138
count        21.0174
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.883745   
1                               8.000873   
2                              10.672473   
3                               8.748243   
4                              15.274649   
5                              11.317259   
6                              14.161619   
7                               8.878933   
8                              11.691252   
9                               8.574700   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              11.139491   
1                               9.344815   
2                              11.986936   
3                               9.517762   
4                              15.729774   
5                              11.453630   
6                              15.154417   
7                               9.858629   
8                              13.14597

 30%|██▉       | 120109/400782 [08:31<2:12:00, 35.43it/s]

(210233, 336)
(10000,)
qid      356953.6546
count        21.0233
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              14.345051   
1                               9.749559   
2                               9.385248   
3                              12.657712   
4                              16.171034   
5                              13.311099   
6                              10.640512   
7                              11.760098   
8                              12.367405   
9                              13.919641   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.593833   
1                              10.752101   
2                               9.842383   
3                              12.192801   
4                              16.341833   
5                              14.195478   
6                              10.975368   
7                              12.548599   
8                              13.11542

 32%|███▏      | 130089/400782 [08:53<1:57:22, 38.44it/s]

(210242, 336)
(10000,)
qid      402634.5494
count        21.0242
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.106102   
1                              13.769972   
2                              13.013794   
3                              17.426910   
4                              13.960839   
5                              13.888522   
6                              19.395971   
7                              11.882025   
8                              12.887510   
9                              11.438029   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.149033   
1                              15.369852   
2                              13.481947   
3                              18.021683   
4                              15.099513   
5                              15.047424   
6                              20.817280   
7                              13.103846   
8                              14.46426

 35%|███▍      | 140190/400782 [09:15<1:19:01, 54.96it/s]

(210258, 336)
(10000,)
qid      428279.2893
count        21.0258
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               7.181495   
1                              20.478367   
2                               7.579915   
3                              15.472593   
4                               7.572382   
5                               7.181495   
6                              15.045549   
7                              12.666134   
8                               8.018781   
9                              10.184351   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               8.080312   
1                              22.300184   
2                               8.255768   
3                              16.622730   
4                               8.236500   
5                               8.080312   
6                              16.738909   
7                              13.389154   
8                               9.19818

 37%|███▋      | 150097/400782 [09:37<1:54:14, 36.57it/s]

(210540, 336)
(10000,)
qid      455094.8035
count        21.0540
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               8.981194   
1                              15.790057   
2                              12.145004   
3                              16.347235   
4                              13.027354   
5                              14.213879   
6                              12.033302   
7                              12.924859   
8                              12.314436   
9                              10.736749   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               9.713717   
1                              15.643287   
2                              11.817376   
3                              17.211443   
4                              14.138248   
5                              14.817186   
6                              11.591328   
7                              13.532791   
8                              12.89219

 40%|███▉      | 160171/400782 [10:01<1:36:57, 41.36it/s]

(210357, 336)
(10000,)
qid      482473.7514
count        21.0357
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.945461   
1                              11.282845   
2                              16.034340   
3                              15.923967   
4                               9.718058   
5                              11.224901   
6                              12.355948   
7                              10.289945   
8                               9.277023   
9                              10.288551   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              12.249607   
1                              12.273409   
2                              17.179432   
3                              16.908325   
4                               9.919733   
5                              12.126390   
6                              12.734758   
7                              11.304367   
8                              10.01090

 42%|████▏     | 170145/400782 [10:23<1:40:21, 38.30it/s]

(210518, 336)
(10000,)
qid      511360.3605
count        21.0518
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               9.230773   
1                              13.469619   
2                              14.030810   
3                              29.180943   
4                               9.514769   
5                               8.032072   
6                               8.006005   
7                               8.751742   
8                               8.897099   
9                               8.084717   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              10.276155   
1                              14.922128   
2                              14.953142   
3                              32.590218   
4                              10.593243   
5                               8.938391   
6                               8.869010   
7                               9.677322   
8                               9.07824

 45%|████▍     | 180100/400782 [10:45<1:32:46, 39.64it/s]

(210685, 336)
(10000,)
qid      540155.3225
count        21.0685
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              19.126642   
1                              20.910910   
2                              19.980577   
3                              16.807257   
4                              20.280333   
5                              19.383894   
6                              17.569389   
7                              20.667038   
8                              19.244720   
9                              20.177389   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              20.438959   
1                              20.727804   
2                              19.712608   
3                              18.240446   
4                              20.108335   
5                              21.372740   
6                              16.992607   
7                              21.140038   
8                              20.72876

 47%|████▋     | 190075/400782 [11:07<2:08:28, 27.34it/s]

(209971, 336)
(10000,)
qid      567539.0918
count        20.9971
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               4.938383   
1                               5.938067   
2                               6.606776   
3                               4.474925   
4                               5.713381   
5                               4.756877   
6                               5.422701   
7                               8.763223   
8                               4.992814   
9                               5.455355   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               5.947136   
1                               6.669422   
2                               7.100610   
3                               4.685407   
4                               6.088757   
5                               5.410365   
6                               5.420225   
7                               9.639617   
8                               6.12063

 50%|████▉     | 200155/400782 [11:28<1:23:31, 40.04it/s]

(210628, 336)
(10000,)
qid      596183.7539
count        21.0628
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.769319   
1                              20.879616   
2                              17.364222   
3                              20.750620   
4                              20.951624   
5                              17.625082   
6                              16.453005   
7                              20.140434   
8                              18.469524   
9                              13.769319   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              15.117961   
1                              19.719410   
2                              18.759676   
3                              21.077116   
4                              22.746914   
5                              18.900932   
6                              18.232979   
7                              21.683268   
8                              20.35272

 52%|█████▏    | 210169/400782 [11:50<1:18:08, 40.66it/s]

(210131, 336)
(10000,)
qid      617184.4926
count        21.0131
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              18.334089   
1                              19.255672   
2                              18.471027   
3                              31.033031   
4                              18.779118   
5                              20.431904   
6                              20.961042   
7                              18.896551   
8                              18.529640   
9                              18.221893   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              19.268406   
1                              20.975502   
2                              19.765936   
3                              34.378223   
4                              20.457172   
5                              21.880999   
6                              23.101276   
7                              20.117466   
8                              19.91145

 55%|█████▍    | 220189/400782 [12:11<52:44, 57.07it/s]  

(207231, 336)
(10000,)
qid      643118.5746
count        20.7231
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               9.968495   
1                              14.568353   
2                               9.956192   
3                               9.971744   
4                              10.107975   
5                              13.302074   
6                               8.770848   
7                              16.772018   
8                              12.662716   
9                              13.739955   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              10.746173   
1                              15.367369   
2                              11.257224   
3                              11.300247   
4                              11.338397   
5                              13.933182   
6                              10.575800   
7                              18.425468   
8                              13.73342

 57%|█████▋    | 230156/400782 [12:33<1:11:14, 39.92it/s]

(208417, 336)
(10000,)
qid      670637.7320
count        20.8417
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.434587   
1                              12.609879   
2                              22.150017   
3                              12.037066   
4                              14.533061   
5                              12.844075   
6                              13.779434   
7                              13.605639   
8                              12.462814   
9                              12.287621   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              13.972656   
1                              12.866620   
2                              23.154692   
3                              12.826653   
4                              15.790530   
5                              14.034810   
6                              14.932918   
7                              14.499546   
8                              13.87234

 60%|█████▉    | 240156/400782 [12:54<1:03:53, 41.90it/s]

(205932, 336)
(10000,)
qid      698227.5791
count        20.5932
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               9.472915   
1                               9.495005   
2                               4.605760   
3                               4.605760   
4                               9.607020   
5                              10.975333   
6                              14.414341   
7                               9.600055   
8                              10.214158   
9                               5.162002   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              10.243314   
1                              10.299256   
2                               4.948124   
3                               4.948124   
4                              10.588392   
5                              11.950302   
6                              15.489265   
7                               9.923700   
8                              11.46171

 62%|██████▏   | 250052/400782 [13:14<1:19:51, 31.46it/s]

(204773, 336)
(10000,)
qid      725219.6245
count        20.4773
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              11.409188   
1                              10.973421   
2                              16.117134   
3                              11.409188   
4                              11.241779   
5                              11.880997   
6                              11.187063   
7                              11.523592   
8                              10.767786   
9                              10.667832   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              12.021054   
1                              11.025069   
2                              17.537308   
3                              12.021054   
4                              11.627163   
5                              13.214861   
6                              11.501540   
7                              12.298818   
8                              10.58650

 65%|██████▍   | 260170/400782 [13:34<50:24, 46.49it/s]  

(203566, 336)
(10000,)
qid      751686.7022
count        20.3566
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.053447   
1                              10.123779   
2                               9.547470   
3                              10.040483   
4                               9.961176   
5                              11.845825   
6                               9.547470   
7                              11.090663   
8                              10.789062   
9                              10.650106   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              10.804271   
1                              10.981289   
2                              10.825148   
3                              10.140624   
4                              10.576937   
5                              13.115676   
6                              10.825148   
7                              11.900352   
8                              11.94840

 67%|██████▋   | 270169/400782 [13:54<47:08, 46.18it/s]  

(203651, 336)
(10000,)
qid      778116.5249
count        20.3651
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              14.548978   
1                              13.713540   
2                              10.959221   
3                              10.295222   
4                              13.870667   
5                              11.286929   
6                              10.361267   
7                              10.488949   
8                              10.237761   
9                              10.579325   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.767600   
1                              14.016228   
2                              11.952472   
3                              11.461152   
4                              13.816588   
5                              12.850908   
6                              11.318102   
7                              11.652645   
8                              11.30859

 70%|██████▉   | 280130/400782 [14:16<51:34, 38.99it/s]  

(207031, 336)
(10000,)
qid      804282.9324
count        20.7031
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              16.228338   
1                              14.548123   
2                              11.452213   
3                              10.465665   
4                              10.540585   
5                              12.007640   
6                               7.784500   
7                               9.927494   
8                              10.659349   
9                              11.242373   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              15.346920   
1                              14.260353   
2                              10.957336   
3                              11.771371   
4                              11.613224   
5                              11.682191   
6                               7.891679   
7                              10.768379   
8                              11.92940

 72%|███████▏  | 290186/400782 [14:38<36:31, 50.46it/s]  

(210247, 336)
(10000,)
qid      831473.0054
count        21.0247
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               9.746423   
1                              12.129684   
2                               9.791244   
3                              39.540009   
4                              10.949347   
5                              12.817673   
6                              10.182705   
7                               8.863401   
8                               4.525095   
9                              12.254058   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               9.342265   
1                              13.031332   
2                               9.431617   
3                              42.485306   
4                              12.091469   
5                              13.904117   
6                              10.846827   
7                               9.840732   
8                               5.45631

 75%|███████▍  | 300078/400782 [15:00<55:43, 30.12it/s] 

(207190, 336)
(10000,)
qid      858783.2012
count        20.7190
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.692012   
1                               9.980751   
2                               9.936747   
3                              11.513998   
4                              10.637863   
5                              21.254005   
6                              11.904793   
7                              15.697970   
8                              18.530319   
9                              10.996590   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              11.643206   
1                               9.972322   
2                              10.766882   
3                              12.771692   
4                              11.505508   
5                              22.207335   
6                              13.168494   
7                              15.095241   
8                              18.99250

 77%|███████▋  | 310168/400782 [15:22<40:13, 37.54it/s] 

(210490, 336)
(10000,)
qid      887075.5948
count        21.0490
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               7.463603   
1                               7.568528   
2                               7.158256   
3                               7.252001   
4                               7.348234   
5                               7.509875   
6                              31.739603   
7                              37.374413   
8                               7.773911   
9                               7.379804   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               8.291652   
1                               8.576859   
2                               7.826912   
3                               8.073287   
4                               8.335677   
5                               8.416034   
6                              34.407017   
7                              40.854321   
8                               8.55987

 80%|███████▉  | 320105/400782 [15:44<36:37, 36.72it/s] 

(210380, 336)
(10000,)
qid      915881.9003
count        21.0380
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              19.506008   
1                              18.851322   
2                              18.201654   
3                              16.414865   
4                              15.719904   
5                              17.012140   
6                              18.586378   
7                              42.124058   
8                              20.304773   
9                              18.113211   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              21.194323   
1                              20.753080   
2                              20.039795   
3                              17.680841   
4                              16.864059   
5                              16.981924   
6                              19.735315   
7                              43.559387   
8                              21.28676

 82%|████████▏ | 330156/400782 [16:06<31:17, 37.63it/s] 

(210217, 336)
(10000,)
qid      945539.8186
count        21.0217
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               3.665089   
1                               3.532672   
2                               3.699759   
3                               2.787631   
4                               2.956191   
5                               0.000000   
6                               2.869438   
7                               9.715895   
8                              14.864882   
9                               2.986287   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               3.858847   
1                               3.555855   
2                               3.942838   
3                               2.785277   
4                               3.175936   
5                               0.000000   
6                               2.967806   
7                              10.620687   
8                              16.88688

 85%|████████▍ | 340083/400782 [16:28<35:12, 28.73it/s] 

(209050, 336)
(10000,)
qid      972541.7927
count        20.9050
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.126706   
1                              11.849751   
2                              11.551775   
3                               9.306126   
4                              18.314100   
5                               9.031683   
6                              10.016306   
7                              12.083521   
8                              12.855180   
9                               8.772963   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              12.414212   
1                              13.010316   
2                              12.260604   
3                              10.065152   
4                              17.631853   
5                               9.398263   
6                              12.062316   
7                              13.635662   
8                              13.94404

 87%|████████▋ | 350177/400782 [16:49<14:16, 59.07it/s] 

(208367, 336)
(10000,)
qid      993600.8354
count        20.8367
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               8.506632   
1                              10.557448   
2                              13.158268   
3                               9.931385   
4                              10.489802   
5                              12.949499   
6                              10.942862   
7                              14.016057   
8                              12.512739   
9                               8.595962   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               9.531067   
1                              11.591445   
2                              14.446989   
3                              10.210564   
4                              11.416915   
5                              13.915185   
6                              11.199997   
7                              13.981413   
8                              12.87270

 90%|████████▉ | 360123/400782 [17:11<18:48, 36.04it/s] 

(209599, 336)
(10000,)
qid      1.019772e+06
count    2.095990e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              15.001380   
1                              10.495945   
2                              15.928602   
3                              15.114327   
4                              10.712572   
5                              14.758161   
6                              11.918958   
7                              10.147246   
8                              19.369432   
9                              20.247404   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              16.169838   
1                              11.326258   
2                              17.105331   
3                              16.454340   
4                              12.031974   
5                              14.226324   
6                              12.809858   
7                              10.599897   
8                              21.323

 92%|█████████▏| 370168/400782 [17:32<12:21, 41.31it/s] 

(208685, 336)
(10000,)
qid      1.045876e+06
count    2.086850e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.601457   
1                              12.581927   
2                              16.790445   
3                               9.650691   
4                               9.628238   
5                              13.031651   
6                              28.393496   
7                              20.390751   
8                              10.043204   
9                               9.951025   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.710818   
1                              12.370104   
2                              18.284174   
3                              10.468129   
4                              10.411269   
5                              13.166841   
6                              32.050430   
7                              21.621803   
8                              11.239

 95%|█████████▍| 380202/400782 [17:54<06:19, 54.26it/s] 

(209872, 336)
(10000,)
qid      1.093821e+06
count    2.098720e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               7.467671   
1                              18.373993   
2                               8.081828   
3                               8.133282   
4                               7.980849   
5                              11.186552   
6                              10.284881   
7                               6.778780   
8                               8.264829   
9                               8.211703   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               7.267464   
1                              18.890501   
2                               8.648403   
3                               8.776690   
4                               8.402762   
5                              12.011958   
6                               9.997471   
7                               7.831010   
8                               9.114

 97%|█████████▋| 390208/400782 [18:16<03:01, 58.19it/s] 

(206766, 336)
(10000,)
qid      1.154245e+06
count    2.067660e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              12.598808   
1                              11.535626   
2                              10.715259   
3                              10.851265   
4                              13.531940   
5                              16.659880   
6                              29.143301   
7                              19.719337   
8                              14.089610   
9                              10.542268   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              12.880904   
1                              11.588848   
2                              11.326918   
3                              11.660480   
4                              14.735783   
5                              17.610874   
6                              28.546869   
7                              22.065397   
8                              16.541

100%|█████████▉| 400074/400782 [18:37<00:25, 27.63it/s] 

(210080, 336)
(10000,)
qid      1.174363e+06
count    2.100800e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              33.433941   
1                              14.954256   
2                              16.924046   
3                              16.246906   
4                              16.311111   
5                              17.528164   
6                              17.285612   
7                              14.378944   
8                              15.746342   
9                              21.426487   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              34.950954   
1                              15.878670   
2                              18.190603   
3                              17.737755   
4                              17.902990   
5                              18.603712   
6                              19.036957   
7                              15.593912   
8                              14.772

100%|██████████| 400782/400782 [18:38<00:00, 358.42it/s]


(16461, 336)
(782,)
qid      1.185035e+06
count    2.104987e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.269936   
1                              10.223487   
2                              11.566961   
3                              11.202425   
4                               9.792343   
5                              10.503739   
6                              11.671648   
7                              16.880190   
8                              10.664545   
9                              15.841070   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              13.485291   
1                              11.868622   
2                              12.111023   
3                              12.574932   
4                              10.688994   
5                              11.334428   
6                              11.046934   
7                              15.679298   
8                              11.747111

(8385888, 3)
(400782,)
qid      594038.702983
count        20.923814
dtype: float64


100%|██████████| 6980/6980 [00:18<00:00, 386.90it/s]


(6974598, 336)
(6980,)
qid      741638.817908
count       999.226074
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              12.348820   
1                              10.927653   
2                              13.675473   
3                              12.699286   
4                              12.492470   
5                              11.077914   
6                              11.181725   
7                              15.955744   
8                              11.468307   
9                              21.200821   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              13.151269   
1                              11.041017   
2                              13.716912   
3                              14.042357   
4                              13.508439   
5                              11.527296   
6                              11.772854   
7                              17.375216   
8                              12.4

(6974598, 3)
(6980,)
qid      741638.817908
count       999.226074
dtype: float64


In [9]:
def eval_mrr(dev_data):
    score_tie_counter = 0
    score_tie_query = set()

    MRR = []
    for qid, group in tqdm(dev_data.groupby('qid')):
        group = group.reset_index()
        rank = 0
        prev_score = None
        assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
        # stable sort is also used in LightGBM

        for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
            if prev_score is not None and abs(t.score - prev_score) < 1e-8:
                score_tie_counter += 1
                score_tie_query.add(qid)
            prev_score = t.score
            prev_pid = t.pid
            rank += 1
            if t.rel>0:
                MRR.append(1.0/rank)
                break
            elif rank == 10 or rank == len(group):
                MRR.append(0.)
                break

    score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
    print(score_tie,np.mean(MRR))


In [10]:
def eval_recall(dev_qrel, dev_data):
    dev_rel_num = dev_qrel[dev_qrel['rel']>0].groupby('qid').count()['rel']

    score_tie_counter = 0
    score_tie_query = set()
    
    recall_point = [10,20,50,100,200,500,1000]
    recall_curve = {k:[] for k in recall_point}
    for qid, group in tqdm(dev_data.groupby('qid')):
        group = group.reset_index()
        rank = 0
        prev_score = None
        assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
        # stable sort is also used in LightGBM
        total_rel = dev_rel_num.loc[qid]
        query_recall = [0 for k in recall_point]
        for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
            if prev_score is not None and abs(t.score - prev_score) < 1e-8:
                score_tie_counter += 1
                score_tie_query.add(qid)
            prev_score = t.score
            rank += 1
            if t.rel>0:
                for i,p in enumerate(recall_point):
                    if rank <= p:
                        query_recall[i] += 1
        for i,p in enumerate(recall_point):
            if total_rel>0:
                recall_curve[p].append(query_recall[i]/total_rel)
            else:
                recall_curve[p].append(0.)

    score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
    print(score_tie)
    
    for k,v in recall_curve.items():
        avg = np.mean(v)
        print(f'recall@{k}:{avg}')


In [11]:
for i,n in enumerate(feature_name):
    if np.isnan(train_extracted['data'].iloc[:,i]).any():
        print(n)
        print(train_extracted['info'].loc[train_extracted['data'].iloc[:,i].isna(),
                                          ['qid','pid']].head(10))

text_text_LMJM_lambda_0.10
            qid      pid
17745      3671  6416650
164965    24130  4209198
532270    72865  5018776
551592    75806  5323609
578451    79357  8026198
693707    95916  4705000
1081291  149467  2267954
1307485  180247  2872095
1320859  181991  7165701
1357277  187376  7229004
text_text_LMJM_lambda_0.40
            qid      pid
17745      3671  6416650
164965    24130  4209198
532270    72865  5018776
551592    75806  5323609
578451    79357  8026198
693707    95916  4705000
1081291  149467  2267954
1307485  180247  2872095
1320859  181991  7165701
1357277  187376  7229004
text_text_LMJM_lambda_0.70
            qid      pid
17745      3671  6416650
164965    24130  4209198
532270    72865  5018776
551592    75806  5323609
578451    79357  8026198
693707    95916  4705000
1081291  149467  2267954
1307485  180247  2872095
1320859  181991  7165701
1357277  187376  7229004
text_text_Prob
            qid      pid
17745      3671  6416650
164965    24130  4209198
5322

In [12]:
def gen_dev_group_rel_num(dev_qrel, dev_extracted):
    dev_rel_num = dev_qrel[dev_qrel['rel']>0].groupby('qid').count()['rel']
    prev_qid = None
    dev_rel_num_list = []
    for t in dev_extracted['info'].itertuples():
        if prev_qid is None or t.qid != prev_qid:
            prev_qid = t.qid
            dev_rel_num_list.append(dev_rel_num.loc[t.qid])
        else:
            continue
    assert len(dev_rel_num_list) == dev_qrel.qid.drop_duplicates().shape[0]

    def recall_at_200(preds, dataset):
        labels = dataset.get_label()
        groups = dataset.get_group()
        idx = 0
        recall = 0
        assert len(dev_rel_num_list) == len(groups)
        for g, gnum in zip(groups, dev_rel_num_list):
            top_preds = labels[idx:idx + g][np.argsort(preds[idx:idx + g])]
            recall += np.sum(top_preds[-200:]) / gnum
            idx += g
        assert idx == len(preds)
        return 'recall@200', recall / len(groups), True

    return recall_at_200
eval_fn = gen_dev_group_rel_num(dev_qrel, dev_extracted)

In [13]:
lgb_train = lgb.Dataset(train_extracted['data'].loc[:, feature_name],
                            label=train_extracted['info']['rel'],
                            group=train_extracted['group']['count'])
lgb_valid = lgb.Dataset(dev_extracted['data'].loc[:, feature_name],
                        label=dev_extracted['info']['rel'],
                        group=dev_extracted['group']['count'],
                        free_raw_data=False)

In [14]:
import re
filtered_feature_name = list(map(lambda x:re.sub('[^A-Za-z0-9_]+', '', x), feature_name))

In [18]:
dev_extracted['info']['score'] = 0.
for seed in [12345]:
    params = {
            'boosting_type': 'goss',
            'objective': 'lambdarank',
            'max_bin':255,
            'num_leaves':100,
            'max_depth':-1,
            'min_data_in_leaf':50,
            'min_sum_hessian_in_leaf':0,
#             'bagging_fraction':0.8,
#             'bagging_freq':50,
#             'feature_fraction':0.5,
            'learning_rate':0.1,
            'num_boost_round':1000,
            'early_stopping_round':200,
            'metric':'custom',
            'label_gain':[0,1],
            'lambdarank_truncation_level':20,
            'lambdarank_norm':True,
            'seed':seed,
            'num_threads':max(multiprocessing.cpu_count()//2,1)
    }
    num_boost_round = params.pop('num_boost_round')
    early_stopping_round = params.pop('early_stopping_round')
    gbm = lgb.train(params, lgb_train,
                    valid_sets=lgb_valid,
                    num_boost_round=num_boost_round,
                    early_stopping_rounds=early_stopping_round,
                    feval=eval_fn,
                    feature_name=filtered_feature_name,
                    verbose_eval=True)

    dev_extracted['info']['score'] = gbm.predict(lgb_valid.get_data())
    best_score = gbm.best_score['valid_0']['recall@200']
    print(best_score)
    best_iteration = gbm.best_iteration
    print(best_iteration)
    feature_importances = sorted(list(zip(feature_name,gbm.feature_importance().tolist())),
                                 key=lambda x:x[1],reverse=True)
    print(feature_importances)
eval_recall(dev_qrel, dev_extracted['info'])
eval_mrr(dev_extracted['info'])

You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 68509
[LightGBM] [Info] Number of data points in the train set: 8385888, number of used features: 336
[LightGBM] [Info] Using GOSS
[1]	valid_0's recall@200: 0.737954
Training until validation scores don't improve for 200 rounds
[2]	valid_0's recall@200: 0.741965
[3]	valid_0's recall@200: 0.746036
[4]	valid_0's recall@200: 0.746335
[5]	valid_0's recall@200: 0.750215
[6]	valid_0's recall@200: 0.754465
[7]	valid_0's recall@200: 0.757426
[8]	valid_0's recall@200: 0.758309
[9]	valid_0's recall@200: 0.762369
[10]	valid_0's recall@200: 0.764959
[11]	valid_0's recall@200: 0.768995
[12]	valid_0's recall@200: 0.770738
[13]	valid_0's recall@200: 0.772385
[14]	valid_0's recall@200: 0.772194
[15]	valid_0's recall@200: 0.773914
[16]	valid_0's recall@200: 0.774558
[17]	valid_0's recall@200: 0.775609
[18]	valid_0's recall@200: 0.775967
[19]	valid_0's recall@200: 0.77666
[20]	valid_0's recall@200: 0.777161
[21]	valid

[216]	valid_0's recall@200: 0.803044
[217]	valid_0's recall@200: 0.803188
[218]	valid_0's recall@200: 0.803116
[219]	valid_0's recall@200: 0.802615
[220]	valid_0's recall@200: 0.802615
[221]	valid_0's recall@200: 0.802615
[222]	valid_0's recall@200: 0.802471
[223]	valid_0's recall@200: 0.802328
[224]	valid_0's recall@200: 0.802543
[225]	valid_0's recall@200: 0.802543
[226]	valid_0's recall@200: 0.80283
[227]	valid_0's recall@200: 0.802901
[228]	valid_0's recall@200: 0.803044
[229]	valid_0's recall@200: 0.803044
[230]	valid_0's recall@200: 0.803044
[231]	valid_0's recall@200: 0.802901
[232]	valid_0's recall@200: 0.802901
[233]	valid_0's recall@200: 0.802901
[234]	valid_0's recall@200: 0.802901
[235]	valid_0's recall@200: 0.80283
[236]	valid_0's recall@200: 0.802973
[237]	valid_0's recall@200: 0.80283
[238]	valid_0's recall@200: 0.802973
[239]	valid_0's recall@200: 0.802973
[240]	valid_0's recall@200: 0.803116
[241]	valid_0's recall@200: 0.803259
[242]	valid_0's recall@200: 0.803331
[243

[438]	valid_0's recall@200: 0.80585
[439]	valid_0's recall@200: 0.80585
[440]	valid_0's recall@200: 0.80585
[441]	valid_0's recall@200: 0.80585
[442]	valid_0's recall@200: 0.806137
[443]	valid_0's recall@200: 0.806137
[444]	valid_0's recall@200: 0.806137
[445]	valid_0's recall@200: 0.805993
[446]	valid_0's recall@200: 0.80628
[447]	valid_0's recall@200: 0.80628
[448]	valid_0's recall@200: 0.806423
[449]	valid_0's recall@200: 0.80628
[450]	valid_0's recall@200: 0.806137
[451]	valid_0's recall@200: 0.80628
[452]	valid_0's recall@200: 0.80628
[453]	valid_0's recall@200: 0.806137
[454]	valid_0's recall@200: 0.805993
[455]	valid_0's recall@200: 0.806137
[456]	valid_0's recall@200: 0.806065
[457]	valid_0's recall@200: 0.806065
[458]	valid_0's recall@200: 0.806065
[459]	valid_0's recall@200: 0.805778
[460]	valid_0's recall@200: 0.805778
[461]	valid_0's recall@200: 0.805922
[462]	valid_0's recall@200: 0.806065
[463]	valid_0's recall@200: 0.805922
[464]	valid_0's recall@200: 0.806065
[465]	vali

[661]	valid_0's recall@200: 0.805635
[662]	valid_0's recall@200: 0.805635
[663]	valid_0's recall@200: 0.805635
[664]	valid_0's recall@200: 0.805778
[665]	valid_0's recall@200: 0.805778
[666]	valid_0's recall@200: 0.805778
[667]	valid_0's recall@200: 0.805778
[668]	valid_0's recall@200: 0.80585
[669]	valid_0's recall@200: 0.80585
[670]	valid_0's recall@200: 0.80585
[671]	valid_0's recall@200: 0.80585
[672]	valid_0's recall@200: 0.80585
[673]	valid_0's recall@200: 0.80585
[674]	valid_0's recall@200: 0.805707
[675]	valid_0's recall@200: 0.80585
[676]	valid_0's recall@200: 0.80585
[677]	valid_0's recall@200: 0.80585
[678]	valid_0's recall@200: 0.80585
[679]	valid_0's recall@200: 0.805993
[680]	valid_0's recall@200: 0.805993
[681]	valid_0's recall@200: 0.805993
[682]	valid_0's recall@200: 0.806137
[683]	valid_0's recall@200: 0.80585
[684]	valid_0's recall@200: 0.80585
[685]	valid_0's recall@200: 0.805993
[686]	valid_0's recall@200: 0.805993
[687]	valid_0's recall@200: 0.806137
[688]	valid_0

100%|██████████| 6980/6980 [00:18<00:00, 377.22it/s]


score_tie occurs 180865 times in 6780 queries
recall@10:0.48933858643744027
recall@20:0.5828319006685769
recall@50:0.6869866284622732
recall@100:0.7529727793696275
recall@200:0.8067812798471825
recall@500:0.8464302769818529
recall@1000:0.8573424068767909


100%|██████████| 6980/6980 [00:10<00:00, 647.51it/s]

score_tie occurs 210 times in 176 queries 0.2501120548505935



