In [1]:
!curl https://zenodo.org/record/3974431/files/vanilla_bert_tiny_on_MSMARCO.tar.gz >> vanilla_bert_tiny_on_MSMARCO.tar.gz

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 42.2M  100 42.2M    0     0  7181k      0  0:00:06  0:00:06 --:--:-- 9226k


In [2]:
!tar -zxvf vanilla_bert_tiny_on_MSMARCO.tar.gz

bert_models_onMSMARCO/vanilla_bert_tiny_on_MSMARCO/
bert_models_onMSMARCO/vanilla_bert_tiny_on_MSMARCO/model.ckpt-1600000.data-00000-of-00001
bert_models_onMSMARCO/vanilla_bert_tiny_on_MSMARCO/model.ckpt-1600000.meta
bert_models_onMSMARCO/vanilla_bert_tiny_on_MSMARCO/model.ckpt-1600000.index


In [27]:
!curl https://storage.googleapis.com/bert_models/2020_02_20/uncased_L-2_H-128_A-2.zip >> bert.zip
!unzip bert.zip

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 15.7M  100 15.7M    0     0  6607k      0  0:00:02  0:00:02 --:--:-- 6610k
Archive:  bert.zip
  inflating: bert_model.ckpt.data-00000-of-00001  
  inflating: bert_config.json        
replace vocab.txt? [y]es, [n]o, [A]ll, [N]one, [r]ename: ^C


In [60]:
import pyterrier as pt
if not pt.started():
    pt.init(version="snapshot")
    
import tensorflow as tf

from pyterrier.transformer import TransformerBase

from run_reranking import model_fn_builder
from input_parser_pyterrier import input_fn_builder
from bert.modeling import BertConfig
from generate_data import PointwiseInstance
from bert import tokenization

In [61]:
def create_instance_pointwise(tokenizer, max_seq_length, qid, docno, query, doc, label):
  query = tokenization.convert_to_unicode(query)
  doc = tokenization.convert_to_unicode(doc)
  passages = get_passages(doc, 150, 50)
  if len(passages) == 0:
    tf.logging.warn("Passage length is 0 in qid {} docno {}".format(qid, docno))

  query = tokenization.convert_to_bert_input(
    text=query,
    max_seq_length=64,
    tokenizer=tokenizer,
    add_cls=True,
    convert_to_id=False
  )
  passages = [tokenization.convert_to_bert_input(
    text=p,
    max_seq_length=max_seq_length-len(query),
    tokenizer=tokenizer,
    add_cls=False,
    convert_to_id=False
  ) for p in passages]
  instance = PointwiseInstance(
    exampleid="{}-{}".format(qid, docno),
    tokens_a=query,
    tokens_b_list=passages,
    relation_label=label
  )

  return instance

def get_passages(text, plen, overlap):
    """ Modified from https://github.com/AdeDZY/SIGIR19-BERT-IR/blob/master/tools/gen_passages.py
    :param text:
    :param plen:
    :param overlap:
    :return:
    """
    words = text.strip().split(' ')
    s, e = 0, 0
    passages = []
    while s < len(words):
      e = s + plen
      if e >= len(words):
        e = len(words)
      # if the last one is shorter than 'overlap', it is already in the previous passage.
      if len(passages) > 0 and e - s <= overlap:
        break
      p = ' '.join(words[s:e])
      passages.append(p)
      s = s + plen - overlap

    if len(passages) > 8:
      chosen_ids = sorted(random.sample(range(1, len(passages) - 1), 8 - 2))
      chosen_ids = [0] + chosen_ids + [len(passages) - 1]
      passages = [passages[id] for id in chosen_ids]

    #lobal stats
    #stats[len(passages)] += 1
    return passages

In [65]:
def df_to_instances(df, tokenizer):
    instances = []
    for line in df.itertuples():
        instances.append(create_instance_pointwise(
                            tokenizer=tokenizer,
                            max_seq_length=512,
                            qid=line.qid,
                            docno=line.docno,
                            query=line.query,
                            doc=line.body,
                            label='hope i dont need this'
        ))
    return instances

        
class PARADEPipeline(TransformerBase):
    def __init__(self, aggregation_method):
        self.aggregation_method = aggregation_method #'cls_max',  'cls_avg', 'cls_attn' or 'cls_transformer'
        self.tokenizer = tokenization.WordpieceTokenizer(vocab='vocab.txt')
        self.run_config = tf.estimator.tpu.RunConfig(
            cluster=None,
            model_dir=None,
            save_checkpoints_steps=1000,
            keep_checkpoint_max=1,
            tpu_config=None)
            
        '''
        tpu_config is set to None since we aren't using a tpu, but in case we actually need this
        i'll just keep this commented out instead of deleting it
            
            tpu_config=tf.estimator.tpu.TPUConfig(
                iterations_per_loop=1000,
                num_shards=8,
                per_host_input_for_training=is_per_host))
        '''
        
        self.model_fn = model_fn_builder(
            bert_config=BertConfig.from_json_file('bert_models_onMSMARCO/vanilla_bert_tiny_on_MSMARCO/bert_config.json'),
            num_labels=2,
            init_checkpoint='bert_models_onMSMARCO/vanilla_bert_tiny_on_MSMARCO/model.ckpt-1600000',
            learning_rate=5e-5,
            num_train_steps=None,
            num_warmup_steps=None,
            use_tpu=False,
            use_one_hot_embeddings=False,
            aggregation_method=self.aggregation_method,
            pretrained_model='bert',
            from_distilled_student=False)
        
        self.estimator = tf.estimator.tpu.TPUEstimator(
            use_tpu=False,
            model_fn=self.model_fn,
            config=self.run_config,
            train_batch_size=32,
            eval_batch_size=32,
            predict_batch_size=32)
        
        
    def transform(self, queries_and_docs):
        eval_input_fn = input_fn_builder(
            dataset_df=df_to_instances(queries_and_docs, self.tokenizer),
            max_num_segments_perdoc=8,
            max_seq_length=128,
            is_training=False)
            
        result = self.estimator.predict(input_fn=eval_input_fn, yield_single_examples=True)
        
        results = []
        for item in result:
            pass
            

In [66]:
import pandas as pd
q = "chemical reactions"
doc1 = "professor proton demonstrated the chemical reaction"
doc2 = "chemical brothers is great techno music"

df = pd.DataFrame([["q1", q, "doc1", doc1], ["q1", q, "doc2", doc2]], columns=["qid", "query", "docno", "body"])
df

Unnamed: 0,qid,query,docno,body
0,q1,chemical reactions,doc1,professor proton demonstrated the chemical rea...
1,q1,chemical reactions,doc2,chemical brothers is great techno music


In [67]:
pipeline = PARADEPipeline(aggregation_method='cls_max')

#vaswani  = pt.datasets.get_dataset("vaswani")
#vaswani.get_corpus()


pipeline(df)

INFO:tensorflow:Using config: {'_model_dir': '/tmp/tmp3_f0x6ey', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': 1000, '_save_checkpoints_secs': None, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 1, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': None, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_service': None, '_cluster_spec': <tensorflow.python.training.server_lib.ClusterSpec object at 0x7f57fe112b70>, '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1, '_tpu_config': TPUConfig(iterations_per_loop=2, num_shards=None, num_cores_per_replica=None,

ValueError: Unsupported string type: <class 'list'>

In [50]:
x = df_to_instances(df, WordpieceTokenizer(vocab='vocab.txt'))