## Causal Inference

In [1]:
import os
import json
from typing import Dict, List, Union

import pandas as pd
import numpy as np
from scipy.special import softmax
from scipy.special import expit
# import torch

from cma import report_CMA
import fuse, causal_utils
from kl_general import DEFAULT_CONFIG as ESTIMATE_C_DEFAULT_CONFIG

import sys
sys.path.append("../")
from my_package.models.traditional.classifier import Classifier
from my_package.utils.handcrafted_features.counter import count_negations
from my_package.utils.handcrafted_features.overlap import get_lexical_overlap, get_entities_overlap

## Configs

In [2]:
DATASET = "fever" # fever or qqp
IS_CORRECTION = False

IS_GENERATE_NEW_BIAS_PRED = False

_MODEL_PATH = {
    # outputs_fever_bert_base_1 outputs_fever_weighted_bert_1 outputs_fever_poe_bert_1
    "fever": "../results/outputs_fever_weighted_bert_1", 
    # outputs_qqp_bert_base_1 outputs_qqp_weighted_bert_1 outputs_qqp_poe_bert_1
    "qqp": "../results/outputs_qqp_bert_base_1"
}
RESULT_PATH = "../results/"

In [3]:
_BIAS_CLASS = {
    "fever": "REFUTES",
    "qqp": "1"
}
_POSSIBLE_LABELS = {
    "fever": ("SUPPORTS", "NOT ENOUGH INFO", "REFUTES"),
    "qqp": ("0", "1")
}


_DATA_PATH = {
    "fever": "../data/fact_verification",
    "qqp": "../data/paraphrase_identification"
}


_SENT1_KEYS = {
    "fever": ("claim", "claim", "claim"),
    "qqp": ("sentence1", "sentence1")
}
_SENT2_KEYS = {
    "fever": ("evidence", "evidence_sentence", "evidence"),
    "qqp": ("sentence2", "sentence2")
}
_LABEL_KEYS = {
    "fever": ("gold_label", "label", "label"),
    "qqp": ("is_duplicate", "is_duplicate")
}
_TEST_FILES = {
    "fever": (
        "fever.dev.jsonl",
        "fever_symmetric_v0.1.test.jsonl",
        "fever_symmetric_v0.2.test.jsonl",
    ),
    "qqp": (
        "qqp.dev.jsonl",
        "paws.dev_and_test.jsonl"
    )
}
_TEST_SETS = {
    "fever": (
        "fever_dev",
        "fever_sym1",
        "fever_sym2",
    ),
    "qqp": (
        "qqp_dev",
        "qqp_paws"
    )
}

_BIAS_MODEL_PATH = {
    "fever": "../results/fever/bias_model",
    "qqp": "../results/qqp/bias_model"
}

_BIAS_VAL_PRED_FILE = {
    "fever": "weighted_fever.val.jsonl",
    "qqp": "weighted_qqp.val.jsonl",
}
_MODEL_VAL_PRED_FILE = {
    "fever": "raw_fever.val.jsonl",
    "qqp": "raw_qqp.val.jsonl",
}

MODEL_PROB_KEY = "probs"
BIAS_PROB_KEY = "bias_prob"

# Fusion method
FUSION = fuse.sum_fuse

In [4]:
BIAS_CLASS = _BIAS_CLASS[DATASET]
POSSIBLE_LABELS = _POSSIBLE_LABELS[DATASET]
ESTIMATE_C_DEFAULT_CONFIG["N_LABELS"] = len(POSSIBLE_LABELS)

ROOT_DATA_PATH = "../data"
DATA_PATH = _DATA_PATH[DATASET]
TEST_FILES = _TEST_FILES[DATASET]
TEST_SETS = _TEST_SETS[DATASET]
SENT1_KEYS = _SENT1_KEYS[DATASET]
SENT2_KEYS = _SENT2_KEYS[DATASET]
LABEL_KEYS = _LABEL_KEYS[DATASET]
WEIGHT_KEY = "sample_weight"

