In [None]:
!pip install pygaggle --quiet
!git clone --recursive https://github.com/castorini/pygaggle.git --quiet
!pip install -r /content/pygaggle/requirements.txt --quiet
!pip install jsonlines --quiet
!pip install spacy==2.3.0 --quiet

!unzip "/content/task2_2021_test_nolabels.zip" -d /content/task2_test # path to test set
label = json.load(open('/content/task2_test_labels_2021.json', 'r')) # path to labels


In [2]:
import torch

if torch.cuda.is_available(): 
    dev = "cuda:0"
    print(dev, torch.cuda.get_device_name(0))
    device = torch.device(dev)
else: 
    dev = "cpu"
    print(dev) 

from pyserini.search import SimpleSearcher
from pygaggle.rerank.base import hits_to_texts
from pygaggle.rerank.base import Query, Text
from pygaggle.rerank.transformer import MonoT5
from transformers import T5ForConditionalGeneration
import json
from pyserini.search import SimpleSearcher
import spacy
from sklearn.metrics import (accuracy_score, f1_score, classification_report)
import jsonlines
import os
from statistics import mean
import statistics
import math
from tqdm.notebook import tqdm
import numpy as np


def int2str(i):
    """
    Fix numbers using three digits according to database. Used to read data.

    Args:
      i: An integer number between 1 and 650

    Returns:
      An integer with 3 digits between 1 and 650.
    """
    if i < 10:
        j = '00' + str(i)
    elif i >= 10 and i < 100:
        j = '0' + str(i)
    else:
        j = str(i)

    return j


def get_base_case (path_to_file):
    """
    Read one base case and segment it

    Args:
      path_to_file: Path to file

    Returns:
      One segmented base case.
    """
    
    with open(path_to_file) as f:
        contents = f.read()
        f.close()

    contents = contents.replace('\n',' ').replace('FRAGMENT_SUPPRESSED','')
    base_case = ' '.join(contents.split())
    
    return base_case


def get_candidate(path_to_file, i, base_case_number, save): 
    """
    Read one candidate case and segment it. It also saves the candidate case as json to use in Pyserini.

    Args:
      path_to_file: Path to candidate file (string)
      i: Candidate case number (int)
      base_case_number: Base case number (int)
      save: Save as json (bool)
     
    Returns:
      One segmented candidate case.
    """
    with open(path_to_file) as f:
        contents = f.read()
        f.close()

    contents = contents.replace('\n','').replace('FRAGMENT_SUPPRESSED','')
    if len(contents.split()) <= 10:
        contents = contents.replace('<p style=','This is a wrong document. It should have only one segment. Thats it.').replace('FRAGMENT_SUPPRESSED','')
        contents = 'This is a wrong document. It should have only one segment. Thats it.'
  
    candidate_case = ' '.join(contents.split()) 
    dict_ = { "id": "{}_candidate{}.txt_task2".format(str(base_case_number),str(i)), "contents": candidate_case}
        
    if save == True:
        with jsonlines.open('/content/tmp_candidates/candidate.jsonl', mode='a') as writer:
            writer.write(dict_)
 
    return candidate_case


def get_one_candidate(base_case_number, k, save):  # get candidates
    """
    Read a candidate case to prepare preprocessing

    Args:
      base_case_number: Base case number
      k: Candidate case number
      save: Save candidate case as json
    Returns:
      One segmented base case.
    """
    n = k
    if int(base_case_number) <= 425:
        candidate_path = '/content/task2_train/{}/paragraphs/{}.txt'.format(base_case_number, n)
    else:
        candidate_path = '/content/task2_test/{}/paragraphs/{}.txt'.format(base_case_number, n)
       
    list_segments_candidate = get_candidate(candidate_path, n, base_case_number, save)
       
    return list_segments_candidate


