In [1]:
# import wandb
# wandb.login()

# %env WANDB_PROJECT=evaluate_LM_with_rationalization

In [2]:
# import gpt3
import logging
import math
import os
import pdb
from typing import List, Dict, Any, NewType

InputDataClass = NewType("InputDataClass", Any)
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from transformers import (
    T5Config,
    T5ForConditionalGeneration,
    T5Tokenizer,
    HfArgumentParser,
    TrainingArguments,
    set_seed,
    EarlyStoppingCallback
)
from transformers.trainer_utils import EvaluationStrategy
from transformers.integrations import TensorBoardCallback
import transformers
from transformers import Trainer

#from feature_conversion_methods import format_instance

from custom_args import (
    DataTrainingArguments,
    ModelArguments
)
from metrics import evaluate
import torch
import datasets
import git
import time
from datetime import datetime
import sys
from tqdm import trange
import random 
import pandas as pd 
import jsonlines
from copy import deepcopy 

logger = logging.getLogger(__name__)
transformers.logging.set_verbosity_info()
import re
def set_global_logging_level(level=logging.ERROR, prefices=[""]):
    """
    Override logging levels of different modules based on their name as a prefix.
    It needs to be invoked after the modules have been loaded so that their loggers have been initialized.

    Args:
        - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
        - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
          Default is `[""]` to match all active loggers.
          The match is a case-sensitive `module_name.startswith(prefix)`
    """
    prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
    for name in logging.root.manager.loggerDict:
        if re.match(prefix_re, name):
            logging.getLogger(name).setLevel(level)
set_global_logging_level(logging.ERROR, ["datasets"])


CONFIG_MAPPING = {"t5": T5Config}
MODEL_MAPPING = {"t5": T5ForConditionalGeneration}
TOKENIZER_MAPPING = {"t5": T5Tokenizer}


def set_other_seeds(seed):
    torch.backends.cudnn.benchmark = False
    #torch.backends.cudnn.deterministic = True
    os.environ['PYTHONHASHSEED'] = str(seed)

# inspired by DefaultDataCollator from:
# https://github.com/huggingface/transformers/blob/master/src/transformers/data/data_collator.py
# modified to perform batch-level padding.
class SequenceCollator:
    def __init__(self, model, pad_token):
        self.model = model
        self.pad_token_mapping = {
            "labels": -100,
            "attention_mask": 0,
            "decoder_attention_mask": 0,
            "input_ids": pad_token,
        }

        self.columns = [
            "input_ids",
            "attention_mask",
            "labels",
            "decoder_attention_mask",
        ]

    def __call__(self, examples: List[Dict[str, InputDataClass]]) -> Dict[str, torch.Tensor]:
        # re-format inputs for training
        batch = {}
        for key in examples[0].keys():
            if key in self.columns:
                tmp_list = []
                for item in examples:
                    tmp_list.append(item[key])

                # pad lists to max length
                if isinstance(tmp_list[0], list):
                    max_length = max(map(len, tmp_list))
                    tmp_list = [
                        el + [self.pad_token_mapping[key]] * (max_length - len(el))
                        for el in tmp_list
                    ]

                batch[key] = torch.tensor(tmp_list, dtype=torch.long)
        return batch


In [3]:
from collections import defaultdict
import random
"""
Example-to-Feature conversion methods
Modified from
https://github.com/salesforce/cos-e/blob/master/code/generation/train_commonsenseqa_v1.0.py and ""_v1.11.py (identical)
as well as Tensorflow code for WTF?: 
https://github.com/google-research/google-research/blob/master/wt5/wt5/preprocessors.py
"""
# This code is based on https://github.com/allenai/label_rationale_association/blob/main/feature_conversion_methods.py

unified_qa_esnli_label_mapping = {0: 'yes', 1: 'maybe', 2: 'no'}
unified_qa_esnli_label_mapping_upper = {0: 'Yes', 1: 'Maybe', 2: 'No'} 
wt5_esnli_label_mapping = {0: 'entailment', 1: 'neutral', 2: 'contradiction'} 
unified_qa_sbic_label_mapping = {"offensive": 'Yes', "not offensive": 'No'}

def format_instance(
        example,
        tokenizer,
        explanation_sep,
        max_seq_length=None,
        datasource=None,
        io_format=None, 
):
    assert datasource in {"cos_e", "esnli", "sbic", "sensemaking", "ecqa"}

    if datasource in ["cos_e", "ecqa"]:
        input_string, answer_string = cqa_formatting(example, io_format, explanation_sep, datasource)
    elif datasource == "esnli":
        input_string, answer_string = esnli_formatting(example, io_format, explanation_sep)
    elif datasource == 'sbic':
        input_string, answer_string = sbic_formatting(example, io_format, explanation_sep)
    elif datasource == 'sensemaking':
        input_string, answer_string = sensemaking_formatting(example, io_format, explanation_sep)
    else:
        raise ValueError("Unknown task. Currently supported: esnli, cos_e, sbic, sensemaking, ecqa.")
    
    if 'unified' in io_format and 'unifew' not in io_format:
        input_string += '</s>'

    input_string = ' '.join(input_string.split())
    answer_string = ' '.join(answer_string.split())

    input_string = ' '.join(input_string.split())
    answer_string = ' '.join(answer_string.split())

    encodings = tokenizer.encode_plus(
        input_string,
        max_length=max_seq_length,
        pad_to_max_length=False,
        return_token_type_ids=False,
        return_attention_mask=True,
    )


    # note even with "lm_labels.shift_right()", the decoder attention mask length is still correct since we remove the last token
    dec = tokenizer.encode_plus(
        answer_string,
        max_length=max_seq_length,
        pad_to_max_length=False,
        return_token_type_ids=False,
        return_attention_mask=True,
    )

    encodings["labels"] = dec["input_ids"]
    encodings["decoder_attention_mask"] = dec["attention_mask"]
    encodings["question_encoding"] = encodings["input_ids"]

    #return encodings
    return {**example, **encodings}

#这里很简单 定义好输入输出string就可以的
def cqa_formatting(item, io_format, explanation_sep, datasource):
    question = item["question"]
    answer = item["answer"]
    abstr_expl = item["abstractive_explanation"].lower() if datasource == 'cos_e' else item["explanation"].lower()


    if io_format == 't5_fewshot_infilling_with_choices':
        input_string = f"explain {datasource} question: {question} choice: " + " choice: ".join(item["choices"]) + f" <extra_id_0> {explanation_sep} <extra_id_1>"
        answer_string = f"<extra_id_0> {answer} <extra_id_1> {abstr_expl} <extra_id_2>"
    elif io_format == 't5_fewshot_infilling_more_natural':
        input_string = f"explain {datasource} question: {question} choice: " + " choice: ".join(item["choices"]) + f" The answer is <extra_id_0> {explanation_sep} <extra_id_1>"
        answer_string = f"<extra_id_0> {answer} <extra_id_1> {abstr_expl} <extra_id_2>"
    elif io_format == "squad": 
        input_string = f"explain {datasource} question: {question} context: " + ', '.join(item['choices']) # explain cos_e question: When getting in shape you need to have this in between workouts? context: give up, period of recovery, jogging
        answer_string = f"{answer} {explanation_sep} {abstr_expl}" # period of recovery because without a period of recovery you will not get any gains.
    elif io_format == "record": 
        # might not work because cos_e doesn't have a passage 
        input_string = f"explain {datasource} query: {question} entities: " + ', '.join(item['choices']) # explain cos_e query: When getting in shape you need to have this in between workouts? entities: give up, period of recovery, jogging
        answer_string = f"{answer} {explanation_sep} {abstr_expl}" # period of recovery because without a period of recovery you will not get any gains.
    elif io_format == 'unifiedqa_matching':
        choice_ids = ['(A)', '(B)', '(C)', '(D)', '(E)']
        input_string = f'explain {question.lower()} \\n'
        for choice_id, choice in zip(choice_ids, item["choices"]):
            input_string += f' {choice_id} {choice.lower()}'
        answer_string = f"{answer.lower()} {explanation_sep} {abstr_expl.lower()}"
        answer_string = answer_string.lower()
    elif io_format == 't5_fewshot_infilling_without_choices_use_refined_expl':
        input_string = f"explain {datasource} question: {question} choice: " + " choice: ".join(item["choices"]) + f" <extra_id_0> {explanation_sep} <extra_id_1>"
        input_string = f"explain {datasource} question: {question} answer: {answer}" + f" {explanation_sep} <extra_id_0>"
        answer_string = f"<extra_id_0> {item['our_explanation']} <extra_id_1>"

    else:
        raise ValueError("The IO format is not supported. Choose `standard` or `masked_cause_generate`.")
    
    return input_string, answer_string


