In [1]:
'''
A (hopefully) Simple API for serving explanation score requests.

input_string = (
    f"{question} answer: {gold_label}. "
    + f" explanation: {abstr_expl}."
)

here are some example input strings:

If you feel like everything is spinning while climbing you are experiencing what? answer: vertigo. explanation: Vertigo is often experienced while climbing or at heights.
Where do you get clothes in a shopping bag? answer: retail store. explanation: For any large item where convenience is beneficial, one might go to a retail store, either a regular one or a big-box store like walmart.
Where should a cat be in a house? answer: floor. explanation: A cat should be on the floor, not on a rug.
'''
import pdb
import argparse
import torch
import transformers
import os
import tqdm
import numpy as np

_model, _tokenizer = None, None

model2url = {
    'large': 'https://storage.googleapis.com/ai2-mosaic-public/projects/few-shot-explanations/pretrained_models/commonsense_qa/valloss%3D0.28665~model%3Dt5-large~lr%3D0.0001~seed%3D1~labelagg%3D0_just_weights.pt',
    '3b': 'https://storage.googleapis.com/ai2-mosaic-public/projects/few-shot-explanations/pretrained_models/commonsense_qa/valloss%3D0.28925~model%3Dt5-3b~lr%3D0.0001~seed%3D1~labelagg%3D0_just_weights.pt',
    '11b': 'https://storage.googleapis.com/ai2-mosaic-public/projects/few-shot-explanations/pretrained_models/commonsense_qa/cose_deepspeed_valloss%3D0.00000~model%3Dt5-11b~lr%3D0.00001~seed%3D1~labelagg%3D0.pt',
}

def get_model(model_type, device=None):
    global _model, model2url
    if model_type not in {'11b', '3b', 'large'}:
        raise NotImplementedError('{} is not a valid model please use "3b" or "large"'.format(model_type))

    if _model is None:
        hf_model_name = 't5-' + model_type
        print('Loading model: this will run only once.')

        if model_type == 'large':
            model_path = 'csqa_models/t5-large.pt'
        elif model_type == '3b':
            model_path = 'csqa_models/valloss=0.28925~model=t5-3b~lr=0.0001~seed=1~labelagg=0_just_weights.pt'
        elif model_type == '11b':
            model_path = 'csqa_models/cose_deepspeed_valloss=0.00000~model=t5-11b~lr=0.00001~seed=1~labelagg=0.pt'

        if not os.path.exists(model_path):
            print('Please download weights for {} model and put in current directory.'.format(model_path))
            print('for example, wget {}'.format(model2url[model_type]))
            quit()

        state = torch.load(model_path)
        if 'model_state_dict' in state:
            state = state['model_state_dict']

        _model = transformers.AutoModelForSeq2SeqLM.from_pretrained(hf_model_name)
        if model_type == '11b': # need to resize due to deepspeed, these entires are not accessed.
            _model.resize_token_embeddings(len(transformers.AutoTokenizer.from_pretrained(hf_model_name)))
        _model.load_state_dict(state)
        _model.eval()
        if device is not None:
            _model = _model.to(device)

    return _model


def get_tokenizer(model_type):
    global _tokenizer
    if model_type not in {'3b', 'large', '11b'}:
        raise NotImplementedError('{} is not a valid model please use "3b" or "large" or "11b"'.format(model_type))

    if _tokenizer is None:
        hf_model_name = 't5-' + model_type
        _tokenizer = transformers.T5TokenizerFast.from_pretrained(hf_model_name)

    return _tokenizer


class T5Dataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __getitem__(self, idx):
        res = self.tokenizer(self.data[idx]['input'], truncation=True)
        res['labels'] = self.tokenizer(self.data[idx]['label']).input_ids
        return res

    def __len__(self):
        return len(self.data)


