In [1]:
import json
import random
from transformers import AutoTokenizer, AdamW, get_linear_schedule_with_warmup, BertModel, AutoModel, LongformerModel
import numpy as np
from allennlp.common.util import import_module_and_submodules as import_submodules
from allennlp.models.archival import load_archive
from allennlp.predictors import Predictor
from scipy.spatial import distance
from nltk.tokenize import sent_tokenize
import nltk
import torch
import tqdm
import re
nltk.download('punkt')

import sys
import os
sys.path.append(os.path.abspath('..'))
from scipy.special import softmax

import_submodules("allennlp_lib")

DATASET="ecthr"
MODEL_NAME="nlpaueb/legal-bert-base-uncased"
model = AutoModel.from_pretrained(MODEL_NAME, return_dict=True)
model_path = "/home/irs38/Negative-Precedent-in-Legal-Outcome-Prediction/results/Outcome/joint_model/legal_bert/facts/7f8014c5df0f432eb3f6c551ecee9ed1/model.pt"
model = torch.load(model_path)
model_state_dict = model.state_dict()

#archive = load_archive(model_path + '/model.tar.gz')
#print(archive.config)
#archive.config['dataset_reader']['type'] = 'ecthr'
#archive.config['model']['output_hidden_states'] = True
#model = archive.model
#model._output_hidden_states = True
#predictor = Predictor.from_archive(archive, 'ecthr')


def text_preprocessing(text):
    """
    - Remove entity mentions (eg. '@united')
    - Correct errors (eg. '&amp;' to '&')
    @param    text (str): a string to be processed.
    @return   text (Str): the processed string.
    """
    # Remove '@name'
    text = re.sub(r'(@.*?)[\s]', ' ', text)

    # Replace '&amp;' with '&'
    text = re.sub(r'&amp;', '&', text)

    # Remove trailing whitespace
    text = re.sub(r'\s+', ' ', text).strip()

    return text

def preprocessing_for_bert(data, tokenizer, max=512):
    """Perform required preprocessing steps for pretrained BERT.
    @param    data (np.array): Array of texts to be processed.
    @return   input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
    @return   attention_masks (torch.Tensor): Tensor of indices specifying which
                  tokens should be attended to by the model.
    """

    # For every sentence...
    input_ids = []
    attention_masks = []

    for sent in tqdm(data):
        sent = " ".join(sent)
        sent = sent[:500000] # Speeds the process up for documents with a lot of precedent we would truncate anyway.
        # `encode_plus` will:
        #    (1) Tokenize the sentence
        #    (2) Add the `[CLS]` and `[SEP]` token to the start and end
        #    (3) Truncate/Pad sentence to max length
        #    (4) Map tokens to their IDs
        #    (5) Create attention mask
        #    (6) Return a dictionary of outputs
        encoded_sent = tokenizer.encode_plus(
            text=text_preprocessing(sent),  # Preprocess sentence
            add_special_tokens=True,  # Add `[CLS]` and `[SEP]`
            max_length=max,  # Max length to truncate/pad
            pad_to_max_length=True,  # Pad sentence to max length
            # return_tensors='pt',           # Return PyTorch tensor
            return_attention_mask=True,  # Return attention mask
            truncation=True,
        )

        # Add the outputs to the lists
        input_ids.append([encoded_sent.get('input_ids')])
        attention_masks.append([encoded_sent.get('attention_mask')])

    # Convert lists to tensors
    input_ids = torch.tensor(input_ids)
    attention_masks = torch.tensor(attention_masks)

    return input_ids, attention_masks

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

with open(model_path + "/label2index.json", "r") as f:
    label2index = json.load(f)
    index2label = {label2index[k]: k for k in label2index}
label2index


  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to /home/irs38/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


FileNotFoundError: file ../experiments/models/ecthr/allenai/longformer-base-4096/model.tar.gz not found

In [7]:
def all_masks(tokenized_text):
    # https://stackoverflow.com/questions/1482308/how-to-get-all-subsets-of-a-set-powerset
    # WITHOUT empty and full sets!
    s = list(range(len(tokenized_text)))
    x = len(s)
    masks = [1 << i for i in range(x)]
    #     for i in range(1 << x):  # empty and full sets included here
    for i in range(1, 1 << x - 1):
        yield [ss for mask, ss in zip(masks, s) if i & mask]
        