def main(n_case, n_candidate, save = False):
    """
    Read a base case and segment it

    Args:
      n_case: Base case number
      n_candidate: Candidate case number
      save: Save candidate case as json
    Returns:
      One segmented base case and one segmented candidate case.
    """
    
    base_case_number = int2str(n_case)
    candidate_number = int2str(n_candidate)
    if int(base_case_number) <= 425:
        path_base = '/content/task2_train/{}/entailed_fragment.txt'.format(base_case_number)
    else:
        path_base = '/content/task2_test/{}/entailed_fragment.txt'.format(base_case_number)
    
    base_case = get_base_case(path_base)
    list_candidates = get_one_candidate(base_case_number, candidate_number, save)

    return base_case, list_candidates


def get_top_scores(sorted_list_dict_candidate):
    """
    Get the best segments score for each candidate

    Args:
      sorted_list_dict_candidate: List of dicts containing the BM25's searcher answers

    Returns:
      The best segment for each candidate case.
    """
    top_sorted = []
    for sorted_candidates in sorted_list_dict_candidate:
        if sorted_candidates['candidate'] in [top['candidate'] for top in top_sorted]:
            pass
        else:
            top_sorted.append(sorted_candidates)
    return top_sorted


def get_correct(n_case, top_sorted_list_dict):
    """
    Remove candidate cases that don't belong to an specific base case

    Args:
      n_case: Base case number
      sorted_list_dict_candidate: List of dicts containing the BM25's searcher answers

    Returns:
      The correct segments for each base case.
    """
    correct_top_sorted_list_dict = []
    for dicts in top_sorted_list_dict:
        if dicts['candidate'].split('_')[0] == str(n_case):
            correct_top_sorted_list_dict.append(dicts)

    return correct_top_sorted_list_dict

def evaluate_one_base_case(n_case):
    """
    Receives a list of segments and evaluate one base case on BM25

    Args:
      n_case: Base case number

    Returns:
      The best segment for each candidate case.
    """  
    list_t5 = []  
    list_dict_candidate = []
    list_hits = []
 
    base_case, _ = main(n_case, 1, save=False)  
    segments_base = example2seg(base_case, 1, 1)
    for segments in segments_base:
        if len(segments_base) > 8:
            hits = searcher.search(segments[0:1024], k=150000)
        elif len(segments_base) > 2 and len(segments_base) <= 8:
            hits = searcher.search(segments[0:1024], k=500000)
        else:
            hits = searcher.search(segments[0:1024], k=1000000)
        list_hits = list_hits + hits
    
    output = prepare_data_t5(list_hits, int2str(n_case))
    texts = hits_to_texts(output)
    query = Query(base_case)
    reranked = reranker.rerank(query, texts)
    reranked.sort(key=lambda x: x.score, reverse=True)

    for num in range(len(output)):
        dict_t5 = {'candidate':reranked[num].metadata["docid"], 'score': math.exp(reranked[num].score) * 100 } 
        list_t5.append(dict_t5) 

    sorted_list_dict_candidate = sorted(list_t5, key=lambda k : k['score'], reverse=True)
    top_dict_t5 = get_top_scores(sorted_list_dict_candidate)
      
    return top_dict_t5

def prepare_data_t5(list_hits, example):
    """
    Filter best segments to be analyzed by T5

    Args:
      list_hits: Answers from BM25
      example: path of the base case analyzed

    Returns:
      List of pyserini objects containing top answers
    """  
    top_sorted = []
    for hits in list_hits: 
        if hits.docid.endswith('task2'): 
            if hits.docid.split('_candidate')[0] == example:
                hits.docid = hits.docid.split('_task2')[0]
                hits.docid = hits.docid.split('_candidate')[1]
                top_sorted.append(hits)
                
    return top_sorted  


def run_bm25 (n_base):
    """
   Run BM25

    Args:
      n_base: Number of base cases to run

    Returns:
      List of dict containing BM25 answers. dict = {candidate: <candidate>, score: <score>}
    """  
    list_score_t5 = []
    list_segs = np.arange(426,526)
    for casos_base in tqdm(list_segs):          
        top_t5 = evaluate_one_base_case(casos_base)
        list_score_t5.append(top_t5)      
      
    return list_score_t5