def esnli_formatting(item, io_format, explanation_sep):

    premise = item["premise"]
    hypothesis = item["hypothesis"]
    answer = unified_qa_esnli_label_mapping[item["label"]] if 'unified' in io_format else wt5_esnli_label_mapping[item["label"]]
    abstr_expl = item["explanation_1"].lower() 
    # Dev/test instances have more than one explanation annotated; merge them into one sequence separated by [SEP] 
    for k in [2,3]:
        if f"explanation_{k}" in item and item[f'explanation_{k}']!='': 
            abstr_expl += f" [SEP] {item[f'explanation_{k}'].lower()}"

    if io_format == 'standard':
        input_string = f"explain nli hypothesis: {hypothesis} premise: {premise}"
        answer_string = f"{answer} {explanation_sep} {abstr_expl}"
    elif io_format == 't5_fewshot_infilling':
        input_string = f"explain nli hypothesis: {hypothesis} premise: {premise} <extra_id_0> {explanation_sep} <extra_id_1>"
        answer_string = f"<extra_id_0> {answer} <extra_id_1> {abstr_expl} <extra_id_2>"
    elif io_format == 't5_fewshot_infilling_more_natural':
        input_string = f"explain nli hypothesis: {hypothesis} premise: {premise} This is <extra_id_0> {explanation_sep} <extra_id_1>"
        answer_string = f"<extra_id_0> {answer} <extra_id_1> {abstr_expl} <extra_id_2>"
    elif io_format == "squad": 
        input_string = f"explain nli question: Is this entailment? context: {hypothesis} {premise}"  
        answer_ynm = unified_qa_esnli_label_mapping[item["label"]]
        answer_string = f"{answer_ynm} {explanation_sep} {abstr_expl}" 
    elif io_format == "squad_endswith_what":
        input_string = f"explain nli question: What is this? context: {hypothesis} {premise}"  
        answer_string = f"{answer} {explanation_sep} {abstr_expl}"  
    elif io_format == "squad_nli_mix": 
        input_string = f"explain nli question: Is this entailment? context: hypothesis: {hypothesis} premise: {premise}"  
        answer_ynm = unified_qa_esnli_label_mapping[item["label"]]
        answer_string = f"{answer_ynm} {explanation_sep} {abstr_expl}"  
    elif io_format == "squad_nli_mix_endswith_what":  
        input_string = f"explain nli question: What is this? context: hypothesis: {hypothesis} premise: {premise}"  
        answer_string = f"{answer} {explanation_sep} {abstr_expl}"   
    elif io_format == 'unifiedqa_unifew':
        hypothesis = hypothesis.lower().rstrip('.')
        unified_qa_esnli_label_mapping_upper = {0: 'Yes', 1: 'Maybe', 2: 'No'}
        answer = unified_qa_esnli_label_mapping_upper[item["label"]]
        input_string = f'explain {premise} Is {hypothesis}? \\n (A) Yes (B) Maybe (C) No'
        answer_string = f"{answer} {explanation_sep} {abstr_expl}"  
    elif io_format == 'unifiedqa_unifew_nli_mix':
        premise = premise.lower().rstrip('.')
        unified_qa_esnli_label_mapping_upper = {0: 'Yes', 1: 'Maybe', 2: 'No'}
        input_string = f'explain hypothesis: {hypothesis} Is premise: {premise}? \\n (A) Yes (B) Maybe (C) No'
        answer_string = f"{answer} {explanation_sep} {abstr_expl}"  
    elif io_format == 'unifiedqa_ynm': 
        input_string = f'explain is this entailment? \\n {hypothesis.lower()} {premise.lower()}'  
        answer = unified_qa_esnli_label_mapping[item["label"]]
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unifiedqa_snli_mix_ynm': 
        input_string = f'explain is this entailment? \\n hypothesis: {hypothesis.lower()} premise: {premise.lower()}' 
        answer = unified_qa_esnli_label_mapping[item["label"]]
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unifiedqa_snli_mix_ynm_with_choices': 
        input_string = f'explain is this entailment? \\n (A) yes (B) maybe (C) no \\n hypothesis: {hypothesis.lower()} premise: {premise.lower()}'  
        answer = unified_qa_esnli_label_mapping[item["label"]]
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unifiedqa_what_v2': 
        input_string = f'explain what is this? \\n {hypothesis.lower()} {premise.lower()}'  
        answer = wt5_esnli_label_mapping[item["label"]]
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unifiedqa_snli_mix_what_v2': 
        input_string = f'explain what is this? \\n hypothesis: {hypothesis.lower()} premise: {premise.lower()}'  
        answer = wt5_esnli_label_mapping[item["label"]]
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unifiedqa_snli_mix_what_with_choices_v2': 
        input_string = f'explain what is this? \\n (A) entailment (B) neutral (C) contradiction \\n hypothesis: {hypothesis.lower()} premise: {premise.lower()}'  
        answer = wt5_esnli_label_mapping[item["label"]]
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"
    elif io_format == 't5_fewshot_infilling_without_choices_use_refined_expl':
        answer = wt5_esnli_label_mapping[item["label"]]
        input_string = f'explain why the relation is {answer} between hypothesis: {hypothesis.lower()} and premise: {premise.lower()}'  
        #input_string = f"explain {datasource} question: {question} answer: {answer}" + f" {explanation_sep} <extra_id_0>"
        answer_string = f"<extra_id_0> {item['our_explanation']} <extra_id_1>"
        #pdb.set_trace()
    else:
        raise ValueError("The IO format is not supported.")

    return input_string, answer_string


