### Evaluating CPM with CEBaB
Our Causal Proxy Model (CPM) is for providing concept-based explanation for a blackbox model. We use newly developed CEBaB benchmark for comparing CPM with other concept-based explanation methods. This notebook evaluates CPM with CEBaB benchmark under different settings.

More importantly, we introduce new baselines for CPM as well. Formally, we evaluate the blackbox model with interchange intervention evaluation (which will be introduced in details below).

In this notebook, we will evaluate the following models:
- CPM: `BERT-base-uncased`
- CPM: `RoBERTa-base`
- CPM: `GPT2`
- CPM: `LSTM+GloVe`
- CPM: `Control`

and we will evaluate with the following conditions:
- 2-class
- 3-class
- 5-class

#### Imports and Libs

In [5]:
from libs import *
"""
For evaluate, we use a single random seed, as
the models are trained with 5 different seeds
already.
"""
_ = random.seed(123)
_ = np.random.seed(123)
_ = torch.manual_seed(123)

In [None]:
"""
TODO:
After this notebook is fully constructed, the following
section needs to be separated out to another file.
"""
class BERTForCEBaB(Model):
    def __init__(self, model_path, device='cpu', batch_size=64):
        self.device = device
        self.model_path = model_path
        self.tokenizer_path = model_path
        self.batch_size = batch_size

        self.model = BertForNonlinearSequenceClassification.from_pretrained(
            self.model_path,
            cache_dir="../../huggingface_cache"
        )
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

        self.model.to(device)

    def __str__(self):
        return self.model_path.split('/')[-1]

    def preprocess(self, df):
        x = self.tokenizer(df['description'].to_list(), padding=True, truncation=True, return_tensors='pt')
        y = df['review_majority'].astype(int)

        return x, y

    def fit(self, dataset):
        # assume model was already trained
        pass

    def predict_proba(self, dataset):
        self.model.eval()

        x, y = self.preprocess(dataset)

        # get the predictions batch per batch
        probas = []
        for i in range(ceil(len(dataset) / self.batch_size)):
            x_batch = {k: v[i * self.batch_size:(i + 1) * self.batch_size].to(self.device) for k, v in x.items()}
            probas.append(torch.nn.functional.softmax(self.model(**x_batch).logits.cpu(), dim=-1).detach())

        probas = torch.concat(probas)
        probas = np.round(probas.numpy(), decimals=16)

        predictions = np.argmax(probas, axis=1)
        clf_report = classification_report(y.to_numpy(), predictions, output_dict=True)

        return probas, clf_report

    def get_embeddings(self, sentences_list):
        x = self.tokenizer(sentences_list, padding=True, truncation=True, return_tensors='pt')
        embeddings = []
        for i in range(ceil(len(x['input_ids']) / self.batch_size)):
            x_batch = {k: v[i * self.batch_size:(i + 1) * self.batch_size].to(self.device) for k, v in x.items()}
            embeddings.append(self.model.base_model(**x_batch).pooler_output.detach().cpu().tolist())

        return embeddings

    def get_classification_head(self):
        return self.model.classifier
    
def get_iit_examples(df):
    """
    Given a dataframe in the new data scheme, return all intervention pairs.
    """
    # Drop label distribution and worker information.
    columns_to_keep = ['id', 'original_id', 'edit_id', 'is_original', 'edit_goal', 'edit_type', 'description', 'review_majority','food_aspect_majority', 'ambiance_aspect_majority', 'service_aspect_majority', 'noise_aspect_majority']
    columns_to_keep += [col for col in df.columns if 'prediction' in col]
    df = df[columns_to_keep]
    return df

