In [1]:
import argparse
import datetime
import glob
import hashlib
import json
import multiprocessing
import pickle
import os
import shutil
import subprocess
import uuid
import random

import numpy as np
import pandas as pd
import lightgbm as lgb
from collections import defaultdict
from lightgbm.sklearn import LGBMRanker
from tqdm import tqdm
import sys
sys.path.append('..')

from pyserini.analysis import Analyzer, get_lucene_analyzer
from pyserini.ltr import *
from pyserini.search import get_topics_with_reader

def train_data_loader(task='triple', neg_sample=10, cutoff=100, random_seed=12345):
    if task == 'triple' or task == 'rank':
        fn = f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle'
    elif task == 'rank_recall' or task == 'rank_hard':
        fn = f'train_{task}_sampled_with_{neg_sample}_{cutoff}_{random_seed}.pickle'
    else:
        raise Exception('unknown parameters')
    if os.path.exists(fn):
        sampled_train = pd.read_pickle(fn)
        print(sampled_train.shape)
        print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
        print(sampled_train.groupby('qid').count().mean())
        print(sampled_train.head(10))
        print(sampled_train.info())
        return sampled_train
    else:
        if task == 'triple':
            train = pd.read_csv('collections/msmarco-passage/qidpidtriples.train.full.2.tsv', sep="\t",
                                names=['qid', 'pos_pid', 'neg_pid'], dtype=np.int32)
            pos_half = train[['qid', 'pos_pid']].rename(columns={"pos_pid": "pid"}).drop_duplicates()
            pos_half['rel'] = np.int32(1)
            neg_half = train[['qid', 'neg_pid']].rename(columns={"neg_pid": "pid"}).drop_duplicates()
            neg_half['rel'] = np.int32(0)
            del train
            sampled_neg_half = []
            for qid, group in tqdm(neg_half.groupby('qid')):
                sampled_neg_half.append(group.sample(n=min(neg_sample, len(group)), random_state=random_seed))
            sampled_train = pd.concat([pos_half] + sampled_neg_half, axis=0, ignore_index=True)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle')
        elif task == 'rank':
            qrel = defaultdict(list)
            with open("collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle')
        elif task == 'rank_recall':
            qrel = defaultdict(list)
            with open("../collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("../runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        if int(rank) > cutoff:
                            qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{cutoff}_{random_seed}.pickle')
        elif task == 'rank_hard':
            qrel = defaultdict(list)
            with open("../collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("../runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        if int(rank) < cutoff:
                            qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{cutoff}_{random_seed}.pickle')
        else:
            raise Exception('unknown parameters')
        return sampled_train
sampled_train = train_data_loader(task='triple',neg_sample=10)

(4416413, 1)
(400782,)
rel    11.019489
dtype: float64
             rel
qid pid         
3   970816     0
    1142680    1
    2019206    0
    2605131    0
    2963098    0
    2971685    0
    3783924    0
    5067083    0
    5904778    0
    6176208    0
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 4416413 entries, (3, 970816) to (1185869, 7770561)
Data columns (total 1 columns):
 #   Column  Dtype
---  ------  -----
 0   rel     int32
dtypes: int32(1)
memory usage: 166.9 MB
None


In [2]:
def dev_data_loader(task='pygaggle'):
    if os.path.exists(f'dev_{task}.pickle'):
        dev = pd.read_pickle(f'dev_{task}.pickle')
        print(dev.shape)
        print(dev.index.get_level_values('qid').drop_duplicates().shape)
        print(dev.groupby('qid').count().mean())
        print(dev.head(10))
        print(dev.info())
        dev_qrel = pd.read_pickle(f'dev_qrel.pickle')
        return dev, dev_qrel
    else:
        if task == 'rerank':
            dev = pd.read_csv('collections/msmarco-passage/top1000.dev', sep="\t",
                              names=['qid', 'pid', 'query', 'doc'], usecols=['qid', 'pid'], dtype=np.int32)
        elif task == 'anserini':
            dev = pd.read_csv('runs/msmarco-passage/run.msmarco-passage.dev.small.tsv',sep="\t",
                            names=['qid','pid','rank'], dtype=np.int32)
        elif task == 'pygaggle':
            dev = pd.read_csv('../pygaggle/data/msmarco_ans_entire/run.dev.small.tsv',sep="\t",
                            names=['qid','pid','rank'], dtype=np.int32)
        else:
            raise Exception('unknown parameters')
        dev_qrel = pd.read_csv('collections/msmarco-passage/qrels.dev.small.tsv', sep="\t",
                               names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'], dtype=np.int32)
        dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left')
        dev['rel'] = dev['rel'].fillna(0).astype(np.int32)
        dev = dev.sort_values(['qid','pid']).set_index(['qid','pid'])
        
        print(dev.shape)
        print(dev.index.get_level_values('qid').drop_duplicates().shape)
        print(dev.groupby('qid').count().mean())
        print(dev.head(10))
        print(dev.info())

        dev.to_pickle(f'dev_{task}.pickle')
        dev_qrel.to_pickle(f'dev_qrel.pickle')
        return dev, dev_qrel
dev, dev_qrel = dev_data_loader(task='pygaggle')

(6974598, 2)
(6980,)
rank    999.226074
rel     999.226074
dtype: float64
            rank  rel
qid pid              
2   55860    345    0
    72202    557    0
    72210    213    0
    98589    278    0
    98590    323    0
    98593    580    0
    98595    553    0
    112123   108    0
    112126   469    0
    112127    21    0
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 6974598 entries, (2, 55860) to (1102400, 8830447)
Data columns (total 2 columns):
 #   Column  Dtype
---  ------  -----
 0   rank    int32
 1   rel     int32
dtypes: int32(2)
memory usage: 282.7 MB
None


In [3]:
def query_loader():
    queries = {}
    with open('queries.train.small.entity.json') as f:
        for line in f:
            query = json.loads(line)
            qid = query.pop('id')
            query['analyzed'] = query['analyzed'].split(" ")
            query['text'] = query['text_unlemm'].split(" ")
            query['text_unlemm'] = query['text_unlemm'].split(" ")
            query['text_bert_tok'] = query['text_bert_tok'].split(" ")
            assert 'raw' in query
            assert 'entity' in query
            queries[qid] = query
    with open('queries.dev.small.entity.json') as f:
        for line in f:
            query = json.loads(line)
            qid = query.pop('id')
            query['analyzed'] = query['analyzed'].split(" ")
            query['text'] = query['text_unlemm'].split(" ")
            query['text_unlemm'] = query['text_unlemm'].split(" ")
            query['text_bert_tok'] = query['text_bert_tok'].split(" ")
            assert 'raw' in query
            assert 'entity' in query
            queries[qid] = query
    with open('queries.eval.small.entity.json') as f:
        for line in f:
            query = json.loads(line)
            qid = query.pop('id')
            query['analyzed'] = query['analyzed'].split(" ")
            query['text'] = query['text_unlemm'].split(" ")
            query['text_unlemm'] = query['text_unlemm'].split(" ")
            query['text_bert_tok'] = query['text_bert_tok'].split(" ")
            assert 'raw' in query
            assert 'entity' in query
            queries[qid] = query
    return queries
queries = query_loader()

In [4]:
fe = FeatureExtractor('../../anserini/indexes/msmarco-passage/lucene-index-msmarco-flex_ent',max(multiprocessing.cpu_count()//2,1))

for qfield, ifield in [('analyzed','contents'),
                       ('text','text'),
                       ('text_unlemm','text_unlemm'),
                       ('text_bert_tok','text_bert_tok')]:
    print(qfield, ifield)
    fe.add(BM25(k1=0.9,b=0.4, field=ifield, qfield=qfield))
    fe.add(BM25(k1=1.2,b=0.75, field=ifield, qfield=qfield))
    fe.add(BM25(k1=2.0,b=0.75, field=ifield, qfield=qfield))

    fe.add(LMDir(mu=1000, field=ifield, qfield=qfield))
    fe.add(LMDir(mu=1500, field=ifield, qfield=qfield))
    fe.add(LMDir(mu=2500, field=ifield, qfield=qfield))

    fe.add(LMJM(0.1, field=ifield, qfield=qfield))
    fe.add(LMJM(0.4, field=ifield, qfield=qfield))
    fe.add(LMJM(0.7, field=ifield, qfield=qfield))

    fe.add(NTFIDF(field=ifield, qfield=qfield))
    fe.add(ProbalitySum(field=ifield, qfield=qfield))

    fe.add(DFR_GL2(field=ifield, qfield=qfield))
    fe.add(DFR_In_expB2(field=ifield, qfield=qfield))
    fe.add(DPH(field=ifield, qfield=qfield))

    fe.add(Proximity(field=ifield, qfield=qfield))
    fe.add(TPscore(field=ifield, qfield=qfield))
    fe.add(tpDist(field=ifield, qfield=qfield))

    fe.add(DocSize(field=ifield))

    fe.add(QueryLength(qfield=qfield))
    fe.add(QueryCoverageRatio(qfield=qfield))
    fe.add(UniqueTermCount(qfield=qfield))
    fe.add(MatchingTermCount(field=ifield, qfield=qfield))
    fe.add(SCS(field=ifield, qfield=qfield))

    fe.add(tfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(tfIdfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(scqStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(normalizedTfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(idfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(ictfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(UnorderedSequentialPairs(3, field=ifield, qfield=qfield))
    fe.add(UnorderedSequentialPairs(8, field=ifield, qfield=qfield))
    fe.add(UnorderedSequentialPairs(15, field=ifield, qfield=qfield))
    fe.add(OrderedSequentialPairs(3, field=ifield, qfield=qfield))
    fe.add(OrderedSequentialPairs(8, field=ifield, qfield=qfield))
    fe.add(OrderedSequentialPairs(15, field=ifield, qfield=qfield))
    fe.add(UnorderedQueryPairs(3, field=ifield, qfield=qfield))
    fe.add(UnorderedQueryPairs(8, field=ifield, qfield=qfield))
    fe.add(UnorderedQueryPairs(15, field=ifield, qfield=qfield))
    fe.add(OrderedQueryPairs(3, field=ifield, qfield=qfield))
    fe.add(OrderedQueryPairs(8, field=ifield, qfield=qfield))
    fe.add(OrderedQueryPairs(15, field=ifield, qfield=qfield))

fe.add(EntityHowLong())
fe.add(EntityHowMany())
fe.add(EntityHowMuch())
fe.add(EntityWhen())
fe.add(EntityWhere())
fe.add(EntityWho())
fe.add(EntityWhereMatch())
fe.add(EntityWhoMatch())

fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/title_unlemm","text_unlemm","title_unlemm","text_unlemm"))
print('IBM model Loaded')
fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/url_unlemm","text_unlemm","url_unlemm","text_unlemm"))
print('IBM model Loaded')
fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/body","text_unlemm","body","text_unlemm"))
print('IBM model Loaded')
fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/text_bert_tok","text_bert_tok","text_bert_tok","text_bert_tok"))
print('IBM model Loaded')

analyzed contents
text text
text_unlemm text_unlemm
text_bert_tok text_bert_tok
IBM model Loaded
IBM model Loaded
IBM model Loaded
IBM model Loaded


In [5]:
def batch_extract(df, queries, fe):
    tasks = []
    task_infos = []
    group_lst = []

    info_dfs = []
    feature_dfs = []
    group_dfs = []

    for qid, group in tqdm(df.groupby('qid')):
        task = {
            "qid": str(qid),
            "docIds": [],
            "rels": [],
            "query_dict": queries[str(qid)]
        }
        for t in group.reset_index().itertuples():
            task["docIds"].append(str(t.pid))
            task_infos.append((qid, t.pid, t.rel))
        tasks.append(task)
        group_lst.append((qid, len(task['docIds'])))
        if len(tasks) == 10000:
            features = fe.batch_extract(tasks)
            task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel'])
            group = pd.DataFrame(group_lst, columns=['qid', 'count'])
            print(features.shape)
            print(task_infos.qid.drop_duplicates().shape)
            print(group.mean())
            print(features.head(10))
            print(features.info())
            info_dfs.append(task_infos)
            feature_dfs.append(features)
            group_dfs.append(group)
            tasks = []
            task_infos = []
            group_lst = []
    # deal with rest
    if len(tasks) > 0:
        features = fe.batch_extract(tasks)
        task_infos = pd.DataFrame(task_infos, columns=['qid', 'pid', 'rel'])
        group = pd.DataFrame(group_lst, columns=['qid', 'count'])
        print(features.shape)
        print(task_infos.qid.drop_duplicates().shape)
        print(group.mean())
        print(features.head(10))
        print(features.info())
        info_dfs.append(task_infos)
        feature_dfs.append(features)
        group_dfs.append(group)
    info_dfs = pd.concat(info_dfs, axis=0, ignore_index=True)
    feature_dfs = pd.concat(feature_dfs, axis=0, ignore_index=True, copy=False)
    group_dfs = pd.concat(group_dfs, axis=0, ignore_index=True)
    return info_dfs, feature_dfs, group_dfs

In [6]:
def hash_df(df):
    h = pd.util.hash_pandas_object(df)
    return hex(h.sum().astype(np.uint64))


def hash_anserini_jar():
    find = glob.glob(os.environ['ANSERINI_CLASSPATH'] + "/*fatjar.jar")
    assert len(find) == 1
    md5Hash = hashlib.md5(open(find[0], 'rb').read())
    return md5Hash.hexdigest()


def hash_fe(fe):
    return hashlib.md5(','.join(sorted(fe.feature_names())).encode()).hexdigest()


def data_loader(task, df, queries, fe):
    df_hash = hash_df(df)
    jar_hash = hash_anserini_jar()
    fe_hash = hash_fe(fe)
    if os.path.exists(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle'):
        res = pickle.load(open(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle', 'rb'))
        print(res['info'].shape)
        print(res['info'].qid.drop_duplicates().shape)
        print(res['group'].mean())
        return res
    else:
        if task == 'train' or task == 'dev':
            info, data, group = batch_extract(df, queries, fe)
            obj = {'info':info, 'data': data, 'group': group,
                   'df_hash': df_hash, 'jar_hash': jar_hash, 'fe_hash': fe_hash}
            print(info.shape)
            print(info.qid.drop_duplicates().shape)
            print(group.mean())
            pickle.dump(obj, open(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle', 'wb'))
            return obj
        else:
            raise Exception('unknown parameters')

In [4]:
import json
def export(df, fn, queries):
    with open(fn,'w') as f:
        line_num = 0
        for qid, group in tqdm(df.groupby('qid')):
            line = {}
            line['qid'] = qid
            line['docIds'] = [str(did) for did in group.reset_index().pid.drop_duplicates().tolist()]
            assert 'qid' not in queries[str(qid)]
            assert 'docIds' not in queries[str(qid)]
            line.update(queries[str(qid)])
            f.write(json.dumps(line)+'\n')
            line_num += 1

In [5]:
export(sampled_train, 'sampled_train_export.json', queries)
export(dev, 'sampled_dev_export.json', queries)

100%|██████████| 400782/400782 [07:50<00:00, 851.79it/s]
100%|██████████| 6980/6980 [00:10<00:00, 653.83it/s]


In [8]:
train_extracted = data_loader('train', sampled_train, queries, fe)
dev_extracted = data_loader('dev', dev, queries, fe)
del sampled_train, dev
feature_name = fe.feature_names()
del queries, fe

  3%|▎         | 10082/400782 [02:11<38:26:10,  2.82it/s]

(110346, 344)
(10000,)
qid      16070.1173
count       11.0346
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              22.893564   
1                              26.467836   
2                              19.257418   
3                              13.279017   
4                              19.064749   
5                              13.103558   
6                              18.477646   
7                              13.021508   
8                              12.939322   
9                              18.572485   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              23.464577   
1                              27.822981   
2                              21.266153   
3                              15.103254   
4                              20.615219   
5                              13.239493   
6                              20.432346   
7                              14.290549   
8                              14.188528 

  5%|▌         | 20178/400782 [02:57<13:17:02,  7.96it/s]

(110523, 344)
(10000,)
qid      43586.6803
count       11.0523
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              12.922753   
1                              12.513888   
2                              12.434628   
3                              12.844487   
4                              20.265697   
5                              18.062141   
6                              17.785980   
7                              19.244509   
8                              18.777145   
9                              15.182590   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.139851   
1                              13.831271   
2                              13.694565   
3                              13.905419   
4                              21.310802   
5                              19.409094   
6                              18.421442   
7                              20.222982   
8                              21.289972 

  8%|▊         | 30083/400782 [03:31<12:44:12,  8.08it/s]

(110506, 344)
(10000,)
qid      71878.3534
count       11.0506
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.622095   
1                              11.843141   
2                              27.521973   
3                              11.660427   
4                              12.787571   
5                              14.953897   
6                              10.778217   
7                              18.952822   
8                              12.468709   
9                              13.049401   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              15.136847   
1                              13.368475   
2                              30.654858   
3                              12.876721   
4                              13.099853   
5                              15.677228   
6                              12.160014   
7                              20.245050   
8                              13.324585 

 10%|█         | 40176/400782 [04:00<3:37:23, 27.65it/s] 

(110448, 344)
(10000,)
qid      102505.6404
count        11.0448
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              14.907814   
1                               9.844326   
2                              11.428194   
3                               8.527626   
4                              10.580976   
5                              10.796414   
6                              25.112209   
7                              15.314219   
8                              11.009486   
9                              10.687610   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              16.272871   
1                               9.578762   
2                              12.319444   
3                               9.554590   
4                              11.219783   
5                              11.756897   
6                              26.177946   
7                              14.916430   
8                              12.79349

 12%|█▏        | 50093/400782 [04:20<2:17:25, 42.53it/s]

(110284, 344)
(10000,)
qid      131703.5651
count        11.0284
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.653768   
1                               8.600472   
2                               9.122086   
3                               9.170490   
4                              10.830508   
5                              10.864552   
6                              10.178338   
7                              13.547235   
8                               9.369357   
9                               9.268857   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              11.094537   
1                               9.005003   
2                              10.455590   
3                              10.593972   
4                              11.448747   
5                              11.531213   
6                               9.905457   
7                              14.752438   
8                              11.18617

 15%|█▌        | 60151/400782 [04:41<2:18:59, 40.85it/s]

(110710, 344)
(10000,)
qid      159810.255
count        11.071
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.145366   
1                              11.561275   
2                               9.024568   
3                               9.939776   
4                              16.656654   
5                              11.715054   
6                              12.030111   
7                              10.696552   
8                              11.581792   
9                              11.743159   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              11.109741   
1                              11.977177   
2                               9.539717   
3                               9.977930   
4                              17.006994   
5                              13.172741   
6                              13.120940   
7                              11.785064   
8                              12.024822 

 18%|█▊        | 70154/400782 [05:01<1:58:56, 46.33it/s]

(110616, 344)
(10000,)
qid      188667.0197
count        11.0616
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              19.041761   
1                              15.286544   
2                              11.359656   
3                              17.077909   
4                              21.122465   
5                              14.739319   
6                              20.413633   
7                              17.717188   
8                              14.694174   
9                              12.421024   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              18.860228   
1                              15.182791   
2                               9.848079   
3                              15.056050   
4                              20.633486   
5                              15.549834   
6                              22.273041   
7                              17.585810   
8                              16.12192

 20%|█▉        | 80144/400782 [05:21<1:50:02, 48.56it/s]

(110534, 344)
(10000,)
qid      221530.6843
count        11.0534
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              18.649384   
1                              10.385224   
2                              11.001211   
3                              13.167721   
4                              13.436407   
5                              13.072525   
6                              12.666722   
7                              14.405597   
8                              12.226311   
9                              12.825982   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              19.288710   
1                              11.309126   
2                              13.023666   
3                              12.938331   
4                              14.274032   
5                              14.227904   
6                              13.231762   
7                              14.858941   
8                              14.13082

 23%|██▎       | 90180/400782 [05:40<1:15:10, 68.86it/s]

(110194, 344)
(10000,)
qid      255886.0298
count        11.0194
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              24.202383   
1                              14.533926   
2                              15.652447   
3                              14.985152   
4                              13.223838   
5                              22.196043   
6                              14.802955   
7                              26.406204   
8                              17.906281   
9                              15.091978   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              24.836401   
1                              14.197733   
2                              16.986910   
3                              16.910011   
4                              14.577914   
5                              23.135389   
6                              14.651818   
7                              26.618204   
8                              17.78747

 25%|██▍       | 100181/400782 [06:00<1:11:14, 70.32it/s]

(110255, 344)
(10000,)
qid      285440.4913
count        11.0255
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.001701   
1                              15.987255   
2                              13.886163   
3                              10.756891   
4                              14.140326   
5                              13.498837   
6                              10.768123   
7                              16.138807   
8                              14.950524   
9                              10.152766   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              11.193398   
1                              15.324629   
2                              14.515985   
3                              11.960898   
4                              14.625075   
5                              14.450054   
6                              12.187647   
7                              17.611979   
8                              15.35526

 27%|██▋       | 110143/400782 [06:19<1:33:52, 51.60it/s]

(110205, 344)
(10000,)
qid      316130.7138
count        11.0205
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.883745   
1                              10.672473   
2                               8.748243   
3                              15.274649   
4                              11.317259   
5                              14.161619   
6                               8.878933   
7                               8.335343   
8                              15.445353   
9                              22.306890   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              11.139491   
1                              11.986936   
2                               9.517762   
3                              15.729774   
4                              11.453630   
5                              15.154417   
6                               9.858629   
7                               9.069818   
8                              16.19533

 30%|██▉       | 120181/400782 [06:37<1:00:41, 77.05it/s]

(110377, 344)
(10000,)
qid      356953.6546
count        11.0377
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               9.749559   
1                              12.657712   
2                              16.171034   
3                              12.367405   
4                              13.919641   
5                              13.204207   
6                              12.568716   
7                              10.469444   
8                              10.584461   
9                               9.718123   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              10.752101   
1                              12.192801   
2                              16.341833   
3                              13.115420   
4                              13.408381   
5                              13.341187   
6                              13.301322   
7                              11.251337   
8                              10.84744

 32%|███▏      | 130206/400782 [06:55<54:28, 82.78it/s]  

(110291, 344)
(10000,)
qid      402634.5494
count        11.0291
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.769972   
1                              17.426910   
2                              13.960839   
3                              13.888522   
4                              19.395971   
5                              11.882025   
6                              12.887510   
7                              13.124071   
8                              16.769426   
9                              15.464694   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              15.369852   
1                              18.021683   
2                              15.099513   
3                              15.047424   
4                              20.817280   
5                              13.103846   
6                              14.464269   
7                              14.657386   
8                              17.96540

 35%|███▍      | 140100/400782 [07:13<1:11:42, 60.59it/s]

(110333, 344)
(10000,)
qid      428279.2893
count        11.0333
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               7.579915   
1                              15.045549   
2                              12.666134   
3                               8.018781   
4                              10.184351   
5                              23.176189   
6                              14.908622   
7                               7.141113   
8                              12.731263   
9                               7.127753   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               8.255768   
1                              16.738909   
2                              13.389154   
3                               9.198184   
4                              10.941393   
5                              22.751377   
6                              15.911568   
7                               7.970811   
8                              13.77917

 37%|███▋      | 150147/400782 [07:31<1:06:12, 63.10it/s]

(110576, 344)
(10000,)
qid      455094.8035
count        11.0576
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               8.981194   
1                              15.790057   
2                              16.347235   
3                              13.027354   
4                              14.213879   
5                              12.033302   
6                              11.254561   
7                              13.456677   
8                              17.293251   
9                              14.085408   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               9.713717   
1                              15.643287   
2                              17.211443   
3                              14.138248   
4                              14.817186   
5                              11.591328   
6                              12.093851   
7                              13.444471   
8                              18.66435

 40%|███▉      | 160120/400782 [07:48<1:02:15, 64.42it/s]

(110377, 344)
(10000,)
qid      482473.7514
count        11.0377
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              11.282845   
1                              16.034340   
2                              15.923967   
3                               9.718058   
4                              11.224901   
5                              12.355948   
6                              10.289945   
7                              10.723938   
8                              10.076472   
9                              10.743705   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              12.273409   
1                              17.179432   
2                              16.908325   
3                               9.919733   
4                              12.126390   
5                              12.734758   
6                              11.304367   
7                              11.667245   
8                              10.76349

 42%|████▏     | 170171/400782 [08:09<42:55, 89.54it/s]  

(110564, 344)
(10000,)
qid      511360.3605
count        11.0564
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.469619   
1                              14.030810   
2                              29.180943   
3                               8.006005   
4                               8.084717   
5                               8.629291   
6                               7.980108   
7                               9.248075   
8                               9.388862   
9                               8.164991   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.922128   
1                              14.953142   
2                              32.590218   
3                               8.869010   
4                               9.080462   
5                               9.360213   
6                               8.800697   
7                              10.322574   
8                              10.70959

 45%|████▍     | 180090/400782 [08:26<56:35, 64.99it/s]  

(110746, 344)
(10000,)
qid      540155.3225
count        11.0746
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              20.910910   
1                              19.980577   
2                              20.280333   
3                              20.667038   
4                              20.874048   
5                              22.176519   
6                              21.272221   
7                              32.147045   
8                              19.185482   
9                              17.515734   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              20.727804   
1                              19.712608   
2                              20.108335   
3                              21.140038   
4                              20.866550   
5                              23.393833   
6                              21.843918   
7                              34.088524   
8                              20.58276

 47%|████▋     | 190130/400782 [08:43<51:52, 67.67it/s]  

(110620, 344)
(10000,)
qid      567539.0918
count        11.0620
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               5.713381   
1                               4.756877   
2                               5.422701   
3                               4.992814   
4                              30.922735   
5                               6.412817   
6                               5.659842   
7                               5.749640   
8                               5.713381   
9                               6.570801   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               6.088757   
1                               5.410365   
2                               5.420225   
3                               6.120634   
4                              31.605267   
5                               7.067897   
6                               5.959053   
7                               6.178410   
8                               6.08875

 50%|████▉     | 200150/400782 [09:01<53:47, 62.17it/s]  

(110670, 344)
(10000,)
qid      596183.7539
count        11.0670
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.769319   
1                              20.879616   
2                              17.625082   
3                              16.453005   
4                              20.140434   
5                              14.351643   
6                              15.995559   
7                              19.086216   
8                              17.003054   
9                              18.855227   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              15.117961   
1                              19.719410   
2                              18.900932   
3                              18.232979   
4                              21.683268   
5                              15.714008   
6                              16.398987   
7                              20.025995   
8                              18.68691

 52%|█████▏    | 210089/400782 [09:18<49:40, 63.99it/s]  

(110294, 344)
(10000,)
qid      617184.4926
count        11.0294
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              18.334089   
1                              18.471027   
2                              31.033031   
3                              18.779118   
4                              18.896551   
5                              18.221893   
6                              18.240242   
7                              18.891279   
8                              19.285879   
9                              19.036123   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              19.268406   
1                              19.765936   
2                              34.378223   
3                              20.457172   
4                              20.117466   
5                              20.828125   
6                              19.204525   
7                              19.375668   
8                              21.09405

 55%|█████▍    | 220210/400782 [09:36<32:46, 91.82it/s]  

(109644, 344)
(10000,)
qid      643118.5746
count        10.9644
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              14.568353   
1                               9.971744   
2                              16.772018   
3                              13.739955   
4                               9.833502   
5                              13.320570   
6                               9.485090   
7                               9.626752   
8                               9.886217   
9                               9.243431   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              15.367369   
1                              11.300247   
2                              18.425468   
3                              14.916780   
4                              10.924479   
5                              13.898643   
6                              11.220706   
7                              10.831611   
8                              10.75416

 57%|█████▋    | 230081/400782 [09:53<1:01:10, 46.50it/s]

(110036, 344)
(10000,)
qid      670637.7320
count        11.0036
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              12.609879   
1                              22.150017   
2                              12.844075   
3                              13.779434   
4                              13.605639   
5                              17.408365   
6                              12.885204   
7                              15.727073   
8                              13.478141   
9                              12.926886   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              12.866620   
1                              23.154692   
2                              14.034810   
3                              14.932918   
4                              14.499546   
5                              17.190912   
6                              14.147182   
7                              16.556192   
8                              14.19067

 60%|█████▉    | 240176/400782 [10:10<28:12, 94.87it/s]  

(109460, 344)
(10000,)
qid      698227.5791
count        10.9460
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               9.472915   
1                               9.495005   
2                               9.607020   
3                              10.975333   
4                              14.414341   
5                               9.600055   
6                               5.162002   
7                               9.634127   
8                               9.629742   
9                              10.100142   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              10.243314   
1                              10.299256   
2                              10.588392   
3                              11.950302   
4                              15.489265   
5                               9.923700   
6                               4.842900   
7                              10.002655   
8                              10.64817

 62%|██████▏   | 250210/400782 [10:27<24:37, 101.94it/s]

(109246, 344)
(10000,)
qid      725219.6245
count        10.9246
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              11.409188   
1                              10.973421   
2                              16.117134   
3                              11.409188   
4                              11.241779   
5                              11.880997   
6                              10.767786   
7                              10.667832   
8                              11.079212   
9                              11.352835   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              12.021054   
1                              11.025069   
2                              17.537308   
3                              12.021054   
4                              11.627163   
5                              13.214861   
6                              10.586506   
7                              10.380053   
8                              11.25826

 65%|██████▍   | 260207/400782 [10:44<22:36, 103.63it/s]

(109063, 344)
(10000,)
qid      751686.7022
count        10.9063
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.053447   
1                              11.845825   
2                              11.090663   
3                              10.650106   
4                               9.547470   
5                              10.318881   
6                              10.123779   
7                              10.514685   
8                               9.707587   
9                               9.183918   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              10.804271   
1                              13.115676   
2                              11.900352   
3                              11.586919   
4                              10.825148   
5                              10.386416   
6                              10.981289   
7                              11.246667   
8                              11.28063

 67%|██████▋   | 270153/400782 [11:00<28:59, 75.10it/s] 

(109016, 344)
(10000,)
qid      778116.5249
count        10.9016
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              14.548978   
1                              10.959221   
2                              13.870667   
3                              10.488949   
4                              10.237761   
5                              10.579325   
6                              28.549007   
7                              13.008965   
8                              10.505131   
9                              14.021282   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.767600   
1                              11.952472   
2                              13.816588   
3                              11.652645   
4                              11.308593   
5                              11.727820   
6                              31.837971   
7                              13.870960   
8                              11.69585

 70%|██████▉   | 280151/400782 [11:17<29:45, 67.55it/s] 

(109737, 344)
(10000,)
qid      804282.9324
count        10.9737
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              14.548123   
1                              10.465665   
2                               7.784500   
3                               9.927494   
4                              10.659349   
5                              11.242373   
6                               9.949830   
7                              24.469971   
8                               8.207112   
9                              10.796282   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.260353   
1                              11.771371   
2                               7.891679   
3                              10.768379   
4                              11.929405   
5                              12.360866   
6                              10.823364   
7                              25.298441   
8                               8.89413

 72%|███████▏  | 290189/400782 [11:35<21:54, 84.16it/s] 

(110432, 344)
(10000,)
qid      831473.0054
count        11.0432
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               9.746423   
1                              12.129684   
2                               9.791244   
3                              39.540009   
4                              10.949347   
5                              10.182705   
6                               5.247441   
7                               8.790057   
8                              14.180614   
9                              17.369658   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               9.342265   
1                              13.031332   
2                               9.431617   
3                              42.485306   
4                              12.091469   
5                              10.846827   
6                               5.723025   
7                               9.648158   
8                              15.87022

 75%|███████▍  | 300138/400782 [11:53<25:59, 64.56it/s] 

(109776, 344)
(10000,)
qid      858783.2012
count        10.9776
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               9.980751   
1                               9.936747   
2                              11.513998   
3                              10.637863   
4                              21.254005   
5                              11.904793   
6                              10.996590   
7                              10.223645   
8                              12.429700   
9                              10.531194   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               9.972322   
1                              10.766882   
2                              12.771692   
3                              11.505508   
4                              22.207335   
5                              13.168494   
6                              12.075178   
7                              10.511040   
8                              12.93371

 77%|███████▋  | 310205/400782 [12:11<17:15, 87.49it/s] 

(110558, 344)
(10000,)
qid      887075.5948
count        11.0558
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               7.463603   
1                               7.568528   
2                               7.252001   
3                               7.348234   
4                              31.739603   
5                              37.374413   
6                              10.512916   
7                               7.631649   
8                               7.432776   
9                               7.376201   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               8.291652   
1                               8.576859   
2                               8.073287   
3                               8.335677   
4                              34.407017   
5                              40.854321   
6                              11.720140   
7                               8.497198   
8                               8.57454

 80%|███████▉  | 320060/400782 [12:28<30:46, 43.71it/s] 

(110469, 344)
(10000,)
qid      915881.9003
count        11.0469
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              19.506008   
1                              16.414865   
2                              15.719904   
3                              17.012140   
4                              42.124058   
5                              18.113211   
6                              20.732162   
7                              20.614870   
8                              17.190655   
9                              22.227768   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              21.194323   
1                              17.680841   
2                              16.864059   
3                              16.981924   
4                              43.559387   
5                              21.048336   
6                              22.933777   
7                              22.009260   
8                              18.55057

 82%|████████▏ | 330139/400782 [12:46<19:07, 61.56it/s] 

(110276, 344)
(10000,)
qid      945539.8186
count        11.0276
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               3.665089   
1                               3.532672   
2                               2.787631   
3                               0.000000   
4                               2.869438   
5                              14.864882   
6                               0.000000   
7                               9.842429   
8                               4.212320   
9                              12.228148   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               3.858847   
1                               3.555855   
2                               2.785277   
3                               0.000000   
4                               2.967806   
5                              16.886881   
6                               0.000000   
7                              10.953025   
8                               4.73645

 85%|████████▍ | 340167/400782 [13:04<11:07, 90.80it/s] 

(110144, 344)
(10000,)
qid      972541.7927
count        11.0144
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              10.126706   
1                              11.849751   
2                              11.551775   
3                              18.314100   
4                              12.083521   
5                              12.749410   
6                              12.795503   
7                               9.802573   
8                               8.088405   
9                               9.750557   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              12.414212   
1                              13.010316   
2                              12.260604   
3                              17.631853   
4                              13.635662   
5                              12.876570   
6                              13.793385   
7                              11.415162   
8                               8.51317

 87%|████████▋ | 350078/400782 [13:20<16:21, 51.66it/s] 

(109974, 344)
(10000,)
qid      993600.8354
count        10.9974
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.158268   
1                               9.931385   
2                              10.489802   
3                              10.942862   
4                              12.512739   
5                               8.595962   
6                              10.324421   
7                              12.140491   
8                              11.615752   
9                              10.333761   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.446989   
1                              10.210564   
2                              11.416915   
3                              11.199997   
4                              12.872707   
5                               9.776853   
6                              11.002748   
7                              12.051901   
8                              12.84422

 90%|████████▉ | 360081/400782 [13:38<15:39, 43.31it/s] 

(110239, 344)
(10000,)
qid      1.019772e+06
count    1.102390e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              15.001380   
1                              15.928602   
2                              15.114327   
3                              10.712572   
4                              10.147246   
5                              19.369432   
6                              11.735179   
7                              11.865397   
8                              10.529464   
9                              11.681620   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              16.169838   
1                              17.105331   
2                              16.454340   
3                              12.031974   
4                              10.599897   
5                              21.323959   
6                              13.195370   
7                              13.387648   
8                              11.410

 92%|█████████▏| 370199/400782 [13:56<05:29, 92.80it/s] 

(110009, 344)
(10000,)
qid      1.045876e+06
count    1.100090e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.601457   
1                              28.393496   
2                              20.390751   
3                               9.951025   
4                              12.922303   
5                              10.440565   
6                              15.122248   
7                              17.302208   
8                              13.914014   
9                              11.430306   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              14.710818   
1                              32.050430   
2                              21.621803   
3                              10.544174   
4                              12.911390   
5                              11.811163   
6                              14.026752   
7                              15.456553   
8                              15.257

 95%|█████████▍| 380193/400782 [14:13<03:52, 88.47it/s] 

(110276, 344)
(10000,)
qid      1.093821e+06
count    1.102760e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                               7.467671   
1                              18.373993   
2                               8.081828   
3                              11.186552   
4                               6.778780   
5                               8.264829   
6                              17.231302   
7                               8.107473   
8                               7.931300   
9                               7.834025   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                               7.267464   
1                              18.890501   
2                               8.648403   
3                              12.011958   
4                               7.831010   
5                               9.114697   
6                              18.351025   
7                               8.712074   
8                               8.285

 97%|█████████▋| 390160/400782 [14:31<02:38, 67.14it/s] 

(109639, 344)
(10000,)
qid      1.154245e+06
count    1.096390e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              12.598808   
1                              13.531940   
2                              16.659880   
3                              29.143301   
4                              14.089610   
5                              10.542268   
6                               9.927992   
7                              12.173290   
8                              14.830498   
9                              12.169173   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              12.880904   
1                              14.735783   
2                              17.610874   
3                              28.546869   
4                              16.541084   
5                              10.533364   
6                               9.358626   
7                              13.192322   
8                              15.507

100%|█████████▉| 400162/400782 [14:48<00:09, 66.32it/s] 

(110309, 344)
(10000,)
qid      1.174363e+06
count    1.103090e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              33.433941   
1                              14.954256   
2                              16.246906   
3                              17.528164   
4                              15.746342   
5                              17.343775   
6                              15.521967   
7                              23.093643   
8                              15.276166   
9                              13.909384   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              34.950954   
1                              15.878670   
2                              17.737755   
3                              18.603712   
4                              14.772942   
5                              19.187943   
6                              16.459084   
7                              23.623775   
8                              15.878

100%|██████████| 400782/400782 [14:49<00:00, 450.69it/s]


(8641, 344)
(782,)
qid      1.185035e+06
count    1.104987e+01
dtype: float64
   contents_analyzed_BM25_k1_0.90_b_0.40  \
0                              13.269936   
1                              10.223487   
2                              11.202425   
3                               9.792343   
4                              10.503739   
5                              10.664545   
6                              15.841070   
7                              10.049994   
8                               9.769454   
9                              10.663472   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              13.485291   
1                              11.868622   
2                              12.574932   
3                              10.688994   
4                              11.334428   
5                              11.747111   
6                              16.366621   
7                              10.649042   
8                              10.630297 

In [9]:
def eval_mrr(dev_data):
    score_tie_counter = 0
    score_tie_query = set()

    MRR = []
    for qid, group in tqdm(dev_data.groupby('qid')):
        group = group.reset_index()
        rank = 0
        prev_score = None
        assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
        # stable sort is also used in LightGBM

        for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
            if prev_score is not None and abs(t.score - prev_score) < 1e-8:
                score_tie_counter += 1
                score_tie_query.add(qid)
            prev_score = t.score
            prev_pid = t.pid
            rank += 1
            if t.rel>0:
                MRR.append(1.0/rank)
                break
            elif rank == 10 or rank == len(group):
                MRR.append(0.)
                break

    score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
    print(score_tie,np.mean(MRR))


In [10]:
def eval_recall(dev_qrel, dev_data):
    dev_rel_num = dev_qrel[dev_qrel['rel']>0].groupby('qid').count()['rel']

    score_tie_counter = 0
    score_tie_query = set()
    
    recall_point = [10,20,50,100,200,500,1000]
    recall_curve = {k:[] for k in recall_point}
    for qid, group in tqdm(dev_data.groupby('qid')):
        group = group.reset_index()
        rank = 0
        prev_score = None
        assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
        # stable sort is also used in LightGBM
        total_rel = dev_rel_num.loc[qid]
        query_recall = [0 for k in recall_point]
        for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
            if prev_score is not None and abs(t.score - prev_score) < 1e-8:
                score_tie_counter += 1
                score_tie_query.add(qid)
            prev_score = t.score
            rank += 1
            if t.rel>0:
                for i,p in enumerate(recall_point):
                    if rank <= p:
                        query_recall[i] += 1
        for i,p in enumerate(recall_point):
            if total_rel>0:
                recall_curve[p].append(query_recall[i]/total_rel)
            else:
                recall_curve[p].append(0.)

    score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
    print(score_tie)
    
    for k,v in recall_curve.items():
        avg = np.mean(v)
        print(f'recall@{k}:{avg}')


In [11]:
for i,n in enumerate(feature_name):
    if np.isnan(train_extracted['data'].iloc[:,i]).any():
        print(n)
        print(train_extracted['info'].loc[train_extracted['data'].iloc[:,i].isna(),
                                          ['qid','pid']].head(10))

text_text_LMJM_lambda_0.10
            qid      pid
9317       3671  6416650
303572    79357  8026198
686383   180247  2872095
826682   221234  5903693
841544   227704   736288
843270   228333  2267954
863024   233763  4359635
949649   259178  7363614
1185354  323324  8330771
1204429  329037  8026199
text_text_LMJM_lambda_0.40
            qid      pid
9317       3671  6416650
303572    79357  8026198
686383   180247  2872095
826682   221234  5903693
841544   227704   736288
843270   228333  2267954
863024   233763  4359635
949649   259178  7363614
1185354  323324  8330771
1204429  329037  8026199
text_text_LMJM_lambda_0.70
            qid      pid
9317       3671  6416650
303572    79357  8026198
686383   180247  2872095
826682   221234  5903693
841544   227704   736288
843270   228333  2267954
863024   233763  4359635
949649   259178  7363614
1185354  323324  8330771
1204429  329037  8026199
text_text_Prob
            qid      pid
9317       3671  6416650
303572    79357  8026198
6863

In [12]:
def gen_dev_group_rel_num(dev_qrel, dev_extracted):
    dev_rel_num = dev_qrel[dev_qrel['rel']>0].groupby('qid').count()['rel']
    prev_qid = None
    dev_rel_num_list = []
    for t in dev_extracted['info'].itertuples():
        if prev_qid is None or t.qid != prev_qid:
            prev_qid = t.qid
            dev_rel_num_list.append(dev_rel_num.loc[t.qid])
        else:
            continue
    assert len(dev_rel_num_list) == dev_qrel.qid.drop_duplicates().shape[0]

    def recall_at_200(preds, dataset):
        labels = dataset.get_label()
        groups = dataset.get_group()
        idx = 0
        recall = 0
        assert len(dev_rel_num_list) == len(groups)
        for g, gnum in zip(groups, dev_rel_num_list):
            top_preds = labels[idx:idx + g][np.argsort(preds[idx:idx + g])]
            recall += np.sum(top_preds[-200:]) / gnum
            idx += g
        assert idx == len(preds)
        return 'recall@200', recall / len(groups), True

    return recall_at_200
eval_fn = gen_dev_group_rel_num(dev_qrel, dev_extracted)

In [13]:
lgb_train = lgb.Dataset(train_extracted['data'].loc[:, feature_name],
                            label=train_extracted['info']['rel'],
                            group=train_extracted['group']['count'])
lgb_valid = lgb.Dataset(dev_extracted['data'].loc[:, feature_name],
                        label=dev_extracted['info']['rel'],
                        group=dev_extracted['group']['count'],
                        free_raw_data=False)

In [14]:
dev_extracted['info']['score'] = 0.
for seed in [12345]:
    params = {
            'boosting_type': 'gbdt',
            'objective': 'lambdarank',
            'max_bin':255,
            'num_leaves':63,
            'max_depth':-1,
            'min_data_in_leaf':30,
            'min_sum_hessian_in_leaf':0,
#             'bagging_fraction':0.8,
#             'bagging_freq':50,
#             'feature_fraction':0.8,
            'learning_rate':0.1,
            'num_boost_round':1000,
            'early_stopping_round':200,
            'metric':'custom',
            'label_gain':[0,1],
            'lambdarank_truncation_level':20,
            'seed':seed,
            'num_threads':max(multiprocessing.cpu_count()//2,1)
    }
    num_boost_round = params.pop('num_boost_round')
    early_stopping_round = params.pop('early_stopping_round')
    gbm = lgb.train(params, lgb_train,
                    valid_sets=lgb_valid,
                    num_boost_round=num_boost_round,
                    early_stopping_rounds=early_stopping_round,
                    feval=eval_fn,
                    feature_name=feature_name,
                    verbose_eval=True)
    del lgb_train
    dev_extracted['info']['score'] = gbm.predict(lgb_valid.get_data())
    best_score = gbm.best_score['valid_0']['recall@200']
    print(best_score)
    best_iteration = gbm.best_iteration
    print(best_iteration)
    feature_importances = sorted(list(zip(feature_name,gbm.feature_importance().tolist())),
                                 key=lambda x:x[1],reverse=True)
    print(feature_importances)
eval_recall(dev_qrel, dev_extracted['info'])
eval_mrr(dev_extracted['info'])

You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 69012
[LightGBM] [Info] Number of data points in the train set: 4416413, number of used features: 344
[1]	valid_0's recall@200: 0.731363
Training until validation scores don't improve for 200 rounds
[2]	valid_0's recall@200: 0.75671
[3]	valid_0's recall@200: 0.766738
[4]	valid_0's recall@200: 0.771263
[5]	valid_0's recall@200: 0.771836
[6]	valid_0's recall@200: 0.775322
[7]	valid_0's recall@200: 0.7774
[8]	valid_0's recall@200: 0.778677
[9]	valid_0's recall@200: 0.780181
[10]	valid_0's recall@200: 0.780146
[11]	valid_0's recall@200: 0.781256
[12]	valid_0's recall@200: 0.781805
[13]	valid_0's recall@200: 0.781137
[14]	valid_0's recall@200: 0.78128
[15]	valid_0's recall@200: 0.782426
[16]	valid_0's recall@200: 0.784432
[17]	valid_0's recall@200: 0.784575
[18]	valid_0's recall@200: 0.78522
[19]	valid_0's recall@200: 0.785721
[20]	valid_0's recall@200: 0.785076
[21]	valid_0's recall@200: 0.785148
[22]	va

[217]	valid_0's recall@200: 0.801648
[218]	valid_0's recall@200: 0.801218
[219]	valid_0's recall@200: 0.801504
[220]	valid_0's recall@200: 0.801648
[221]	valid_0's recall@200: 0.801504
[222]	valid_0's recall@200: 0.801648
[223]	valid_0's recall@200: 0.801648
[224]	valid_0's recall@200: 0.801648
[225]	valid_0's recall@200: 0.801791
[226]	valid_0's recall@200: 0.801648
[227]	valid_0's recall@200: 0.801791
[228]	valid_0's recall@200: 0.801612
[229]	valid_0's recall@200: 0.802077
[230]	valid_0's recall@200: 0.80203
[231]	valid_0's recall@200: 0.802173
[232]	valid_0's recall@200: 0.802603
[233]	valid_0's recall@200: 0.802567
[234]	valid_0's recall@200: 0.802233
[235]	valid_0's recall@200: 0.802376
[236]	valid_0's recall@200: 0.80271
[237]	valid_0's recall@200: 0.802662
[238]	valid_0's recall@200: 0.802949
[239]	valid_0's recall@200: 0.803056
[240]	valid_0's recall@200: 0.803104
[241]	valid_0's recall@200: 0.803032
[242]	valid_0's recall@200: 0.803462
[243]	valid_0's recall@200: 0.803462
[24

[439]	valid_0's recall@200: 0.803092
[440]	valid_0's recall@200: 0.803092
[441]	valid_0's recall@200: 0.802949
[442]	valid_0's recall@200: 0.802949
Early stopping, best iteration is:
[242]	valid_0's recall@200: 0.803462
0.8034622731614134
242
[('text_unlemm_text_unlemm_IBMModel1_body', 1498), ('text_bert_tok_text_bert_tok_IBMModel1_text_bert_tok', 798), ('text_unlemm_text_unlemm_IDF_sum', 536), ('text_unlemm_text_unlemm_IBMModel1_title_unlemm', 404), ('text_bert_tok_text_bert_tok_ICTF_sum', 371), ('text_unlemm_text_unlemm_SCQ_sum', 352), ('text_unlemm_text_unlemm_ICTF_sum', 329), ('text_bert_tok_text_bert_tok_UnorderedQueryPairs_3', 293), ('text_unlemm_text_unlemm_DFR_GL2', 230), ('text_unlemm_text_unlemm_IBMModel1_url_unlemm', 209), ('text_bert_tok_text_bert_tok_OrderedQueryPairs_15', 204), ('text_unlemm_text_unlemm_NTFIDF', 194), ('text_bert_tok_text_bert_tok_LMJM_lambda_0.10', 194), ('text_bert_tok_text_bert_tok_NTFIDF', 190), ('text_unlemm_text_unlemm_LMJM_lambda_0.10', 181), ('tex

100%|██████████| 6980/6980 [00:18<00:00, 378.86it/s]


score_tie occurs 302131 times in 6956 queries
recall@10:0.48494508118433616
recall@20:0.5794293218720152
recall@50:0.685243553008596
recall@100:0.751277459407832
recall@200:0.8034622731614136
recall@500:0.8447110792741165
recall@1000:0.8573424068767909


100%|██████████| 6980/6980 [00:10<00:00, 663.04it/s]

score_tie occurs 254 times in 220 queries 0.24342992450084142