BIAS_MODEL_PATH = _BIAS_MODEL_PATH[DATASET] 
MODEL_PATH = _MODEL_PATH[DATASET]


BIAS_VAL_PRED_FILE = _BIAS_VAL_PRED_FILE[DATASET]
MODEL_VAL_PRED_FILE = _MODEL_VAL_PRED_FILE[DATASET]

## Bias model instantiation

In [5]:
def get_fever_classifier():
    POSSIBLE_LABELS = _POSSIBLE_LABELS["fever"]
    feature_extractors = [
        lambda s1, s2: count_negations(s1),
        lambda s1, s2: count_negations(s2),
        get_lexical_overlap,
        get_entities_overlap
    ]
    classifier = Classifier(
        possible_labels=POSSIBLE_LABELS,
        feature_extractors=feature_extractors
    )
    classifier.load(BIAS_MODEL_PATH)
    return classifier

def get_qqp_classifier():
    POSSIBLE_LABELS = _POSSIBLE_LABELS["qqp"]
    feature_extractors = [
        get_lexical_overlap,
        get_entities_overlap
    ]
    classifier = Classifier(
        possible_labels=POSSIBLE_LABELS,
        feature_extractors=feature_extractors
    )
    classifier.load(BIAS_MODEL_PATH)
    return classifier

MAP_GET_CLS = {
    "fever": get_fever_classifier,
    "qqp": get_qqp_classifier
}

In [6]:
if DATASET == "fever":
    classifier = get_fever_classifier()
elif DATASET == "qqp":
    classifier = get_qqp_classifier()
else:
    raise NotImplementedError("No classifier for %s does"%DATASET)

## Compute average input for bias model

In [7]:
def _read_jsonl(file_path: str) -> List[Dict[str, Union[str, int]]]:
    output = []
    f = open(file_path, 'r')
    line = f.readline()
    while line:
        doc = json.loads(line)
        output.append(doc)
        line = f.readline()
    f.close()
    return output

In [8]:
val_data = _read_jsonl(os.path.join(DATA_PATH, "%s.val.jsonl"%DATASET))
val_data = [
    (x[SENT1_KEYS[0]], x[SENT2_KEYS[0]])
    for x in val_data
]
print(len(val_data))
val_data[:2]

5000


[('Pate is a type of fish .',
  'Pate , pâte , or paste , a term for the interior body -LRB- non-rind portion -RRB- of cheese , described by its texture , density , and color'),
 ('Floyd Mayweather Jr. was born in New York in February 24 , 1977 .',
  'Tabure Thabo Bogopa Junior -LRB- born March 12 , 1987 -RRB- , who performs under the stage name JR , is a South African rapper .')]

In [9]:
feature_vecs = []
for data in val_data:
    feature_vecs.append(
        classifier.normalizer.transform( # 2) Normalize vector
            [classifier._transform(data),] # 1) Text to raw vector
        )[0]
    )
feature_vecs = np.array(feature_vecs)
print("Shape feature vecs: ", feature_vecs.shape)

input_a0 = np.mean(feature_vecs, axis=0).tolist()
print("input_a0 len: ", len(input_a0))
print(input_a0)

