In [1]:
'''
A (hopefully) Simple API for serving explanation score requests.

input_string = (
    f"{question} answer: {gold_label}. "
    + f" explanation: {abstr_expl}."
)

here are some example input strings:

If you feel like everything is spinning while climbing you are experiencing what? answer: vertigo. explanation: Vertigo is often experienced while climbing or at heights.
Where do you get clothes in a shopping bag? answer: retail store. explanation: For any large item where convenience is beneficial, one might go to a retail store, either a regular one or a big-box store like walmart.
Where should a cat be in a house? answer: floor. explanation: A cat should be on the floor, not on a rug.
'''
import pdb
import argparse
import torch
import transformers
import os
import tqdm
import numpy as np

_model, _tokenizer = None, None

model2url = {
    'large': 'https://storage.googleapis.com/ai2-mosaic-public/projects/few-shot-explanations/pretrained_models/commonsense_qa/valloss%3D0.28665~model%3Dt5-large~lr%3D0.0001~seed%3D1~labelagg%3D0_just_weights.pt',
    '3b': 'https://storage.googleapis.com/ai2-mosaic-public/projects/few-shot-explanations/pretrained_models/commonsense_qa/valloss%3D0.28925~model%3Dt5-3b~lr%3D0.0001~seed%3D1~labelagg%3D0_just_weights.pt',
    '11b': 'https://storage.googleapis.com/ai2-mosaic-public/projects/few-shot-explanations/pretrained_models/commonsense_qa/cose_deepspeed_valloss%3D0.00000~model%3Dt5-11b~lr%3D0.00001~seed%3D1~labelagg%3D0.pt',
}

def get_model(model_type, device=None):
    global _model, model2url
    if model_type not in {'11b', '3b', 'large'}:
        raise NotImplementedError('{} is not a valid model please use "3b" or "large"'.format(model_type))

    if _model is None:
        hf_model_name = 't5-' + model_type
        print('Loading model: this will run only once.')

        if model_type == 'large':
            model_path = 'snli_models/t5-large.pt'
        elif model_type == '3b':
            model_path = 'snli_models/valloss=0.24209~model=t5-3b~lr=0.0001~seed=1~labelagg=0_just_weights.pt'
        elif model_type == '11b':
            model_path = 'snli_models/cose_deepspeed_valloss=0.00000~model=t5-11b~lr=0.00001~seed=1~labelagg=0.pt'

        if not os.path.exists(model_path):
            print('Please download weights for {} model and put in current directory.'.format(model_path))
            print('for example, wget {}'.format(model2url[model_type]))
            quit()

        state = torch.load(model_path)
        if 'model_state_dict' in state:
            state = state['model_state_dict']

        _model = transformers.AutoModelForSeq2SeqLM.from_pretrained(hf_model_name)
        if model_type == '11b': # need to resize due to deepspeed, these entires are not accessed.
            _model.resize_token_embeddings(len(transformers.AutoTokenizer.from_pretrained(hf_model_name)))
        _model.load_state_dict(state)
        _model.eval()
        if device is not None:
            _model = _model.to(device)

    return _model


def get_tokenizer(model_type):
    global _tokenizer
    if model_type not in {'3b', 'large', '11b'}:
        raise NotImplementedError('{} is not a valid model please use "3b" or "large" or "11b"'.format(model_type))

    if _tokenizer is None:
        hf_model_name = 't5-' + model_type
        _tokenizer = transformers.T5TokenizerFast.from_pretrained(hf_model_name)

    return _tokenizer


class T5Dataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        res = self.tokenizer(self.data[idx]['input'], truncation=True)
        res['labels'] = self.tokenizer(self.data[idx]['label']).input_ids
        return res

    def __len__(self):
        return len(self.data)


