In [None]:
!pip install pyserini==0.10.1.0 --quiet  
!pip install jsonlines

!unzip /content/task1_files.zip -d /content/Coliee_2021

In [None]:
from pyserini.search import SimpleSearcher
import jsonlines
from os import listdir
from os.path import isfile, join
import json
import spacy
from sklearn.metrics import (accuracy_score, f1_score, classification_report)
from tqdm.notebook import tqdm

test_labels = json.load(open('/content/task1_test_labels_2021.json', 'r'))

nlp = spacy.blank("en")
nlp.add_pipe(nlp.create_pipe("sentencizer"))

def example2seg(docs, max_length, stride):
    """
    Receives a document and segment it

    Args:
      docs: Document
      max_length: number of sentences in each segment
      stride: stride
    Returns:
      One segmented document.
    """
    doc = nlp(docs) 
    
    sentences = [sent.string.strip() for sent in doc.sents]

    segments = []
    for i in range(0, len(sentences), stride):
        segment = ' '.join(sentences[i:i + max_length])
        segments.append(segment)
        if i + max_length >= len(sentences):
            break

    return segments

def get_base_case (path):
    """
    Read one base case and segment it

    Args:
      path: Path to file

    Returns:
      One segmented base case.
    """
    path_to_file = '/content/Coliee_2021/{}'.format(path) 
    with open(path_to_file) as f:
        contents = f.read()
        f.close()
   
    contents = contents.replace('\n','').replace('FRAGMENT_SUPPRESSED','')
    base_case = ' '.join(contents.split())

    segmented_base_case = example2seg(base_case, 10, 5) 
    return segmented_base_case[0:25]

def index_doc(path, save=False, french=False): 
    """
    Read one candidate case and segment it. It also saves the candidate case as json to use in Pyserini.

    Args:
      path: Path to candidate file (string)    
      save: Save as json (bool)
      french: Use French (bool)
    Returns:
      One segmented candidate case.
    """
    
    path_to_file = '/content/Coliee_2021/{}'.format(path) 
    with open(path_to_file) as f:
        contents = f.read()
        f.close() 
    
    contents = contents.replace('\n','').replace('FRAGMENT_SUPPRESSED','')
    if len(contents.split()) <= 10:
        contents = contents.replace('<p style=','This is a wrong document. It should have only one segment. Thats it.').replace('FRAGMENT_SUPPRESSED','')
        contents = 'This is a wrong document. It should have only one segment. Thats it.'
    
    candidate_case = ' '.join(contents.split()) 
    segments_candidate = example2seg(candidate_case, 10, 5)
    
    list_segments_candidate = []
    for cont, segs in enumerate (segments_candidate):
        dict_ = { "id": "{}_segment{}-2021".format(path,str(cont+1)), "contents": segs}
        list_segments_candidate.append(dict_)   
        
        if save == True:
            with jsonlines.open('/content/tmp_candidates/candidate.jsonl', mode='a') as writer:
                writer.write(dict_)

def get_correct(top_sorted_list_dict, example):
    """
    Remove candidate cases that don't belong to an specific base case

    Args:
      sorted_list_dict_candidate: List of dicts containing the BM25's searcher answers
      example: base case number to be removed
    Returns:
      The correct segments for each base case.
    """
    correct_top_sorted_list_dict = []
    for dicts in top_sorted_list_dict:
        if dicts['candidate'].endswith('2021'): 
            if dicts['candidate'].split('.txt')[0] != example:
                correct_top_sorted_list_dict.append(dicts)

    return correct_top_sorted_list_dict

def fix_names(top_sorted_list_dict):
    """
    Fix dict name to perform evaluation

    Args:
     top_sorted_list_dict: BM25's answers

    Returns:
      list of dicts
    """  
    for dicts in top_sorted_list_dict:
        dicts['candidate'] = dicts['candidate'].split('.txt')[0]

    return top_sorted_list_dict

def get_top_scores(sorted_list_dict_candidate):
    """
    Get the best segments score for each candidate

    Args:
      sorted_list_dict_candidate: List of dicts containing the BM25's searcher answers

    Returns:
      The best segment for each candidate case.
    """
    top_sorted = []
    for sorted_candidates in sorted_list_dict_candidate:
        if sorted_candidates['candidate'] in [top['candidate'] for top in top_sorted]:
            pass
        else:
            top_sorted.append(sorted_candidates)
    return top_sorted

  