class CausalProxyModelForBERT(Explainer):
    def __init__(
        self, 
        blackbox_model_path,
        cpm_model_path, 
        device, batch_size, 
        intervention_h_dim=1,
        min_iit_pair_examples=1,
        match_non_int_type=False,
        cache_dir="../../huggingface_cache",
    ):
        self.batch_size = batch_size
        self.device = device
        self.min_iit_pair_examples = min_iit_pair_examples
        self.match_non_int_type = match_non_int_type
        # blackbox model loading.
        self.blackbox_model = BertForNonlinearSequenceClassification.from_pretrained(
            blackbox_model_path,
            cache_dir=cache_dir
        )
        self.blackbox_model.to(device)
        
        # causal proxy model loading.
        cpm_config = AutoConfig.from_pretrained(
            cpm_model_path,
            cache_dir=cache_dir,
            use_auth_token=True if "CEBaB/" in cpm_model_path else False,
        )
        try:
            cpm_config.intervention_h_dim = cpm_config.intervention_h_dim
        except:
            cpm_config.intervention_h_dim = intervention_h_dim
        print(f"intervention_h_dim={cpm_config.intervention_h_dim}")
        cpm_model = IITBERTForSequenceClassification.from_pretrained(
            cpm_model_path,
            config=cpm_config,
            cache_dir=cache_dir
        )
        cpm_model.to(device)
        self.cpm_model = InterventionableIITTransformerForSequenceClassification(
            model=cpm_model
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        
    def fit(self, dataset, classifier_predictions, classifier, dev_dataset=None):
        # we don't need to train IIT here.
        pass
    
    def preprocess(self, pairs_dataset, dev_dataset):
        
        # configs
        min_iit_pair_examples = self.min_iit_pair_examples
        match_non_int_type = self.match_non_int_type
        
        query_dataset = get_iit_examples(dev_dataset)
        iit_pairs_dataset = []
        iit_id = 0
        for index, row in pairs_dataset.iterrows():
            query_description_base = row['description_base']
            query_int_type = row['intervention_type']
            query_non_int_type = {
                "ambiance", "food", "noise", "service"
            } - {query_int_type}
            query_int_aspect_base = row["intervention_aspect_base"]
            query_int_aspect_assignment = row['intervention_aspect_counterfactual']
            query_original_id = row["original_id_base"]
            matched_iit_examples = query_dataset[
                (query_dataset[f"{query_int_type}_aspect_majority"]==query_int_aspect_assignment)&
                (query_dataset["original_id"]!=query_original_id)
            ]
            if match_non_int_type:
                for _t in query_non_int_type:
                    matched_iit_examples = matched_iit_examples[
                        (matched_iit_examples[f"{_t}_aspect_majority"]==\
                         row[f"{_t}_aspect_majority_base"])
                    ]
            if len(set(matched_iit_examples["id"])) < min_iit_pair_examples:
                if match_non_int_type:
                    # simply avoid mapping the rest of the aspects.
                    matched_iit_examples = query_dataset[
                        (query_dataset[f"{query_int_type}_aspect_majority"]==query_int_aspect_assignment)&
                        (query_dataset["original_id"]!=query_original_id)
                    ]
                else:
                    assert False # we need to check the number!
            sampled_iit_example_ids = random.sample(
                set(matched_iit_examples["id"]), min_iit_pair_examples
            )
            for _id in sampled_iit_example_ids:
                description_iit = query_dataset[query_dataset["id"]==_id]["description"].iloc[0]
                iit_pairs_dataset += [[
                    iit_id,
                    query_int_type,
                    query_description_base, 
                    description_iit
                ]]
            iit_id += 1
        iit_pairs_dataset = pd.DataFrame(
            columns=[
                'iit_id',
                'intervention_type', 
                'description_base', 
                'description_iit'], 
            data=iit_pairs_dataset
        )
        
        base_x = self.tokenizer(
            iit_pairs_dataset['description_base'].to_list(), 
            padding=True, truncation=True, return_tensors='pt'
        )
        source_x = self.tokenizer(
            iit_pairs_dataset['description_iit'].to_list(), 
            padding=True, truncation=True, return_tensors='pt'
        )
        intervention_corr = []
        for _type in iit_pairs_dataset["intervention_type"].tolist():
            if _type == "ambiance":
                intervention_corr += [0]
            if _type == "food":
                intervention_corr += [1]
            if _type == "noise":
                intervention_corr += [2]
            if _type == "service":
                intervention_corr += [3]
        intervention_corr = torch.tensor(intervention_corr).long()
        return base_x, source_x, intervention_corr, iit_pairs_dataset
    
    def estimate_icace(self, pairs, df):
        CPM_iTEs = []
        self.blackbox_model.eval()
        self.cpm_model.model.eval()
        base_x, source_x, intervention_corr, iit_pairs_dataset = self.preprocess(
            pairs, df
        )
        with torch.no_grad():
            for i in tqdm(range(ceil(len(iit_pairs_dataset)/self.batch_size))):
                base_x_batch = {k:v[i*self.batch_size:(i+1)*self.batch_size].to(self.device) for k,v in base_x.items()} 
                source_x_batch = {k:v[i*self.batch_size:(i+1)*self.batch_size].to(self.device) for k,v in source_x.items()} 
                intervention_corr_batch = intervention_corr[i*self.batch_size:(i+1)*self.batch_size].to(self.device)
                
                base_outputs = torch.nn.functional.softmax(
                    self.blackbox_model(**base_x_batch).logits.cpu(), dim=-1
                ).detach()
                _, _, counterfactual_outputs = self.cpm_model.forward(
                    base=(base_x_batch['input_ids'], base_x_batch['attention_mask']),
                    source=(source_x_batch['input_ids'], source_x_batch['attention_mask']),
                    base_intervention_corr=intervention_corr_batch,
                    source_intervention_corr=intervention_corr_batch,
                )
                counterfactual_outputs = torch.nn.functional.softmax(
                    counterfactual_outputs["logits"][0].cpu(), dim=-1
                ).detach()
                CPM_iTE = counterfactual_outputs-base_outputs
                CPM_iTEs.append(CPM_iTE)
        CPM_iTEs = torch.concat(CPM_iTEs)
        CPM_iTEs = np.round(CPM_iTEs.numpy(), decimals=4)

        # only for iit explainer!
        iit_pairs_dataset["EiCaCE"] = list(CPM_iTEs)
        CPM_iTEs = list(iit_pairs_dataset.groupby(["iit_id"])["EiCaCE"].mean())
        
        return CPM_iTEs

def cebab_pipeline(
    model, explainer, 
    train_dataset, dev_dataset, 
    dataset_type='5-way', 
    shorten_model_name=False,
    correction_epsilon=0.001,
):
    # get predictions on train and dev
    train_predictions, _ = model.predict_proba(
        train_dataset
    )
    dev_predictions, dev_report = model.predict_proba(
        dev_dataset
    )

    # append predictions to datasets
    train_dataset['prediction'] = list(train_predictions)
    dev_dataset['prediction'] = list(dev_predictions)

    # fit explainer
    explainer.fit(
        train_dataset, train_predictions, 
        model, dev_dataset
    )

    # get intervention pairs
    
    pairs_dataset = get_intervention_pairs(
        dev_dataset, dataset_type=dataset_type
    )  # TODO why is the index not unique here?
        
    # get explanations
    if isinstance(explanator, CausalProxyModel):
        explanations = explainer.estimate_icace(
            pairs_dataset,
            train_dataset # for query data.
        )
    else:
        explanations = explainer.estimate_icace(
            pairs_dataset,
        )
    
    # append explanations to the pairs
    pairs_dataset['EICaCE'] = explanations
    
    # TODO: add cosine
    pairs_dataset = metric_utils._calculate_ite(pairs_dataset)  # effect of crowd-workers on other crowd-workers (no model, no explainer)
    
    def _calculate_icace(pairs):
        """
        This metric measures the effect of a certain concept on the given model.
        It is independent of the explainer.
        """
        pairs['ICaCE'] = (pairs['prediction_counterfactual'] - pairs['prediction_base']).apply(lambda x: np.round(x, decimals=4))

        return pairs
    pairs_dataset = _calculate_icace(pairs_dataset)  # effect of concept on the model (with model, no explainer)
    
    # TOREMOVE: just to try if we ignore all [0,0,...] cases here.
    # zeros_like = tuple([0.0 for i in range(int(dataset_type.split("-")[0]))])
    # pairs_dataset = pairs_dataset[pairs_dataset.ICaCE.map(tuple).isin([zeros_like])==False]
    
    def _cosine_distance(a,b,epsilon):
        if epsilon == None:
            if np.linalg.norm(a, ord=2) == 0 or np.linalg.norm(b, ord=2) == 0:
                return 1
            else:
                return cosine(a,b)
        
        if np.linalg.norm(a, ord=2) == 0 and np.linalg.norm(b, ord=2) == 0:
            """
            We cannot determine whether prediction is corrected or not.
            Thus, we simply return 1.
            """
            return 1
        elif np.linalg.norm(a, ord=2) == 0:
            """
            When true iCACE score is 0, instead of always returning 1, we
            allow some epsilon error by default. If the EiCACE is within a
            range, we return the error as 0.
            """
            if np.max(np.abs(b)) <= epsilon:
                return 0
            else:
                return 1
        elif np.linalg.norm(b, ord=2) == 0:
            """
            This case happens when iCACE is not 0, but EiCACE is 0. This is
            unlikely, but we give score of 1 for this case.
            """
            return 1
        else:
            return cosine(a,b)
    
    def _calculate_estimate_loss(pairs,epsilon):
        """
        Calculate the distance between the ICaCE and EICaCE.
        """

        pairs['ICaCE-L2'] = pairs[['ICaCE', 'EICaCE']].apply(lambda x: np.linalg.norm(x[0] - x[1], ord=2), axis=1)
        pairs['ICaCE-cosine'] = pairs[['ICaCE', 'EICaCE']].apply(lambda x: _cosine_distance(x[0], x[1], epsilon), axis=1)
        pairs['ICaCE-normdiff'] = pairs[['ICaCE', 'EICaCE']].apply(lambda x: abs(np.linalg.norm(x[0], ord=2) - np.linalg.norm(x[1], ord=2)), axis=1)

        return pairs
    
    pairs_dataset = _calculate_estimate_loss(pairs_dataset,correction_epsilon)  # l2 CEBaB Score (model and explainer)

    # only keep columns relevant for metrics
    CEBaB_metrics_per_pair = pairs_dataset[[
        'intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual', 'ITE', 'ICaCE', 'EICaCE', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff']].copy()
    CEBaB_metrics_per_pair['count'] = 1

    # get CEBaB tables
    metrics = ['count', 'ICaCE', 'EICaCE']

    groupby_aspect_direction = ['intervention_type', 'intervention_aspect_base', 'intervention_aspect_counterfactual']

    CaCE_per_aspect_direction = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, groupby_aspect_direction, metrics)
    CaCE_per_aspect_direction.columns = ['count', 'CaCE', 'ECaCE']
    CaCE_per_aspect_direction = CaCE_per_aspect_direction.set_index(['count'], append=True)
    
    ACaCE_per_aspect = metric_utils._aggregate_metrics(CaCE_per_aspect_direction.abs(), ['intervention_type'], ['CaCE', 'ECaCE'])
    ACaCE_per_aspect.columns = ['ACaCE', 'EACaCE']

    CEBaB_metrics_per_aspect_direction = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, groupby_aspect_direction, ['count', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff'])
    CEBaB_metrics_per_aspect_direction.columns = ['count', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff']
    CEBaB_metrics_per_aspect_direction = CEBaB_metrics_per_aspect_direction.set_index(['count'], append=True)

    CEBaB_metrics_per_aspect = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, ['intervention_type'], ['count', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff'])
    CEBaB_metrics_per_aspect.columns = ['count', 'ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff']
    CEBaB_metrics_per_aspect = CEBaB_metrics_per_aspect.set_index(['count'], append=True)

    CEBaB_metrics = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, [], ['ICaCE-L2', 'ICaCE-cosine', 'ICaCE-normdiff'])

    # get ATE table
    ATE = metric_utils._aggregate_metrics(CEBaB_metrics_per_pair, groupby_aspect_direction, ['count', 'ITE'])
    ATE.columns = ['count', 'ATE']

    # add model and explainer information
    if shorten_model_name:
        model_name = str(model).split('.')[0]
    else:
        model_name = str(model)

    CaCE_per_aspect_direction.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) if col != 'CaCE' else (model_name, '', col) for col in CaCE_per_aspect_direction.columns])
    ACaCE_per_aspect.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) if col != 'ACaCE' else (model_name, '', col) for col in ACaCE_per_aspect.columns])
    CEBaB_metrics_per_aspect_direction.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) for col in CEBaB_metrics_per_aspect_direction.columns])
    CEBaB_metrics_per_aspect.columns = pd.MultiIndex.from_tuples(
        [(model_name, str(explainer), col) for col in CEBaB_metrics_per_aspect.columns])
    CEBaB_metrics.index = pd.MultiIndex.from_product([[model_name], [str(explainer)], CEBaB_metrics.index])
    
    # performance report
    performance_report_index = ['macro-f1', 'accuracy']
    performance_report_data = [dev_report['macro avg']['f1-score'], dev_report['accuracy']]
    performance_report_col = [model_name]
    performance_report = pd.DataFrame(data=performance_report_data, index=performance_report_index, columns=performance_report_col)

    return pairs_dataset, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, ACaCE_per_aspect, performance_report

