### Evaluating CPM with CEBaB
Our Causal Proxy Model (CPM) is for providing concept-based explanation for a blackbox model. We use newly developed CEBaB benchmark for comparing CPM with other concept-based explanation methods. This notebook evaluates CPM with CEBaB benchmark under different settings.

More importantly, we introduce new baselines for CPM as well. Formally, we evaluate the blackbox model with interchange intervention evaluation (which will be introduced in details below).

In this notebook, we can evaluate the following models:
- CPM: `BERT-base-uncased`
- CPM: `RoBERTa-base`
- CPM: `GPT2`
- CPM: `LSTM+GloVe`
- CPM: `Control`

and we can evaluate with the following conditions:
- 2-class
- 3-class
- 5-class

and this script also support evaluates experiments with different control conditions:
- `control=True`
- `control=False`
- `control="pretrain"`
- `control="random"`
- `control="finetune"`
- `control="fewshots"` (supports BERT only)
- `control="fewshots-augment"` (supports BERT only)
- `control="layer"` (supports BERT only)

#### Imports and Libs

In [12]:
from libs import *
from modelings.modelings_bert import *
from modelings.modelings_roberta import *
from modelings.modelings_gpt2 import *
from modelings.modelings_lstm import *
"""
For evaluate, we use a single random seed, as
the models are trained with 5 different seeds
already.
"""
_ = random.seed(123)
_ = np.random.seed(123)
_ = torch.manual_seed(123)

#### Main evaluate script

In [13]:
"""
The following blocks will run CEBaB benchmark in
all the combinations of the following conditions.
"""
grid = {
    "seed": [42],
    "h_dim": [192], # 1, 4, 16, 32, 64, 128, 192
    "class_num": [5],
    "control": ["dev"],
    # True, False, pretrain, random, finetune, 
    # fewshots fewshots-augment fewshots-layer
    # fewshots-hdim
    # fewshots-augment-balance
    # fewshots-augment-large-epochs
    # approximate-large-epochs
    # fewshots-augment-large-epochs-no-probe
    "beta" : [1.0],
    "gemma" : [3.0],
    "cls_dropout" : [0.1],
    "enc_dropout" : [0.1],
    "model_arch" : ["bert-base-uncased"],
    "true_cfc" : [1755], 
    # only used for fewshots evaluations.
    # 5, 50, 200, 600, 1200, 1755
    "interchange_layer" : [12]
}

keys, values = zip(*grid.items())
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]

device = 'cuda:0'
batch_size = 32

