### Causal Mediation Analyses
Static version of activate alignment search without interchange intervention training. This analysis follows closely the [causal mediation analyses paper](https://proceedings.neurips.cc/paper/2020/file/92650b2e92217715fe312e6fa7b90d82-Paper.pdf). We simplify their method to adapt to CEBaB. Details can be found in code.

In [2]:
from libs import *
from modelings.modelings_bert import *
from modelings.modelings_roberta import *
from modelings.modelings_gpt2 import *
from modelings.modelings_lstm import *

"""
For evaluate, we use a single random seed, as
the models are trained with 5 different seeds
already.
"""
_ = random.seed(123)
_ = np.random.seed(123)
_ = torch.manual_seed(123)

In [56]:
seed=42
class_num=5
beta=1.0
gemma=3.0
h_dim=4
dataset_type = f'{class_num}-way'
correction_epsilon=None
cls_dropout=0.1
enc_dropout=0.1
control=False
model_arch="bert-base-uncased"
neuron_pool_size = h_dim*4

# these are the control settings.
concepts = ["ambiance", "food", "noise", "service"]
random_align_neurons = {}
random_neuron_pool = [i for i in range(neuron_pool_size)]
random.shuffle(random_neuron_pool)
for idx, c in enumerate(concepts):
    random_align_neurons[c] = set(random_neuron_pool[idx*h_dim:(idx+1)*h_dim])

device='cuda:0'
batch_size=32
    
model_path = f'CEBaB/{model_arch}.CEBaB.sa.{class_num}-class.exclusive.seed_{seed}'

# load data from HF
cebab = datasets.load_dataset(
    'CEBaB/CEBaB', use_auth_token=True,
    cache_dir="../train_cache/"
)
cebab['train'] = cebab['train_exclusive']
train, dev, test = preprocess_hf_dataset(
    cebab, one_example_per_world=True, 
    verbose=1, dataset_type=dataset_type
)

train_dataset = train.copy()
dev_dataset = test.copy()

tf_model = BERTForCEBaB(
    model_path, 
    device=device, 
    batch_size=batch_size
)

Using custom data configuration CEBaB--CEBaB-0e2f7ed67c9d7e55
Reusing dataset parquet (../train_cache/CEBaB___parquet/CEBaB--CEBaB-0e2f7ed67c9d7e55/0.0.0/7328ef7ee03eaf3f86ae40594d46a1cec86161704e02dd19f232d81eee72ade8)


  0%|          | 0/4 [00:00<?, ?it/s]

Dropping no majority reviews: 16.6382% of train dataset.


In [161]:
output_filename = os.path.join("./neuron_alignments/", model_path.split("/")[-1]+".alignment")
if not os.path.isfile(output_filename):
    dataset = train_dataset[[
        "description",
        "ambiance_aspect_majority",
        "food_aspect_majority",
        "noise_aspect_majority",
        "service_aspect_majority",
    ]]
    align_neurons = {}
    neuron_pool = set([i for i in range(neuron_pool_size)])
    loss = nn.MSELoss()
    
    print("preloading representations.")
    # pre-calculating representations to avoid repeated computations.
    preload_representations = {}
    preload_logits = {}
    for index, row in dataset.iterrows():
        description = row["description"]
        explanator.model.model.eval()
        x = explanator.tokenizer([description], padding=True, truncation=True, return_tensors='pt')
        x_batch = {k: v.to(explanator.device) for k, v in x.items()}
        outputs = explanator.model.model(
            **x_batch,
            output_hidden_states=True,
        )
        cls_hidden_state = outputs.hidden_states[-1][:,0,:].detach()
        output_logit = torch.nn.functional.softmax(
                outputs.logits[0].cpu(), dim=-1
        ).detach()[0]
        preload_representations[description] = cls_hidden_state
        preload_logits[description] = output_logit
    
    for align_concept in concepts:
        print(f"aligning for concept={align_concept}.")
        align_neurons[align_concept] = set([])
        neuron_causal_effect = dict([
            (i, 0.0) for i in range(explanator.model.model.config.hidden_size)
        ]) # treatment effect / control effect
        for index, row in tqdm(dataset.iterrows(), total=dataset.shape[0]):
            control_concepts = list(set(concepts) - set([align_concept]))
            description = row["description"]
            concept_label = row[f"{align_concept}_aspect_majority"]
            counterfactual_concept_row = dataset[
                dataset[f"{align_concept}_aspect_majority"]!=concept_label
            ].sample().iloc[0]
            counterfactual_description = counterfactual_concept_row["description"]
            if concept_label != "" and counterfactual_concept_row[
                f"{align_concept}_aspect_majority"
            ] != "":
                # check logit change by looping through all neurons.
                for neuron_id in neuron_pool:
                    logits = preload_logits[description]
                    intervened_logits = intervene_neuron_logits(
                        explanator, 
                        preload_representations[description].clone(), 
                        preload_representations[counterfactual_description],
                        neuron_id,
                    )
                    te = loss(intervened_logits,logits)
                    ce = 0.0
                    ce_count = 0
                    for control_concept in control_concepts:
                        control_label = row[f"{control_concept}_aspect_majority"]
                        counterfactual_control_concept_row = dataset[
                            dataset[f"{control_concept}_aspect_majority"]!=control_label
                        ].sample().iloc[0]
                        counterfactual_control_description = counterfactual_control_concept_row[
                            "description"
                        ]
                        if control_label != "" and counterfactual_control_concept_row[
                            f"{control_concept}_aspect_majority"
                        ] != "":
                            intervened_control_logits = intervene_neuron_logits(
                                explanator, 
                                preload_representations[description].clone(), 
                                preload_representations[counterfactual_control_description],
                                neuron_id,
                            )
                            ce += loss(intervened_control_logits,logits)
                            ce_count += 1
                    if ce_count != 0:
                        ce /= ce_count
                        ratio = 1/len(dataset)
                        neuron_causal_effect[neuron_id] += (ratio*(te/ce)).tolist()
        neuron_causal_effect = sorted(neuron_causal_effect.items(), key=lambda x: x[1], reverse=True)
        selected_neuron = set([
            neuron_causal_effect[i][0] for i in range(int(explanator.intervention_h_dim))
        ])
        align_neurons[align_concept] = selected_neuron
        # remove from avaliable pool.
        neuron_pool -= selected_neuron
        # we may stop early, if we ask to align all neurons for this repr.
        if len(neuron_pool) == int(explanator.intervention_h_dim):
            print("since mapping all neurons, skip the last one to take all remaining neurons.")
            align_neurons[concepts[-1]] = neuron_pool
            break

    # save the alignment for future use.
    serialized_align_neurons = {}
    for k, v in align_neurons.items():
        serialized_align_neurons[k] = list(v)

    with open(output_filename, "w") as outfile:
        json.dump(serialized_align_neurons, outfile, indent=4)
else:
    print(f"saved alignment found for this model: {output_filename}.")
    align_neurons = json.load(open(output_filename))

saved alignment found for this model: ./neuron_alignments/bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42.alignment.


Evaluate with neurons selected by causal mediation analyses

In [163]:
explainer = CausalMediationModelForBERT(
    model_path,
    device=device, 
    batch_size=batch_size,
    intervention_h_dim=h_dim,
    align_neurons=align_neurons
)

result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \
CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \
ACaCE_per_aspect, performance_report = cebab_pipeline(
    tf_model, explainer, 
    train_dataset, dev_dataset, 
    dataset_type="5-way",
    correction_epsilon=None,
)

CEBaB_metrics

Some weights of IITBERTForSequenceClassification were not initialized from the model checkpoint at CEBaB/bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42 and are newly initialized: ['multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight', 'multitask_classifier.dense.bias', 'multitask_classifier.out_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 124/124 [00:20<00:00,  6.08it/s]


Unnamed: 0,Unnamed: 1,Unnamed: 2,ICaCE-L2,ICaCE-cosine,ICaCE-normdiff
bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42,CausalMediationModelForBERT,mean,0.7922,0.6726,0.7169


Evaluate with neurons that are randomly selected

In [165]:
control_explainer = CausalMediationModelForBERT(
    model_path,
    device=device, 
    batch_size=batch_size,
    intervention_h_dim=h_dim,
    align_neurons=random_align_neurons
)

result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \
CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \
ACaCE_per_aspect, performance_report = cebab_pipeline(
    tf_model, control_explainer, 
    train_dataset, dev_dataset, 
    dataset_type="5-way",
    correction_epsilon=None,
)

CEBaB_metrics

Some weights of IITBERTForSequenceClassification were not initialized from the model checkpoint at CEBaB/bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42 and are newly initialized: ['multitask_classifier.dense.weight', 'multitask_classifier.out_proj.weight', 'multitask_classifier.dense.bias', 'multitask_classifier.out_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 124/124 [00:12<00:00,  9.87it/s]


Unnamed: 0,Unnamed: 1,Unnamed: 2,ICaCE-L2,ICaCE-cosine,ICaCE-normdiff
bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42,CausalMediationModelForBERT,mean,0.8135,0.849,0.8108