def get_scores(inputs, model_type, device=None, batch_size=32, verbose=False):
    '''
    Inputs:
      - a list of explanations to score, e.g.,:
        premise: A man getting a tattoo on his back. hypothesis: A woman is getting a tattoo. answer: contradiction. explanation: Because the tattoo artist is a man, the person getting the tattoo is not a woman.
      - model type, either "3b" or "large" or "11b"
      - device: which torch device to load model on, e.g., "cuda:3"
    Outputs:
      - P(good explanation); higher is better
    '''
    assert model_type in {'large', '3b', '11b'}

    if isinstance(inputs, str):
        inputs = [inputs]

    model = get_model(model_type, device=device)
    tokenizer = get_tokenizer(model_type)

    score_itr = T5Dataset([{'input': inp, 'label': 'x'} for inp in inputs], tokenizer) # dummy labels for inference
    data_collator = transformers.DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        label_pad_token_id=-100,
#         return_tensors='pt'
    )
    score_itr = torch.utils.data.DataLoader(score_itr, shuffle=False, collate_fn=data_collator, batch_size=batch_size)
    score_itr = score_itr if not verbose else tqdm.tqdm(score_itr, total=len(score_itr))

    good_idx, bad_idx = tokenizer('good').input_ids[0], tokenizer('bad').input_ids[0]
    scores = []
    
    with torch.no_grad():
        for batch in score_itr:
            if device is not None:
                input_ids, attention_mask, targets = batch['input_ids'].to(device), batch['attention_mask'].to(device), batch['labels'].to(device)
            model_output = model(input_ids=input_ids, attention_mask=attention_mask, labels=targets)
            logits_pos = model_output['logits'][:, 0, good_idx].cpu().numpy()
            logits_neg = model_output['logits'][:, 0, bad_idx].cpu().numpy()
            exp_logit_pos, exp_logit_neg = np.exp(logits_pos), np.exp(logits_neg)
            score = list([float(x) for x in exp_logit_pos / (exp_logit_pos + exp_logit_neg)])
            #pdb.set_trace()
            scores.extend(score)
    return scores


# def parse_args():
#     '''
#     Optional args for main function, mostly just to test.
#     '''
#     parser = argparse.ArgumentParser()
#     parser.add_argument(
#         'model_type',
#         default='large',
#         choices={'large', '3b', '11b'})
#     parser.add_argument(
#         '--batch_size',
#         default=32,
#         type=int)

#     args = parser.parse_args(['--batch_size', '1'])
#     return args

In [2]:
# args = parse_args()
# parser = argparse.ArgumentParser()
# parser.add_argument(
#     'model_type',
#     default='large',
#     choices={'large', '3b', '11b'})
# parser.add_argument(
#     '--batch_size',
#     default=32,
#     type=int)

# args = parser.parse_args(["--model_type", "3b"])
# args.device = 'cpu'#'cuda' if torch.cuda.is_available() else 'cpu'
np.random.seed(1)
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# scores = get_scores(
#     ['If you feel like everything is spinning while climbing you are experiencing what? answer: vertigo. explanation: Vertigo is often experienced while climbing or at heights.',
#      'Where do you get clothes in a shopping bag? answer: retail store. explanation: For any large item where convenience is beneficial, one might go to a retail store, either a regular one or a big-box store like walmart.',
#      'Where should a cat be in a house? answer: floor. explanation: A cat should be on the floor, not on a rug.'],
#     'large',
#     device='cuda:0',
#     batch_size=1,
#     verbose=False)
# print(scores)


In [3]:
import json
from tqdm import tqdm
# with open("../../scripts/results/dev_rationale_pair.json") as f:
#     rationale_pair_dev_data = json.load(f)
import json
file_path = "../../scripts/results/24shots_cose_t5_3b_chatgpt_rationales_generator_test_rationale_pair.json"
with open(file_path, 'r') as f:
    rationale_pair_dev_data = json.load(f)
    

In [4]:
rationale_pair_dev_data[0].keys(), len(rationale_pair_dev_data)

(dict_keys(['id', 'question', 'choices', 'answer', 'abstractive_explanation', 'extractive_explanation', 'our_explanation', 'input_ids', 'attention_mask', 'labels', 'decoder_attention_mask', 'question_encoding', 'common_expl_list', 'generated_explanation']),
 201)

In [5]:
for da in rationale_pair_dev_data:
    print(da['abstractive_explanation'])

