### Causal Mediation Analyses
Static version of activate alignment search without interchange intervention training. This analysis follows closely the [causal mediation analyses paper](https://proceedings.neurips.cc/paper/2020/file/92650b2e92217715fe312e6fa7b90d82-Paper.pdf). We simplify their method to adapt to CEBaB. Details can be found in code.

In [None]:
from libs import *
from modelings.modelings_bert import *
from modelings.modelings_roberta import *
from modelings.modelings_gpt2 import *
from modelings.modelings_lstm import *

"""
For evaluate, we use a single random seed, as
the models are trained with 5 different seeds
already.
"""
_ = random.seed(123)
_ = np.random.seed(123)
_ = torch.manual_seed(123)
torch.cuda.empty_cache()

In [None]:
seed=42
class_num=5
beta=1.0
gemma=3.0
h_dim=192
dataset_type = f'{class_num}-way'
correction_epsilon=None
cls_dropout=0.1
enc_dropout=0.1
control=False
model_arch="bert-base-uncased"
neuron_pool_size = h_dim*4

# these are the control settings.
concepts = ["ambiance", "food", "noise", "service"]
random_align_neurons = {}
random_neuron_pool = [i for i in range(neuron_pool_size)]
random.shuffle(random_neuron_pool)
for idx, c in enumerate(concepts):
    random_align_neurons[c] = set(random_neuron_pool[idx*h_dim:(idx+1)*h_dim])

device='cuda:0'
batch_size=32
    
model_path = f'CEBaB/{model_arch}.CEBaB.sa.{class_num}-class.exclusive.seed_{seed}'

# load data from HF
cebab = datasets.load_dataset(
    'CEBaB/CEBaB', use_auth_token=True,
    cache_dir="../train_cache/"
)
cebab['train'] = cebab['train_exclusive']
train, dev, test = preprocess_hf_dataset(
    cebab, one_example_per_world=True, 
    verbose=1, dataset_type=dataset_type
)

train_dataset = train.copy()
dev_dataset = test.copy()

tf_model = BERTForCEBaB(
    model_path, 
    device=device, 
    batch_size=batch_size
)

explainer = CausalMediationModelForBERT(
    model_path,
    device=device, 
    batch_size=batch_size,
    intervention_h_dim=h_dim,
    align_neurons=None
)

In [None]:
def sample_rows_with_condition(df, col_name, value, condition="=="):
    if condition == "==":
        return df[
            (df[col_name]==value)&
            (df[col_name]!="")
        ]
    elif condition == "!=":
        return df[
            (df[col_name]!=value)&
            (df[col_name]!="")
        ]
    else:
        assert False # not supporting this type yet.

def get_treatment_effect_pairs(dataset, align_concept):
    concepts = ["ambiance", "food", "noise", "service"]
    treatment_effect_pairs = []
    for index, row in dataset.iterrows():
        control_concepts = list(set(concepts) - set([align_concept]))
        description = row["description"]
        concept_label = row[f"{align_concept}_aspect_majority"]
        if concept_label != "":
            counterfactual_concept_row = sample_rows_with_condition(
                dataset,
                f"{align_concept}_aspect_majority",
                concept_label,
                "!="
            )
            for control_concept in control_concepts:
                control_concept_label = row[f"{control_concept}_aspect_majority"]
                counterfactual_concept_row = sample_rows_with_condition(
                    counterfactual_concept_row,
                    f"{control_concept}_aspect_majority",
                    control_concept_label,
                    "=="
                )
            if len(counterfactual_concept_row) > 0:
                counterfactual_description = counterfactual_concept_row.sample(
                ).iloc[0]["description"]
                treatment_effect_pairs += [(description, counterfactual_description)]
    return treatment_effect_pairs

def evaluate_te(explainer, tc_pairs, neuron_id):
    te = 0.0
    for pair in tc_pairs:
        base_logits, _ = preload_logits[pair[0]], preload_logits[pair[1]]
        base_reprs, source_reprs = preload_representations[pair[0]], preload_representations[pair[1]]
        counterfactual_logits = intervene_neuron_logits(
            explainer, 
            base_reprs.clone(), 
            source_reprs,
            neuron_id,
        )
        te += loss(counterfactual_logits, base_logits)
    return te/len(tc_pairs)

def sequential_alignment_search(aspect_neuron_te_map):
    align_neurons = {}
    # lets help causal mediation a bit.
    concepts = ["food", "service", "ambiance", "noise"]
    neuron_pool = set([i for i in range(len(aspect_neuron_te_map["ambiance"]))])
    h_dim = len(aspect_neuron_te_map["ambiance"])//len(concepts)
    for concept in concepts:
        avaliable_neurons = []
        for item in aspect_neuron_te_map[concept]:
            if item[0] in neuron_pool:
                avaliable_neurons.append(item)
        neuron_causal_effect = sorted(avaliable_neurons, key=lambda x: x[1], reverse=True)
        aligned_neurons = set([item[0] for item in neuron_causal_effect[:h_dim]])
        align_neurons[concept] = aligned_neurons
        neuron_pool -= aligned_neurons
    return align_neurons

In [None]:
dataset = train_dataset[[
    "description",
    "ambiance_aspect_majority",
    "food_aspect_majority",
    "noise_aspect_majority",
    "service_aspect_majority",
]]
align_neurons = {}
neuron_pool = set([i for i in range(neuron_pool_size)])
loss = nn.MSELoss()

print("preloading representations and descriptions.")
# pre-calculating representations to avoid repeated computations.
preload_representations = {}
preload_logits = {}
for index, row in dataset.iterrows():
    description = row["description"]
    explainer.model.model.eval()
    x = explainer.tokenizer([description], padding=True, truncation=True, return_tensors='pt')
    x_batch = {k: v.to(explainer.device) for k, v in x.items()}
    outputs = explainer.model.model(
        **x_batch,
        output_hidden_states=True,
    )
    cls_hidden_state = outputs.hidden_states[-1][:,0,:].detach()
    output_logit = torch.nn.functional.softmax(
            outputs.logits[0].cpu(), dim=-1
    ).detach()[0]
    preload_representations[description] = cls_hidden_state
    preload_logits[description] = output_logit

# getting base and source pairs for each aspect.
# base and source should only differ in one aspect!
k = 300
ambiance_tc_pairs = random.sample(get_treatment_effect_pairs(dataset, "ambiance"), k=k)
food_tc_pairs = random.sample(get_treatment_effect_pairs(dataset, "food"), k=k)
noise_tc_pairs = random.sample(get_treatment_effect_pairs(dataset, "noise"), k=k)
service_tc_pairs = random.sample(get_treatment_effect_pairs(dataset, "service"), k=k)

In [None]:
# the steps are listed here:
# 1. loop through all neurons, calculate te for each aspects. save it!
# 2. loop through all neurons again, the effect of each neuron on a 
#.   aspect, is essentially a combination of saved te scores!
aspect_neuron_te_map = {
    "ambiance": [], 
    "food": [], 
    "noise": [], 
    "service": []
}
for neuron_id in tqdm(neuron_pool):
    ambiance_te_est = evaluate_te(
        explainer,
        ambiance_tc_pairs, 
        neuron_id
    )
    food_te_est = evaluate_te(
        explainer,
        food_tc_pairs, 
        neuron_id
    )
    noise_te_est = evaluate_te(
        explainer,
        noise_tc_pairs, 
        neuron_id
    )
    service_te_est = evaluate_te(
        explainer,
        service_tc_pairs, 
        neuron_id
    )
    aspect_neuron_te_map["ambiance"].append(
        (neuron_id, ambiance_te_est/(food_te_est+noise_te_est+service_te_est))
    )
    aspect_neuron_te_map["food"].append(
        (neuron_id, food_te_est/(ambiance_te_est+noise_te_est+service_te_est))
    )
    aspect_neuron_te_map["noise"].append(
        (neuron_id, noise_te_est/(food_te_est+ambiance_te_est+service_te_est))
    )
    aspect_neuron_te_map["service"].append(
        (neuron_id, service_te_est/(food_te_est+noise_te_est+ambiance_te_est))
    )

In [None]:
align_neurons = sequential_alignment_search(aspect_neuron_te_map)

Evaluate with neurons selected by causal mediation analyses

In [None]:
explainer = CausalMediationModelForBERT(
    model_path,
    device=device, 
    batch_size=batch_size,
    intervention_h_dim=h_dim,
    align_neurons=align_neurons
)

result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \
CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \
ACaCE_per_aspect, performance_report = cebab_pipeline(
    tf_model, explainer, 
    train_dataset, dev_dataset, 
    dataset_type="5-way",
    correction_epsilon=None,
)

CEBaB_metrics

Evaluate with neurons that are randomly selected

In [None]:
control_explainer = CausalMediationModelForBERT(
    model_path,
    device=device, 
    batch_size=batch_size,
    intervention_h_dim=h_dim,
    align_neurons=random_align_neurons
)

result_per_example, ATE, CEBaB_metrics, CEBaB_metrics_per_aspect_direction, \
CEBaB_metrics_per_aspect, CaCE_per_aspect_direction, \
ACaCE_per_aspect, performance_report = cebab_pipeline(
    tf_model, control_explainer, 
    train_dataset, dev_dataset, 
    dataset_type="5-way",
    correction_epsilon=None,
)

CEBaB_metrics