Shape feature vecs:  (5000, 604)
input_a0 len:  604
[0.0, 0.0, 0.007533333333333333, 0.0, 0.0068, 0.0, 0.004, 0.008666666666666673, 0.0, 0.0022, 0.0, 0.0, 0.00295, 0.0, 0.0, 0.008066666666666668, 0.0, 0.0, 0.0, 0.0011333333333333332, 0.0038, 0.005066666666666663, 0.0011, 0.0063, 0.0, 0.0, 0.011733333333333354, 0.0, 0.002, 0.0043, 0.003759999999999995, 0.0, 0.0068, 0.0032000000000000023, 0.0, 0.0030000000000000014, 0.007159999999999995, 0.0, 0.006133333333333328, 0.0026, 0.010800000000000018, 0.0, 0.0051, 0.00775, 0.0, 0.0, 0.0036, 0.0022000000000000006, 0.0024666666666666674, 0.0009, 0.00691999999999999, 0.00965, 0.0, 0.0048, 0.002666666666666668, 0.0044, 0.0014, 0.0019, 0.002, 0.0024, 0.0025333333333333345, 0.0046, 0.0032, 0.0027, 0.0082, 0.0024000000000000002, 0.00845, 0.0, 0.005133333333333328, 0.0033333333333333357, 0.0, 0.0012, 0.0008, 0.0018, 0.0005, 0.0, 0.0009, 0.00435, 0.0, 0.0, 0.0009, 0.0012, 0.0014800000000000008, 0.0005, 0.01599999999999999, 0.0011333333333333334, 0.0007, 

### Test prediction

In [10]:
def model_pred(
    _input: List[float]
) -> List[float]:
    return classifier.model.predict_proba([_input, ]).tolist()

In [11]:
model_pred(input_a0)

[[0.40402256167499867, 0.46030682167588166, 0.1356706166491197]]

### Predict bias prob in multiple test sets (This should be done only one time)

In [12]:
if DATASET == "fever":
    def get_weight(prob_score_ground_truth_class: float) -> float:
        return 1/prob_score_ground_truth_class
elif DATASET == 'qqp':
    def get_weight(prob_score_bias_class: float, ground_truth_label: str, bias_label: str = BIAS_CLASS) -> float:
        if ground_truth_label == bias_label:
            return 1/prob_score_bias_class
        return 1/(1-prob_score_bias_class)
else:
    raise NotImplementedError("No classifier for %s does"%DATASET)

In [13]:
MAX_SAMPLE = -1


if DATASET == "fever":
    def inference_prob_to_index(x: List[Dict[str, float]]) -> List[float]:
        return [
            x["SUPPORTS"],
            x["NOT ENOUGH INFO"],
            x["REFUTES"]
        ]
elif DATASET == "qqp":
    def inference_prob_to_index(x: List[Dict[str, float]]) -> List[float]:
        return [
            x["0"],
            x["1"]
        ]

def write_weight_to_file(
    DATA_FILE: str,
    OUTPUT_DATA_FILE: str,
    _classifier
) -> None:
    f_output = open(OUTPUT_DATA_FILE, 'w')

    N_SAMPLE = 0

    with open(DATA_FILE, 'r') as fh:
        line = fh.readline()
        while line:
            datapoint = json.loads(line)
            ground_truth_label = datapoint[LABEL_KEY]
            x = [[datapoint[SENT1_KEY], datapoint[SENT2_KEY]]]

            probs = _classifier.inference(x)[0]
            prob = probs[ground_truth_label]
            if DATASET == 'fever':
                weight = get_weight(prob_score_ground_truth_class=prob)
            elif DATASET == 'qqp':
                weight = get_weight(
                    prob_score_bias_class=prob,
                    ground_truth_label=str(ground_truth_label),
                    bias_label=BIAS_CLASS
                )
                
            if datapoint.get("weight", None) != None:
                del datapoint["weight"] # only for fever
                
            f_output.write("%s\n"%json.dumps({
                **datapoint,
                WEIGHT_KEY: weight,
                "bias_probs": inference_prob_to_index(probs),
                "bias_prob": prob
            }))

            N_SAMPLE += 1
            if MAX_SAMPLE != -1 and N_SAMPLE == MAX_SAMPLE:
                break
            line = fh.readline()

    f_output.close()

In [14]:
if IS_GENERATE_NEW_BIAS_PRED:
    for SENT1_KEY, SENT2_KEY, LABEL_KEY, test_file in zip(SENT1_KEYS, SENT2_KEYS, LABEL_KEYS, TEST_FILES):
        print("test_file: ", test_file)
        _test_file = os.path.join(DATA_PATH, test_file)
        _pred_test_file = os.path.join(DATA_PATH, "weighted_%s"%test_file)
        write_weight_to_file(
            DATA_FILE = _test_file,
            OUTPUT_DATA_FILE = _pred_test_file,
            _classifier = classifier
        )

## Causal Mediation Analysis

In [15]:
# seed_paths = [
#     '../results/outputs_fever_bert_base_1_seed23370/utama_distil/utama_distill_model_0/',
#     '../results/outputs_fever_bert_base_1_seed33370/utama_distil/utama_distill_model_0/',
#     '../results/outputs_fever_bert_base_1_seed43370/utama_distil/utama_distill_model_0/',
#     '../results/outputs_fever_bert_base_1_seed53370/utama_distil/utama_distill_model_0/',
#     '../results/outputs_fever_bert_base_1_seed63370/utama_distil/utama_distill_model_0/'
# ]
# seed_paths = [
#     '../results/outputs_qqp_bert_base_1_seed23370/utama_distil/utama_distill_model_0/',
#     '../results/outputs_qqp_bert_base_1_seed33370/utama_distil/utama_distill_model_0/',
#     '../results/outputs_qqp_bert_base_1_seed43370/utama_distil/utama_distill_model_0/',
#     '../results/outputs_qqp_bert_base_1_seed53370/utama_distil/utama_distill_model_0/',
#     '../results/outputs_qqp_bert_base_1_seed63370/utama_distil/utama_distill_model_0/'
# ]
# print(seed_paths)

seed_path_prefix = "%s_seed"%MODEL_PATH

seed_paths = os.listdir(RESULT_PATH)
seed_paths = list(map(lambda x: os.path.join(RESULT_PATH, x), seed_paths))
seed_paths = list(filter(lambda x: x.startswith(seed_path_prefix), seed_paths))
seed_paths

['../results/outputs_fever_weighted_bert_1_seed53370',
 '../results/outputs_fever_weighted_bert_1_seed23370',
 '../results/outputs_fever_weighted_bert_1_seed43370',
 '../results/outputs_fever_weighted_bert_1_seed63370',
 '../results/outputs_fever_weighted_bert_1_seed33370']

In [16]:
for test_set, label_key in zip(TEST_SETS, LABEL_KEYS):
    print("========= TEST_SET: %s ========="%test_set)
    report_CMA(
        model_path = "",
        task = "",
        seed_path = seed_paths,

        data_path = DATA_PATH,
        test_set = test_set, ################
        fusion = fuse.sum_fuse,
        input_a0 = input_a0,
        estimate_c_config = ESTIMATE_C_DEFAULT_CONFIG,

        correction = IS_CORRECTION,
        ground_truth_key = label_key,
        model_pred_method = model_pred,

        bias_val_pred_file = BIAS_VAL_PRED_FILE,
        model_val_pred_file = MODEL_VAL_PRED_FILE,

        TIE_ratio_threshold = 9999,
    )
    print("========= END ======== \n\n\n")
#     break

../data/fact_verification/weighted_fever.dev.jsonl
../results/outputs_fever_weighted_bert_1_seed53370/raw_fever.dev.jsonl
unique_labels:  ['SUPPORTS', 'REFUTES']
-0.0003515276196476634 0.1270167168126008
-0.006841940716921615 0.01776347220710397
0.007057305881407289 0.13390612185833034
../results/outputs_fever_weighted_bert_1_seed23370/raw_fever.dev.jsonl
unique_labels:  ['SUPPORTS', 'REFUTES']
0.0023786042653448887 0.13211662692410966
-0.006902428616275633 0.01797412233652055
0.010179391777479911 0.13922101951897461
../results/outputs_fever_weighted_bert_1_seed43370/raw_fever.dev.jsonl
unique_labels:  ['SUPPORTS', 'REFUTES']
0.002874526626728921 0.12832161385154495
-0.007003556214163973 0.0179047783130748
0.010575928236445762 0.13500754208386534
../results/outputs_fever_weighted_bert_1_seed63370/raw_fever.dev.jsonl
unique_labels:  ['SUPPORTS', 'REFUTES']
-0.0034978625547127376 0.1263491347585421
-0.006699419687941639 0.01818315522993369
0.003270830214270498 0.13265473790828533
../resu