book store book
bringing legal action results in arguments and aggravation.
boredom - wikipedia
knowledge is a familiarity,
canada is located to the north of usa
sam spent most of his time standing up. his job was supermarket never got any rest. but he was the best cashier at his workplace. where might he work
rivers flow trough valleys.
one-day leave application samples -
if a referee isn't sure of something he will be hired for he can still be sure of this.
when something bounces of a wall it comes back towards the person that threw it.
health complications
tests have equations.
movers - local & long distance moving services | moving.com
rivers flow trough valleys.
the first world war saw large-scale chemical warfare.
disease often spread but shouldn't be hospital.
trust company - drop to zero - youtube
a food movie times
files are kept in a drawer
i have money
gate house is in a subdivision.
this word was most relevant.
rivers flow trough valleys.
this word is more relevant
b.loose®

In [17]:
#rationale unlabeled data construction 991
import pdb 
import pandas as pd
import datasets
scr_csqa_unlabeled_test_file="/cognitive_comp/huangyongfeng/evaluate_LM_with_rationalization/few_shot_explanations/data/acceptability_annotations/commonsenseqa_test.csv"
fse_csqa_dev_dataset = datasets.load_dataset('csv', data_files=scr_csqa_unlabeled_test_file)
scr_csqa_unlabeled_test_df=pd.read_csv(scr_csqa_unlabeled_test_file)
fse_csqa_dev_data_dict={}
for kk, da in enumerate(fse_csqa_dev_dataset['train']):
    #pdb.set_trace()
    id_=da['Input.id']
    if da['Answer.acceptable']:
        answer_accept=set(da['Answer.acceptable'].split('|'))
    else:
        answer_accept=set()
    explanation_list=[da['Input.explanation_1'], 
                      da['Input.explanation_2'],
                      da['Input.explanation_3'],
                      da['Input.explanation_4'],
                      da['Input.explanation_5']]
    sample_list=[da['Input.source_sample_1'], 
                      da['Input.source_sample_2'],
                      da['Input.source_sample_3'],
                      da['Input.source_sample_4'],
                      da['Input.source_sample_5']]
    if id_ not in fse_csqa_dev_data_dict.keys():
        fse_csqa_dev_data_dict[id_]={"index":kk,"id":id_,"question":da["Input.question"],'answer':da['Input.gold_label'],"accept_set_list":[answer_accept],"explanation_list":explanation_list,"sample_list":sample_list,"greedy_expanation":explanation_list[sample_list.index('4')]}

        #[[kk, id_, answer_accept, explanation_list]]
    else:
        fse_csqa_dev_data_dict[id_]["accept_set_list"].append(answer_accept)
    if len(fse_csqa_dev_data_dict[id_]["accept_set_list"])==3:
        fse_csqa_dev_data_dict[id_]["common_expl_list"]=[]
        common_accept_expl_sample=set.intersection(fse_csqa_dev_data_dict[id_]["accept_set_list"][0], 
                                                   fse_csqa_dev_data_dict[id_]["accept_set_list"][1], 
                                                   fse_csqa_dev_data_dict[id_]["accept_set_list"][2])
        for idx in list(common_accept_expl_sample):
            idx=int(idx)-1
            fse_csqa_dev_data_dict[id_]["common_expl_list"].append(fse_csqa_dev_data_dict[id_]["explanation_list"][idx])
        
#discriminate 3/3 id
id_accept_expl_list=[]
our_accept_expl_list=[]
id_unaccept_expl_list=[]
for k,v in fse_csqa_dev_data_dict.items():
    accept_set_list=v['accept_set_list']
    assert len(accept_set_list)==3
    #pdb.set_trace()
    common_accept_expl_sample=set.intersection(accept_set_list[0], accept_set_list[1], accept_set_list[2])
    #print(accept_set_list,common_accept_expl_sample)
    #pdb.set_trace()
    if common_accept_expl_sample:
        id_accept_expl_list.append(k)
        our_expl=""
        for idx in list(common_accept_expl_sample):
            idx=int(idx)-1
            if len(v['explanation_list'][idx]) > len(our_expl):
                our_expl = v['explanation_list'][idx]
                our_accept_expl_list.append(our_expl)
    else:
        id_unaccept_expl_list.append(k)