def sbic_formatting(item, io_format, explanation_sep):
    # We pre-processed the SBIC dataset such that we join multiple implied statements with the [SEP] token for dev/test instances 
    # Each annotation in the training split is a separate instance
    post = item['post']
    answer = unified_qa_sbic_label_mapping[item["offensiveYN"]] if 'unified' in io_format else item["offensiveYN"].replace("not offensive", "not_offensive")
    abstr_expl = item["targetStereotype"]
    

    if io_format == 'cola':
        input_string = f"explain sbic sentence: {post}"
        answer_string = f"{answer} {explanation_sep} {abstr_expl}"  
    elif io_format == 't5_fewshot_infilling':
        input_string = f"explain sbic sentence: {post} <extra_id_0> {explanation_sep} <extra_id_1>"
        answer_string = f"<extra_id_0> {answer} <extra_id_1> {abstr_expl} <extra_id_2>"
    elif io_format == 't5_fewshot_infilling_bool':
        answer_bool = 'True' if item["offensiveYN"] == 'offensive' else 'False'
        input_string = f"explain sbic sentence: {post} It is <extra_id_0> that this is offensive {explanation_sep} <extra_id_1>"
        answer_string = f"<extra_id_0> {answer} <extra_id_1> {abstr_expl} <extra_id_2>"
    elif io_format == 't5_fewshot_infilling_more_natural':
        input_string = f"explain sbic sentence: {post} This is <extra_id_0> {explanation_sep} <extra_id_1>"
        answer_string = f"<extra_id_0> {answer} <extra_id_1> {abstr_expl} <extra_id_2>"
    elif io_format == 'squad_yn': 
        input_string = f"explain sbic question: Is this offensive? context: {post}" # explain sbic question: Is this offensive? context: It's only happened to you like two times.
        answer_yn = unified_qa_sbic_label_mapping[item["offensiveYN"]] 
        answer_string = f"{answer_yn} {explanation_sep} {abstr_expl}" #  No because this post does not imply anything offensive
    elif io_format == 'squad_yn_with_tags': 
        input_string = f"explain sbic question: Is this offensive? context: post: {post}" # explain sbic question: Is this offensive? context: It's only happened to you like two times.
        answer_yn = unified_qa_sbic_label_mapping[item["offensiveYN"]] 
        answer_string = f"{answer_yn} {explanation_sep} {abstr_expl}" #  No because this post does not imply anything offensive
    elif io_format == 'squad_what': 
        input_string = f"explain sbic question: What is this? context: {post}" # explain sbic question: This is what? context: It's only happened to you like two times.
        answer = item["offensiveYN"].replace("not offensive", "not_offensive")        
        answer_string = f"{answer} {explanation_sep} {abstr_expl}" # not_offensive because this post does not imply anything offensive    
    elif io_format == 'squad_what_with_tags': 
        input_string = f"explain sbic question: What is this? context: post: {post}" # explain sbic question: This is what? context: It's only happened to you like two times.
        answer = item["offensiveYN"].replace("not offensive", "not_offensive")        
        answer_string = f"{answer} {explanation_sep} {abstr_expl}" # not_offensive because this post does not imply anything offensive    
    elif io_format == 'unified_qa_yn':
        input_string = f'explain is this offensive? \\n {post.lower()}'
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}" 
    elif io_format == 'unified_qa_yn_with_tags':
        input_string = f'explain is this offensive? \\n post: {post.lower()}'
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unified_qa_yn_with_choices':
        input_string = f'explain is this offensive? \\n (A) yes (B) no \\n {post.lower()}'
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"
    elif io_format == 'unified_qa_yn_with_choices_and_tags':
        input_string = f'explain is this offensive? \\n (A) yes (B) no \\n post: {post.lower()}'
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unified_qa_what':
        input_string = f'explain what is this? \\n {post.lower()}'
        answer = item["offensiveYN"].replace("not offensive", "not_offensive")
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unified_qa_what_with_tags':
        input_string = f'explain what is this? \\n post: {post.lower()}'
        answer = item["offensiveYN"].replace("not offensive", "not_offensive")
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unified_qa_what_with_choices':
        input_string = f'explain what is this? \\n (A) offensive (B) not_offensive \\n {post.lower()}'
        answer = item["offensiveYN"].replace("not offensive", "not_offensive")
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"
    elif io_format == 'unified_qa_what_with_choices_and_tags':
        input_string = f'explain what is this? \\n (A) offensive (B) not_offensive \\n post: {post.lower()}'
        answer = item["offensiveYN"].replace("not offensive", "not_offensive")
        answer_string = f"{answer} {explanation_sep} {abstr_expl.lower()}"  
    elif io_format == 'unifiedqa_unifew':
        input_string = f"Topic? \\n (A) offensive (B) not_offensive \\n {post}"
        answer = item["offensiveYN"].replace("not offensive", "not_offensive")
        answer_string = f"{answer} {explanation_sep} {abstr_expl}"
    else:
        raise ValueError("The IO format is not supported. Choose `standard` or `masked_cause_generate`.")

    input_string = ' '.join(input_string.split())
    answer_string = ' '.join(answer_string.split())
    return input_string, answer_string

def sensemaking_formatting(item, io_format, explanation_sep):
    # TODO: explore whether removing periods makes difference? 
    sent0 = item['sent0']
    sent1 = item['sent1']
    nonsensical_sentence = str(int(item['label'])+1)
    explanation = item['explanation'].lower()

    if io_format == 'copa_with_question':
        input_string = f"explain sensemaking choice1: {sent0} choice2: {sent1} question: nonsensical"
        answer_string = f"choice{nonsensical_sentence} {explanation_sep} {explanation}"
    elif io_format == 'copa_bool':  
        answer_bool = str(bool(int(item['label']))) # True if choice2 is more nonsensical    
        input_string = f"explain sensemaking choice1: {sent0} choice2: {sent1} Less common is choice2"
        answer_string = f"{answer_bool} {explanation_sep} {explanation}"
    elif io_format == 't5_fewshot_infilling':  
        input_string = f"explain sensemaking choice1: {sent0} choice2: {sent1} <extra_id_0> {explanation_sep} <extra_id_1>"
        answer_string = f"<extra_id_0> choice{nonsensical_sentence} <extra_id_1> {explanation} <extra_id_2>"
    elif io_format == 't5_fewshot_infilling_bool':  
        answer_bool = str(bool(int(item['label']))) # True if choice2 is more nonsensical    
        input_string = f"explain sensemaking choice1: {sent0} choice2: {sent1} It is <extra_id_0> that choice2 is less common {explanation_sep} <extra_id_1>"
        answer_string = f"<extra_id_0> {answer_bool} <extra_id_1> {explanation} <extra_id_2>"
    elif io_format == "squad_yn": 
        input_string = f"explain sensemaking question: Is choice2 more nonsensical? context: choice1: {sent0} choice2: {sent1}" # explain sensemaking question: What is nonsensical, choice1 or choice2? context: choice1: All state flowers are the scarlet carnation. choice2: The New Jersey state flower is the scarlet carnation
        answer = "Yes" if bool(int(item['label'])) else "No"
        answer_string = f"{answer} {explanation_sep} {explanation}" #  choice1 because state flowers are unique to each state.  
    elif io_format == "squad_yn_no_tags": 
        input_string = f"explain sensemaking question: Is choice2 more nonsensical? context: {sent0} {sent1}" # explain sensemaking question: What is nonsensical, choice1 or choice2? context: choice1: All state flowers are the scarlet carnation. choice2: The New Jersey state flower is the scarlet carnation
        answer = "Yes" if bool(int(item['label'])) else "No"
        answer_string = f"{answer} {explanation_sep} {explanation}" #  choice1 because state flowers are unique to each state.  
    elif io_format == "squad_what": 
        input_string = f"explain sensemaking question: What is more nonsensical? context: choice1: {sent0} choice2: {sent1}" # explain sensemaking question: What is nonsensical, choice1 or choice2? context: choice1: All state flowers are the scarlet carnation. choice2: The New Jersey state flower is the scarlet carnation
        answer_string = f"choice{nonsensical_sentence} {explanation_sep} {explanation}" #  choice1 because state flowers are unique to each state.  
    elif io_format == "squad_what_no_tags": 
        input_string = f"explain sensemaking question: What is more nonsensical? context: {sent0} {sent1}" # explain sensemaking question: What is nonsensical, choice1 or choice2? context: choice1: All state flowers are the scarlet carnation. choice2: The New Jersey state flower is the scarlet carnation
        answer_string = f"choice{nonsensical_sentence} {explanation_sep} {explanation}" #  choice1 because state flowers are unique to each state.  
    elif io_format == "record": 
        input_string = f"explain sensemaking query: What is more nonsensical? entities: choice1, choice2 passage: choice1: {sent0} choice2: {sent1}" # explain sensemaking query: What is nonsensical? entities: choice1, choice2 passage: choice1: All state flowers are the scarlet carnation. choice2: The New Jersey state flower is the scarlet carnation.
        answer_string = f"choice{nonsensical_sentence} {explanation_sep} {explanation}" # choice1 because state flowers are unique to each state.
    elif io_format == 'unifiedqa_yn_with_choices':
        answer = "yes" if bool(int(item['label'])) else "no"
        input_string = f'explain is choice2 more nonsensical? \\n (A) yes (B) no \\n choice1: {sent0.lower()} choice2: {sent1.lower()}'
        answer_string = f"{answer} {explanation_sep} {explanation.lower()}" 
    elif io_format == 'unifiedqa_yn':
        answer = "yes" if bool(int(item['label'])) else "no"
        input_string = f'explain is choice2 more nonsensical? \\n choice1: {sent0.lower()} choice2: {sent1.lower()}'
        answer_string = f"{answer} {explanation_sep} {explanation.lower()}"  
    elif io_format == 'unifiedqa_yn_no_tags':
        answer = "yes" if bool(int(item['label'])) else "no"
        input_string = f'explain is choice2 more nonsensical? \\n {sent0.lower()} {sent1.lower()}'
        answer_string = f"{answer} {explanation_sep} {explanation.lower()}"  
    elif io_format == 'unifiedqa_what_with_choices':
        nonsensical_sentence = str(int(item['label'])+1)
        input_string = f'explain what is more nonsensical? \\n (A) choice1 (B) choice2 \\n choice1: {sent0.lower()} choice2: {sent1.lower()}'
        answer_string = f"choice{nonsensical_sentence} {explanation_sep} {explanation.lower()}"  # use " BECAUSE "
    elif io_format == 'unifiedqa_what':
        nonsensical_sentence = str(int(item['label'])+1)
        input_string = f'explain what is more nonsensical? \\n choice1: {sent0.lower()} choice2: {sent1.lower()}'
        answer_string = f"choice{nonsensical_sentence} {explanation_sep} {explanation.lower()}"  # use " BECAUSE "
    elif io_format == 'unifiedqa_what_no_tags':
        nonsensical_sentence = str(int(item['label'])+1)
        input_string = f'explain what is more nonsensical? \\n {sent0.lower()} {sent1.lower()}'
        answer_string = f"choice{nonsensical_sentence} {explanation_sep} {explanation.lower()}"  # use " BECAUSE "


    input_string = ' '.join(input_string.split())
    answer_string = ' '.join(answer_string.split())
    return input_string, answer_string