def get_scores(inputs, model_type, device=None, batch_size=32, verbose=False):
    '''
    Inputs:
      - a list of explanations to score, e.g.,:
        premise: A man getting a tattoo on his back. hypothesis: A woman is getting a tattoo. answer: contradiction. explanation: Because the tattoo artist is a man, the person getting the tattoo is not a woman.
      - model type, either "3b" or "large" or "11b"
      - device: which torch device to load model on, e.g., "cuda:3"
    Outputs:
      - P(good explanation); higher is better
    '''
    assert model_type in {'large', '3b', '11b'}

    if isinstance(inputs, str):
        inputs = [inputs]

    model = get_model(model_type, device=device)
    tokenizer = get_tokenizer(model_type)

    score_itr = T5Dataset([{'input': inp, 'label': 'x'} for inp in inputs], tokenizer) # dummy labels for inference
    data_collator = transformers.DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=-100,
#         return_tensors='pt'
    )
    score_itr = torch.utils.data.DataLoader(score_itr, shuffle=False, collate_fn=data_collator, batch_size=batch_size)
    score_itr = score_itr if not verbose else tqdm.tqdm(score_itr, total=len(score_itr))

    good_idx, bad_idx = tokenizer('good').input_ids[0], tokenizer('bad').input_ids[0]
    scores = []
    
    with torch.no_grad():
        for batch in score_itr:
            if device is not None:
                input_ids, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)
            model_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=targets)
            logits_pos = model_output['logits'][:, 0, good_idx].cpu().numpy()
            logits_neg = model_output['logits'][:, 0, bad_idx].cpu().numpy()
            exp_logit_pos, exp_logit_neg = np.exp(logits_pos), np.exp(logits_neg)
            score = list([float(x) for x in exp_logit_pos / (exp_logit_pos + exp_logit_neg)])
            #pdb.set_trace()
            scores.extend(score)
    return scores


# def parse_args():
#     '''
#     Optional args for main function, mostly just to test.
#     '''
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         'model_type',
#         default='large',
#         choices={'large', '3b', '11b'})
#     parser.add_argument(
#         '--batch_size',
#         default=32,
#         type=int)

#     args = parser.parse_args(['--batch_size', '1'])
#     return args

In [2]:
# args = parse_args()
# parser = argparse.ArgumentParser()
# parser.add_argument(
#     'model_type',
#     default='large',
#     choices={'large', '3b', '11b'})
# parser.add_argument(
#     '--batch_size',
#     default=32,
#     type=int)

# args = parser.parse_args(["--model_type", "3b"])
# args.device = 'cpu'#'cuda' if torch.cuda.is_available() else 'cpu'
np.random.seed(1)
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# scores = get_scores(
#     ['If you feel like everything is spinning while climbing you are experiencing what? answer: vertigo. explanation: Vertigo is often experienced while climbing or at heights.',
#      'Where do you get clothes in a shopping bag? answer: retail store. explanation: For any large item where convenience is beneficial, one might go to a retail store, either a regular one or a big-box store like walmart.',
#      'Where should a cat be in a house? answer: floor. explanation: A cat should be on the floor, not on a rug.'],
#     'large',
#     device='cuda:0',
#     batch_size=1,
#     verbose=False)
# print(scores)


In [3]:
import json
from tqdm import tqdm
# with open("../../scripts/results/dev_rationale_pair.json") as f:
#     rationale_pair_dev_data = json.load(f)
import json
file_path = "../../scripts/results/24shots_esnli_t5_3b_chatgpt_rationales_generator_test_rationale_pair.json"
with open(file_path, 'r') as f:
    rationale_pair_dev_data = json.load(f)
    

In [4]:
rationale_pair_dev_data[0].keys(), len(rationale_pair_dev_data)

(dict_keys(['premise', 'hypothesis', 'label', 'explanation_1', 'explanation_2', 'explanation_3', 'our_explanation', 'common_expl_list', 'input_ids', 'attention_mask', 'labels', 'decoder_attention_mask', 'question_encoding', 'generated_explanation']),
 640)