def all_consecutive_masks(tokenized_text, max_length = -1):
    # WITHOUT empty and full sets!
    s = list(range(len(tokenized_text)))
    x = len(s)
    for i in range(x):
        for j in range(i+1, x):
            mask = s[:i] + s[j:]
            if max_length > 0:
                if j - i >= max_length:
                    yield mask
            else:
                yield mask
                
def all_consecutive_masks2(tokenized_text, max_length = -1):
    # WITHOUT empty and full sets!
    s = list(range(len(tokenized_text)))
    x = len(s)
    for i in range(x+1):
        for j in range(i+1, x+1):
            mask = s[i:j]
            if max_length > 0:
                if j - i <= max_length:
                    yield mask
            else:
                yield mask

def precisionAtK(actual, predicted, k):
    act_set = set(actual)
    pred_set = set(predicted[:k])
    result = len(act_set & pred_set) / float(k)
    return result

def recallAtK(actual, predicted, k):
    act_set = set(actual)
    pred_set = set(predicted[:k])
    result = len(act_set & pred_set) / float(len(act_set))
    return result

def meanPrecisionAtK(actual, predicted, k):
    return np.mean([precisionAtK(a, p, k) for a, p in zip(actual, predicted)])

def meanRecallAtK(actual, predicted, k):
    return np.mean([recallAtK(a, p, k) for a, p in zip(actual, predicted)])


In [4]:
#read in the validation data, which is a json dict in each new line
with open("/home/irs38/contrastive-explanations/data/ecthr/Chalkidis/simple_val.jsonl", "r") as f:
    val_data = [json.loads(line) for line in f.readlines()]
with open("/home/irs38/contrastive-explanations/data/ecthr/Chalkidis/dev.jsonl", "r") as f:
    val_meta_data = [json.loads(line) for line in f.readlines()]

for item in val_data: 
    item["facts_sentences"] = [i for i in val_meta_data if i["case_no"] == item["case_no"]][0]["facts"]

articles = ['10', '11', '13', '14', '18', '2', '3', '4', '5', '6', '7', '8', '9', 'P1-1', 'P4-2', 'P7-1', 'P7-4']