In [18]:
class CausalProxyModelForBERT(Explainer, CausalExplainer):
    def __init__(
        self, 
        blackbox_model_path,
        cpm_model_path, 
        device, batch_size, 
        intervention_h_dim=1,
        min_iit_pair_examples=1,
        match_non_int_type=False,
        cache_dir="../../huggingface_cache",
    ):
        self.batch_size = batch_size
        self.device = device
        self.min_iit_pair_examples = min_iit_pair_examples
        self.match_non_int_type = match_non_int_type
        # blackbox model loading.
        self.blackbox_model = BertForNonlinearSequenceClassification.from_pretrained(
            blackbox_model_path,
            cache_dir=cache_dir
        )
        self.blackbox_model.to(device)
        
        # causal proxy model loading.
        cpm_config = AutoConfig.from_pretrained(
            cpm_model_path,
            cache_dir=cache_dir,
            use_auth_token=True if "CEBaB/" in cpm_model_path else False,
        )
        try:
            cpm_config.intervention_h_dim = cpm_config.intervention_h_dim
        except:
            cpm_config.intervention_h_dim = intervention_h_dim
        print(f"intervention_h_dim={cpm_config.intervention_h_dim}")
        cpm_model = IITBERTForSequenceClassification.from_pretrained(
            cpm_model_path,
            config=cpm_config,
            cache_dir=cache_dir
        )
        cpm_model.to(device)
        self.cpm_model = InterventionableIITTransformerForSequenceClassification(
            model=cpm_model
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        
    def preprocess_predict_proba(self, df):
        x = self.tokenizer(df['description'].to_list(), padding=True, truncation=True, return_tensors='pt')
        y = df['review_majority'].astype(int)

        return x, y
        
    def predict_proba(self, dataset):
        self.cpm_model.model.eval()

        x, y = self.preprocess_predict_proba(dataset)

        # get the predictions batch per batch
        probas = []
        for i in range(ceil(len(dataset) / self.batch_size)):
            x_batch = {k: v[i * self.batch_size:(i + 1) * self.batch_size].to(self.device) for k, v in x.items()}
            probas.append(torch.nn.functional.softmax(self.cpm_model.model(**x_batch).logits[0].cpu(), dim=-1).detach())

        probas = torch.concat(probas)
        probas = np.round(probas.numpy(), decimals=16)

        predictions = np.argmax(probas, axis=1)
        clf_report = classification_report(y.to_numpy(), predictions, output_dict=True)

        return probas, clf_report
        
    def fit(self, dataset, classifier_predictions, classifier, dev_dataset=None):
        # we don't need to train IIT here.
        pass
    
    def preprocess(self, pairs_dataset, dev_dataset):
        
        # configs
        min_iit_pair_examples = self.min_iit_pair_examples
        match_non_int_type = self.match_non_int_type
        
        query_dataset = get_iit_examples(dev_dataset)
        iit_pairs_dataset = []
        iit_id = 0
        for index, row in pairs_dataset.iterrows():
            query_description_base = row['description_base']
            query_int_type = row['intervention_type']
            query_non_int_type = {
                "ambiance", "food", "noise", "service"
            } - {query_int_type}
            query_int_aspect_base = row["intervention_aspect_base"]
            query_int_aspect_assignment = row['intervention_aspect_counterfactual']
            query_original_id = row["original_id_base"]
            matched_iit_examples = query_dataset[
                (query_dataset[f"{query_int_type}_aspect_majority"]==query_int_aspect_assignment)&
                (query_dataset["original_id"]!=query_original_id)
            ]
            if match_non_int_type:
                for _t in query_non_int_type:
                    matched_iit_examples = matched_iit_examples[
                        (matched_iit_examples[f"{_t}_aspect_majority"]==\
                         row[f"{_t}_aspect_majority_base"])
                    ]
            if len(set(matched_iit_examples["id"])) < min_iit_pair_examples:
                if match_non_int_type:
                    # simply avoid mapping the rest of the aspects.
                    matched_iit_examples = query_dataset[
                        (query_dataset[f"{query_int_type}_aspect_majority"]==query_int_aspect_assignment)&
                        (query_dataset["original_id"]!=query_original_id)
                    ]
                else:
                    assert False # we need to check the number!
            sampled_iit_example_ids = random.sample(
                set(matched_iit_examples["id"]), min_iit_pair_examples
            )
            for _id in sampled_iit_example_ids:
                description_iit = query_dataset[query_dataset["id"]==_id]["description"].iloc[0]
                iit_pairs_dataset += [[
                    iit_id,
                    query_int_type,
                    query_description_base, 
                    description_iit
                ]]
            iit_id += 1
        iit_pairs_dataset = pd.DataFrame(
            columns=[
                'iit_id',
                'intervention_type', 
                'description_base', 
                'description_iit'], 
            data=iit_pairs_dataset
        )
        
        base_x = self.tokenizer(
            iit_pairs_dataset['description_base'].to_list(), 
            padding=True, truncation=True, return_tensors='pt'
        )
        source_x = self.tokenizer(
            iit_pairs_dataset['description_iit'].to_list(), 
            padding=True, truncation=True, return_tensors='pt'
        )
        intervention_corr = []
        for _type in iit_pairs_dataset["intervention_type"].tolist():
            if _type == "ambiance":
                intervention_corr += [0]
            if _type == "food":
                intervention_corr += [1]
            if _type == "noise":
                intervention_corr += [2]
            if _type == "service":
                intervention_corr += [3]
        intervention_corr = torch.tensor(intervention_corr).long()
        return base_x, source_x, intervention_corr, iit_pairs_dataset
    
    def estimate_icace(self, pairs, df):
        CPM_iTEs = []
        # self.blackbox_model.eval()
        # self.cpm_model.model.eval()
        base_x, source_x, intervention_corr, iit_pairs_dataset = self.preprocess(
            pairs, df
        )
        with torch.no_grad():
            for i in tqdm(range(ceil(len(iit_pairs_dataset)/self.batch_size))):
                base_x_batch = {k:v[i*self.batch_size:(i+1)*self.batch_size].to(self.device) for k,v in base_x.items()} 
                source_x_batch = {k:v[i*self.batch_size:(i+1)*self.batch_size].to(self.device) for k,v in source_x.items()} 
                intervention_corr_batch = intervention_corr[i*self.batch_size:(i+1)*self.batch_size].to(self.device)
                
                base_outputs = torch.nn.functional.softmax(
                    self.blackbox_model(**base_x_batch).logits.cpu(), dim=-1
                ).detach()
                _, _, counterfactual_outputs = self.cpm_model.forward(
                    base=(base_x_batch['input_ids'], base_x_batch['attention_mask']),
                    source=(source_x_batch['input_ids'], source_x_batch['attention_mask']),
                    base_intervention_corr=intervention_corr_batch,
                    source_intervention_corr=intervention_corr_batch,
                )
                counterfactual_outputs = torch.nn.functional.softmax(
                    counterfactual_outputs["logits"][0].cpu(), dim=-1
                ).detach()
                CPM_iTE = counterfactual_outputs-base_outputs
                CPM_iTEs.append(CPM_iTE)
        CPM_iTEs = torch.concat(CPM_iTEs)
        CPM_iTEs = np.round(CPM_iTEs.numpy(), decimals=4)

        # only for iit explainer!
        iit_pairs_dataset["EiCaCE"] = list(CPM_iTEs)
        CPM_iTEs = list(iit_pairs_dataset.groupby(["iit_id"])["EiCaCE"].mean())
        
        return CPM_iTEs

In [21]:
results = {}
for i in range(len(permutations_dicts)):
    
    seed=permutations_dicts[i]["seed"]
    class_num=permutations_dicts[i]["class_num"]
    beta=permutations_dicts[i]["beta"]
    gemma=permutations_dicts[i]["gemma"]
    h_dim=permutations_dicts[i]["h_dim"]
    dataset_type = f'{class_num}-way'
    correction_epsilon=None
    cls_dropout=permutations_dicts[i]["cls_dropout"]
    enc_dropout=permutations_dicts[i]["enc_dropout"]
    control=permutations_dicts[i]["control"]
    model_arch=permutations_dicts[i]["model_arch"]
    true_cfc=permutations_dicts[i]["true_cfc"]
    interchange_layer=permutations_dicts[i]["interchange_layer"]
    
    if model_arch == "bert-base-uncased":
        model_path = "BERT-results" if control == False \
            else "BERT-control-results" if control == True \
            else f"BERT-{control}-results"
        model_module = BERTForCEBaB
        explainer_module = CausalProxyModelForBERT
    elif model_arch == "roberta-base":
        model_path = "RoBERTa-results" if control == False \
            else "RoBERTa-control-results" if control == True \
            else f"RoBERTa-{control}-results" 
        model_module = RoBERTaForCEBaB
        explainer_module = CausalProxyModelForRoBERTa
    elif model_arch == "gpt2":
        model_path = "gpt2-results" if control == False \
            else "gpt2-control-results" if control == True \
            else f"gpt2-{control}-results"
        model_module = GPT2ForCEBaB
        explainer_module = CausalProxyModelForGPT2
    elif model_arch == "lstm":
        model_path = "lstm-results" if control == False \
            else "lstm-control-results" if control == True \
            else f"lstm-{control}-results"
        model_module = LSTMForCEBaB
        explainer_module = CausalProxyModelForLSTM
        
    grid_conditions=(
        ("seed", seed),
        ("class_num", class_num),
        ("beta", beta),
        ("gemma", gemma),
        ("h_dim", h_dim),
        ("dataset_type", dataset_type),
        ("correction_epsilon", correction_epsilon),
        ("cls_dropout", cls_dropout),
        ("enc_dropout", enc_dropout),
        ("control", control),
        ("model_arch", model_arch),
        ("true_cfc", true_cfc),
        ("interchange_layer", interchange_layer)
    )
    print("Running for this setting: ", grid_conditions)

    blackbox_model_path = f'CEBaB/{model_arch}.CEBaB.sa.{class_num}-class.exclusive.seed_{seed}'
    if control == "finetune": # not for other control cases, e.g., random or pretrain
        cpm_model_path = blackbox_model_path
    else:
        split_name = "train"
        cpm_model_path = '../proxy_training_results/dev/'\
                         'cebab.alpha.1.0.beta.1.0.gemma.3.0'\
                         '.lr.8e-05.dim.192.hightype.bert-base-uncased'\
                         '.CEBaB.cls.dropout.0.1.enc.dropout.0.1.counter'\
                         '.type.true.k.19684.int.layer.10.seed_42/'
        if control == "dev":
            pass
    # load data from HF
    cebab = datasets.load_dataset(
        'CEBaB/CEBaB', use_auth_token=True,
        cache_dir="../train_cache/"
    )
    train, dev, test = preprocess_hf_dataset(
        cebab, one_example_per_world=True, 
        verbose=1, dataset_type=dataset_type
    )

    tf_model = model_module(
        blackbox_model_path, 
        device=device, 
        batch_size=batch_size
    )
    explanator = explainer_module(
        blackbox_model_path,
        cpm_model_path, 
        device=device, 
        batch_size=batch_size,
        intervention_h_dim=h_dim,
    )

    train_dataset = train.copy()
    dev_dataset = test.copy()

    result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \
    CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \
    ACaCE_per_aspect, performance_report = cebab_pipeline(
        tf_model, explanator, 
        train_dataset, dev_dataset, 
        dataset_type=dataset_type,
        correction_epsilon=correction_epsilon,
    )
    
    results[grid_conditions] = (
        result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \
        CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \
        ACaCE_per_aspect, performance_report
    )

Running for this setting:  (('seed', 42), ('class_num', 5), ('beta', 1.0), ('gemma', 3.0), ('h_dim', 192), ('dataset_type', '5-way'), ('correction_epsilon', None), ('cls_dropout', 0.1), ('enc_dropout', 0.1), ('control', 'dev'), ('model_arch', 'bert-base-uncased'), ('true_cfc', 1755), ('interchange_layer', 12))


Using custom data configuration CEBaB--CEBaB-0e2f7ed67c9d7e55
Reusing dataset parquet (../train_cache/CEBaB___parquet/CEBaB--CEBaB-0e2f7ed67c9d7e55/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)


  0%|          | 0/4 [00:00<?, ?it/s]

Dropping no majority reviews: 16.6382% of train dataset.
intervention_h_dim=192


100%|██████████| 124/124 [00:28<00:00,  4.42it/s]


#### Show your results

In [22]:
important_keys = [
    "seed", "h_dim", "class_num", 
    "control", "beta", "gemma", 
    "cls_dropout", "enc_dropout", 
    "model_arch", "true_cfc", 
    "interchange_layer"
]
values = []
for k, v in results.items():
    _values = []
    for ik in important_keys:
        _values.append(dict(k)[ik])
    _values.append(v[2]["ICaCE-L2"].iloc[0])
    _values.append(v[2]["ICaCE-cosine"].iloc[0])
    _values.append(v[2]["ICaCE-normdiff"].iloc[0])
    _values.append(v[-1].iloc[0][0])
    values.append(_values)
important_keys.extend(["ICaCE-L2", "ICaCE-cosine", "ICaCE-normdiff", "macro-f1"])
df = pd.DataFrame(values, columns=important_keys)
df.sort_values(by=['class_num'], ascending=True)

Unnamed: 0,seed,h_dim,class_num,control,beta,gemma,cls_dropout,enc_dropout,model_arch,true_cfc,interchange_layer,ICaCE-L2,ICaCE-cosine,ICaCE-normdiff,macro-f1
0,42,192,5,dev,1.0,3.0,0.1,0.1,bert-base-uncased,1755,12,0.4481,0.3582,0.2465,0.68596


#### Save your results somewhere and load again to tabularize your results altogether

In [28]:
output_name = input("Plase give an output file name: ")

output_directory = f'../proxy_training_results/{model_path}/'
output_filename = os.path.join(output_directory, f'{output_name}.pkl')
print("Writing to file: ", output_filename)
with open(output_filename, 'wb') as f:
    pickle.dump(results, f)

Plase give an output file name: results
Writing to file:  ../proxy_training_results/BERT-fewshots-hdim-results/results.pkl