In [5]:
# answer = wt5_esnli_label_mapping[item["label"]]
# input_string = f'explain why the relation is {answer} between hypothesis: {hypothesis.lower()} and premise: {premise.lower()}'  
# #input_string = f"explain {datasource} question: {question} answer: {answer}" + f" {explanation_sep} <extra_id_0>"
# answer_string = f"<extra_id_0> {item['our_explanation']} <extra_id_1>"
wt5_esnli_label_mapping = {0: 'entailment', 1: 'neutral', 2: 'contradiction'} 
for kk, item in enumerate(rationale_pair_dev_data):
    answer = wt5_esnli_label_mapping[item["label"]]
    hypothesis = item['hypothesis']
    premise = item['premise']
    input_string = f'explain why the relation is {answer} between hypothesis: {hypothesis.lower()} and premise: {premise.lower()}' 
    print("#######")
    print(f"sample {kk}: {input_string}")
    print("**")
    print(f"generated explanation: {item['generated_explanation']}")
    print("**")
    print(f"common_expl_list: {item['common_expl_list']}")
    print("**")
    print(f"explanation_1: {item['explanation_1']}")
    print("#######")
    

#######
sample 0: explain why the relation is entailment between hypothesis: there is a black dog. and premise: a black dog running really fast with a blond woman running after him in the background.
**
generated explanation: <pad> <extra_id_0> The relation is entailment because the hypothesis states that there is a black dog, which implies that there is a black dog running really fast with a blond woman running after him. The premise describes a black dog running really fast with a blond woman running after him, which implies that there is a black dog running really fast with a blond woman running after him. The premise does not confirm or deny the hypothesis that there is a black dog, but it does provide additional information about the hypothesis that the hypothesis states that the hypothesis states that the hypothesis.<extra_id_1> </s>
**
common_expl_list: ["If there's a black dog running really fast, then, yes, there is a black dog.", 'If the dog is black, then it is black.']
**
e

explanation_1: A skateboard cannot operate in the snow, nor can a snowboard operate in a garage.  One must either be skateboarding in the garage or snowboarding, doing both is not possible.
#######
#######
sample 242: explain why the relation is neutral between hypothesis: two women have been surfing. and premise: 2 woman surfers carrying their boards across the beach.
**
generated explanation: <pad> <extra_id_0> The relation is neutral between the hypothesis "two women have been surfing" and the premise "2 woman surfers carrying boards across the beach" because the premise does not provide any information about the women's activity. The hypothesis states that the women have been surfing, but the premise does not provide any information about the women's activity. The premise does not provide any information about the women's activity, but it does provide information about the women's presence on the beach.<extra_id_1> </s>
**
common_expl_list: ['While surfing is likely, carrying a sur

sample 594: explain why the relation is neutral between hypothesis: the women are older and premise: several women wearing dresses dance in the forest.
**
generated explanation: <pad> <extra_id_0> The relation is neutral because the hypothesis states that the women are older, but the premise does not provide any information that would indicate whether the women are older or younger. The premise does not provide any information that would indicate whether the women are older or younger, but it does provide information about the women's clothing and clothing style, which could indicate whether they are older or younger.<extra_id_1> </s>
**
common_expl_list: ['Women wearing dresses is not a sufficient indication that the women are older. It could also be a younger woman dancing or a golden-aged woman dancing around the house to a fancy dress party. Some elderly women do not wear dresses.', 'Dresses could be worn regardless of how old one is.', 'Dresses are worn by a variety of people of v

In [17]:
wt5_esnli_label_mapping = {0: 'entailment', 1: 'neutral', 2: 'contradiction'} 
from tqdm import tqdm
import pdb
qae_list = []
score_list = []
for da in tqdm(rationale_pair_dev_data[0:200], total=len(rationale_pair_dev_data[0:200])):
    qae = ["premise: {} hypothesis: {} answer: {}, explanation: {}".format(da['premise'], 
                                                 da['hypothesis'], 
                                                wt5_esnli_label_mapping[da['label']],
                                                expl) for expl in da['common_expl_list']]

    scores = get_scores(
        qae,
        '3b',
        device='cuda:0',
        batch_size=1,
        verbose=False)
    score_list.append(scores[0])
np.mean(score_list), np.median(score_list)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:15<00:00, 13.12it/s]


(0.8201008582115173, 0.8638935983181)

In [16]:
wt5_esnli_label_mapping = {0: 'entailment', 1: 'neutral', 2: 'contradiction'} 
from tqdm import tqdm
import pdb
qae_list = []
score_list = []
for da in tqdm(rationale_pair_dev_data[0:200], total=len(rationale_pair_dev_data[0:200])):
    qae = "premise: {} hypothesis: {} answer: {}, explanation: {}".format(da['premise'], 
                                                 da['hypothesis'], 
                                                wt5_esnli_label_mapping[da['label']],
                                                da['explanation_1'])

    scores = get_scores(
        [qae],
        '3b',
        device='cuda:0',
        batch_size=1,
        verbose=False)
    score_list.append(scores[0])
np.mean(score_list), np.median(score_list)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:09<00:00, 21.53it/s]