#ex = {"facts": "5.  The applicant was born in 1983 and is detained in Sztum. 6.  At the time of the events in question, the applicant was serving a prison sentence in the Barczewo prison. 7.  On 8 January 2011 the applicant\u2019s grandmother died. On 10 January 2011 the applicant lodged a request with the Director of Prison and the Penitentiary judge for leave to attend her funeral which was to take place on 12 January 2011. Together with his application he submitted a statement from his sister E.K. who confirmed that she would personally collect the applicant from prison and bring him back after the funeral. 8.  On 11 January 2011 the Penitentiary judge of the Olsztyn Regional Court (S\u0119dzia Penitencjarny S\u0105du Okr\u0119gowego w Olsztynie) allowed the applicant to attend the funeral under prison officers\u2019 escort. The reasoning of the decision read as follows:\n\u201cIn view of [the applicant\u2019s] multiple convictions and his long term of imprisonment there is no guarantee that he will return to prison\u201d 9.  The applicant refused to attend the funeral, since he believed his appearance under escort of uniformed officers would create a disturbance during the ceremony. 10.  On the same day the applicant lodged an appeal with the Olsztyn Regional Court (S\u0105d Okr\u0119gowy) complaining that the compassionate leave was granted under escort and also that he was only allowed to participate in the funeral (not the preceding church service). 11.  On 3 February 2011 the Olsztyn Regional Court upheld the Penitentiary judge\u2019s decision and dismissed the appeal. The court stressed that the applicant had been allowed to participate in the funeral under prison officers\u2019 escort. It further noted that the applicant was a habitual offender sentenced to a long term of imprisonment therefore there was no positive criminological prognosis and no guarantee that he would have returned to prison after the ceremony.", "claims": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], "outcomes": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "case_no": "20488/11"}
#ex = {"facts": "4.  The applicant was born in 1960 and lives in Oleksandrivka, Kirovograd Region. 5.  On 3 February 2007 the applicant was assaulted. According to the subsequent findings of medical experts, she sustained haematomas on her jaw, shoulder and hip, a bruise under her right eye, concussion, and a displaced rib fracture. The applicant alleges that her assailants were Mr and Mrs K., her daughter\u2019s former parents-in-law, whereas the domestic authorities found that it was only Mrs K. who had assaulted the applicant. The incident occurred in front of the applicant\u2019s two-year-old granddaughter. 6.  On 4 February 2007 the applicant lodged a complaint with the police. 7.  On 5 February 2007 a forensic medical expert examined the applicant. He found that she had haematomas which he classified as \u201cminor bodily injuries\u201d. 8.  On 14 February 2007 the Oleksandrivka District Police Department (\u201cthe Oleksandrivka police\u201d) refused to institute criminal proceedings in connection with the incident. 9.  On 22 February 2007 a forensic medical examination of the applicant was carried out. The expert found that in addition to the previously noted haematomas, the applicant had also suffered concussion and a displaced rib fracture. The expert classified the injuries as \u201cbodily harm of medium severity\u201d. 10.  On 20 March 2007 the Oleksandrivka prosecutor overruled the decision of 14 February 2007 as premature and on 21 March 2007 instituted criminal proceedings in connection with the infliction of bodily harm of medium severity on the applicant. 11.  On 20 May 2007 the investigator suspended the investigation for failure to identify the perpetrator. 12.  On 29 August and 3 October 2007 the Oleksandrivka prosecutor\u2019s office issued two decisions in which it overruled the investigator\u2019s decision of 20 May 2007 as premature. 13.  On 6 October 2007 the investigator questioned Mr and Mrs K. 14.  On 1 December 2007 the investigator again suspended the investigation for failure to identify the perpetrator. 15.  On 10 December 2007 the Oleksandrivka prosecutor\u2019s office, in response to the applicant\u2019s complaint about the progress of the investigation, asked the Kirovograd Regional Police Department to have the police officers in charge of the investigation disciplined. 16.  On 21 January 2008 the Kirovograd Regional Police Department instructed the Oleksandrivka police to immediately resume the investigation. 17.  On 7 April 2008 the investigator decided to ask a forensic medical expert to determine the degree of gravity of the applicant\u2019s injuries. On 22 September 2008 the expert drew up a report generally confirming the findings of 22 February 2007. 18.  On 15 May 2008 the Kirovograd Regional Police Department informed the applicant that the police officers in charge of the case had been disciplined for omissions in the investigation. 19.  On 23 October 2008 the Oleksandrivka Court absolved Mrs K. from criminal liability under an amnesty law, on the grounds that she had an elderly mother who was dependent on her. On 24 February 2009 the Kirovograd Regional Court of Appeal (\u201cthe Court of Appeal\u201d) quashed that judgment, finding no evidence that Mrs K.\u2019s mother was dependent on her. 20.  On 1 July 2009 the investigator refused to institute criminal proceedings against Mr K. 21.  On 7 July 2009 the Novomyrgorod prosecutor issued a bill of indictment against Mrs K. 22.  On 24 July 2009 the Oleksandrivka Court remitted the case against Mrs K. for further investigation, holding that the applicant had not been informed about the completion of the investigation until 3 July 2009 and had therefore not been given enough time to study the case file. It also held that the refusal to institute criminal proceedings against Mr K. had contravened the law. 23.  On 13 November 2009 the Novomyrgorod prosecutor quashed the decision of 1 July 2009 not to institute criminal proceedings against Mr K. Subsequently the investigator again refused to institute criminal proceedings against Mr K. 24.  On 21 December 2009 the new round of pre-trial investigation in the case against Mrs K. was completed and another bill of indictment was issued by the Novomyrgorod prosecutor. 25.  On 29 March 2010 the Oleksandrivka Court remitted the case against Mrs K. for further investigation, holding in particular that the decision not to institute criminal proceedings against Mr K. had been premature, since his role in the incident had not been sufficiently clarified. 26.  On 13 July 2010 the Novomyrgorod prosecutor quashed the decision not to institute criminal proceedings against Mr K. On 26 May 2011 the investigator again refused to institute criminal proceedings against Mr K. 27.  On 20 December 2011 the Znamyanka Court convicted Mrs K. of inflicting bodily harm of medium severity on the applicant, sentencing her to restriction of liberty for two years, suspended for a one-year probationary period. The court found that the decision not to institute criminal proceedings against Mr K. in connection with the same incident had been correct. Mrs K., the prosecutor and the applicant appealed. 28.  On 6 March 2012 the Court of Appeal quashed the judgment and discontinued the criminal proceedings against Mrs K. as time-barred.", "claims": [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "outcomes": [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], "case_no": "27454/11"}