#rationale unlabeled data construction 991
import pdb
scr_csqa_unlabeled_train_file="/cognitive_comp/huangyongfeng/evaluate_LM_with_rationalization/few_shot_explanations/data/acceptability_annotations/commonsenseqa_train.csv"
fse_csqa_train_dataset = datasets.load_dataset('csv', data_files=scr_csqa_unlabeled_train_file)
scr_csqa_unlabeled_train_df=pd.read_csv(scr_csqa_unlabeled_train_file)
fse_csqa_train_data_dict={}
for kk, da in enumerate(fse_csqa_train_dataset['train']):
#     pdb.set_trace()
    id_=da['Input.id']
    if da['Answer.acceptable']:
        answer_accept=set(da['Answer.acceptable'].split('|'))
    else:
        answer_accept=set()
    explanation_list=[da['Input.explanation_1'], da['Input.explanation_2'],da['Input.explanation_3'],da['Input.explanation_4'],da['Input.explanation_5']]
    sample_list=[da['Input.source_sample_1'], da['Input.source_sample_2'],da['Input.source_sample_3'],da['Input.source_sample_4'],da['Input.source_sample_5']]
    if id_ not in fse_csqa_train_data_dict.keys():
        fse_csqa_train_data_dict[id_]={"index":kk,"id":id_,"question":da["Input.question"],'answer':da['Input.gold_label'],"accept_set_list":[answer_accept],"explanation_list":explanation_list,"sample_list":sample_list,"greedy_expanation":explanation_list[sample_list.index('4')]}
        #[[kk, id_, answer_accept, explanation_list]]
    else:
        fse_csqa_train_data_dict[id_]["accept_set_list"].append(answer_accept)
        
    if len(fse_csqa_train_data_dict[id_]["accept_set_list"])==3:
        fse_csqa_train_data_dict[id_]["common_expl_list"]=[]
        common_accept_expl_sample=set.intersection(fse_csqa_train_data_dict[id_]["accept_set_list"][0], 
                                                   fse_csqa_train_data_dict[id_]["accept_set_list"][1], 
                                                   fse_csqa_train_data_dict[id_]["accept_set_list"][2])
        for idx in list(common_accept_expl_sample):
            idx=int(idx)-1
            fse_csqa_train_data_dict[id_]["common_expl_list"].append(fse_csqa_train_data_dict[id_]["explanation_list"][idx])

    