def sum_scores(list_dict):
    """
    Get the sum of segments score for each candidate

    Args:
      list_dict_candidate: List of dicts containing the BM25's searcher answers

    Returns:
      The sum of scores for each candidate case.
    """
    computed = []
    for iterate in list_dict:
        soma = sum(item['score'] for item in list_dict if item['candidate'] == iterate['candidate'] )
        new_dict = {'candidate':iterate['candidate'], 'score':soma}
        if new_dict not in computed:
            computed.append(new_dict)

    return computed

def evaluate_one_base_case(n_case):
    """
    Receives a list of segments and evaluate one base case on BM25

    Args:
      n_case: Base case path

    Returns:
      The best segment for each candidate case.
    """  
    list_hits = []
    list_dict_candidate = []
 
    base_case = get_base_case(n_case) 

    for bases in base_case:
        hits = searcher.search(bases[0:1024], k=10000)
        list_hits = list_hits + hits

    for num in range(len(list_hits)):
        dict_candidate = {'candidate':list_hits[num].docid, 'score': list_hits[num].score}
        list_dict_candidate.append(dict_candidate)
           
    sorted_list_dict_candidate = sorted(list_dict_candidate, key=lambda k : k['score'], reverse=True)

    correct_list_dict = get_correct(sorted_list_dict_candidate, n_case.split('.txt')[0])
    correct_list_dict = fix_names(correct_list_dict)
    top_sorted = get_top_scores(correct_list_dict)
    sum_dict_ = sum_scores(correct_list_dict)
    sum_dict_list = sorted(sum_dict_, key=lambda k : k['score'], reverse=True)
   
    return top_sorted, sum_dict_list

def run_bm25 (list_paths, save = False):
    """
    Run BM25

    Args:
      list_paths: list of paths to run bm25

    Returns:
      List of dict containing BM25 answers. dict = {candidate: <candidate>, score: <score>}
    """  
    list_score_top = []
    list_score_sum = []
    cont = 0
    for casos_base in tqdm(list_paths):  
        
        top_sorted_list_dict, sum_list_dict = evaluate_one_base_case(casos_base)
        list_score_top.append(top_sorted_list_dict)
        list_score_sum.append(sum_list_dict)
     
        
    return list_score_top, list_score_sum


def get_top_n(list_score_all, n):
    """
    Get top n candidates

    Args:
      list_score_all: BM25's answers
      n: number of top candidates to be chosen

    Returns:
      list of dicts containing top n
    """  
    new_list_score_all = []
    for tentativa in list_score_all:
        teste = sorted(tentativa, key=lambda k : k['score'], reverse=True)[:n]
        new_list_score_all.append(teste)
    
    return new_list_score_all

def my_classification_report(list_label_ohe, list_answer_ohe):
    """
    Calculate F1, Precision and Recall

    Args:
      list_label_ohe: list of one hot encodings of the labels
      list_answer_ohe: list of one hot encodings of the answers

    Returns:
      F1, Precision, Recall
    """  
    true_positive = 0
    false_positive = 0
    false_negative = 0

    for list_label, list_ohe in zip(list_label_ohe, list_answer_ohe):
      
        for l, o in zip(list_label, list_ohe):
          
            if o == 1 and l == 1:
                true_positive += 1
            elif o == 0 and l == 1:
                false_negative += 1
            elif o == 1 and l == 0:
                false_positive += 1

     

    precision = true_positive/(true_positive+false_positive)
    recall = true_positive/(true_positive+false_negative)
    f1 = 2*((precision*recall)/(precision + recall))

    return f1, precision, recall

def id_to_ohe(lista):
    """
    Convert list of numbers in one hot encoding

    Args:
      lista: list of integers
      

    Returns:
      List of one hot encodings
    """  
    list_ref = list(range(1, 100000)) # 201
    new_list_ref = []
    for t in list_ref:
        if t in lista:
            new_list_ref.append(1)
        else:
            new_list_ref.append(0)

    return new_list_ref

def evaluate(list_answers, label):
    """
    Convert both list of numbers in one hot encoding

    Args:
      list_answers: list of answers
      label: list of labels
      

    Returns:
      Lists of one hot encodings
    """  
    answer_ohe = id_to_ohe([int(l.split('.txt')[0]) for l in list_answers])
    label_ohe = id_to_ohe([int(k.split('.txt')[0]) for k in label]) #521 a 650
    
    return answer_ohe, label_ohe

