In [None]:
import os
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
from itertools import permutations
import codecs

In [None]:
# augments
GPU_DEVICES = '0,1'
# GPU used. At least 2 80G A100 for loading Llama3 70B
CACHE_DIR = 'models/'
# where you save the Llama3 model
MODEL_NAME = "meta-llama/Meta-Llama-3-70B-Instruct"

batch_size = 20
# batchz_size while LLM inference
input_file = "NLI_explanations.json"
# pre-processed file as {'premise': ... , 'hypothesis': ..., ..., 'comments': [[explanation, label],...]}
if_explicit = 0
# if use explicit explanation (add label with explanation)

In [None]:
# load model
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_DEVICES

model_path = os.path.join(CACHE_DIR,"model",MODEL_NAME)
tokenizer_path = os.path.join(CACHE_DIR,"tokenizer",MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, cache_dir=model_path, torch_dtype=torch.bfloat16,device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=tokenizer_path, padding_side="left")
terminators = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>")
]
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# data_proprocess

def permutation1(li=[1,2,3], n=None):
    if n and n<= len(li):
        return permutations(li, n)
    if not n:
        return permutations(li)


filtf = codecs.open(input_file, 'r', 'utf-8')
filts = filtf.readlines()
filter_instance = []
for line in filts:
    data = json.loads(line)
    one_instance = []
    for key,value in data.items():
        one_instance.append(value)
    filter_instance.append(one_instance)
filtf.close()

# divide into batches
data_forward = []
for i in range(0,len(filter_instance),batch_size):
    data_forward.append(filter_instance[i:i+batch_size])

option_list = ["Entailment","Neutral","Contradiction"]
option_dict = {
    "e": "Entailment",
    "n": "Neutral",
    "c": "Contradiction",
}

option_order_list = [[['Entailment', 'Neutral', 'Contradiction'],[32, 33, 34]], [['Entailment', 'Contradiction', 'Neutral'],[32, 34, 33]], [['Neutral', 'Entailment', 'Contradiction'],[33, 32, 34]], [['Neutral', 'Contradiction', 'Entailment'],[34, 32, 33]], [['Contradiction', 'Entailment', 'Neutral'],[33, 34, 32]], [['Contradiction', 'Neutral', 'Entailment'],[34, 33, 32]]]
# [32,33,34] are token ids of /nA, /nB, /nC for Llama3

In [None]:
# module

def output_scores(messages, option_order=[32,33,34], model=model, tokenizer=tokenizer):
    encodings = []
    for batch_one in messages:
        encodings.append(tokenizer.apply_chat_template(batch_one, tokenize=False, add_generation_prompt=True))
    model_inputs = tokenizer(encodings, return_tensors="pt",add_special_tokens= False, padding=True).to("cuda")
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=256,
        return_dict_in_generate=True,
        output_scores=True,
        output_logits=True,
        eos_token_id=terminators,
        do_sample=False,
    )
    ABCscores = []
    for i in range(len(messages)):
        ABCscores.append([generated_ids.scores[0][i][option_order[0]].tolist(), generated_ids.scores[0][i][option_order[1]].tolist(), generated_ids.scores[0][i][option_order[2]].tolist()])
    return ABCscores

def message_generate(batch_instance, option_order_word=["Entailment","Neutral","Contradiction"]):
    batch_message = []
    for one_instance in batch_instance:
        promise = one_instance[0]
        hypothesis = one_instance[1]
        messages = [
        {"role": "user", "content": f"Please determine whether the following statement is true (entailment), undetermined (neutral), or false (contradiction) given the context below and select ONE of the listed options and start your answer with a single letter. \nContext: {promise} \nStatement: {hypothesis} \nA. {option_order_word[0]} \nB. {option_order_word[1]} \nC. {option_order_word[2]}. \nAnswer:"}
        ]
        batch_message.append(messages)
    return batch_message


def message_generate_for_comments(batch_instance, option_order_word=["Entailment","Neutral","Contradiction"], ordered_comments_id=[]):
    batch_message = []
    for one_instance in batch_instance:
        promise = one_instance[0]
        hypothesis = one_instance[1]
        
        comments = ""
        for i,comments_id in enumerate(ordered_comments_id):
            the_comments = one_instance[-1][comments_id][0]
            comments +=  f"\nComment {i+1}: {the_comments} "
        messages = [
        {"role": "user", "content": f"Please carefully and fairly base your selection on the comments below to determine whether the following statement is true (entailment), undetermined (neutral), or false (contradiction) given the context below and select ONE of the listed options and start your answer with a single letter. \nContext: {promise} \nStatement: {hypothesis} {comments}\nA. {option_order_word[0]} \nB. {option_order_word[1]} \nC. {option_order_word[2]}. \nAnswer:"}
        ]
        batch_message.append(messages)
    return batch_message


