### Causal Mediation Analyses
Static version of activate alignment search without interchange intervention training. This analysis follows closely the [causal mediation analyses paper](https://proceedings.neurips.cc/paper/2020/file/92650b2e92217715fe312e6fa7b90d82-Paper.pdf). We simplify their method to adapt to CEBaB. Details can be found in code.

In [None]:
from libs import *
"""
For evaluate, we use a single random seed, as
the models are trained with 5 different seeds
already.
"""
_ = random.seed(123)
_ = np.random.seed(123)
_ = torch.manual_seed(123)
    
def intervene_neuron_logits(
    explanator, hidden_reprs, counterfactual_reprs, neuron_id
):
    hidden_reprs[0,neuron_id] = counterfactual_reprs[0,neuron_id]
    intervened_outputs, _, _ = explanator.model.forward_with_cls_hidden_reprs(
        cls_hidden_reprs=hidden_reprs.unsqueeze(dim=1)
    )
    intervened_logits = torch.nn.functional.softmax(
            intervened_outputs.logits[0].cpu(), dim=-1
    ).detach()[0]
    return intervened_logits


In [None]:
seed=42
class_num=5
beta=1.0
gemma=3.0
h_dim=192
dataset_type = f'{class_num}-way'
correction_epsilon=None
cls_dropout=0.1
enc_dropout=0.1
control=False
model_arch="bert-base-uncased"
    
device='cuda:6'
batch_size=32
    
model_path = f'CEBaB/{model_arch}.CEBaB.sa.{class_num}-class.exclusive.seed_{seed}'

# load data from HF
cebab = datasets.load_dataset(
    'CEBaB/CEBaB', use_auth_token=True,
    cache_dir="../../huggingface_cache/"
)
cebab['train'] = cebab['train_exclusive']
train, dev, test = preprocess_hf_dataset(
    cebab, one_example_per_world=True, 
    verbose=1, dataset_type=dataset_type
)

train_dataset = train.copy()
dev_dataset = test.copy()

In [None]:
# pre-calculating representations to avoid repeated computations.
preload_representations = {}
preload_logits = {}
for index, row in dataset.iterrows():
    description = row["description"]
    explanator.model.model.eval()
    x = explanator.tokenizer([description], padding=True, truncation=True, return_tensors='pt')
    x_batch = {k: v.to(explanator.device) for k, v in x.items()}
    outputs = explanator.model.model(
        **x_batch,
        output_hidden_states=True,
    )
    cls_hidden_state = outputs.hidden_states[-1][:,0,:].detach()
    output_logit = torch.nn.functional.softmax(
            outputs.logits[0].cpu(), dim=-1
    ).detach()[0]
    preload_representations[description] = cls_hidden_state
    preload_logits[description] = output_logit

In [110]:
output_filename = os.path.join("./neuron_alignments/", model_path.split("/")[-1]+".alignment")
if not os.path.isfile(output_filename):
    dataset = train_dataset[[
        "description",
        "ambiance_aspect_majority",
        "food_aspect_majority",
        "noise_aspect_majority",
        "service_aspect_majority",
    ]]
    concepts = ["ambiance", "food", "noise", "service"]
    align_neurons = {}
    neuron_pool = set([i for i in range(explanator.model.model.config.hidden_size)])
    loss = nn.MSELoss()

    for align_concept in concepts:
        print(f"aligning for concept={align_concept}.")
        align_neurons[align_concept] = set([])
        neuron_causal_effect = dict([
            (i, 0.0) for i in range(explanator.model.model.config.hidden_size)
        ]) # treatment effect / control effect
        for index, row in tqdm(dataset.iterrows(), total=dataset.shape[0]):
            control_concepts = list(set(concepts) - set([align_concept]))
            description = row["description"]
            concept_label = row[f"{align_concept}_aspect_majority"]
            counterfactual_concept_row = dataset[
                dataset[f"{align_concept}_aspect_majority"]!=concept_label
            ].sample().iloc[0]
            counterfactual_description = counterfactual_concept_row["description"]
            if concept_label != "" and counterfactual_concept_row[
                f"{align_concept}_aspect_majority"
            ] != "":
                # check logit change by looping through all neurons.
                for neuron_id in neuron_pool:
                    logits = preload_logits[description]
                    intervened_logits = intervene_neuron_logits(
                        explanator, 
                        preload_representations[description].clone(), 
                        preload_representations[counterfactual_description],
                        neuron_id,
                    )
                    te = loss(intervened_logits,logits)
                    ce = 0.0
                    ce_count = 0
                    for control_concept in control_concepts:
                        control_label = row[f"{control_concept}_aspect_majority"]
                        counterfactual_control_concept_row = dataset[
                            dataset[f"{control_concept}_aspect_majority"]!=control_label
                        ].sample().iloc[0]
                        counterfactual_control_description = counterfactual_control_concept_row[
                            "description"
                        ]
                        if control_label != "" and counterfactual_control_concept_row[
                            f"{control_concept}_aspect_majority"
                        ] != "":
                            intervened_control_logits = intervene_neuron_logits(
                                explanator, 
                                preload_representations[description].clone(), 
                                preload_representations[counterfactual_control_description],
                                neuron_id,
                            )
                            ce += loss(intervened_control_logits,logits)
                            ce_count += 1
                    if ce_count != 0:
                        neuron_causal_effect[neuron_id] += (te/(ce/ce_count)*(1/len(dataset))).tolist()
        neuron_causal_effect = sorted(neuron_causal_effect.items(), key=lambda x: x[1], reverse=True)
        selected_neuron = set([
            neuron_causal_effect[i][0] for i in range(int(explanator.intervention_h_dim))
        ])
        align_neurons[align_concept] = selected_neuron
        # remove from avaliable pool.
        neuron_pool -= selected_neuron
        # we may stop early, if we ask to align all neurons for this repr.
        if len(neuron_pool) == int(explanator.intervention_h_dim):
            print("since mapping all neurons, skip the last one to take all remaining neurons.")
            align_neurons[concepts[-1]] = neuron_pool
            break

    # save the alignment for future use.
    serialized_align_neurons = {}
    for k, v in align_neurons.items():
        serialized_align_neurons[k] = list(v)

    with open(, "w") as outfile:
        json.dump(serialized_align_neurons, outfile, indent=4)
else:
    print(f"saved alignment found for this model: {output_filename}.")
    align_neurons = json.load(open(output_filename))

aligning for concept=ambiance.


100%|██████████| 1463/1463 [25:44<00:00,  1.06s/it]


aligning for concept=food.


100%|██████████| 1463/1463 [36:30<00:00,  1.50s/it]


aligning for concept=noise.


100%|██████████| 1463/1463 [09:11<00:00,  2.65it/s]

since mapping all neurons, skip the last one to take all remaining neurons.





In [136]:
tf_model = BERTForCEBaB(
    model_path, 
    device=device, 
    batch_size=batch_size
)

explanator = CausalMediationModelForBERT(
    model_path,
    device=device, 
    batch_size=1,
    intervention_h_dim=h_dim,
    align_neurons=align_neurons
)

Some weights of IITBERTForSequenceClassification were not initialized from the model checkpoint at CEBaB/bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42 and are newly initialized: ['multitask_classifier.out_proj.bias', 'multitask_classifier.dense.weight', 'multitask_classifier.dense.bias', 'multitask_classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [137]:
result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \
CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \
ACaCE_per_aspect, performance_report = cebab_pipeline(
    tf_model, explanator, 
    train_dataset, dev_dataset, 
    dataset_type="5-way",
    correction_epsilon=None,
)

100%|██████████| 3958/3958 [01:07<00:00, 58.46it/s]


In [138]:
CEBaB_metrics

Unnamed: 0,Unnamed: 1,Unnamed: 2,ICaCE-L2,ICaCE-cosine,ICaCE-normdiff
bert-base-uncased.CEBaB.sa.5-class.exclusive.seed_42,CausalMediationModelForBERT,mean,0.7981,0.6885,0.7222