def get_list_answers(list_score):
    """
    Get answer from dicts

    Args:
      list_score: 
      
    Returns:
      Lists of answers
    """  
    teste = [d['candidate'] for d in list_score]
    return teste

def get_max_percentage(list_score_all, percent):
    """
    Select answers above threeshold. If score_candidate > max(score_candidate)*percent -> select candidate

    Args:
      list_score_all: BM25's answers
      percent: Threshold

    Returns:
      List of selected dicts
    """  
    list_dict_percent = []
    for sample in list_score_all:
        try:
            maximum = max(sample, key=lambda x:x['score'])  
        except:
            if len(sample) == 0:
                pass
            else:
                maximum = sample[0]      
        dict_percent = [dicts for dicts in sample if dicts['score'] > maximum['score']*percent]     
        list_dict_percent.append(dict_percent)
        
    return list_dict_percent

def apply_threshold(list_score_all, threshold):
    """
    Select answers above threeshold. If score_candidate > threshold -> select candidate

    Args:
      list_score_all: BM25's answers
      threshold: Threshold

    Returns:
      List of selected dicts
    """  
    list_dict_threshold = []
    for samples in list_score_all:
        dict_threshold = [dicts for dicts in samples if dicts['score'] > threshold]
        list_dict_threshold.append(dict_threshold)

    return list_dict_threshold

def grid(threshold, top_n, percentss, list_dict):
    """
    Test three different heuristics parameters

    Args:
      threshold: Select score above this threshold 
      list_dict: BM25's answers
      top_n: Select top N scores 
      percentss: Select score above max(candidate_score)#percentss
      
    Returns:
      F1, Precision, Recall
    """  
    list_dict_percent = get_max_percentage(list_dict, percentss) 
    list_dict_threshold = apply_threshold(list_dict_percent, threshold) # try threshold
    new_list_score_all = get_top_n(list_dict_threshold, top_n) # try top_n    
    list_answer_ohe = []
    list_label_ohe = []

    for nn in range(len(new_list_score_all)):
        teste =  get_list_answers(new_list_score_all[nn])
        answer_ohe, label_ohe = evaluate(teste, list(test_labels.values())[nn])        
        list_answer_ohe.append(answer_ohe)
        list_label_ohe.append(label_ohe)

    val_micro_f1_score = f1_score(list_label_ohe, list_answer_ohe, average='micro')
    f1, precision, recall = my_classification_report(list_label_ohe, list_answer_ohe)

    return f1, precision, recall

## Main code


In [None]:
def index_all():
    '''
    Create documents in json format to be indexed before apply bm25
    '''
    !rm -r /content/candidates
    !rm -r /content/tmp_candidates
    !mkdir /content/tmp_candidates
    mypath = '/content/Coliee_2021'
    onlyfiles = [f for f in listdir(mypath) if isfile(join(mypath, f))] # get paths
    i = 0
    for path_candidate in tqdm(onlyfiles):
        index_doc(path_candidate, True)
    
        i+=1
    
index_all()

!python -m pyserini.index -collection JsonCollection -generator DefaultLuceneDocumentGenerator \
-threads 1 -input /content/tmp_candidates \
-index /content/candidates/indexes -storePositions -storeDocvectors -storeRaw

searcher = SimpleSearcher('/content/candidates/indexes')

### Run BM25

In [None]:
list_score_top, list_score_sum = run_bm25(test_labels, True) 

### Evaluate

In [None]:
list_threshold = [10, 20, 40, 60, 80, 90, 95]
list_top = [5, 10, 15, 20, 200]
list_percent = [0, 0.2, 0.4, 0.6, 0.8, 0.9]
list_tuples = []

for thresholds in list_threshold:
    for tops in list_top:
        for percents in list_percent:
            f1, precision, recall = grid(thresholds, tops, percents, list_score_top)           
            tuples = (thresholds, tops, percents, f1, precision, recall)
            list_tuples.append(tuples)
         
sorted_list = sorted(list_tuples, key=lambda x: x[3], reverse=True)
print('Threshold, Top_N, Percent, F1, Precision, Recall')
sorted_list[0:5]