In [4]:
og_start_time = time.time()

#parser = HfArgumentParser(
#    (ModelArguments, DataTrainingArguments, TrainingArguments)
#)
parser = HfArgumentParser(
    (ModelArguments, DataTrainingArguments, TrainingArguments)
)

model_args, data_args, training_args, unused_args = parser.parse_args_into_dataclasses(
    ["--model_type", "t5-3b",
     "--tokenizer_name", "t5-3b",
     "--task_name", "esnli", 
     "--output_dir", "./esnli_output_t5_base", 
     "--n_shots", "10",
     "--do_train", "True"], return_remaining_strings=True)
if unused_args != []:
    raise ValueError(f"Received unused arguments: {unused_args}")
# make sure only one dataset split pick if manually specifying evaluation file

if model_args.use_gpt3:
    assert training_args.do_train
    assert not training_args.do_eval
    assert data_args.generations_filepath is None
    if data_args.gpt3_max_eval_size is not None:
        assert data_args.gpt3_max_eval_size <= data_args.fewshot_eval_size
        assert data_args.gpt3_max_eval_size % 2 == 0
        assert data_args.gpt3_max_eval_size % 3 == 0

if data_args.generations_filepath is not None:
    training_args.do_train = False
    training_args.do_eval = False
    if "train" in data_args.generations_filepath:
        data_args.train_predict = True
        data_args.test_predict = False
        data_args.dev_predict = False
    elif "test" in data_args.generations_filepath:
        data_args.train_predict = False
        data_args.test_predict = True
        data_args.dev_predict = False
    elif "validation" in data_args.generations_filepath:
        data_args.train_predict = False
        data_args.test_predict = False
        data_args.dev_predict = True

if not training_args.do_train and data_args.generations_filepath is None:
    if not model_args.pretrained_model_file:
        raise Exception(
            "if not training a model from scratch, must specify a trained model to load for evaluation"
        )

if training_args.do_train:
    # create a save directory and a logfile
    training_args.output_dir = os.path.join(
        training_args.output_dir, datetime.now().strftime("%m%d%y_%H%M%S")
    )
    training_args.logging_dir = training_args.output_dir
    assert not os.path.exists(training_args.output_dir)
    os.makedirs(training_args.output_dir)

    if (
            os.path.exists(training_args.output_dir)
            and os.listdir(training_args.output_dir)
            and training_args.do_train
            and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )
    handlers = [
        logging.FileHandler(os.path.join(training_args.output_dir, "logger.log")),
        logging.StreamHandler(),
    ]
else:
    # don't overwrite existing logfile or create new directory
    training_args.output_dir = model_args.pretrained_model_file
    handlers = [logging.StreamHandler()]

# Setup logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    handlers=handlers,
)
logger.warning(
    "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
    training_args.local_rank,
    training_args.device,
    training_args.n_gpu,
    bool(training_args.local_rank != -1),
    training_args.fp16,
)
logger.info("Save path: %s" % training_args.output_dir)

# get git hash and branch where deployed
repo = git.Repo(search_parent_directories=True)
git_hash = repo.head.object.hexsha
git_branch = repo.active_branch.name
logger.info("Git branch: %s" % git_branch)
logger.info("Git hash: %s" % git_hash)

model_class = "t5"
assert data_args.task_name in {"cos_e", "esnli", "sbic", "sensemaking", "ecqa"}

if training_args.do_train:
    # write command and args to file
    with open(
            os.path.join(training_args.output_dir, "commandline_args.txt"), "w"
    ) as f:
        f.write("Git branch: " + git_branch + "\n")
        f.write("Git hash: " + git_hash + "\n")
        f.write("Command:\n")
        f.write("\n".join(sys.argv[1:]))

# Set seed
set_seed(training_args.seed)
set_other_seeds(training_args.seed)

# Load pretrained model and tokenizer
#
# Distributed training:
# The .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab.

PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
01/28/2023 15:18:44 - INFO - __main__ -   Save path: ./esnli_output_t5_base/012823_151844
01/28/2023 15:18:44 - INFO - __main__ -   Git branch: dev
01/28/2023 15:18:44 - INFO - __main__ -   Git hash: 1cbb5c3b4e53baf31cbafc20d9655c63f091f901


In [5]:
training_args.set_device = "cuda:0"

In [6]:
import logging
logger = logging.getLogger(__name__)
CONFIG_MAPPING = {"t5": T5Config}
MODEL_MAPPING = {"t5": T5ForConditionalGeneration}
TOKENIZER_MAPPING = {"t5": T5Tokenizer}
model_class = "t5"
tokenizer_name = TOKENIZER_MAPPING[model_class]
logger.info("Loading pretrained tokenizer...")
model_args.tokenizer_name='t5-base'
tokenizer = tokenizer_name.from_pretrained(model_args.tokenizer_name)#, cache_dir=model_args.cache_dir)

model = T5ForConditionalGeneration.from_pretrained("t5-base")

01/28/2023 15:18:44 - INFO - __main__ -   Loading pretrained tokenizer...
loading file https://huggingface.co/t5-base/resolve/main/spiece.model from cache at /home/huangyongfeng/.cache/huggingface/transformers/684a47ca6257e4ca71f0037771464c5b323e945fbc58697d2fad8a7dd1a2f8ba.3b69006860e7b5d0a63ffdddc01ddcd6b7c318a6f4fd793596552c741734c62d
loading file https://huggingface.co/t5-base/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/t5-base/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/t5-base/resolve/main/tokenizer_config.json from cache at None
loading file https://huggingface.co/t5-base/resolve/main/tokenizer.json from cache at /home/huangyongfeng/.cache/huggingface/transformers/90de37880b5ff5ac7ab70ff0bd369f207e9b74133fa153c163d14c5bb0116207.8627f1bd5d270a9fd2e5a51c8bec3223896587cc3cfe13edeabb0992ab43c529
loading configuration file https://huggingface.co/t5-base/resolve/main/config.json from cache at /ho