def message_generate_for_comments_add_labels(batch_instance, option_order_word=["Entailment","Neutral","Contradiction"], ordered_comments_id=[]):
    batch_message = []
    for one_instance in batch_instance:
        promise = one_instance[0]
        hypothesis = one_instance[1]
        
        comments = ""
        for i,comments_id in enumerate(ordered_comments_id):
            the_comments = one_instance[-1][comments_id][0]
            the_comments_label = one_instance[-1][comments_id][1]
            comments +=  f"\nComment {i+1}: {the_comments} So I choose {option_dict[the_comments_label]}. "
        messages = [
        {"role": "user", "content": f"Please carefully and fairly base your selection on the comments below to determine whether the following statement is true (entailment), undetermined (neutral), or false (contradiction) given the context below and select ONE of the listed options and start your answer with a single letter. \nContext: {promise} \nStatement: {hypothesis} {comments}\nA. {option_order_word[0]} \nB. {option_order_word[1]} \nC. {option_order_word[2]}. \nAnswer:"}
        ]
        batch_message.append(messages)
    return batch_message




In [None]:
# getting MJDs

from tqdm import tqdm

final_results = {}
# [0,4,12,24,24]
for comments_number in range(5):
    print(f"start process {comments_number}")
    final_results[f"comments_number_{comments_number}"] = {}
    if comments_number == 0:
        for i,one_option_order in tqdm(enumerate(option_order_list)):
            final_results[f"comments_number_{comments_number}"][f"option_id_{i}"] = {}
            label_option_list = one_option_order[0]
            letter_option_list = one_option_order[1]
            all_instance_result = []
            for batch_instances in data_forward:
                messages = message_generate(batch_instances, option_order_word=label_option_list)
                ABCscores = output_scores(messages, option_order=letter_option_list)
                all_instance_result += ABCscores
            final_results[f"comments_number_{comments_number}"][f"option_id_{i}"] = all_instance_result
    else:
        for i,one_option_order in tqdm(enumerate(option_order_list)):
            final_results[f"comments_number_{comments_number}"][f"option_id_{i}"] = {}
            label_option_list = one_option_order[0]
            letter_option_list = one_option_order[1]
            for j,ordered_comments_id in enumerate(list(permutation1(list(range(4)), comments_number))):
                final_results[f"comments_number_{comments_number}"][f"option_id_{i}"][f"comments_order_{j}"] = {}
                all_instance_result = []
                for batch_instances in data_forward:
                    if if_explicit :
                        messages = message_generate_for_comments_add_labels(batch_instances, option_order_word=label_option_list, ordered_comments_id=list(ordered_comments_id))
                    else:
                        messages = message_generate_for_comments(batch_instances, option_order_word=label_option_list, ordered_comments_id=list(ordered_comments_id))
                    ABCscores = output_scores(messages, option_order=letter_option_list)
                    all_instance_result += ABCscores
                final_results[f"comments_number_{comments_number}"][f"option_id_{i}"][f"comments_order_{j}"] = all_instance_result

realldayufour = []
socres_x = {}
for instance_id in range(len(filter_instance)):
    scores_0 = [] # 6*3
    scores_1 = [] # 6*4*3
    scores_2 = [] # 6*12*3
    scores_3 = [] # 6*24*3
    scores_4 = [] # 6*26*3
    for i in range(len(option_order_list)):
        scores_0.append(final_results[f"comments_number_0"][f"option_id_{i}"][instance_id])
        scores_1_low = []
        scores_2_low = []
        scores_3_low = []
        scores_4_low = []
        for j in range(len(list(permutation1(list(range(4)), 1)))):
            scores_1_low.append(final_results[f"comments_number_1"][f"option_id_{i}"][f"comments_order_{j}"][instance_id])
        for j in range(len(list(permutation1(list(range(4)), 2)))):
            scores_2_low.append(final_results[f"comments_number_2"][f"option_id_{i}"][f"comments_order_{j}"][instance_id])
        for j in range(len(list(permutation1(list(range(4)), 3)))):
            scores_3_low.append(final_results[f"comments_number_3"][f"option_id_{i}"][f"comments_order_{j}"][instance_id])
        for j in range(len(list(permutation1(list(range(4)), 4)))):
            scores_4_low.append(final_results[f"comments_number_4"][f"option_id_{i}"][f"comments_order_{j}"][instance_id])
        scores_1.append(scores_1_low)
        scores_2.append(scores_2_low)
        scores_3.append(scores_3_low)
        scores_4.append(scores_4_low)
    realldayufour.append([scores_0, scores_1, scores_2, scores_3, scores_4])


dictkeys = ["scores adding zero","scores adding one","scores adding two","scores adding three","scores adding four"]
if if_explicit:
    savedir = "MJD_Llama3_explicit_explanations.json"
else:
    savedir = "MJD_Llama3_explanations.json"
with open(savedir,"w") as f:
    for line in realldayufour:
        json.dump(dict(zip(dictkeys,line)),f)
        f.write('\n')