Using custom data configuration default-ef5bde298eedd1cc
Reusing dataset csv (/home/huangyongfeng/.cache/huggingface/datasets/csv/default-ef5bde298eedd1cc/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

ValueError: '4' is not in list

In [None]:
from tqdm import tqdm
qae_list = []
score_list = []
for da in tqdm(rationale_pair_dev_data, total=len(rationale_pair_dev_data)):
    id_ = da['id']
    if id_ in fse_csqa_dev_data_dict.keys():
        explanation_list = fse_csqa_dev_data_dict[id_]['explanation_list']
        greedy_expanation = fse_csqa_dev_data_dict[id_]['greedy_expanation']
    else:
        explanation_list = fse_csqa_train_data_dict[id_]['explanation_list']
        greedy_expanation = fse_csqa_train_data_dict[id_]['greedy_expanation']
    qae = "{} answer: {} explanation: {}".format(da['question'], 
                                                 da['answer'], 
                                                 greedy_expanation) 
    
    

    scores = get_scores(
        [qae],
        '3b',
        device='cuda:0',
        batch_size=1,
        verbose=False)
    score_list.append(np.max(scores))
#     if scores[0] > 0.7 or scores[0] < 0.2:
#         print("question: {}".format(da['question']))
#         print("answer: {}".format(da['answer']))
#         print("common_expl_list: {}".format(da['common_expl_list']))
#         print("generated_explanation: {}".format(da['generated_explanation']))
#         print("score: {}".format(scores[0]))

In [15]:
np.mean(score_list), np.median(score_list)

(0.5661890743142781, 0.6518362164497375)

In [9]:
from tqdm import tqdm
qae_list = []
score_list = []
for da in tqdm(rationale_pair_dev_data, total=len(rationale_pair_dev_data)):
    id_ = da['id']
    if id_ in fse_csqa_dev_data_dict.keys():
        explanation_list = fse_csqa_dev_data_dict[id_]['explanation_list']
    else:
        explanation_list = fse_csqa_train_data_dict[id_]['explanation_list']
    qae = ["{} answer: {} explanation: {}".format(da['question'], 
                                                 da['answer'], 
                                                 expl) for expl in explanation_list]
    
    

    scores = get_scores(
        qae,
        '3b',
        device='cuda:0',
        batch_size=1,
        verbose=False)
    score_list.append(np.max(scores))
#     if scores[0] > 0.7 or scores[0] < 0.2:
#         print("question: {}".format(da['question']))
#         print("answer: {}".format(da['answer']))
#         print("common_expl_list: {}".format(da['common_expl_list']))
#         print("generated_explanation: {}".format(da['generated_explanation']))
#         print("score: {}".format(scores[0]))

 78%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                                     | 156/201 [00:36<00:10,  4.33it/s]


KeyboardInterrupt: 

In [None]:
np.mean(score_list), np.median(score_list)

## evaluate generated rationale with bert-score

In [None]:
import datasets
import numpy as np
bertscore_metric = datasets.load_metric("bertscore")
rouge_metric = datasets.load_metric('rouge')
bleu_metric = datasets.load_metric('sacrebleu')

In [None]:
import pdb

bert_scores = []
bleu_scores = []
rouge1_scores = []
rouge2_scores = []
rougeL_scores = []

for da in tqdm(rationale_pair_dev_data, total=len(rationale_pair_dev_data)):
    generated_expl = da['generated_explanation']
    common_expl_list = da['common_expl_list']
    pred_expl = generated_expl.split("<extra_id_0> ")[1].split("<extra_id_1>")[0]
    list_gold_expl = [l.lower() for l in common_expl_list]
    
    bert_score = bertscore_metric.compute(predictions=[pred_expl.lower()], references=[list_gold_expl], lang="en")["f1"][0]*100
    bleu_score = bleu_metric.compute(predictions=[pred_expl.lower()], references=[list_gold_expl])['score']
    rouge_score = rouge_metric.compute(predictions=[pred_expl.lower()]*len(list_gold_expl), references=list_gold_expl)
    rouge1_score = rouge_score["rouge1"].mid.fmeasure
    rouge2_score = rouge_score["rouge2"].mid.fmeasure
    rougeL_score = rouge_score["rougeL"].mid.fmeasure
    bert_scores.append(bert_score)
    bleu_scores.append(bleu_score)
    rouge1_scores.append(rouge1_score)
    rouge2_scores.append(rouge2_score)
    rougeL_scores.append(rougeL_score)
    
    
#     #print(generated_expl)
#     #print(generated_expl.split("<extra_id_0> ")[1].split("<extra_id_1>")[0])
#     instance_bertscores = []
#     for gold_expl in list_gold_expl: 
#         score = bertscore_metric.compute(predictions=[pred_expl.lower()]*len(), references=[gold_expl.lower()], lang="en")["f1"][0]*100
#         instance_bertscores.append(score)
#     bertscores.append(np.mean(instance_bertscores))
    
#     bleuscore = bleu_score(pred_expl, list_gold_expl)
#     bleuscores.append(bleuscore)
    
#     rougescore = rouge(pred_expl, list_gold_expl)
#     rouge1_scores.append(rougescore['rouge1_fmeasure'].numpy()[0])
#     rouge2_scores.append(rougescore['rouge2_fmeasure'].numpy()[0])
#     rougeL_scores.append(rougescore['rougeL_fmeasure'].numpy()[0])
    
    


    #pdb.set_trace()

In [None]:
print("bert_score: {}".format(np.mean(bert_scores)))
print("bleu_score: {}".format(np.mean(bleu_scores)))
print("rouge1_score: {}".format(np.mean(rouge1_scores)))
print("rouge2_score: {}".format(np.mean(rouge2_scores)))
print("rougeL_score: {}".format(np.mean(rougeL_scores)))


In [None]:
rouge_score

In [None]:
bertscore

In [None]:
bertscores

In [None]:
np.mean(score_list)