In [1]:
import argparse
import datetime
import glob
import hashlib
import json
import multiprocessing
import pickle
import os
import shutil
import subprocess
import uuid
import random

import numpy as np
import pandas as pd
import lightgbm as lgb
from collections import defaultdict
from lightgbm.sklearn import LGBMRanker
from tqdm import tqdm
import sys
sys.path.append('..')

from pyserini.analysis import Analyzer, get_lucene_analyzer
from pyserini.ltr import *
from pyserini.search import get_topics_with_reader

def train_data_loader(task='triple', neg_sample=10, cutoff=100, random_seed=12345):
    if task == 'triple' or task == 'rank':
        fn = f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle'
    elif task == 'rank_recall' or task == 'rank_hard':
        fn = f'train_{task}_sampled_with_{neg_sample}_{cutoff}_{random_seed}.pickle'
    else:
        raise Exception('unknown parameters')
    if os.path.exists(fn):
        sampled_train = pd.read_pickle(fn)
        print(sampled_train.shape)
        print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
        print(sampled_train.groupby('qid').count().mean())
        print(sampled_train.head(10))
        print(sampled_train.info())
        return sampled_train
    else:
        if task == 'triple':
            train = pd.read_csv('collections/msmarco-passage/qidpidtriples.train.full.2.tsv', sep="\t",
                                names=['qid', 'pos_pid', 'neg_pid'], dtype=np.int32)
            pos_half = train[['qid', 'pos_pid']].rename(columns={"pos_pid": "pid"}).drop_duplicates()
            pos_half['rel'] = np.int32(1)
            neg_half = train[['qid', 'neg_pid']].rename(columns={"neg_pid": "pid"}).drop_duplicates()
            neg_half['rel'] = np.int32(0)
            del train
            sampled_neg_half = []
            for qid, group in tqdm(neg_half.groupby('qid')):
                sampled_neg_half.append(group.sample(n=min(neg_sample, len(group)), random_state=random_seed))
            sampled_train = pd.concat([pos_half] + sampled_neg_half, axis=0, ignore_index=True)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle')
        elif task == 'rank':
            qrel = defaultdict(list)
            with open("collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{random_seed}.pickle')
        elif task == 'rank_recall':
            qrel = defaultdict(list)
            with open("../collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("../runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        if int(rank) > cutoff:
                            qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{cutoff}_{random_seed}.pickle')
        elif task == 'rank_hard':
            qrel = defaultdict(list)
            with open("../collections/msmarco-passage/qrels.train.tsv") as f:
                for line in f:
                    topicid, _, docid, rel = line.strip().split('\t')
                    assert rel == "1", line.split(' ')
                    qrel[topicid].append(docid)
            
            qid2pos = defaultdict(list)
            qid2neg = defaultdict(list)
            with open("../runs/msmarco-passage/run.train.small.tsv") as f:
                for line in tqdm(f):
                    topicid, docid, rank = line.split()
                    assert topicid in qrel
                    if docid in qrel[topicid]:
                        qid2pos[topicid].append(docid)
                    else:
                        if int(rank) < cutoff:
                            qid2neg[topicid].append(docid)
            sampled_train = []
            for topicid, pos_list in tqdm(qid2pos.items()):
                neg_list = random.sample(qid2neg[topicid], min(len(qid2neg[topicid]), neg_sample))
                for positive_docid in pos_list:
                    sampled_train.append((int(topicid), int(positive_docid), 1))
                for negative_docid in neg_list:
                    sampled_train.append((int(topicid), int(negative_docid), 0))
            sampled_train = pd.DataFrame(sampled_train,columns=['qid','pid','rel'],dtype=np.int32)
            sampled_train = sampled_train.sort_values(['qid','pid']).set_index(['qid','pid'])
            print(sampled_train.shape)
            print(sampled_train.index.get_level_values('qid').drop_duplicates().shape)
            print(sampled_train.groupby('qid').count().mean())
            print(sampled_train.head(10))
            print(sampled_train.info())

            sampled_train.to_pickle(f'train_{task}_sampled_with_{neg_sample}_{cutoff}_{random_seed}.pickle')
        else:
            raise Exception('unknown parameters')
        return sampled_train
sampled_train = train_data_loader(task='triple')

(4416413, 1)
(400782,)
rel    11.019489
dtype: float64
             rel
qid pid         
3   970816     0
    1142680    1
    2019206    0
    2605131    0
    2963098    0
    2971685    0
    3783924    0
    5067083    0
    5904778    0
    6176208    0
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 4416413 entries, (3, 970816) to (1185869, 7770561)
Data columns (total 1 columns):
 #   Column  Dtype
---  ------  -----
 0   rel     int32
dtypes: int32(1)
memory usage: 166.9 MB
None


In [2]:
def dev_data_loader(task='pygaggle'):
    if os.path.exists(f'dev_{task}.pickle'):
        dev = pd.read_pickle(f'dev_{task}.pickle')
        print(dev.shape)
        print(dev.index.get_level_values('qid').drop_duplicates().shape)
        print(dev.groupby('qid').count().mean())
        print(dev.head(10))
        print(dev.info())
        dev_qrel = pd.read_pickle(f'dev_qrel.pickle')
        return dev, dev_qrel
    else:
        if task == 'rerank':
            dev = pd.read_csv('collections/msmarco-passage/top1000.dev', sep="\t",
                              names=['qid', 'pid', 'query', 'doc'], usecols=['qid', 'pid'], dtype=np.int32)
        elif task == 'anserini':
            dev = pd.read_csv('runs/msmarco-passage/run.msmarco-passage.dev.small.tsv',sep="\t",
                            names=['qid','pid','rank'], dtype=np.int32)
        elif task == 'pygaggle':
            dev = pd.read_csv('../pygaggle/data/msmarco_ans_entire/run.dev.small.tsv',sep="\t",
                            names=['qid','pid','rank'], dtype=np.int32)
        else:
            raise Exception('unknown parameters')
        dev_qrel = pd.read_csv('collections/msmarco-passage/qrels.dev.small.tsv', sep="\t",
                               names=["qid", "q0", "pid", "rel"], usecols=['qid', 'pid', 'rel'], dtype=np.int32)
        dev = dev.merge(dev_qrel, left_on=['qid', 'pid'], right_on=['qid', 'pid'], how='left')
        dev['rel'] = dev['rel'].fillna(0).astype(np.int32)
        dev = dev.sort_values(['qid','pid']).set_index(['qid','pid'])
        
        print(dev.shape)
        print(dev.index.get_level_values('qid').drop_duplicates().shape)
        print(dev.groupby('qid').count().mean())
        print(dev.head(10))
        print(dev.info())

        dev.to_pickle(f'dev_{task}.pickle')
        dev_qrel.to_pickle(f'dev_qrel.pickle')
        return dev, dev_qrel
dev, dev_qrel = dev_data_loader(task='pygaggle')

(6974598, 2)
(6980,)
rank    999.226074
rel     999.226074
dtype: float64
            rank  rel
qid pid              
2   55860    345    0
    72202    557    0
    72210    213    0
    98589    278    0
    98590    323    0
    98593    580    0
    98595    553    0
    112123   108    0
    112126   469    0
    112127    21    0
<class 'pandas.core.frame.DataFrame'>
MultiIndex: 6974598 entries, (2, 55860) to (1102400, 8830447)
Data columns (total 2 columns):
 #   Column  Dtype
---  ------  -----
 0   rank    int32
 1   rel     int32
dtypes: int32(2)
memory usage: 282.7 MB
None


In [3]:
def query_loader():
    queries = {}
    with open('queries.train.small.Flex.json') as f:
        for line in f:
            query = json.loads(line)
            qid = query.pop('id')
            query['analyzed'] = query['analyzed'].split(" ")
            query['text'] = query['text'].split(" ")
            query['text_unlemm'] = query['text_unlemm'].split(" ")
            query['text_bert_tok'] = query['text_bert_tok'].split(" ")
            queries[qid] = query
    with open('queries.dev.small.Flex.json') as f:
        for line in f:
            query = json.loads(line)
            qid = query.pop('id')
            query['analyzed'] = query['analyzed'].split(" ")
            query['text'] = query['text'].split(" ")
            query['text_unlemm'] = query['text_unlemm'].split(" ")
            query['text_bert_tok'] = query['text_bert_tok'].split(" ")
            queries[qid] = query
    with open('queries.eval.small.Flex.json') as f:
        for line in f:
            query = json.loads(line)
            qid = query.pop('id')
            query['analyzed'] = query['analyzed'].split(" ")
            query['text'] = query['text'].split(" ")
            query['text_unlemm'] = query['text_unlemm'].split(" ")
            query['text_bert_tok'] = query['text_bert_tok'].split(" ")
            queries[qid] = query
    return queries
queries = query_loader()

In [4]:
fe = FeatureExtractor('../../anserini/indexes/msmarco-passage/lucene-index-msmarco-flex',max(multiprocessing.cpu_count()//2,1))

for qfield, ifield in [('analyzed','contents'),
                       ('text','text'),
                       ('text_unlemm','text_unlemm'),
                       ('text_bert_tok','text_bert_tok')]:
    print(qfield, ifield)
    fe.add(BM25(k1=0.9,b=0.4, field=ifield, qfield=qfield))
    fe.add(BM25(k1=1.2,b=0.75, field=ifield, qfield=qfield))
    fe.add(BM25(k1=2.0,b=0.75, field=ifield, qfield=qfield))

    fe.add(LMDir(mu=1000, field=ifield, qfield=qfield))
    fe.add(LMDir(mu=1500, field=ifield, qfield=qfield))
    fe.add(LMDir(mu=2500, field=ifield, qfield=qfield))

    fe.add(LMJM(0.1, field=ifield, qfield=qfield))
    fe.add(LMJM(0.4, field=ifield, qfield=qfield))
    fe.add(LMJM(0.7, field=ifield, qfield=qfield))

    fe.add(NTFIDF(field=ifield, qfield=qfield))
    fe.add(ProbalitySum(field=ifield, qfield=qfield))

    fe.add(DFR_GL2(field=ifield, qfield=qfield))
    fe.add(DFR_In_expB2(field=ifield, qfield=qfield))
    fe.add(DPH(field=ifield, qfield=qfield))

    fe.add(Proximity(field=ifield, qfield=qfield))
    fe.add(TPscore(field=ifield, qfield=qfield))
    fe.add(tpDist(field=ifield, qfield=qfield))

    fe.add(DocSize(field=ifield))

    fe.add(QueryLength(qfield=qfield))
    fe.add(QueryCoverageRatio(qfield=qfield))
    fe.add(UniqueTermCount(qfield=qfield))
    fe.add(MatchingTermCount(field=ifield, qfield=qfield))
    fe.add(SCS(field=ifield, qfield=qfield))

    fe.add(tfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(tfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(tfIdfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(tfIdfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(scqStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(scqStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(normalizedTfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(normalizedTfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(idfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(idfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(ictfStat(AvgPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MedianPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(SumPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MinPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MaxPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(VarPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(MaxMinRatioPooler(), field=ifield, qfield=qfield))
    fe.add(ictfStat(ConfidencePooler(), field=ifield, qfield=qfield))

    fe.add(UnorderedSequentialPairs(3, field=ifield, qfield=qfield))
    fe.add(UnorderedSequentialPairs(8, field=ifield, qfield=qfield))
    fe.add(UnorderedSequentialPairs(15, field=ifield, qfield=qfield))
    fe.add(OrderedSequentialPairs(3, field=ifield, qfield=qfield))
    fe.add(OrderedSequentialPairs(8, field=ifield, qfield=qfield))
    fe.add(OrderedSequentialPairs(15, field=ifield, qfield=qfield))
    fe.add(UnorderedQueryPairs(3, field=ifield, qfield=qfield))
    fe.add(UnorderedQueryPairs(8, field=ifield, qfield=qfield))
    fe.add(UnorderedQueryPairs(15, field=ifield, qfield=qfield))
    fe.add(OrderedQueryPairs(3, field=ifield, qfield=qfield))
    fe.add(OrderedQueryPairs(8, field=ifield, qfield=qfield))
    fe.add(OrderedQueryPairs(15, field=ifield, qfield=qfield))

#     fe.add(BM25Mean(MaxPooler()))
#     fe.add(BM25Mean(MinPooler()))
#     fe.add(BM25Min(MaxPooler()))
#     fe.add(BM25Min(MinPooler()))
#     fe.add(BM25Max(MaxPooler()))
#     fe.add(BM25Max(MinPooler()))
#     fe.add(BM25HMean(MaxPooler()))
#     fe.add(BM25HMean(MinPooler()))
#     fe.add(BM25Var(MaxPooler()))
#     fe.add(BM25Var(MinPooler()))
#     fe.add(BM25Quartile(MaxPooler()))
#     fe.add(BM25Quartile(MinPooler()))

fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/title_unlemm","text_unlemm","title_unlemm","text_unlemm"))
fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/url_unlemm","text_unlemm","url_unlemm","text_unlemm"))
fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/body","text_unlemm","body","text_unlemm"))
fe.add(IBMModel1("../FlexNeuART/collections/msmarco_doc/derived_data/giza/text_bert_tok","text_bert_tok","text_bert_tok","text_bert_tok"))


analyzed contents
text text
text_unlemm text_unlemm
text_bert_tok text_bert_tok


In [5]:
def extract(df, queries, fe):
    df_pieces = []
    fetch_later = []
    qidpid2rel = defaultdict(dict)
    need_rows = 0
    for qid,group in tqdm(df.groupby('qid')):
        for t in group.reset_index().itertuples():
            assert t.pid not in qidpid2rel[t.qid]
            qidpid2rel[t.qid][t.pid] = t.rel
            need_rows += 1
        #test.py has bug here, it does not convert pid to str, not sure why it does not cause problem in java
        fe.lazy_extract(str(qid),
                        [str(pid) for pid in qidpid2rel[t.qid].keys()],
                        queries[str(qid)])
        fetch_later.append(str(qid))
        if len(fetch_later) == 10000:
            info = np.zeros(shape=(need_rows,3), dtype=np.int32)
            feature = np.zeros(shape=(need_rows,len(fe.feature_names())), dtype=np.float32)
            idx = 0
            for qid in fetch_later:
                for doc in fe.get_result(qid):
                    info[idx,0] = int(qid)
                    info[idx,1] = int(doc['pid'])
                    info[idx,2] = qidpid2rel[int(qid)][int(doc['pid'])]
                    feature[idx,:] = doc['features']
                    idx += 1
            info = pd.DataFrame(info, columns=['qid','pid','rel'])
            feature = pd.DataFrame(feature, columns=fe.feature_names())
            df_pieces.append(pd.concat([info,feature], axis=1))
            del info, feature
            fetch_later = []
            need_rows = 0
    #deal with rest
    if len(fetch_later) > 0:
        info = np.zeros(shape=(need_rows,3), dtype=np.int32)
        feature = np.zeros(shape=(need_rows,len(fe.feature_names())), dtype=np.float32)
        idx = 0
        for qid in fetch_later:
            for doc in fe.get_result(qid):
                info[idx,0] = int(qid)
                info[idx,1] = int(doc['pid'])
                info[idx,2] = qidpid2rel[int(qid)][int(doc['pid'])]
                feature[idx,:] = doc['features']
                idx += 1
        info = pd.DataFrame(info, columns=['qid','pid','rel'])
        feature = pd.DataFrame(feature, columns=fe.feature_names())
        df_pieces.append(pd.concat([info,feature], axis=1))
        del info, feature
    data = pd.concat(df_pieces, axis=0, ignore_index=True)
    del df_pieces
    data = data.sort_values(by='qid', kind='mergesort')
    group = data.groupby('qid').agg(count=('pid', 'count'))['count']
    return data,group

In [6]:
def hash_df(df):
    h = pd.util.hash_pandas_object(df)
    return hex(h.sum().astype(np.uint64))


def hash_anserini_jar():
    find = glob.glob(os.environ['ANSERINI_CLASSPATH'] + "/*fatjar.jar")
    assert len(find) == 1
    md5Hash = hashlib.md5(open(find[0], 'rb').read())
    return md5Hash.hexdigest()


def hash_fe(fe):
    return hashlib.md5(','.join(sorted(fe.feature_names())).encode()).hexdigest()


def data_loader(task, df, queries, fe):
    df_hash = hash_df(df)
    jar_hash = hash_anserini_jar()
    fe_hash = hash_fe(fe)
    if os.path.exists(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle'):
        res = pickle.load(open(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle','rb'))
        print(res['data'].shape)
        print(res['data'].qid.drop_duplicates().shape)
        print(res['group'].mean())
        print(res['data'].head(10))
        print(res['data'].info())
        return res
    else:
        if task == 'train' or task == 'dev': 
            data,group = extract(df, queries, fe)
            obj = {'data':data,'group':group,'df_hash':df_hash,'jar_hash':jar_hash,'fe_hash':fe_hash}
            print(data.shape)
            print(data.qid.drop_duplicates().shape)
            print(group.mean())
            print(data.head(10))
            print(data.info())
            pickle.dump(obj,open(f'{task}_{df_hash}_{jar_hash}_{fe_hash}.pickle','wb'))
            return obj
        else:
            raise Exception('unknown parameters')

In [7]:
import json
def export(df, fn):
    with open(fn,'w') as f:
        line_num = 0
        for qid, group in tqdm(df.groupby('qid')):
            line = {}
            line['qid'] = qid
            line['docIds'] = [str(did) for did in group.reset_index().pid.drop_duplicates().tolist()]
            assert 'qid' not in queries[str(qid)]
            assert 'docIds' not in queries[str(qid)]
            line.update(queries[str(qid)])
            f.write(json.dumps(line)+'\n')
            line_num += 1
            if line_num >= 1000:
                break

In [8]:
train_extracted = data_loader('train', sampled_train, queries, fe)
# export(sampled_train, 'sampled_train_export.json')
dev_extracted = data_loader('dev', dev, queries, fe)
# export(dev, 'sampled_dev_export.json')
del sampled_train, dev
feature_name = fe.feature_names()
del queries, fe

100%|██████████| 400782/400782 [24:40<00:00, 270.71it/s]   


(4416413, 339)
(400782,)
11.019489398226467
   qid      pid  rel  contents_analyzed_BM25_k1_0.90_b_0.40  \
0    3   970816    0                              22.893564   
1    3  1142680    1                              26.467836   
2    3  2019206    0                              19.257418   
3    3  2605131    0                              13.279017   
4    3  2963098    0                              19.064749   
5    3  2971685    0                              13.103558   
6    3  3783924    0                              18.477646   
7    3  5067083    0                              13.021508   
8    3  5904778    0                              12.939322   
9    3  6176208    0                              18.572485   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              23.464577   
1                              27.822981   
2                              21.266153   
3                              15.103254   
4                              20.615219  

100%|██████████| 6980/6980 [00:57<00:00, 120.93it/s]


(6974598, 339)
(6980,)
999.2260744985673
   qid     pid  rel  contents_analyzed_BM25_k1_0.90_b_0.40  \
0    2   55860    0                              12.348820   
1    2   72202    0                              10.927653   
2    2   72210    0                              13.675473   
3    2   98589    0                              12.699286   
4    2   98590    0                              12.492470   
5    2   98593    0                              11.077914   
6    2   98595    0                              11.181725   
7    2  112123    0                              15.955744   
8    2  112126    0                              11.468307   
9    2  112127    0                              21.200821   

   contents_analyzed_BM25_k1_1.20_b_0.75  \
0                              13.151269   
1                              11.041017   
2                              13.716912   
3                              14.042357   
4                              13.508439   
5           

In [9]:
def eval_mrr(dev_data):
    score_tie_counter = 0
    score_tie_query = set()

    MRR = []
    for qid, group in tqdm(dev_data.groupby('qid')):
        group = group.reset_index()
        rank = 0
        prev_score = None
        assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
        # stable sort is also used in LightGBM

        for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
            if prev_score is not None and abs(t.score - prev_score) < 1e-8:
                score_tie_counter += 1
                score_tie_query.add(qid)
            prev_score = t.score
            prev_pid = t.pid
            rank += 1
            if t.rel>0:
                MRR.append(1.0/rank)
                break
            elif rank == 10 or rank == len(group):
                MRR.append(0.)
                break

    score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
    print(score_tie,np.mean(MRR))


In [10]:
def eval_recall(dev_qrel, dev_data):
    dev_rel_num = dev_qrel[dev_qrel['rel']>0].groupby('qid').count()['rel']

    score_tie_counter = 0
    score_tie_query = set()
    
    recall_point = [10,20,50,100,200,500,1000]
    recall_curve = {k:[] for k in recall_point}
    for qid, group in tqdm(dev_data.groupby('qid')):
        group = group.reset_index()
        rank = 0
        prev_score = None
        assert len(group['pid'].tolist()) == len(set(group['pid'].tolist()))
        # stable sort is also used in LightGBM
        total_rel = dev_rel_num.loc[qid]
        query_recall = [0 for k in recall_point]
        for t in group.sort_values('score', ascending=False, kind='mergesort').itertuples():
            if prev_score is not None and abs(t.score - prev_score) < 1e-8:
                score_tie_counter += 1
                score_tie_query.add(qid)
            prev_score = t.score
            rank += 1
            if t.rel>0:
                for i,p in enumerate(recall_point):
                    if rank <= p:
                        query_recall[i] += 1
        for i,p in enumerate(recall_point):
            if total_rel>0:
                recall_curve[p].append(query_recall[i]/total_rel)
            else:
                recall_curve[p].append(0.)

    score_tie = f'score_tie occurs {score_tie_counter} times in {len(score_tie_query)} queries'
    print(score_tie)
    
    for k,v in recall_curve.items():
        avg = np.mean(v)
        print(f'recall@{k}:{avg}')


In [11]:
train_X = train_extracted['data'].loc[:, feature_name]
train_Y = train_extracted['data']['rel']
dev_X = dev_extracted['data'].loc[:, feature_name]
dev_Y = dev_extracted['data']['rel']
lgb_train = lgb.Dataset(train_X,label=train_Y,group=train_extracted['group'])
lgb_valid = lgb.Dataset(dev_X,label=dev_Y,group=dev_extracted['group'])

In [12]:
for i,n in enumerate(feature_name):
    if np.isnan(train_X.iloc[:,i]).any():
        print(n)
        print(train_extracted['data'].loc[train_X.iloc[:,i].isna(),
                                          ['qid','pid']].head(10))

text_text_LMJM_lambda_0.10
            qid      pid
9317       3671  6416650
303572    79357  8026198
686383   180247  2872095
826682   221234  5903693
841544   227704   736288
843270   228333  2267954
863024   233763  4359635
949649   259178  7363614
1185354  323324  8330771
1204429  329037  8026199
text_text_LMJM_lambda_0.40
            qid      pid
9317       3671  6416650
303572    79357  8026198
686383   180247  2872095
826682   221234  5903693
841544   227704   736288
843270   228333  2267954
863024   233763  4359635
949649   259178  7363614
1185354  323324  8330771
1204429  329037  8026199
text_text_LMJM_lambda_0.70
            qid      pid
9317       3671  6416650
303572    79357  8026198
686383   180247  2872095
826682   221234  5903693
841544   227704   736288
843270   228333  2267954
863024   233763  4359635
949649   259178  7363614
1185354  323324  8330771
1204429  329037  8026199
text_text_Prob
            qid      pid
9317       3671  6416650
303572    79357  8026198
6863

In [13]:
dev_rel_num = dev_qrel[dev_qrel['rel']>0].groupby('qid').count()['rel']
gid = 0
sample_id = 0
dev_group_rel_num = []
for qid,group in dev_extracted['data'].groupby('qid'):
    group = group.sort_values(['pid'])
    assert len(group) == dev_extracted['group'].iloc[gid]
    assert np.isclose(group.iloc[0,:].loc[feature_name],
                      dev_X.iloc[sample_id,:], equal_nan=True).all()
    dev_group_rel_num.append(dev_rel_num.loc[qid])
    gid += 1
    sample_id += len(group)
dev_group = dev_extracted['group']

In [14]:
def recall_at_200(preds, dataset):
    global dev_group_rel_num
    global dev_group
    labels = dataset.get_label()
    groups = dataset.get_group()
    assert np.equal(groups, dev_group).all()
    idx = 0
    recall = 0
    for g,gnum in zip(groups, dev_group_rel_num):
        top_preds = labels[idx:idx+g][np.argsort(preds[idx:idx+g])]
        recall += np.sum(top_preds[-200:])/gnum
        idx += g
    assert idx == len(preds)
    return 'recall@200', recall/len(groups), True

In [15]:
dev_extracted['data']['score'] = 0.
for seed in [12345]:
    params = {
            'boosting_type': 'gbdt',
            'objective': 'lambdarank',
            'max_bin':255,
            'num_leaves':63,
            'max_depth':-1,
            'min_data_in_leaf':50,
            'min_sum_hessian_in_leaf':0,
#             'bagging_fraction':0.8,
#             'bagging_freq':50,
            'feature_fraction':1,
            'learning_rate':0.1,
            'num_boost_round':1000,
            'early_stopping_round':200,
            'metric':'custom',
            'label_gain':[0,1],
            'lambdarank_truncation_level':20,
            'seed':seed,
            'num_threads':max(multiprocessing.cpu_count()//2,1)
    }
    num_boost_round = params.pop('num_boost_round')
    early_stopping_round = params.pop('early_stopping_round')
    gbm = lgb.train(params, lgb_train,
                    valid_sets=lgb_valid,
                    num_boost_round=num_boost_round,
                    early_stopping_rounds=early_stopping_round,
                    feval=recall_at_200,
                    feature_name=feature_name,
                    verbose_eval=True)
    dev_extracted['data']['score'] += gbm.predict(dev_X)
    best_score = gbm.best_score['valid_0']['recall@200']
    print(best_score)
    best_iteration = gbm.best_iteration
    print(best_iteration)
    feature_importances = sorted(list(zip(feature_name,gbm.feature_importance().tolist())),
                                 key=lambda x:x[1],reverse=True)
    print(feature_importances)
eval_recall(dev_qrel, dev_extracted['data'])
eval_mrr(dev_extracted['data'])

You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 68782
[LightGBM] [Info] Number of data points in the train set: 4416413, number of used features: 336
[1]	valid_0's recall@200: 0.735375
Training until validation scores don't improve for 200 rounds
[2]	valid_0's recall@200: 0.761378
[3]	valid_0's recall@200: 0.769759
[4]	valid_0's recall@200: 0.776063
[5]	valid_0's recall@200: 0.779035
[6]	valid_0's recall@200: 0.780886
[7]	valid_0's recall@200: 0.781769
[8]	valid_0's recall@200: 0.782987
[9]	valid_0's recall@200: 0.784766
[10]	valid_0's recall@200: 0.785793
[11]	valid_0's recall@200: 0.785864
[12]	valid_0's recall@200: 0.785769
[13]	valid_0's recall@200: 0.7867
[14]	valid_0's recall@200: 0.787416
[15]	valid_0's recall@200: 0.788276
[16]	valid_0's recall@200: 0.787416
[17]	valid_0's recall@200: 0.7867
[18]	valid_0's recall@200: 0.787631
[19]	valid_0's recall@200: 0.789064
[20]	valid_0's recall@200: 0.789136
[21]	valid_0's recall@200: 0.79021
[22]	va

[217]	valid_0's recall@200: 0.802913
[218]	valid_0's recall@200: 0.802913
[219]	valid_0's recall@200: 0.802985
[220]	valid_0's recall@200: 0.803271
[221]	valid_0's recall@200: 0.803128
[222]	valid_0's recall@200: 0.802985
[223]	valid_0's recall@200: 0.802985
[224]	valid_0's recall@200: 0.803056
[225]	valid_0's recall@200: 0.802913
[226]	valid_0's recall@200: 0.803056
[227]	valid_0's recall@200: 0.802913
[228]	valid_0's recall@200: 0.803056
[229]	valid_0's recall@200: 0.8032
[230]	valid_0's recall@200: 0.803343
[231]	valid_0's recall@200: 0.8032
[232]	valid_0's recall@200: 0.803343
[233]	valid_0's recall@200: 0.803056
[234]	valid_0's recall@200: 0.803056
[235]	valid_0's recall@200: 0.8032
[236]	valid_0's recall@200: 0.8032
[237]	valid_0's recall@200: 0.803056
[238]	valid_0's recall@200: 0.802913
[239]	valid_0's recall@200: 0.802913
[240]	valid_0's recall@200: 0.802627
[241]	valid_0's recall@200: 0.802627
[242]	valid_0's recall@200: 0.802627
[243]	valid_0's recall@200: 0.802627
[244]	val

[440]	valid_0's recall@200: 0.80234
[441]	valid_0's recall@200: 0.802197
[442]	valid_0's recall@200: 0.802053
[443]	valid_0's recall@200: 0.80234
[444]	valid_0's recall@200: 0.80234
[445]	valid_0's recall@200: 0.802483
[446]	valid_0's recall@200: 0.802698
[447]	valid_0's recall@200: 0.802555
[448]	valid_0's recall@200: 0.802555
[449]	valid_0's recall@200: 0.802412
[450]	valid_0's recall@200: 0.802412
[451]	valid_0's recall@200: 0.802412
[452]	valid_0's recall@200: 0.802268
[453]	valid_0's recall@200: 0.802268
[454]	valid_0's recall@200: 0.802412
[455]	valid_0's recall@200: 0.802412
[456]	valid_0's recall@200: 0.802627
[457]	valid_0's recall@200: 0.80277
[458]	valid_0's recall@200: 0.802698
[459]	valid_0's recall@200: 0.802555
[460]	valid_0's recall@200: 0.80234
[461]	valid_0's recall@200: 0.802412
[462]	valid_0's recall@200: 0.802555
[463]	valid_0's recall@200: 0.802555
[464]	valid_0's recall@200: 0.802412
[465]	valid_0's recall@200: 0.802268
[466]	valid_0's recall@200: 0.802197
[467]	

100%|██████████| 6980/6980 [06:35<00:00, 17.63it/s] 


score_tie occurs 247511 times in 6908 queries
recall@10:0.4895893027698185
recall@20:0.5785816618911175
recall@50:0.6844794651384909
recall@100:0.753689111747851
recall@200:0.8036055396370582
recall@500:0.8449976122254059
recall@1000:0.8573424068767909


100%|██████████| 6980/6980 [02:46<00:00, 42.01it/s] 


score_tie occurs 252 times in 220 queries 0.24848552553781783


In [25]:
selected_feature_name = [f for f,s, in feature_importances[:90]]
print(selected_feature_name)
train_X = train_extracted['data'].loc[:, selected_feature_name]
train_Y = train_extracted['data']['rel']
dev_X = dev_extracted['data'].loc[:, selected_feature_name]
dev_Y = dev_extracted['data']['rel']
lgb_train = lgb.Dataset(train_X,label=train_Y,group=train_extracted['group'])
lgb_valid = lgb.Dataset(dev_X,label=dev_Y,group=dev_extracted['group'])

['text_unlemm_text_unlemm_IBMModel1_body', 'text_bert_tok_text_bert_tok_IBMModel1_text_bert_tok', 'text_bert_tok_text_bert_tok_ICTF_sum', 'text_unlemm_text_unlemm_IBMModel1_title_unlemm', 'text_bert_tok_text_bert_tok_UnorderedQueryPairs_3', 'text_unlemm_text_unlemm_IDF_sum', 'text_text_SCQ_sum', 'text_unlemm_text_unlemm_IBMModel1_url_unlemm', 'text_unlemm_text_unlemm_NTFIDF', 'text_bert_tok_text_bert_tok_OrderedQueryPairs_15', 'text_bert_tok_text_bert_tok_NTFIDF', 'text_bert_tok_text_bert_tok_LMJM_lambda_0.10', 'text_unlemm_text_unlemm_ICTF_sum', 'text_text_ICTF_sum', 'text_unlemm_text_unlemm_DFR_GL2', 'text_bert_tok_text_bert_tok_NormalizedTF_avg', 'contents_analyzed_NormalizedTF_min', 'text_text_LMJM_lambda_0.10', 'contents_analyzed_DFR_GL2', 'text_text_IDF_sum', 'text_bert_tok_text_bert_tok_NormalizedTF_min', 'text_bert_tok_text_bert_tok_NormalizedTF_sum', 'text_unlemm_text_unlemm_Prob', 'contents_analyzed_NTFIDF', 'text_bert_tok_text_bert_tok_UnorderedQueryPairs_8', 'text_unlemm_te

In [27]:
dev_extracted['data']['score'] = 0.
for seed in [12345]:
    params = {
            'boosting_type': 'gbdt',
            'objective': 'lambdarank',
            'max_bin':255,
            'num_leaves':63,
            'max_depth':-1,
            'min_data_in_leaf':50,
            'min_sum_hessian_in_leaf':0,
#             'bagging_fraction':0.8,
#             'bagging_freq':50,
            'feature_fraction':1,
            'learning_rate':0.1,
            'num_boost_round':1000,
            'early_stopping_round':200,
            'metric':'custom',
            'label_gain':[0,1],
            'lambdarank_truncation_level':20,
            'seed':seed,
            'num_threads':max(multiprocessing.cpu_count()//2,1)
    }
    num_boost_round = params.pop('num_boost_round')
    early_stopping_round = params.pop('early_stopping_round')
    gbm = lgb.train(params, lgb_train,
                    valid_sets=lgb_valid,
                    num_boost_round=num_boost_round,
                    early_stopping_rounds=early_stopping_round,
                    feval=recall_at_200,
                    feature_name=selected_feature_name,
                    verbose_eval=True)
    dev_extracted['data']['score'] += gbm.predict(dev_X)
    best_score = gbm.best_score['valid_0']['recall@200']
    print(best_score)
    best_iteration = gbm.best_iteration
    print(best_iteration)
    feature_importances = sorted(list(zip(feature_name,gbm.feature_importance().tolist())),
                                 key=lambda x:x[1],reverse=True)
    print(feature_importances)
eval_recall(dev_qrel, dev_extracted['data'])
eval_mrr(dev_extracted['data'])

You can set `force_row_wise=true` to remove the overhead.
And if memory is not enough, you can set `force_col_wise=true`.
[LightGBM] [Info] Total Bins 20478
[LightGBM] [Info] Number of data points in the train set: 4416413, number of used features: 90
[1]	valid_0's recall@200: 0.736103
Training until validation scores don't improve for 200 rounds
[2]	valid_0's recall@200: 0.76164
[3]	valid_0's recall@200: 0.771466
[4]	valid_0's recall@200: 0.775836
[5]	valid_0's recall@200: 0.777256
[6]	valid_0's recall@200: 0.779298
[7]	valid_0's recall@200: 0.781519
[8]	valid_0's recall@200: 0.782474
[9]	valid_0's recall@200: 0.783405
[10]	valid_0's recall@200: 0.784097
[11]	valid_0's recall@200: 0.785602
[12]	valid_0's recall@200: 0.786294
[13]	valid_0's recall@200: 0.786617
[14]	valid_0's recall@200: 0.786258
[15]	valid_0's recall@200: 0.785375
[16]	valid_0's recall@200: 0.786163
[17]	valid_0's recall@200: 0.786234
[18]	valid_0's recall@200: 0.787309
[19]	valid_0's recall@200: 0.787667
[20]	valid_0

[215]	valid_0's recall@200: 0.802137
[216]	valid_0's recall@200: 0.80228
[217]	valid_0's recall@200: 0.801851
[218]	valid_0's recall@200: 0.801851
[219]	valid_0's recall@200: 0.801707
[220]	valid_0's recall@200: 0.801851
[221]	valid_0's recall@200: 0.801922
[222]	valid_0's recall@200: 0.801755
[223]	valid_0's recall@200: 0.801683
[224]	valid_0's recall@200: 0.80154
[225]	valid_0's recall@200: 0.80154
[226]	valid_0's recall@200: 0.801755
[227]	valid_0's recall@200: 0.801683
[228]	valid_0's recall@200: 0.801755
[229]	valid_0's recall@200: 0.801755
[230]	valid_0's recall@200: 0.802185
[231]	valid_0's recall@200: 0.802471
[232]	valid_0's recall@200: 0.802471
[233]	valid_0's recall@200: 0.802686
[234]	valid_0's recall@200: 0.802471
[235]	valid_0's recall@200: 0.802471
[236]	valid_0's recall@200: 0.802471
[237]	valid_0's recall@200: 0.802615
[238]	valid_0's recall@200: 0.802615
[239]	valid_0's recall@200: 0.802615
[240]	valid_0's recall@200: 0.8024
[241]	valid_0's recall@200: 0.802113
[242]	

100%|██████████| 6980/6980 [06:32<00:00, 17.78it/s] 


score_tie occurs 421260 times in 6958 queries
recall@10:0.4786652340019102
recall@20:0.5722301814708692
recall@50:0.6813992359121299
recall@100:0.7482688634192932
recall@200:0.8035458452722063
recall@500:0.8449976122254059
recall@1000:0.8573424068767909


100%|██████████| 6980/6980 [02:52<00:00, 40.54it/s] 


score_tie occurs 360 times in 317 queries 0.2415089598399054