nlp = spacy.blank("en")
nlp.add_pipe(nlp.create_pipe("sentencizer"))
def example2seg(docs, max_length, stride):
    """
    Receives a document and segment it

    Args:
      docs: Document
      max_length: number of sentences in each segment
      stride: stride
    Returns:
      One segmented document.
    """
    doc = nlp(docs) 
    sentences = [sent.string.strip() for sent in doc.sents]
    segments = []
    for i in range(0, len(sentences), stride):
        segment = ' '.join(sentences[i:i + max_length])
        segments.append(segment)
        if i + max_length >= len(sentences):
            break

    return segments


def get_top_n(list_score_all, n):
    """
    Get top n candidates

    Args:
      list_score_all: BM25's answers
      n: number of top candidates to be chosen

    Returns:
      list of dicts containing top n
    """  
    new_list_score_all = []
    for tentativa in list_score_all:
        teste = sorted(tentativa, key=lambda k : k['score'], reverse=True)[:n]
        new_list_score_all.append(teste)
    
    return new_list_score_all


def my_classification_report(list_label_ohe, list_answer_ohe):
    """
    Calculate F1, Precision and Recall

    Args:
      list_label_ohe: list of one hot encodings of the labels
      list_answer_ohe: list of one hot encodings of the answers

    Returns:
      F1, Precision, Recall
    """  
    true_positive = 0
    false_positive = 0
    false_negative = 0

    for list_label, list_ohe in zip(list_label_ohe, list_answer_ohe):
      
        for l, o in zip(list_label, list_ohe):
          
            if o == 1 and l == 1:
                true_positive += 1
            elif o == 0 and l == 1:
                false_negative += 1
            elif o == 1 and l == 0:
                false_positive += 1   

    precision = true_positive/(true_positive+false_positive)
    recall = true_positive/(true_positive+false_negative)
    f1 = 2*((precision*recall)/(precision + recall))

    return f1, precision, recall


def id_to_ohe(lista):
    """
    Convert list of numbers in one hot encoding

    Args:
      lista: list of integers
      

    Returns:
      List of one hot encodings
    """  
    list_ref = list(range(1, 243)) 
    new_list_ref = []
    for t in list_ref:
        if t in lista:
            new_list_ref.append(1)
        else:
            new_list_ref.append(0)

    return new_list_ref


def evaluate(list_answers, label):
    """
    Convert both list of numbers in one hot encoding

    Args:
      list_answers: list of answers
      label: list of labels
      

    Returns:
      Lists of one hot encodings
    """  
    answer_ohe = id_to_ohe([int(l.split('.txt')[0]) for l in list_answers])
    label_ohe = id_to_ohe([int(k.split('.txt')[0]) for k in label]) 
    
    return answer_ohe, label_ohe


def get_list_answers(list_score):
    """
    Get answer from dicts

    Args:
      list_score: 
      
    Returns:
      Lists of answers
    """  
    teste = [d['candidate'] for d in list_score]
    return teste


def get_max_percentage(list_score_all, percent):
    """
    Select answers above threeshold. If score_candidate > max(score_candidate)*percent -> select candidate

    Args:
      list_score_all: BM25's answers
      percent: Threshold

    Returns:
      List of selected dicts
    """  
    list_dict_percent = []
    for sample in list_score_all:
        try:
            maximum = max(sample, key=lambda x:x['score'])  
        except:
            maximum = 1
            print('erro!')
    
        dict_percent = [dicts for dicts in sample if dicts['score'] > maximum['score']*percent]     
        list_dict_percent.append(dict_percent)
        
    return list_dict_percent


def apply_threshold(list_score_all, threshold):
    """
    Select answers above threeshold. If score_candidate > threshold -> select candidate

    Args:
      list_score_all: BM25's answers
      threshold: Threshold

    Returns:
      List of selected dicts
    """  
    list_dict_threshold = []
    for samples in list_score_all:
        dict_threshold = [dicts for dicts in samples if dicts['score'] > threshold]
        list_dict_threshold.append(dict_threshold)

    return list_dict_threshold