(0.5679544609179721, 0.6270546019077301)

In [7]:
from tqdm import tqdm
import pdb
wt5_esnli_label_mapping = {0: 'entailment', 1: 'neutral', 2: 'contradiction'} 
qae_list = []
score_list = []
for da in tqdm(rationale_pair_dev_data[0:200], total=len(rationale_pair_dev_data[0:200])):
    qae = "premise: {} hypothesis: {} answer: {}, explanation: {}".format(da['premise'], 
                                                 da['hypothesis'], 
                                                wt5_esnli_label_mapping[da['label']],
                                                da['generated_explanation'])

    scores = get_scores(
        [qae],
        '3b',
        device='cuda:0',
        batch_size=1,
        verbose=False)
    score_list.append(scores[0])
#     if scores[0] > 0.7 or scores[0] < 0.2:
#         print("question: {}".format(da['question']))
#         print("answer: {}".format(da['answer']))
#         print("common_expl_list: {}".format(da['common_expl_list']))
#         print("generated_explanation: {}".format(da['generated_explanation']))
#         print("score: {}".format(scores[0]))\
np.mean(score_list), np.median(score_list)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:17<00:00, 11.51it/s]


(0.3917398187145591, 0.38663966953754425)

## evaluate generated rationale with bert-score

In [8]:
score_list