#shuffle val_data
random.shuffle(val_data)

interesting_items = []

non_zero = 0
for e in val_data: 
    out = predictor.predict_json(e)
    claims = e["claims"]
    outcomes = e["outcomes"]
    gold = ["not_claimed" if c == 0 else "claimed_not_violated" if c == 1 and o == 0 else "claimed_and_violated" for c, o in zip(claims, outcomes)]
    gold_id = e["case_no"]
    silver_rationales = [i for i in val_meta_data if i["case_no"] == gold_id][0]["silver_rationales"]
    if len(set(out["labels"])) != 1 and out["labels"] == gold and silver_rationales:
        #print(out["labels"])
        non_zero += 1
        ex = e
        interesting_items.append({"out":out, "claims":claims, "outcomes":outcomes, "ex":ex, "gold":gold, "silver_rationales":silver_rationales})
        #break

        

KeyboardInterrupt: 

In [5]:
all_interesting_results = []

for interesting_item in interesting_items: 
    out = interesting_item["out"]
    claims = interesting_item["claims"]
    outcomes = interesting_item["outcomes"]
    ex = interesting_item["ex"]
    gold = interesting_item["gold"]
    silver_rationales = interesting_item["silver_rationales"]
    
    encoded_orig = out['encoded_representations']

    facts = out['labels']
    #print('Predicted: ', facts)

    tokenizer.convert_tokens_to_string(out['tokens'])


    facts_sentences = ex["facts_sentences"]

    masks1 = [[]]  # change this if you also want to mask out parts of the premise.
    masks2 = list(all_consecutive_masks2(facts_sentences, max_length=1))
    encoded = []
    mask_mapping = []
    preds = np.zeros(shape=(len(masks1), len(masks2)))

    for m1_i, m1 in enumerate(masks1):
        masked1 = []
        for i in m1:
            masked1[i] = '<mask>'
        masked1 = ' '.join(masked1)
        masked_sentence = []
        for m2_i, m2 in enumerate(masks2):
            masked2 = facts_sentences.copy()
            for i in m2:
                masked_sentence.append(masked2[i])
                sentence_length = len(tok.tokenize(masked2[i]))
                masked2[i] = '<mask> '*sentence_length
            masked2 = tok.tokenize(' '.join(masked2))
                
            masked_ex = {
                "facts": masked2,
                "claims": claims,
                "outcomes": outcomes,
                "case_no": ex['case_no']
            }
            
            masked_out = predictor.predict_json(masked_ex)

            #print("indices", m1_i, m2_i)
            #print("case facts with masks in them", f"{masked1}\n{masked2}")
            #print("gold labels", masked_out['labels'])
            #print("masked out sentence", masked_sentence)
            encoded.append(masked_out['encoded_representations'])
            mask_mapping.append((m1_i, m2_i))
            
            #print("====")
            
    encoded = np.array(encoded)

    # replace some random f in the following list with another option from
    # ["not_claimed", "claimed_and_violated", "claimed_not_violated"] at random
    label_options = ["not_claimed", "claimed_and_violated", "claimed_not_violated"]
    interesting_label_options = ["claimed_and_violated", "claimed_not_violated"]
    article_id = random.choice([i for i in range(len(facts)) if facts[i] in interesting_label_options or gold[i] in interesting_label_options])
    foils = [f if i != article_id else random.choice([o for o in label_options if o != f]) for i,f in enumerate(facts)]

    fact_idx = label2index[facts[article_id]]
    foil_idx = label2index[foils[article_id]]
    #print("article number", articles[article_id])
    #print('fact:', index2label[fact_idx])
    #print('foil:', index2label[foil_idx])

    fact_idx = article_id * len(label_options) + fact_idx
    foil_idx = article_id * len(label_options) + foil_idx

    classifier_w = model_state_dict["classifier.3.weight"].numpy()
    classifier_b = model_state_dict["classifier.3.bias"].numpy()
    #classifier_w = np.load(f"{model_path}/w.npy")
    #classifier_b = np.load(f"{model_path}/b.npy")

    u = classifier_w[fact_idx] - classifier_w[foil_idx]
    contrastive_projection = np.outer(u, u) / np.dot(u, u)

    #print(contrastive_projection.shape)

    z_all = encoded_orig 
    z_h = encoded 
    z_all_row = encoded_orig @ contrastive_projection
    z_h_row = encoded @ contrastive_projection

    prediction_probabilities = softmax(z_all_row @ classifier_w.T + classifier_b)
    prediction_probabilities = np.tile(prediction_probabilities, (z_h_row.shape[0], 1))

    prediction_probabilities_del = softmax(z_h_row @ classifier_w.T + classifier_b, axis=1)

    p = prediction_probabilities[:, [fact_idx, foil_idx]]
    q = prediction_probabilities_del[:, [fact_idx, foil_idx]]

    p = p / p.sum(axis=1).reshape(-1, 1)
    q = q / q.sum(axis=1).reshape(-1, 1)
    distances = (p[:, 0] - q[:, 0])

    #print("the case", ex['facts'])
    #print("silver rationales", silver_rationales)
    #print("=========\n=======Farthest masks:=======")    
        
    highlight_rankings = np.argsort(-distances)
    explained_indices = []

    for i in range(len(facts_sentences)):
        rank = highlight_rankings[i]
        m1_i, m2_i = mask_mapping[rank]
        
        masked_sentence = []
        masked2 = facts_sentences.copy()
        for k in masks2[m2_i]:
            masked_sentence.append(masked2[k])
            masked2[k] = '<mask>'
        explained_indices.append(k)
        masked2 = ' '.join(masked2)
        #print("input with sentence masked out \n",masked2)
        #print("the sentence that has been omitted\n", masked_sentence)
        #print("omitted index\n", i)
        #print(np.round(distances[rank], 4))
        
    #print(explained_indices)
    all_interesting_results.append({"ex":ex, "silver_rationales":silver_rationales, "explained_indices":explained_indices})
    print({"ex":ex, "silver_rationales":silver_rationales, "explained_indices":explained_indices})

