In [None]:
import faiss
import pyterrier as pt
import ujson
import numpy as np
import pandas as pd

import itertools
import threading
import queue

from colbert.modeling.inference import ModelInference
from colbert.evaluation.loaders import load_colbert
from pyterrier_colbert import load_checkpoint
# monkeypatch to use our downloading version
import colbert.evaluation.loaders

colbert.evaluation.loaders.load_checkpoint = load_checkpoint
colbert.evaluation.loaders.load_model.__globals__['load_checkpoint'] = load_checkpoint
from colbert.utils.utils import print_message
import pickle
from colbert.indexing.index_manager import IndexManager
from warnings import warn

In [None]:
pt.init()

In [None]:
from pyterrier_colbert.preprocessing import DatasetPreprocessor, TokenRemover, HFTokenizer, NLTKTokenizer
from transformers import AutoTokenizer, AutoModelForMaskedLM
import ir_datasets, ir_measures

In [None]:
class Object():
    pass

In [None]:
checkpoint="http://www.dcs.gla.ac.uk/~craigm/colbert.dnn.zip"

In [None]:
args = Object()
args.similarity = 'cosine'
args.dim = 128
args.query_maxlen = 32
args.doc_maxlen = 180
args.checkpoint = checkpoint
args.mask_punctuation = False

In [None]:
wordpiece = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
hf_tokenizer = HFTokenizer(tokenizer=wordpiece)
nltk_tokenizer = NLTKTokenizer(tokenizer_type='treebank')

In [None]:
from ir_datasets import create_dataset

In [None]:
dataset = create_dataset(docs_tsv='minimarco/msmarco-passage-trec-dl-2019-docs.tsv', queries_tsv='minimarco/msmarco-passage-trec-dl-2019-queries.tsv', qrels_trec='minimarco/msmarco-passage-trec-dl-2019-qrels.tsv')

In [None]:
irds_dataset = pt.datasets.IRDSDataset(irds_id='irds:minimarco', defer_load=True)
irds_dataset._irds_ref = dataset

In [None]:
irds_dataset.get_qrels()

In [None]:
en_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'])
en_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt', 'stopwords/stopwords-punctuations.txt'])

en2_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'], stopword_max_length=2)
en2_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt', 'stopwords/stopwords-punctuations.txt'], stopword_max_length=2)

en4_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'], stopword_max_length=4)
en4_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt', 'stopwords/stopwords-punctuations.txt'], stopword_max_length=4)


lim_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-limited.txt'])
lim_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-limited.txt', 'stopwords/stopwords-punctuations.txt'])

en2the_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-en.txt'], stopword_max_length=2)
en2the_remover.stopwords = en2the_remover.stopwords | set(['the'])

en2the_punc_remover = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-limited.txt', 'stopwords/stopwords-punctuations.txt'], stopword_max_length=2)
en2the_punc_remover.stopwords = en2the_punc_remover.stopwords | set(['the'])

punc_only = TokenRemover(tokenizer=nltk_tokenizer, stopwords_files=['stopwords/stopwords-punctuations.txt'])

In [None]:
cleaners = [('en', en_remover),
            ('enpunc', en_punc_remover),
            ('en2', en2_remover),
            ('en2punc', en2_punc_remover),
            ('en2the', en2the_remover),
            ('en2thepunc', en2the_punc_remover),
            ('en4', en4_remover) ,
            ('en4punc', en4_punc_remover),
            ('few', lim_remover),
            ('punc', punc_only),
            ('fewpunc', lim_punc_remover)
           ]

In [None]:
datasets_cleaned = {name: DatasetPreprocessor(dataset=irds_dataset, tokenizer=nltk_tokenizer, preprocessor=cleaner) for name, cleaner in cleaners}

In [None]:
import pyterrier_colbert.indexing
import torch
import os

In [None]:
if not os.path.exists('./indexes/index.base.minimarco/'):
    base_indexer = pyterrier_colbert.indexing.ColBERTIndexer(checkpoint, "./indexes", "index.base.minimarco", chunksize=3, num_partitions=256)
    base_indexer.index(irds_dataset.get_corpus_iter(),)

In [None]:
for name, data_iter in datasets_cleaned.items():
    if not os.path.exists(f'./indexes/index.clean.{name}.minimarco/'):
        cleaned_indexer = pyterrier_colbert.indexing.ColBERTIndexer(checkpoint, f'./indexes', f'index.clean.{name}.minimarco', chunksize=3, num_partitions=256)
        cleaned_indexer.index(data_iter)
        clean_indexer = None
        torch.cuda.empty_cache()

In [None]:
print('gg')