In [None]:
import faiss
import pyterrier as pt
import ujson
import numpy as np

import itertools
import threading
import queue

from colbert.modeling.inference import ModelInference
from colbert.evaluation.loaders import load_colbert
from pyterrier_colbert import load_checkpoint
# monkeypatch to use our downloading version
import colbert.evaluation.loaders

colbert.evaluation.loaders.load_checkpoint = load_checkpoint
colbert.evaluation.loaders.load_model.__globals__['load_checkpoint'] = load_checkpoint
from colbert.utils.utils import print_message
import pickle
from colbert.indexing.index_manager import IndexManager
from warnings import warn

In [None]:
pt.init()

In [None]:
from pyterrier_colbert.preprocessing import DatasetPreprocessor, TokenRemover, HFTokenizer, NLTKTokenizer
from transformers import AutoTokenizer, AutoModelForMaskedLM


In [None]:
class Object():
    pass

In [None]:
checkpoint="http://www.dcs.gla.ac.uk/~craigm/colbert.dnn.zip"

In [None]:
args = Object()
args.similarity = 'cosine'
args.dim = 128
args.query_maxlen = 32
args.doc_maxlen = 180
args.checkpoint = checkpoint
args.mask_punctuation = False

In [None]:
wordpiece = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
hf_tokenizer = HFTokenizer(tokenizer=wordpiece)
nltk_tokenizer = NLTKTokenizer(tokenizer_type='treebank')

In [None]:
from pyterrier_colbert.ranking import ColBERTFactory

In [None]:
retrievers = {}

In [None]:
pyterrier_colbert_factory = ColBERTFactory(checkpoint, "./indexes/", "index.base.vaswani",memtype='mmap')
colbert_e2e = pyterrier_colbert_factory.end_to_end()
retrievers['base'] = colbert_e2e

In [None]:
cleaner_names = ['en', 'en2' ,'en4', 'few',]

In [None]:
for name in cleaner_names:
    factory = ColBERTFactory(checkpoint, f'./indexes', f'index.clean.{name}.vaswani')
    rete2e = factory.end_to_end()
    retrievers[name] = rete2e

In [None]:
dataset = pt.get_dataset("vaswani")

In [None]:
pt.Experiment(
    list(retrievers.values()),
    dataset.get_topics(),
    dataset.get_qrels(),
    eval_metrics=["recip_rank", "ndcg_cut_10", 'mrt'],
    names = list(retrievers.keys()),
    baseline=0
)

In [None]:
en_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'])
en2_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'], stopword_max_length=2)
en4_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'], stopword_max_length=4)
en_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt', 'stopwords/stopwords-punctuations.txt'])
lim_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-limited.txt'])
lim_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-limited.txt', 'stopwords/stopwords-punctuations.txt'])

In [None]:
en2_remover.stopwords = en2_remover.stopwords | set(['the'])

In [None]:
len(en_remover.stopwords), len(en2_remover.stopwords),len(en4_remover.stopwords) ,len(en_punc_remover.stopwords), len(lim_remover.stopwords), len(lim_punc_remover.stopwords)

In [None]:
cleaners = [('en', en_remover),('en2', en2_remover),('en4', en4_remover) ,('few', lim_remover)]

In [None]:
datasets_cleaned = {name: DatasetPreprocessor(dataset=dataset, tokenizer=nltk_tokenizer, preprocessor=cleaner) for name, cleaner in cleaners}

In [None]:
stopwords = set()
for file in ['stopwords/stopwords-limited.txt']:
    with open(file, 'r') as f:
        stopwords = stopwords | set(f.read().splitlines())

In [None]:
def clean(tokenizer, stopwords, maxl, x):
    text = tokenizer.tokenize(x)
    tokens = [tok for tok in text if tok not in stopwords]
    return tokenizer.detokenize(tokens)

In [None]:
clean_topics = dataset.get_topics().copy()
clean_topics['query'] = clean_topics['query'].map(lambda x: clean(nltk_tokenizer, stopwords, 512 ,x))

In [None]:
pt.Experiment(
    list(retrievers.values()),
    clean_topics,
    dataset.get_qrels(),
    eval_metrics=["recip_rank", "ndcg_cut_10", 'mrt'],
    names = list(retrievers.keys()),
    baseline=0
)

In [None]:
checkpoint

In [None]:
retrievers