NameError: name 'interesting_items' is not defined

In [12]:
print(len(all_interesting_results))
print(len(val_data))
actual = [a["explained_indices"] for a in all_interesting_results]
predicted = [p["silver_rationales"] for p in all_interesting_results]
for i in range(2, 10):
    print("meanPrecision@", i, " ", meanPrecisionAtK(actual, predicted, i))
    print("meanRecall@", i, " ", meanRecallAtK(actual, predicted, i))


87
985
meanPrecision@ 2   0.9022988505747126
meanRecall@ 2   0.14178598326617356
meanPrecision@ 3   0.7892720306513408
meanRecall@ 3   0.1763480453128247
meanPrecision@ 4   0.7155172413793104
meanRecall@ 4   0.2056083138118052
meanPrecision@ 5   0.6505747126436783
meanRecall@ 5   0.22842611003506907
meanPrecision@ 6   0.6034482758620688
meanRecall@ 6   0.24904954889156178
meanPrecision@ 7   0.5533661740558292
meanRecall@ 7   0.2615054696612314
meanPrecision@ 8   0.5100574712643678
meanRecall@ 8   0.271469390094073
meanPrecision@ 9   0.47126436781609204
meanRecall@ 9   0.2781305787759071


0
985
meanPrecision@ 2   nan
meanRecall@ 2   nan
meanPrecision@ 3   nan
meanRecall@ 3   nan
meanPrecision@ 4   nan
meanRecall@ 4   nan
meanPrecision@ 5   nan
meanRecall@ 5   nan
meanPrecision@ 6   nan
meanRecall@ 6   nan
meanPrecision@ 7   nan
meanRecall@ 7   nan
meanPrecision@ 8   nan
meanRecall@ 8   nan
meanPrecision@ 9   nan
meanRecall@ 9   nan


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [5]:
incorrect_items = []
non_zero = 0
for e in val_data: 
    out = predictor.predict_json(e)
    claims = e["claims"]
    outcomes = e["outcomes"]
    gold = ["not_claimed" if c == 0 else "claimed_not_violated" if c == 1 and o == 0 else "claimed_and_violated" for c, o in zip(claims, outcomes)]
    gold_id = e["case_no"]
    silver_rationales = [i for i in val_meta_data if i["case_no"] == gold_id][0]["silver_rationales"]
    if out["labels"] != gold and silver_rationales:
        non_zero += 1
        ex = e
        incorrect_items.append({"out":out, "claims":claims, "outcomes":outcomes, "ex":ex, "gold":gold, "silver_rationales":silver_rationales})