#### Main evaluate script

In [10]:
"""
The following blocks will run CEBaB benchmark in
all the combinations of the following conditions.
"""
grid = {
    "h_dim": [192],
    "class_num": [5],
    "control": [False],
    "beta" : [1.0],
    "gemma" : [3.0],
    "cls_dropout" : [0.1],
    "enc_dropout" : [0.1],
    "model_arch" : ["bert-base-uncased"]
}

keys, values = zip(*grid.items())
permutations_dicts = [dict(zip(keys, v)) for v in itertools.product(*values)]

device = 'cuda:9'
batch_size = 32

In [11]:
results = {}
for i in range(len(permutations_dicts)):
    
    seed=42
    class_num=permutations_dicts[i]["class_num"]
    beta=permutations_dicts[i]["beta"]
    gemma=permutations_dicts[i]["gemma"]
    h_dim=permutations_dicts[i]["h_dim"]
    dataset_type = f'{class_num}-way'
    correction_epsilon=None
    cls_dropout=permutations_dicts[i]["cls_dropout"]
    enc_dropout=permutations_dicts[i]["enc_dropout"]
    control=permutations_dicts[i]["control"]
    model_arch=permutations_dicts[i]["model_arch"]
    
    if model_arch == "bert-base-uncased":
        model_path = "BERT-control-results" if control else "BERT-results"
        model_module = BERTForCEBaB
        explainer_module = CausalProxyModelForBERT
    elif model_arch == "roberta-base":
        model_path = "RoBERTa-control-results" if control else "RoBERTa-results"
    elif model_arch == "gpt2":
        model_path = "gpt2-control-results" if control else "gpt2-results"
    elif model_arch == "lstm":
        model_path = "lstm-control-results" if control else "lstm-results"
        
    grid_conditions=(
        ("seed", seed),
        ("class_num", class_num),
        ("beta", beta),
        ("gemma", gemma),
        ("h_dim", h_dim),
        ("dataset_type", dataset_type),
        ("correction_epsilon", correction_epsilon),
        ("cls_dropout", cls_dropout),
        ("enc_dropout", enc_dropout),
        ("control", control),
        ("model_arch", model_arch),
    )
    print("Running for this setting: ", grid_conditions)

    blackbox_model_path = f'CEBaB/{model_arch}.CEBaB.sa.{class_num}-class.exclusive.seed_{seed}'
    if control:
        cpm_model_path = blackbox_model_path
    else:
        cpm_model_path = f'../proxy_training_results/{model_path}/'\
                           f'cebab.train.train.alpha.1.0'\
                           f'.beta.{beta}.gemma.{gemma}.dim.{h_dim}.hightype.'\
                           f'{model_arch}.Proxy.'\
                           f'CEBaB.sa.{class_num}-class.exclusive.'\
                           f'mode.align.cls.dropout.{cls_dropout}.enc.dropout.{enc_dropout}.seed_{seed}'

    # load data from HF
    cebab = datasets.load_dataset(
        'CEBaB/CEBaB', use_auth_token=True,
        cache_dir="../../huggingface_cache/"
    )
    cebab['train'] = cebab['train_exclusive']
    train, dev, test = preprocess_hf_dataset(
        cebab, one_example_per_world=False, 
        verbose=1, dataset_type=dataset_type
    )

    tf_model = model_module(
        blackbox_model_path, 
        device=device, 
        batch_size=batch_size
    )
    explanator = explainer_module(
        blackbox_model_path,
        cpm_model_path, 
        device=device, 
        batch_size=batch_size,
        intervention_h_dim=h_dim,
    )

    train_dataset = train.copy()
    dev_dataset = test.copy()

    result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \
    CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \
    ACaCE_per_aspect, performance_report = cebab_pipeline(
        tf_model, explanator, 
        train_dataset, dev_dataset, 
        dataset_type=dataset_type,
        correction_epsilon=correction_epsilon,
    )
    
    results[grid_conditions] = (
        result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \
        CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \
        ACaCE_per_aspect, performance_report
    )