In [7]:
data_splits = {'train': None, 'validation': None, 'test': None}
original_data_splits = {'train': None, 'validation': None, 'test': None}  
data_args.io_format="t5_fewshot_infilling_without_choices_use_refined_expl"
data_args

DataTrainingArguments(task_name='esnli', early_stopping_patience=10, overwrite_cache=False, train_predict=False, test_predict=False, dev_predict=False, version_name='v1.11', generations_filepath=None, n_shots=10, fewshot_eval_size=350, io_format='t5_fewshot_infilling_without_choices_use_refined_expl', explanation_sep='explanation', data_path=None, gpt3_max_eval_size=None)

In [8]:
dataset = datasets.load_dataset(data_args.task_name)
# train_ids_list=[x['id'] for x in dataset["train"]]



  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
dataset.keys()
len(dataset['train']), len(dataset['validation']), len(dataset['test'])
dataset['train'][0].keys()
all_data_dict = {}
for da in dataset['train']:
    hyp=da['hypothesis']
    pre=da['premise']
    key_str = f'hypothesis: {hyp}; premise:{pre}'
    if key_str not in all_data_dict.keys():
        all_data_dict[key_str] = da
    else:
        print("*****")
        print(da)
        print(all_data_dict[key_str])
        print("*****")

*****
{'premise': 'Child in red and blue shirt painting a log.', 'hypothesis': 'The child is painting.', 'label': 0, 'explanation_1': '"Child is painting" is a simpler version of  "child painting a log".', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'Child in red and blue shirt painting a log.', 'hypothesis': 'The child is painting.', 'label': 0, 'explanation_1': 'Both subjects are painting', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'Two small dogs run across the green grass.', 'hypothesis': 'Two dogs are outside.', 'label': 0, 'explanation_1': 'Dogs can be small.  Grass is usually outside.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'Two small dogs run across the green grass.', 'hypothesis': 'Two dogs are outside.', 'label': 0, 'explanation_1': 'Dogs can be small. Grass is usually outside,', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'A young woman is drawing with a Sharpie marker.', 'hypothesis': 'A woman is drawin

*****
{'premise': 'Two foreign women, both wearing black, selling different types of foods.', 'hypothesis': 'Two women are selling food.', 'label': 0, 'explanation_1': 'Two women are selling food is the same as two women selling food.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'Two foreign women, both wearing black, selling different types of foods.', 'hypothesis': 'Two women are selling food.', 'label': 0, 'explanation_1': 'Two women are selling food is the same as two women selling food.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'A bald dark-skinned man in a purple tank top and jeans works with another man next to a blue bucket.', 'hypothesis': 'Two men are working together.', 'label': 0, 'explanation_1': 'THEY WORK TOGETHER WHEN WORKING.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'A bald dark-skinned man in a purple tank top and jeans works with another man next to a blue bucket.', 'hypothesis': 'Two men are working together.', 'la

*****
{'premise': 'She is playing tennis with her blue tennis clothes on, it looks hot out there.', 'hypothesis': 'She is playing tennis with her blue tennis clothes on, it looks hot out there.', 'label': 0, 'explanation_1': 'She is playing tennis with her blue tennis clothes on, it looks hot out there is the same as she is playing tennis with her blue tennis clothes on, it looks hot out there.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'She is playing tennis with her blue tennis clothes on, it looks hot out there.', 'hypothesis': 'She is playing tennis with her blue tennis clothes on, it looks hot out there.', 'label': 1, 'explanation_1': 'How hot the weather looks cannot be inferred from how she is dressed.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'A blue-eyed woman with red-hair and glasses looks up at the camera while holding a hamburger with a bite out of it, as she sits on a boat with four other young people who are drinking and eating.', '

{'premise': 'A person stands in a vast field of glacial ice.', 'hypothesis': 'A person is standing outside.', 'label': 0, 'explanation_1': 'A person needs to be outside to stand on glacial ice.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'A bunch of birds outside of a building.', 'hypothesis': 'The birds are outside.', 'label': 0, 'explanation_1': '"Outside" is a less detailed restatement of "outside a building".', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'A bunch of birds outside of a building.', 'hypothesis': 'The birds are outside.', 'label': 0, 'explanation_1': 'Birds and a bunch of birds are synonyms.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'A group of young African American kids sitting on a white wall in front of a white building.', 'hypothesis': 'A group of kids are outside.', 'label': 0, 'explanation_1': 'If the group of kids are sitting on a white wall in front of a white building, they are sitting outside.', 'e

*****
{'premise': 'A man with black hair is wearing a red and yellow shirt and has a yellow clown nose and face makeup on.', 'hypothesis': 'A man is dressed as a clown.', 'label': 0, 'explanation_1': 'The red and yellow shirt, yellow clown nose and face makeup are descriptive words in the first sentence that are attributes of the man being dressed as a clown.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'A man with black hair is wearing a red and yellow shirt and has a yellow clown nose and face makeup on.', 'hypothesis': 'A man is dressed as a clown.', 'label': 1, 'explanation_1': "Just because a man is wearing a red and yellow shirt doesn't necessarily imply he is dressed as a clown.", 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'A person is jumping up in the woods, they are wearing a red sweater and yellow pants, they are surrounded by trees and leaves.', 'hypothesis': 'A person is jumping in the woods.', 'label': 0, 'explanation_1': 'Jumping usually

*****
{'premise': 'A delighted man is jumping rope in front of a door.', 'hypothesis': 'A man is jumping rope.', 'label': 0, 'explanation_1': 'A man jump rope is jumping rope in front of a door.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'A delighted man is jumping rope in front of a door.', 'hypothesis': 'A man is jumping rope.', 'label': 0, 'explanation_1': 'A man jumping rope is delighted.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'People watching a man breakdance.', 'hypothesis': 'A man is breakdancing.', 'label': 0, 'explanation_1': 'a man doing breakdance is been watched by people breakdancing.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'People watching a man breakdance.', 'hypothesis': 'A man is breakdancing.', 'label': 0, 'explanation_1': 'a man doing breakdance is been watch by people breakdancing.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'People watching a man breakdance.', 'hypothesis': 'A man is 

*****
{'premise': "A little boy trying to brush a woman's hair.", 'hypothesis': 'A boy is brushing hair.', 'label': 0, 'explanation_1': 'HE IS BRUSHING HER HAIR WITH A BRUSH.', 'explanation_2': '', 'explanation_3': ''}
{'premise': "A little boy trying to brush a woman's hair.", 'hypothesis': 'A boy is brushing hair.', 'label': 0, 'explanation_1': 'The boy is brushing because he is trying to brush.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'Man in an orange shirt with tattoos walking two large poodles.', 'hypothesis': 'A man is walking two dogs.', 'label': 0, 'explanation_1': 'Man in an orange shirt with tattoos walking two large poodles, would not be A man is walking two dogs.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'Man in an orange shirt with tattoos walking two large poodles.', 'hypothesis': 'A man is walking two dogs.', 'label': 0, 'explanation_1': 'poodles almost the same as dogs.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'p

*****
{'premise': 'Two dogs fighting in the street.', 'hypothesis': 'Two dogs are fighting.', 'label': 0, 'explanation_1': 'Dogs usually fight in the street.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'Two dogs fighting in the street.', 'hypothesis': 'Two dogs are fighting.', 'label': 0, 'explanation_1': 'Dog are in street.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'The image links are broken.', 'hypothesis': 'The image links are broken.', 'label': 0, 'explanation_1': 'Links to the image are broken', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'The image links are broken.', 'hypothesis': 'The image links are broken.', 'label': 2, 'explanation_1': 'Both sentences are the exact same. They both say the image links are broken.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'The image links are broken.', 'hypothesis': 'The image links are broken.', 'label': 1, 'explanation_1': 'this can be inferred from the image links ar

*****
{'premise': 'Two soccer player are getting hit in the head with a ball during a game, as a crowd watches.', 'hypothesis': 'A boy drips ketchup from his fries.', 'label': 2, 'explanation_1': 'One the players are getting hit in the head with a ball while the other the boy drips ketchup.  Also, one sentence is about a boy while the other is about two soccer players.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'Two soccer player are getting hit in the head with a ball during a game, as a crowd watches.', 'hypothesis': 'A boy drips ketchup from his fries.', 'label': 2, 'explanation_1': 'One the players are getting hit in the head with a ball while the other the boy drips ketchup.  Also, one sentence is about a boy while the other is about two soccer players.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'Two soccer player are getting hit in the head with a ball during a game, as a crowd watches.', 'hypothesis': 'Players are getting bludgeoned with the

*****
*****
{'premise': 'A surfer rides the waves.', 'hypothesis': 'A person is surfing.', 'label': 0, 'explanation_1': "surfing and wave is the same because you can't surf without wave.", 'explanation_2': '', 'explanation_3': ''}
{'premise': 'A surfer rides the waves.', 'hypothesis': 'A person is surfing.', 'label': 0, 'explanation_1': 'A surfer is a person, and to ride the waves is to surf.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'Four men in white T-shirts, blue jeans, and hats talking.', 'hypothesis': 'Four men are talking.', 'label': 0, 'explanation_1': 'they are talking in both', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'Four men in white T-shirts, blue jeans, and hats talking.', 'hypothesis': 'Four men are talking.', 'label': 0, 'explanation_1': 'it states in both of he sentences that they are talking', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'A young man in a green t-shirt strums his guitar while sitting cross-l

*****
{'premise': 'A shirtless man with sunglasses and a helmet is riding a motorcycle next to traffic cones.', 'hypothesis': 'A man is riding a motorcycle.', 'label': 0, 'explanation_1': 'A man with a helmet would be riding a motorcycle.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'A shirtless man with sunglasses and a helmet is riding a motorcycle next to traffic cones.', 'hypothesis': 'A man is riding a motorcycle.', 'label': 1, 'explanation_1': 'The man was riding a motorcycle, this is correct.', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'A little boy with a scooter riding down a sidewalk.', 'hypothesis': 'The boy is outside.', 'label': 0, 'explanation_1': 'sidewalk implies outside', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'A little boy with a scooter riding down a sidewalk.', 'hypothesis': 'The boy is outside.', 'label': 2, 'explanation_1': 'the boy and a little boy are playing together', 'explanation_2': '', 'explanation_3': ''}
*

*****
{'premise': 'a brown and white dog jumping over a red, yellow and white pole', 'hypothesis': 'a dog jumps over a pole', 'label': 0, 'explanation_1': 'Jumps is a rephrase of jumping.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'a brown and white dog jumping over a red, yellow and white pole', 'hypothesis': 'a dog jumps over a pole', 'label': 0, 'explanation_1': 'dog jumping over a pole is described in both sentences', 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'Some of the hair in his beard are turning grey.', 'hypothesis': 'The man is clean shaven.', 'label': 2, 'explanation_1': 'Some one cannot have hair and they are still clean shaven.', 'explanation_2': '', 'explanation_3': ''}
{'premise': 'Some of the hair in his beard are turning grey.', 'hypothesis': 'The man is clean shaven.', 'label': 2, 'explanation_1': "A beard means he isn't clean shaven.", 'explanation_2': '', 'explanation_3': ''}
*****
*****
{'premise': 'A red desk chair has been ro

In [None]:
for da in dataset['validation']:
    hyp=da['hypothesis']
    pre=da['premise']
    key_str = f'hypothesis: {hyp}; premise:{pre}'
    if key_str not in all_data_dict.keys():
        all_data_dict[key_str] = da
    else:
        print("*****")
        print(da)
        print(all_data_dict[key_str])
        print("*****")

In [None]:
for da in dataset['test']:
    hyp=da['hypothesis']
    pre=da['premise']
    key_str = f'hypothesis: {hyp}; premise:{pre}'
    if key_str not in all_data_dict.keys():
        all_data_dict[key_str] = da
    else:
        print("*****")
        print(da)
        print(all_data_dict[key_str])
        print("*****")

In [None]:
len(all_data_dict), len(dataset['train']) + len(dataset['validation'])+ len(dataset['test'])


In [None]:
#rationale generation labeled data construction 115
import pandas as pd
scr_csqa_labeled_path="/cognitive_comp/huangyongfeng/evaluate_LM_with_rationalization/few_shot_explanations/data/handwritten_snli_examples.csv"
scr_csqa_label_df=pd.read_csv(scr_csqa_labeled_path)
scr_csqa_label_data=datasets.load_dataset('csv', data_files=scr_csqa_labeled_path)
scr_csqa_label_data['train'][0].keys()

refine_train_data=[]
for da in scr_csqa_label_data['train']:
    hyp=da['hypothesis']
    pre=da['premise']
    key_str = f'hypothesis: {hyp}; premise:{pre}'
    if key_str in all_data_dict.keys():
        new_da=all_data_dict[key_str]
        new_da['our_explanation'] = da['our_explanation']
        refine_train_data.append(new_da)
        print(new_da)
        
    
# for kk, (ex,da) in enumerate(zip(scr_csqa_label_data, data_splits['train'])):
# #     print(da)
#     data_splits['train'][kk]['our_explanation']=ex
#     #print(type(data_splits['train'][kk]),data_splits['train'][kk].keys())
#     da['our_explanation']=ex
#     refine_train_data.append(da)

# len(data_splits['train'])


In [None]:
#rationale unlabeled data construction 991
import pdb
scr_csqa_unlabeled_test_file="/cognitive_comp/huangyongfeng/evaluate_LM_with_rationalization/few_shot_explanations/data/acceptability_annotations/snli_train.csv"
fse_csqa_dev_dataset = datasets.load_dataset('csv', data_files=scr_csqa_unlabeled_test_file)
scr_csqa_unlabeled_test_df=pd.read_csv(scr_csqa_unlabeled_test_file)
fse_csqa_dev_data_dict={}
for kk, da in enumerate(fse_csqa_dev_dataset['train']):
    #pdb.set_trace()
    id_=da['Input.id']
    if da['Answer.acceptable']:
        answer_accept=set(da['Answer.acceptable'].split('|'))
    else:
        answer_accept=set()
    explanation_list=[da['Input.explanation_1'], 
                      da['Input.explanation_2'],
                      da['Input.explanation_3'],
                      da['Input.explanation_4'],
                      da['Input.explanation_5']]
    if id_ not in fse_csqa_dev_data_dict.keys():
        fse_csqa_dev_data_dict[id_]={"index":kk,"id":id_,
                                     "premise":da["Input.premise"],
                                     "hypothesis":da["Input.hypothesis"],
                                     'answer':da['Input.gold_label'],
                                     "accept_set_list":[answer_accept],
                                     "explanation_list":explanation_list}
        #[[kk, id_, answer_accept, explanation_list]]
    else:
        fse_csqa_dev_data_dict[id_]["accept_set_list"].append(answer_accept)
    if len(fse_csqa_dev_data_dict[id_]["accept_set_list"])==3:
        fse_csqa_dev_data_dict[id_]["common_expl_list"]=[]
        common_accept_expl_sample=set.intersection(fse_csqa_dev_data_dict[id_]["accept_set_list"][0], 
                                                   fse_csqa_dev_data_dict[id_]["accept_set_list"][1], 
                                                   fse_csqa_dev_data_dict[id_]["accept_set_list"][2])
        for idx in list(common_accept_expl_sample):
            idx=int(idx)-1
            fse_csqa_dev_data_dict[id_]["common_expl_list"].append(fse_csqa_dev_data_dict[id_]["explanation_list"][idx])
            
# fse_csqa_dev_data_dict['snli_train_570']
refine_dev_data=[]
for kk, da in fse_csqa_dev_data_dict.items():
    hyp=da['hypothesis']
    pre=da['premise']
    key_str = f'hypothesis: {hyp}; premise:{pre}'
    if key_str in all_data_dict.keys() and da["common_expl_list"]:
        new_da=all_data_dict[key_str]
        our_expl = ""
        for expl in da["common_expl_list"]:
            if len(expl) > len(our_expl):
                our_expl=expl
        new_da['our_explanation'] = our_expl#da['our_explanation']
        new_da['common_expl_list'] = da["common_expl_list"]
        refine_dev_data.append(new_da)
        #print(len(da["common_expl_list"]))
len(refine_dev_data)
    

In [None]:
#rationale unlabeled data construction 991
import pdb
scr_csqa_unlabeled_test_file="/cognitive_comp/huangyongfeng/evaluate_LM_with_rationalization/few_shot_explanations/data/acceptability_annotations/snli_test.csv"
fse_csqa_dev_dataset = datasets.load_dataset('csv', data_files=scr_csqa_unlabeled_test_file)
scr_csqa_unlabeled_test_df=pd.read_csv(scr_csqa_unlabeled_test_file)
fse_csqa_dev_data_dict={}
for kk, da in enumerate(fse_csqa_dev_dataset['train']):
    #pdb.set_trace()
    id_=da['Input.id']
    if da['Answer.acceptable']:
        answer_accept=set(da['Answer.acceptable'].split('|'))
    else:
        answer_accept=set()
    explanation_list=[da['Input.explanation_1'], 
                      da['Input.explanation_2'],
                      da['Input.explanation_3'],
                      da['Input.explanation_4'],
                      da['Input.explanation_5']]
    if id_ not in fse_csqa_dev_data_dict.keys():
        fse_csqa_dev_data_dict[id_]={"index":kk,"id":id_,
                                     "premise":da["Input.premise"],
                                     "hypothesis":da["Input.hypothesis"],
                                     'answer':da['Input.gold_label'],
                                     "accept_set_list":[answer_accept],
                                     "explanation_list":explanation_list}
        #[[kk, id_, answer_accept, explanation_list]]
    else:
        fse_csqa_dev_data_dict[id_]["accept_set_list"].append(answer_accept)
    if len(fse_csqa_dev_data_dict[id_]["accept_set_list"])==3:
        fse_csqa_dev_data_dict[id_]["common_expl_list"]=[]
        common_accept_expl_sample=set.intersection(fse_csqa_dev_data_dict[id_]["accept_set_list"][0], 
                                                   fse_csqa_dev_data_dict[id_]["accept_set_list"][1], 
                                                   fse_csqa_dev_data_dict[id_]["accept_set_list"][2])
        for idx in list(common_accept_expl_sample):
            idx=int(idx)-1
            fse_csqa_dev_data_dict[id_]["common_expl_list"].append(fse_csqa_dev_data_dict[id_]["explanation_list"][idx])
            
# fse_csqa_dev_data_dict['snli_train_570']
#refine_test_data=[]
for kk, da in fse_csqa_dev_data_dict.items():
    hyp=da['hypothesis']
    pre=da['premise']
    key_str = f'hypothesis: {hyp}; premise:{pre}'
    if key_str in all_data_dict.keys() and da["common_expl_list"]:
        new_da=all_data_dict[key_str]
        our_expl = ""
        for expl in da["common_expl_list"]:
            if len(expl) > len(our_expl):
                our_expl=expl
        new_da['our_explanation'] = our_expl#da['our_explanation']
        new_da['common_expl_list'] = da["common_expl_list"]
        refine_dev_data.append(new_da)
        #print(len(da["common_expl_list"]))
len(refine_dev_data)
    

In [None]:
#rationale unlabeled data construction 991
import pdb
scr_csqa_unlabeled_test_file="/cognitive_comp/huangyongfeng/evaluate_LM_with_rationalization/few_shot_explanations/data/acceptability_annotations/snli_test.csv"
fse_csqa_dev_dataset = datasets.load_dataset('csv', data_files=scr_csqa_unlabeled_test_file)
scr_csqa_unlabeled_test_df=pd.read_csv(scr_csqa_unlabeled_test_file)
fse_csqa_dev_data_dict={}
for kk, da in enumerate(fse_csqa_dev_dataset['train']):
    #pdb.set_trace()
    id_=da['Input.id']
    if da['Answer.acceptable']:
        answer_accept=set(da['Answer.acceptable'].split('|'))
    else:
        answer_accept=set()
    explanation_list=[da['Input.explanation_1'], 
                      da['Input.explanation_2'],
                      da['Input.explanation_3'],
                      da['Input.explanation_4'],
                      da['Input.explanation_5']]
    if id_ not in fse_csqa_dev_data_dict.keys():
        fse_csqa_dev_data_dict[id_]={"index":kk,"id":id_,
                                     "premise":da["Input.premise"],
                                     "hypothesis":da["Input.hypothesis"],
                                     'answer':da['Input.gold_label'],
                                     "accept_set_list":[answer_accept],
                                     "explanation_list":explanation_list}
        #[[kk, id_, answer_accept, explanation_list]]
    else:
        fse_csqa_dev_data_dict[id_]["accept_set_list"].append(answer_accept)
    if len(fse_csqa_dev_data_dict[id_]["accept_set_list"])==3:
        fse_csqa_dev_data_dict[id_]["common_expl_list"]=[]
        common_accept_expl_sample=set.intersection(fse_csqa_dev_data_dict[id_]["accept_set_list"][0], 
                                                   fse_csqa_dev_data_dict[id_]["accept_set_list"][1], 
                                                   fse_csqa_dev_data_dict[id_]["accept_set_list"][2])
        for idx in list(common_accept_expl_sample):
            idx=int(idx)-1
            fse_csqa_dev_data_dict[id_]["common_expl_list"].append(fse_csqa_dev_data_dict[id_]["explanation_list"][idx])
            
# fse_csqa_dev_data_dict['snli_train_570']
refine_test_data=[]
for kk, da in fse_csqa_dev_data_dict.items():
    hyp=da['hypothesis']
    pre=da['premise']
    key_str = f'hypothesis: {hyp}; premise:{pre}'
    if key_str in all_data_dict.keys() and da["common_expl_list"]:
        new_da=all_data_dict[key_str]
        our_expl = ""
        for expl in da["common_expl_list"]:
            if len(expl) > len(our_expl):
                our_expl=expl
        new_da['our_explanation'] = our_expl#da['our_explanation']
        new_da['common_expl_list'] = da["common_expl_list"]
        refine_test_data.append(new_da)
        #print(len(da["common_expl_list"]))
len(refine_test_data)
    

In [None]:
refine_train_data

In [None]:
our_data_splits={}
our_data_splits['train']=refine_train_data
our_data_splits['dev']=refine_dev_data
our_data_splits['test']=refine_test_data
refine_dev_data[0].keys()

# shots number

In [None]:
def list2dict(refine_data):
    refine_data_dict={}
    for key in refine_data[0].keys():
        refine_data_dict[key]=[x[key] for x in refine_data]
    return refine_data_dict

refine_train_data_dict = list2dict(refine_train_data)
our_data_splits['train'] = datasets.Dataset.from_dict(refine_train_data_dict).shuffle().select(range(96))

refine_dev_data_dict = list2dict(refine_dev_data)
our_data_splits['dev'] = datasets.Dataset.from_dict(refine_dev_data_dict)

refine_test_data_dict = list2dict(refine_test_data)
our_data_splits['test'] = datasets.Dataset.from_dict(refine_test_data_dict)



In [None]:
import datasets
class SequenceCollator:
    def __init__(self, pad_token):
        # self.pad_token_mapping = {
        #     "lm_labels": -100,
        #     "attention_mask": 0,
        #     "decoder_attention_mask": 0,
        #     "input_ids": pad_token,
        # }
        # self.columns = [
        #     "input_ids",
        #     "attention_mask",
        #     "lm_labels",
        #     "decoder_attention_mask",
        # ]
        self.pad_token_mapping = {
            "labels": -100,
            "attention_mask": 0,
            "decoder_attention_mask": 0,
            "input_ids": pad_token,
        }
        self.columns = [
            "input_ids",
            "attention_mask",
            "labels",
            "decoder_attention_mask",
        ]

    def collate_batch(self, examples):

        # batch inputs for training
        batch = {}
        for key in examples[0].keys():
            if key in self.columns:
                tmp_list = []
                for item in examples:
                    tmp_list.append(item[key])

                # pad lists to max length
                if isinstance(tmp_list[0], list):
                    max_length = max(map(len, tmp_list))
                    tmp_list = [
                        el + [self.pad_token_mapping[key]] * (max_length - len(el))
                        for el in tmp_list
                    ]

                batch[key] = torch.tensor(tmp_list, dtype=torch.long)
        return batch
    
    def __call__(self, examples: List[Dict[str, InputDataClass]]) -> Dict[str, torch.Tensor]:
        # re-format inputs for training
        batch = {}
        for key in examples[0].keys():
            if key in self.columns:
                tmp_list = []
                for item in examples:
                    tmp_list.append(item[key])

                # pad lists to max length
                if isinstance(tmp_list[0], list):
                    max_length = max(map(len, tmp_list))
                    tmp_list = [
                        el + [self.pad_token_mapping[key]] * (max_length - len(el))
                        for el in tmp_list
                    ]

                batch[key] = torch.tensor(tmp_list, dtype=torch.long)
        return batch
# dataset = datasets.load_dataset(data_args.task_name, data_args.version_name)

In [None]:
# seq_collector = SequenceCollator(0)
# train_ds = seq_collector.__call__(dataset['train'])
# train_ds
# dataset['train'][0].keys()
for split in ['train','dev','test']:
    our_data_splits[split] = our_data_splits[split].map(
            lambda x: format_instance(
                x,
                tokenizer,
                data_args.explanation_sep,
                datasource=data_args.task_name,
                io_format=data_args.io_format
            ),
            batched=False,
            load_from_cache_file=False,
        )

In [None]:
from tqdm import tqdm
# fse_csqa_dev_data_dict
# fse_csqa_train_data_dict


for da in tqdm(our_data_splits['train'], total=len(our_data_splits['train'])):
    print("*******")
    print(da['our_explanation'])
    #print(da['common_expl_list'])
    print(tokenizer.decode(da['input_ids']))
    print(tokenizer.decode(da['labels']))
    

    
    

In [None]:
# import wandb
# training_args.run_name=""
training_args.logging_steps=3
training_args.save_steps=10
training_args.evaluation_strategy="epoch"
training_args.num_train_epochs=15
training_args.do_eval=True

In [None]:
training_args

In [None]:
training_args.per_device_eval_batch_size=8
training_args.per_device_train_batch_size=8
training_args

In [None]:
class SequenceCollator:
    def __init__(self, model, pad_token):
        self.model = model
        self.pad_token_mapping = {
            "labels": -100,
            "attention_mask": 0,
            "decoder_attention_mask": 0,
            "input_ids": pad_token,
        }

        self.columns = [
            "input_ids",
            "attention_mask",
            "labels",
            "decoder_attention_mask",
        ]

    def __call__(self, examples: List[Dict[str, InputDataClass]]) -> Dict[str, torch.Tensor]:
        # re-format inputs for training
        batch = {}
        for key in examples[0].keys():
            if key in self.columns:
                tmp_list = []
                for item in examples:
                    tmp_list.append(item[key])

                # pad lists to max length
                if isinstance(tmp_list[0], list):
                    max_length = max(map(len, tmp_list))
                    tmp_list = [
                        el + [self.pad_token_mapping[key]] * (max_length - len(el))
                        for el in tmp_list
                    ]

                batch[key] = torch.tensor(tmp_list, dtype=torch.long)
        return batch

In [None]:
# os.environ["WANDB_DISABLED"] = "True"
if data_args.generations_filepath is None:
    callbacks = [TensorBoardCallback()]
    if data_args.early_stopping_patience > 0:
        callbacks.append(EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience))
        training_args.load_best_model_at_end = True
    else:
        training_args.load_best_model_at_end = False  # use the last model state
    training_args.metric_for_best_model = 'eval_loss'
    training_args.greater_is_better = False
    if training_args.eval_steps is None:
        training_args.evaluation_strategy = EvaluationStrategy.EPOCH
    else:
        training_args.evaluation_strategy = EvaluationStrategy.STEPS

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=our_data_splits['train'],
        eval_dataset=our_data_splits['dev'],
        data_collator=SequenceCollator(
            model=model_class, pad_token=tokenizer.pad_token_id
        ),
        callbacks=callbacks,
    )

# Training. Don't train if it is use_gpt3
if training_args.do_train and not model_args.use_gpt3:
    start_time = time.time()
    trainer.train()
    train_time = time.time() - start_time
    model = trainer.model
    #wandb.finish()
else:
    start_time = time.time()
    train_time = time.time() - start_time

# Rationale Discriminator

## Construction of generated rationale and accepted rationale pairs


In [None]:
# dict_keys(['premise', 'hypothesis', 'label',
#            'explanation_1', 'explanation_2', 'explanation_3', 
#            'our_explanation', 'common_expl_list'])

In [None]:
from tqdm import tqdm
# fse_csqa_dev_data_dict
# fse_csqa_train_data_dict
rationale_pair_dev_data = []
cc=0
good_model = model

for da in tqdm(our_data_splits['dev'], total=len(our_data_splits['dev'])):
    print("*******")
    print("hypothesis: {}".format(da['hypothesis']))
    print("premise: {}".format(da['premise']))
    print("answer: {}".format(da['label']))
    
    #print("our_explanation: {}".format(da['our_explanation']))
    #id_ = da['id']
#     if id_ in fse_csqa_dev_data_dict.keys():
#         common_expl_list = fse_csqa_dev_data_dict[id_]['common_expl_list']
#     else:
#         common_expl_list = fse_csqa_train_data_dict[id_]['common_expl_list']
#     da["common_expl_list"] = common_expl_list
    print("common expl list: {}".format(da["common_expl_list"]))
    inp_ids = torch.tensor(da["input_ids"], device=model.device).reshape(1, -1)
    out = good_model.generate(
                    inp_ids,
                    max_length=100,
                    pad_token_id=tokenizer.pad_token_id,
                    eos_token_id=tokenizer.eos_token_id,
                )
    skip_special_tokens = False if "infilling" in data_args.io_format else True
    words = tokenizer.decode(out[0].tolist(), skip_special_tokens=skip_special_tokens)
    print("generated explanation: {}".format(words))
    da["generated_explanation"] = words
    print("########")
    rationale_pair_dev_data.append(da)
    cc += 1
#     if cc > 200:
#         break
    #pdb.set_trace()

In [None]:
import json
rationale_pair_save_path = os.path.join("./results", "96shots_esnli_t5_base_authorwritten_rationales_generator_test_rationale_pair.json")
with open(rationale_pair_save_path, 'w') as f:
    json.dump(rationale_pair_dev_data, f)