# save items from the incorrect_items list to a file
with open(f"./incorrect_items.txt", "w") as f:
    for item in incorrect_items:
        f.write(json.dumps(item) + "\n")

# if the file f"./incorrect_results.txt" exists then remove it
if os.path.exists(f"./incorrect_results.txt"):
    os.remove(f"./incorrect_results.txt")

FileNotFoundError: [Errno 2] No such file or directory: '../experiments/ecthr/incorrect_items_allenai/longformer-base-4096.jsonl'

In [10]:
# read in incorrect items
with open(f"./incorrect_items.txt", "r") as f:
    incorrect_items = [json.loads(line) for line in f.readlines()]

#check if a file in the path ./incorrect_results.txt exists
if os.path.exists(f"./incorrect_results.txt"):
    # load incorrect results from the incorrect_results file
    with open(f"./incorrect_results.txt", "r") as f:
        incorrect_results = [json.loads(line) for line in f.readlines()]
else: 
    incorrect_results = []
    
saved_exes = [i["ex"] for i in incorrect_results]

all_incorrect_results = []

for incorrect_item in incorrect_items: 
    out = incorrect_item["out"]
    claims = incorrect_item["claims"]
    outcomes = incorrect_item["outcomes"]
    ex = incorrect_item["ex"]
    gold = incorrect_item["gold"]
    silver_rationales = incorrect_item["silver_rationales"]
    
    if ex not in saved_exes:

        encoded_orig = out['encoded_representations']

        facts = out['labels']
        #print('Predicted: ', facts)

        tok.convert_tokens_to_string(out['tokens'])


        facts_sentences = ex["facts_sentences"]

        masks1 = [[]]  # change this if you also want to mask out parts of the premise.
        masks2 = list(all_consecutive_masks2(facts_sentences, max_length=1))
        encoded = []
        mask_mapping = []
        preds = np.zeros(shape=(len(masks1), len(masks2)))

        for m1_i, m1 in enumerate(masks1):
            masked1 = []
            for i in m1:
                masked1[i] = '<mask>'
            masked1 = ' '.join(masked1)
            masked_sentence = []
            for m2_i, m2 in enumerate(masks2):
                masked2 = facts_sentences.copy()
                for i in m2:
                    masked_sentence.append(masked2[i])
                    sentence_length = len(tok.tokenize(masked2[i]))
                    masked2[i] = '<mask> '*sentence_length
                masked2 = tok.tokenize(' '.join(masked2))
                    
                masked_ex = {
                    "facts": masked2,
                    "claims": claims,
                    "outcomes": outcomes,
                    "case_no": ex['case_no']
                }
                
                masked_out = predictor.predict_json(masked_ex)

                #print("indices", m1_i, m2_i)
                #print("case facts with masks in them", f"{masked1}\n{masked2}")
                #print("gold labels", masked_out['labels'])
                #print("masked out sentence", masked_sentence)
                encoded.append(masked_out['encoded_representations'])
                mask_mapping.append((m1_i, m2_i))
                
                #print("====")
            
        encoded = np.array(encoded)

        # replace some random f in the following list with another option from
        # ["not_claimed", "claimed_and_violated", "claimed_not_violated"] at random
        label_options = ["not_claimed", "claimed_and_violated", "claimed_not_violated"]
        article_id = random.choice([i for i in range(len(facts))])
        foils = [f if i != article_id else random.choice([o for o in label_options if o != f]) for i,f in enumerate(facts)]

        fact_idx = label2index[facts[article_id]]
        foil_idx = label2index[foils[article_id]]
        #print("article number", articles[article_id])
        #print('fact:', index2label[fact_idx])
        #print('foil:', index2label[foil_idx])

        fact_idx = article_id * len(label_options) + fact_idx
        foil_idx = article_id * len(label_options) + foil_idx

        classifier_w = np.load(f"{model_path}/w.npy")
        classifier_b = np.load(f"{model_path}/b.npy")

        u = classifier_w[fact_idx] - classifier_w[foil_idx]
        contrastive_projection = np.outer(u, u) / np.dot(u, u)

        #print(contrastive_projection.shape)

        z_all = encoded_orig 
        z_h = encoded 
        z_all_row = encoded_orig @ contrastive_projection
        z_h_row = encoded @ contrastive_projection

        prediction_probabilities = softmax(z_all_row @ classifier_w.T + classifier_b)
        prediction_probabilities = np.tile(prediction_probabilities, (z_h_row.shape[0], 1))

        prediction_probabilities_del = softmax(z_h_row @ classifier_w.T + classifier_b, axis=1)

        p = prediction_probabilities[:, [fact_idx, foil_idx]]
        q = prediction_probabilities_del[:, [fact_idx, foil_idx]]

        p = p / p.sum(axis=1).reshape(-1, 1)
        q = q / q.sum(axis=1).reshape(-1, 1)
        distances = (p[:, 0] - q[:, 0])

        #print("the case", ex['facts'])
        #print("silver rationales", silver_rationales)
        #print("=========\n=======Farthest masks:=======")    
            
        highlight_rankings = np.argsort(-distances)
        explained_indices = []

        for i in range(len(facts_sentences)):
            rank = highlight_rankings[i]
            m1_i, m2_i = mask_mapping[rank]
            
            masked_sentence = []
            masked2 = facts_sentences.copy()
            for k in masks2[m2_i]:
                masked_sentence.append(masked2[k])
                masked2[k] = '<mask>'
            explained_indices.append(k)
            masked2 = ' '.join(masked2)
            #print("input with sentence masked out \n",masked2)
            #print("the sentence that has been omitted\n", masked_sentence)
            #print("omitted index\n", i)
            #print(np.round(distances[rank], 4))
            
        #print(explained_indices)
        all_incorrect_results.append({"ex":ex, "silver_rationales":silver_rationales, "explained_indices":explained_indices})
        print({"ex":ex, "silver_rationales":silver_rationales, "explained_indices":explained_indices})
        # append incorrect result to an external incorrect_results file
        with open(f"./incorrect_results.txt", "a") as f:
            f.write(json.dumps({"ex":ex, "silver_rationales":silver_rationales, "explained_indices":explained_indices}) + "\n")