def grid(threshold, top_n, percentss, list_dict):
    """
    Test three diferent heuristics parameters

    Args:
      threshold: Select score above this threshold 
      list_dict: BM25's answers
      top_n: Select top N scores 
      percentss: Select score above max(candidate_score)#percentss
      
    Returns:
      F1, Precision, Recall
    """  
    list_dict_percent = get_max_percentage(list_dict, percentss) 
    list_dict_threshold = apply_threshold(list_dict_percent, threshold) 
    new_list_score_all = get_top_n(list_dict_threshold, top_n)     
    list_answer_ohe = []
    list_label_ohe = []
  
    for nn in range(len(new_list_score_all)):
        teste =  get_list_answers(new_list_score_all[nn])
        answer_ohe, label_ohe = evaluate(teste, label[int2str(nn+426)])        
        list_answer_ohe.append(answer_ohe)
        list_label_ohe.append(label_ohe)

    val_micro_f1_score = f1_score(list_label_ohe, list_answer_ohe, average='micro')
    f1, precision, recall = my_classification_report(list_label_ohe, list_answer_ohe)

    return f1, precision, recall

cuda:0 Tesla K80


2021-09-22 18:27:43 [INFO] loader: Loading faiss with AVX2 support.
2021-09-22 18:27:43 [INFO] loader: Could not load library with AVX2 support due to:
ModuleNotFoundError("No module named 'faiss.swigfaiss_avx2'")
2021-09-22 18:27:43 [INFO] loader: Loading faiss.
2021-09-22 18:27:43 [INFO] loader: Successfully loaded faiss.


## Main code

In [3]:
## Index data to be further retrieved
!mkdir tmp_candidates

list_hits = []
list_dict_candidate = []
for casos in range(426,526): 
    paragraphs = os.listdir('/content/task2_test/{}/paragraphs'.format(int2str(casos)))
    for candidatos in range(1, len(paragraphs)+1): 
        base_case, list_candidates = main(casos, candidatos, save=True)

!python -m pyserini.index -collection JsonCollection -generator DefaultLuceneDocumentGenerator \
-threads 1 -input /content/tmp_candidates \
-index /content/candidates/indexes -storePositions -storeDocvectors -storeRaw

searcher = SimpleSearcher('/content/candidates/indexes')

2021-09-22 18:27:51,924 INFO  [main] index.IndexCollection (IndexCollection.java:631) - Setting log level to INFO
2021-09-22 18:27:51,927 INFO  [main] index.IndexCollection (IndexCollection.java:634) - Starting indexer...
2021-09-22 18:27:51,928 INFO  [main] index.IndexCollection (IndexCollection.java:636) - DocumentCollection path: /content/tmp_candidates
2021-09-22 18:27:51,929 INFO  [main] index.IndexCollection (IndexCollection.java:637) - CollectionClass: JsonCollection
2021-09-22 18:27:51,929 INFO  [main] index.IndexCollection (IndexCollection.java:638) - Generator: DefaultLuceneDocumentGenerator
2021-09-22 18:27:51,930 INFO  [main] index.IndexCollection (IndexCollection.java:639) - Threads: 1
2021-09-22 18:27:51,930 INFO  [main] index.IndexCollection (IndexCollection.java:640) - Stemmer: porter
2021-09-22 18:27:51,931 INFO  [main] index.IndexCollection (IndexCollection.java:641) - Keep stopwords? false
2021-09-22 18:27:51,931 INFO  [main] index.IndexCollection (IndexCollection.ja

In [4]:
## Model
model = T5ForConditionalGeneration.from_pretrained('castorini/monot5-large-msmarco-10k').to(device).eval()
reranker =  MonoT5(model)

## Run BM25
searcher = SimpleSearcher('/content/candidates/indexes')  
list_score_t5 = run_bm25(100)

f1, precision, recall = grid(0, 3, 0.995, list_score_t5)

print('F1:', f1)
print('Precision:', precision)
print('Recall:', recall)

Downloading:   0%|          | 0.00/1.36k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.95G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.20k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.39M [00:00<?, ?B/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated eos tokens being added."


F1: 0.6956521739130435
Precision: 0.7079646017699115
Recall: 0.6837606837606838