Running for this setting:  (('seed', 42), ('class_num', 5), ('beta', 1.0), ('gemma', 3.0), ('h_dim', 192), ('dataset_type', '5-way'), ('correction_epsilon', None), ('cls_dropout', 0.1), ('enc_dropout', 0.1), ('control', False), ('model_arch', 'bert-base-uncased'))


Using custom data configuration CEBaB--CEBaB-0e2f7ed67c9d7e55
Reusing dataset parquet (../../huggingface_cache/parquet/CEBaB--CEBaB-0e2f7ed67c9d7e55/0.0.0/0b6d5799bb726b24ad7fc7be720c170d8e497f575d02d47537de9a5bac074901)


  0%|          | 0/4 [00:00<?, ?it/s]

Dropping no majority reviews: 16.03% of train dataset.
intervention_h_dim=192


100%|██████████| 124/124 [00:26<00:00,  4.65it/s]


#### Tabularize your results

In [12]:
important_keys = list(grid.keys())
values = []
for k, v in results.items():
    _values = []
    for ik in important_keys:
        _values.append(dict(k)[ik])
    _values.append(v[2]["ICaCE-L2"].iloc[0])
    _values.append(v[2]["ICaCE-cosine"].iloc[0])
    _values.append(v[2]["ICaCE-normdiff"].iloc[0])
    _values.append(v[-1].iloc[0][0])
    values.append(_values)
important_keys.extend(["ICaCE-L2", "ICaCE-cosine", "ICaCE-normdiff", "macro-f1"])
df = pd.DataFrame(values, columns=important_keys)
df.sort_values(by=['class_num'], ascending=True)

Unnamed: 0,h_dim,class_num,control,beta,gemma,cls_dropout,enc_dropout,model_arch,ICaCE-L2,ICaCE-cosine,ICaCE-normdiff,macro-f1
0,192,5,False,1.0,3.0,0.1,0.1,bert-base-uncased,0.6617,0.4717,0.4691,0.696461