# read in all_incorrect_results from the incorrect_results file
with open(f"./incorrect_results.txt", "r") as f:
    all_incorrect_results = [json.loads(line) for line in f.readlines()]

with open("/home/irs38/contrastive-explanations/data/ecthr/Chalkidis/simple_val.jsonl", "r") as f:
    val_data = [json.loads(line) for line in f.readlines()]       

print(len(all_incorrect_results))
print(len(val_data))
actual = [a["explained_indices"] for a in all_incorrect_results]
predicted = [p["silver_rationales"] for p in all_incorrect_results]
for i in range(2, 10):
    print("meanPrecision@", i, " ", meanPrecisionAtK(actual, predicted, i))
    print("meanRecall@", i, " ", meanRecallAtK(actual, predicted, i))


584
985
meanPrecision@ 2   0.9477739726027398
meanRecall@ 2   0.09767005583962775
meanPrecision@ 3   0.8978310502283103
meanRecall@ 3   0.135352180169421
meanPrecision@ 4   0.8441780821917808
meanRecall@ 4   0.16501159742012056
meanPrecision@ 5   0.7948630136986301
meanRecall@ 5   0.1890894937220949
meanPrecision@ 6   0.7471461187214612
meanRecall@ 6   0.2077014659264444
meanPrecision@ 7   0.7045009784735813
meanRecall@ 7   0.2227548321363115
meanPrecision@ 8   0.6673801369863014
meanRecall@ 8   0.23588981179558793
meanPrecision@ 9   0.6322298325722984
meanRecall@ 9   0.2463872776379505