[0.4052284061908722,
 0.5191667079925537,
 0.2200496792793274,
 0.6166287660598755,
 0.23566164076328278,
 0.23137959837913513,
 0.4726582467556,
 0.44587987661361694,
 0.3216574192047119,
 0.5206500887870789,
 0.374889075756073,
 0.3831893503665924,
 0.39179790019989014,
 0.23738797008991241,
 0.3433012068271637,
 0.4466983675956726,
 0.1248847022652626,
 0.23148657381534576,
 0.29242706298828125,
 0.43564972281455994,
 0.24475640058517456,
 0.5609840154647827,
 0.20569974184036255,
 0.33773431181907654,
 0.7527338862419128,
 0.5272515416145325,
 0.6047167778015137,
 0.22808103263378143,
 0.4841300845146179,
 0.5092390179634094,
 0.6105713844299316,
 0.39877912402153015,
 0.2403331995010376,
 0.4602493643760681,
 0.2712806761264801,
 0.3767451047897339,
 0.4169318675994873,
 0.3975336253643036,
 0.2173650860786438,
 0.41740041971206665,
 0.3692389130592346,
 0.5197189450263977,
 0.4667205512523651,
 0.4045630395412445,
 0.18840569257736206,
 0.36257457733154297,
 0.5066966414451599,
 

In [9]:
import datasets
import numpy as np
bertscore_metric = datasets.load_metric("bertscore")
rouge_metric = datasets.load_metric('rouge')
bleu_metric = datasets.load_metric('sacrebleu')

In [10]:
import pdb

bert_scores = []
bleu_scores = []
rouge1_scores = []
rouge2_scores = []
rougeL_scores = []

for da in tqdm(rationale_pair_dev_data, total=len(rationale_pair_dev_data)):
    generated_expl = da['generated_explanation']
    common_expl_list = da['common_expl_list']
    pred_expl = generated_expl.split("<extra_id_0> ")[1].split("<extra_id_1>")[0]
    list_gold_expl = [l.lower() for l in common_expl_list]
    
    bert_score = bertscore_metric.compute(predictions=[pred_expl.lower()], references=[list_gold_expl], lang="en")["f1"][0]*100
    bleu_score = bleu_metric.compute(predictions=[pred_expl.lower()], references=[list_gold_expl])['score']
    rouge_score = rouge_metric.compute(predictions=[pred_expl.lower()]*len(list_gold_expl), references=list_gold_expl)
    rouge1_score = rouge_score["rouge1"].mid.fmeasure
    rouge2_score = rouge_score["rouge2"].mid.fmeasure
    rougeL_score = rouge_score["rougeL"].mid.fmeasure
    bert_scores.append(bert_score)
    bleu_scores.append(bleu_score)
    rouge1_scores.append(rouge1_score)
    rouge2_scores.append(rouge2_score)
    rougeL_scores.append(rougeL_score)
    
    
#     #print(generated_expl)
#     #print(generated_expl.split("<extra_id_0> ")[1].split("<extra_id_1>")[0])
#     instance_bertscores = []
#     for gold_expl in list_gold_expl: 
#         score = bertscore_metric.compute(predictions=[pred_expl.lower()]*len(), references=[gold_expl.lower()], lang="en")["f1"][0]*100
#         instance_bertscores.append(score)
#     bertscores.append(np.mean(instance_bertscores))
    
#     bleuscore = bleu_score(pred_expl, list_gold_expl)
#     bleuscores.append(bleuscore)
    
#     rougescore = rouge(pred_expl, list_gold_expl)
#     rouge1_scores.append(rougescore['rouge1_fmeasure'].numpy()[0])
#     rouge2_scores.append(rougescore['rouge2_fmeasure'].numpy()[0])
#     rougeL_scores.append(rougescore['rougeL_fmeasure'].numpy()[0])
    
    


    #pdb.set_trace()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 640/640 [02:43<00:00,  3.91it/s]


In [11]:
print("bert_score: {}".format(np.mean(bert_scores)))
print("bleu_score: {}".format(np.mean(bleu_scores)))
print("rouge1_score: {}".format(np.mean(rouge1_scores)))
print("rouge2_score: {}".format(np.mean(rouge2_scores)))
print("rougeL_score: {}".format(np.mean(rougeL_scores)))


bert_score: 86.09047032892704
bleu_score: 6.661188993161209
rouge1_score: 0.23278395041583275
rouge2_score: 0.11325451058290263
rougeL_score: 0.19008936096033716


In [12]:
rouge_score

{'rouge1': AggregateScore(low=Score(precision=0.18309859154929578, recall=0.52, fmeasure=0.27083333333333337), mid=Score(precision=0.18309859154929578, recall=0.52, fmeasure=0.27083333333333337), high=Score(precision=0.18309859154929578, recall=0.52, fmeasure=0.27083333333333337)),
 'rouge2': AggregateScore(low=Score(precision=0.11428571428571428, recall=0.3333333333333333, fmeasure=0.1702127659574468), mid=Score(precision=0.11428571428571428, recall=0.3333333333333333, fmeasure=0.1702127659574468), high=Score(precision=0.11428571428571428, recall=0.3333333333333333, fmeasure=0.1702127659574468)),
 'rougeL': AggregateScore(low=Score(precision=0.16901408450704225, recall=0.48, fmeasure=0.25), mid=Score(precision=0.16901408450704225, recall=0.48, fmeasure=0.25), high=Score(precision=0.16901408450704225, recall=0.48, fmeasure=0.25)),
 'rougeLsum': AggregateScore(low=Score(precision=0.16901408450704225, recall=0.48, fmeasure=0.25), mid=Score(precision=0.16901408450704225, recall=0.48, fmea

In [13]:
bertscore

NameError: name 'bertscore' is not defined

In [None]:
bertscores

In [None]:
np.mean(score_list)