# Causal Analysis

In [None]:
import import_ipynb # using this to import the modules notebook
import modules # importing the notebook
import torch
import plotly.express as px
import pandas as pd
import numpy as np


Set Model Configuration

In [None]:
# num_tokens: the number of different tokens in the corpus
# t: the length of the sequences as input to the model
# depth: depth of the network (number of transformer blocks)
# heads: number of attention heads in the multi-head attention mechanism
# k: embedding dimension (needs to be a multiple of heads)

k = 6 # x * heads
num_tokens = 10 # integers from 0 to 9
heads = 3
depth = 2
t = 5

Load Model

In [None]:
# load trained model
model = modules.GTransformer(k=k, heads=heads, depth=depth, t=t, num_tokens=num_tokens)
model.load_state_dict(torch.load('gtransformer.pth'))

Load Training Data

In [None]:
# set token
tokens = np.arange(num_tokens)
print(tokens)

In [None]:
# define class
input_class = "class2"
save_fig_path = f"./images/causal_mediation_{input_class}/"

In [None]:
# load aligned data from dataset and get source tensors
aligned_data = torch.load(f'data/data_{input_class}.pt').tensors[0]
print(aligned_data.shape)
print(aligned_data[0])

In [None]:
# create misaligned data by setting the third token to the same as the second token
misaligned_data = aligned_data.clone()
for i in range(aligned_data.shape[0]):
    misaligned_data[i][2] = misaligned_data[i][1]
print(misaligned_data.shape)
print(misaligned_data[0])

In [None]:
# create tuples , that consist of each an example of aligned and misaligned data
data = [(aligned_data[i], misaligned_data[i]) for i in range(aligned_data.shape[0])]
print(data)

In [None]:
def causal_mediation_analysis(data_tuple, model, heads, depth, k, t):
    # make sure the model is in evaluation mode
    model.eval()
    # setup dictionaries, that will hold the results of the mediation analysis
    total_result_all = {}
    total_result_class ={}
    direct_result_all_attn = {}
    direct_result_all_ff = {}
    direct_result_class_attn = {}
    direct_result_class_ff = {}
    indirect_result_all_attn = {}
    indirect_result_all_ff = {}
    indirect_result_class_attn = {}
    indirect_result_class_ff = {}
    attn_mediators = []
    ff_mediators = []
    # run the model on the aligned and misaligned data and retrieve log_probs, attn_out_unified_all, ff_out_all (no update of model parameters)
    with torch.no_grad():
        log_probs_aligned, attn_out_unified_all_aligned, ff_out_all_aligned = model(data_tuple[0].unsqueeze(0))
        total_result_all['probs_aligned'] = torch.exp(log_probs_aligned)
        total_result_class['probs_aligned'] = torch.exp(log_probs_aligned)[0][-1]
        log_probs_misaligned, attn_out_unified_all_misaligned, ff_out_all_misaligned = model(data_tuple[1].unsqueeze(0))
        total_result_all['probs_misaligned'] = torch.exp(log_probs_misaligned)
        total_result_class['probs_misaligned'] = torch.exp(log_probs_misaligned)[0][-1]

    # iterate over all heads and layers and mask the attention output for each head
    attn_out_mediator_db = {}
    attn_out_mediator_db = {}
    for l in range(depth):
        for h in range(heads):
            # prep the attention layer mediator interventions, by masking the full attn_out for each head and layer
            attn_out_unified_all_aligned_masked = [np.zeros_like(item) for item in attn_out_unified_all_aligned]
            attn_out_unified_all_misaligned_masked = [np.zeros_like(item) for item in attn_out_unified_all_misaligned]
            # e.g. attn_out_unified_all_aligned[0] for the first layer has torch.Size([1, 4, 6]), where 4 is seq_length, 6 is k which is created by k//3 of each head
            # this means e.g. for head 1, only [1, 4, 0:2] is used, for head 2, only [1, 4, 2:4] is used, for head 3, only [1, 4, 4:6] is used 
            attn_out_mediator_db[l, h , "aligned-fix"] = attn_out_unified_all_aligned_masked.copy()
            attn_out_mediator_db[l, h , "aligned-fix"][l][0, :, h * (k//heads):(h+1) * (k//heads)] = attn_out_unified_all_aligned[l][0, :, h * (k//heads):(h+1) * (k//heads)].clone()
            attn_out_mediator_db[l, h , "misaligned-fix"] = attn_out_unified_all_misaligned_masked.copy()
            attn_out_mediator_db[l, h , "misaligned-fix"][l][0, :, h * (k//heads):(h+1) * (k//heads)] = attn_out_unified_all_misaligned[l][0, :, h * (k//heads):(h+1) * (k//heads)].clone()
    # iterate over all the feed forward layers and mask the feed forward output for each neuron
    ff_out_mediator_db = {}
    for l in range(depth):
        for s in range(t):
            for dim in range (k):
                # prep the feed forward layer mediator interventions, by masking the full ff_out for each head and layer
                ff_out_all_aligned_masked = [np.zeros_like(item) for item in ff_out_all_aligned]
                ff_out_all_misaligned_masked = [np.zeros_like(item) for item in ff_out_all_misaligned]
                ff_out_mediator_db[l,s,dim, "aligned-fix"] = ff_out_all_aligned_masked.copy()
                ff_out_mediator_db[l,s,dim, "aligned-fix"][l][0, s, dim] = ff_out_all_aligned[l][0, s, dim].clone()
                ff_out_mediator_db[l,s,dim, "misaligned-fix"] = ff_out_all_misaligned_masked.copy()
                ff_out_mediator_db[l,s,dim, "misaligned-fix"][l][0, s, dim] = ff_out_all_misaligned[l][0, s, dim].clone()
    
    # perform the model runs for direct and indirect effects
    with torch.no_grad():
        # for all attention heads as mediators
        for h in range(heads):
            for l in range(depth):
                log_probs_direct, _, _ = model(data_tuple[1].unsqueeze(0), fix_attn_out_unified_all=torch.tensor(attn_out_mediator_db[l,h, "aligned-fix"]))	
                attn_out_mediator_db[l, h , "probs_direct"]= torch.exp(log_probs_direct)
                log_probs_indirect, _, _ = model(data_tuple[0].unsqueeze(0), fix_attn_out_unified_all=torch.tensor(attn_out_mediator_db[l,h, "misaligned-fix"]))
                attn_out_mediator_db[l, h , "probs_indirect"] =torch.exp(log_probs_indirect)
        # for all feed forward layers as mediators
        for l in range(depth):
            for s in range(t):
                for dim in range(k):
                    log_probs_direct, _, _ = model(data_tuple[1].unsqueeze(0), fix_ff_out_all=torch.tensor(ff_out_mediator_db[l,s,dim, "aligned-fix"]))
                    ff_out_mediator_db[l,s,dim, "probs_direct"]= torch.exp(log_probs_direct)
                    log_probs_indirect, _, _ = model(data_tuple[0].unsqueeze(0), fix_ff_out_all=torch.tensor(ff_out_mediator_db[l,s,dim, "misaligned-fix"]))
                    ff_out_mediator_db[l,s,dim, "probs_indirect"] = torch.exp(log_probs_indirect)
                
    # calculate the direct and indirect effects using the log probabilities
    for h in range(heads):
        for l in range(depth):
            attn_mediators.append('layer_'+str(l)+'_head_'+str(h))
            direct_result_all_attn['layer_'+str(l)+'_head_'+str(h)] = attn_out_mediator_db[l, h , "probs_direct"] - total_result_all['probs_aligned']
            direct_result_class_attn['layer_'+str(l)+'_head_'+str(h)] = attn_out_mediator_db[l, h , "probs_direct"][0][-1] - total_result_class['probs_aligned']
            indirect_result_all_attn['layer_'+str(l)+'_head_'+str(h)] = attn_out_mediator_db[l, h , "probs_indirect"] - total_result_all['probs_aligned']
            indirect_result_class_attn['layer_'+str(l)+'_head_'+str(h)] = attn_out_mediator_db[l, h , "probs_indirect"][0][-1] - total_result_class['probs_aligned']
    for l in range(depth):
        for s in range(t):
            for dim in range(k):
                ff_mediators.append('layer_'+str(l)+'_seq_'+str(s)+'_dim_'+str(dim))
                direct_result_all_ff['layer_'+str(l)+'_seq_'+str(s)+'_dim_'+str(dim)] = ff_out_mediator_db[l,s,dim, "probs_direct"] - total_result_all['probs_aligned']
                direct_result_class_ff['layer_'+str(l)+'_seq_'+str(s)+'_dim_'+str(dim)] = ff_out_mediator_db[l,s,dim, "probs_direct"][0][-1] - total_result_class['probs_aligned']
                indirect_result_all_ff['layer_'+str(l)+'_seq_'+str(s)+'_dim_'+str(dim)] = ff_out_mediator_db[l,s,dim, "probs_indirect"] - total_result_all['probs_aligned']
                indirect_result_class_ff['layer_'+str(l)+'_seq_'+str(s)+'_dim_'+str(dim)] = ff_out_mediator_db[l,s,dim, "probs_indirect"][0][-1] - total_result_class['probs_aligned']
    return total_result_all, total_result_class , direct_result_all_attn, direct_result_all_ff, direct_result_class_attn, direct_result_class_ff, indirect_result_all_attn, indirect_result_all_ff, indirect_result_class_attn, indirect_result_class_ff, attn_mediators, ff_mediators
               

    
        

In [None]:
# test the function
total_result_all, total_result_class , direct_result_all_attn, direct_result_all_ff, direct_result_class_attn, direct_result_class_ff, indirect_result_all_attn, indirect_result_all_ff, indirect_result_class_attn, indirect_result_class_ff, attn_mediators, ff_mediators = causal_mediation_analysis(data[0], model, heads, depth, k, t)
print(f"attention mediators - number (depth*head): {len(attn_mediators)}")
print(attn_mediators)
print(f"feed forward mediators - number (depth*k*t): {len(ff_mediators)}")
print(ff_mediators)
print("check keys")
print(direct_result_all_attn.keys())
print(direct_result_all_ff.keys())
print("check values (class)")
print(direct_result_class_attn['layer_0_head_0'].shape)
print(direct_result_class_ff['layer_0_seq_0_dim_0'].shape)
print("check values (all)")
print(direct_result_all_attn['layer_0_head_0'].shape)
print(direct_result_all_ff['layer_0_seq_0_dim_0'].shape)




In [None]:
# loop through the whole data, get results
data_total_result_all = []
data_total_result_class = []
data_total_result_class_ff = []
data_direct_result_all_attn = []
data_direct_result_all_ff = []
data_direct_result_class_attn = []
data_direct_result_class_ff = []
data_indirect_result_all_attn = []
data_indirect_result_all_ff = []
data_indirect_result_class_attn = []
data_indirect_result_class_ff = []
for i in range(len(data)):
    total_result_all, total_result_class , direct_result_all_attn, direct_result_all_ff, direct_result_class_attn, direct_result_class_ff, indirect_result_all_attn, indirect_result_all_ff, indirect_result_class_attn, indirect_result_class_ff, attn_mediators, ff_mediators = causal_mediation_analysis(data[i], model, heads, depth, k, t)
    data_total_result_all.append(total_result_all)
    data_total_result_class.append(total_result_class)
    data_direct_result_all_attn.append(direct_result_all_attn)
    data_direct_result_all_ff.append(direct_result_all_ff)
    data_direct_result_class_attn.append(direct_result_class_attn)
    data_direct_result_class_ff.append(direct_result_class_ff)
    data_indirect_result_all_attn.append(indirect_result_all_attn)
    data_indirect_result_all_ff.append(indirect_result_all_ff)
    data_indirect_result_class_attn.append(indirect_result_class_attn)
    data_indirect_result_class_ff.append(indirect_result_class_ff)   


In [None]:
def create_plot_all(effect, title, data):
    flattened_data = []
    for seq_pos in range(effect[0].shape[1]):
        for vocab_pos in range(effect[0].shape[2]):
            for i, sample in enumerate(effect):
                flattened_data.append({
                    'Sample': f'aligned:{data[i][0]} - misaligned:{data[i][1]}',
                    'Sequence Position': seq_pos,
                    'Vocabulary': vocab_pos,
                    'Effect Value': round(float(sample[0, seq_pos, vocab_pos].item()),3)
                })    
    df = pd.DataFrame(flattened_data)    
    fig = px.box(df, x='Sequence Position', y='Effect Value', color='Vocabulary', title=title, points='all', hover_data=['Sample'])
    fig.update_yaxes(range = [-1.0,1.0],tickvals=np.linspace(-1.0, 1.0, 11))
    fig.update_layout(width=1500,height=650,  font=dict(size=18, color='black')) 
    #fig.show()
    fig.write_image(save_fig_path + title + ".png")

In [None]:
def create_plot_class(effect, title, data):
    flattened_data = []    
    for vocab_pos in range(effect[0].shape[0]):
        for i,sample in enumerate(effect):
            flattened_data.append({
                'Sample': f'aligned:{data[i][0]} - misaligned:{data[i][1]}',
                'Sequence Position': t-1,
                'Vocabulary': vocab_pos,
                'Effect Value': round(float(sample[vocab_pos].item()),3)
            })
    
    df = pd.DataFrame(flattened_data)    
    fig = px.box(df, x='Sequence Position', y='Effect Value', color='Vocabulary', title=title, points='all', hover_data=['Sample'])
    fig.update_yaxes(range = [-1.0,1.0],tickvals=np.linspace(-1.0, 1.0, 11))
    fig.update_layout(width=1500,height=650, font=dict(size=18, color='black')) 
    #fig.show()
    fig.write_image(save_fig_path + title + ".png")

In [None]:
def create_plot_class_overview(effect, mediators,  title, data):
    flattened_data = [] 
    for i, sample in enumerate(effect):
        for m in mediators:         
            for vocab_pos in range(sample[m].shape[0]):            
                flattened_data.append({
                    'Sample': f'aligned:{data[i][0]} - misaligned:{data[i][1]}',
                    'Mediator': m,
                    'Vocabulary': vocab_pos,
                    'Effect Value': round(float(sample[m][vocab_pos].item()),3)
                })
    df = pd.DataFrame(flattened_data)    
    fig = px.box(df, x='Mediator', y='Effect Value', color='Vocabulary', title=title, points='all', hover_data=['Sample'])
    fig.update_yaxes(range = [-1.0,1.0],tickvals=np.linspace(-1.0, 1.0, 11))
    fig.update_xaxes(tickangle=60, tickmode='linear', dtick=1)
    fig.update_layout(width=1500,height=650,  font=dict(size=18, color='black')) 
    #fig.show()
    fig.write_image(save_fig_path + title + ".png")

In [None]:
def create_plot_total_result_all(effect, title, data):
    flattened_data = [] 
    for i, sample in enumerate(effect):    
        for seq_pos in range(sample['probs_misaligned'][0].shape[0]):          
            for vocab_pos in range(sample['probs_misaligned'][0].shape[1]):           
                flattened_data.append({
                    'Sample': f'aligned:{data[i][0]} - misaligned:{data[i][1]}',
                    'Sequence Position': seq_pos,
                    'Vocabulary': vocab_pos,
                    'Effect Value': round(float(sample['probs_misaligned'][0][seq_pos][vocab_pos].item())-float(sample['probs_aligned'][0][seq_pos][vocab_pos].item()),3)
                })
    df = pd.DataFrame(flattened_data)    
    fig = px.box(df, x='Sequence Position', y='Effect Value', color='Vocabulary', title=title, points='all', hover_data=['Sample'])
    fig.update_yaxes(range = [-1.0,1.0],tickvals=np.linspace(-1.0, 1.0, 11))
    fig.update_layout(width=1500,height=650,  font=dict(size=18, color='black')) 
    #fig.show()
    fig.write_image(save_fig_path + title + ".png")

In [None]:
def create_plot_total_result_class(effect, class_position, title, data):
    flattened_data = [] 
    for i, sample in enumerate(effect):                 
        for vocab_pos in range(sample['probs_misaligned'].shape[0]):           
            flattened_data.append({
                'Sample': f'aligned:{data[i][0]} - misaligned:{data[i][1]}',
                'Sequence Position': class_position,
                'Vocabulary': vocab_pos,
                'Effect Value': round(float(sample['probs_misaligned'][vocab_pos].item())-float(sample['probs_aligned'][vocab_pos].item()),3)
            })
    df = pd.DataFrame(flattened_data)    
    fig = px.box(df, x='Sequence Position', y='Effect Value', color='Vocabulary', title=title, points='all', hover_data=['Sample'])
    fig.update_yaxes(range = [-1.0,1.0],tickvals=np.linspace(-1.0, 1.0, 11))
    fig.update_layout(width=1500,height=650,  font=dict(size=18, color='black')) 
    #fig.show()
    fig.write_image(save_fig_path + title + ".png")

## Attention Mediators

In [None]:
# Prepare data for analysis of all sequence positions
direct_effects_all = []
indirect_effects_all = []
direct_effects_class = []
indirect_effects_class = []


# Iterate over all mediators
for mediator in attn_mediators:
    direct_effect_values_all = []
    indirect_effect_values_all = []
    direct_effect_values_class = []
    indirect_effect_values_class = []
    
    # Iterate over all test data
    for i in range(len(data)):
        direct_effect_values_all.append(data_direct_result_all_attn[i][mediator])
        indirect_effect_values_all.append(data_indirect_result_all_attn[i][mediator])
        direct_effect_values_class.append(data_direct_result_class_attn[i][mediator])
        indirect_effect_values_class.append(data_indirect_result_class_attn[i][mediator])
    
    # append the effect values to the respective lists
    direct_effects_all.append(direct_effect_values_all)
    indirect_effects_all.append(indirect_effect_values_all)
    direct_effects_class.append(direct_effect_values_class)
    indirect_effects_class.append(indirect_effect_values_class)

# Plot the direct and indirect effects for all attn mediators over all sequence positions
for m, mediator in enumerate(attn_mediators):
    direct_effect_all = direct_effects_all[m]
    indirect_effect_all = indirect_effects_all[m]

    # Plot the direct effects for attn mediators over all sequence positions
    create_plot_all(direct_effect_all, f'Direct Effects (change probability) caused by Attention Mediators (for all sequence positions): {mediator}', data)

    # Plot the indirect effects for attn mediators over all sequence positions
    create_plot_all(indirect_effect_all, f'Indirect Effects (change probability) caused by Attention Mediators (for all sequence positions): {mediator}', data)


    # direct_effect_class = direct_effects_class[m]
    # indirect_effect_class = indirect_effects_class[m]

    # # Plot the direct effects for attn mediators over class sequence position
    # create_plot_class(direct_effect_class, f'Direct Effects (change probability) of Attention Mediator (for class positions): {mediator}', data)

    # # Plot the indirect effects for attn mediators over class sequence position
    # create_plot_class(indirect_effect_class, f'Indirect Effects (change probability) of Attention Mediator (for class positions): {mediator}', data)



# Plot the direct effects for all attn mediators for the class position at once
create_plot_class_overview(data_direct_result_class_attn, attn_mediators, 'Direct Effects (change probability) caused by Attention Mediators (for class position)', data)

# Plot the indirect effects for all attn mediators for the class position at once
create_plot_class_overview(data_indirect_result_class_attn, attn_mediators, 'Indirect Effects (change probability) caused by Attention Mediators (for class position)', data)

# Plot total effect for all sequence positions
create_plot_total_result_all(data_total_result_all, 'Total Effects (change probability) for all sequence positions', data)

# Plot total effect for all sequence positions
create_plot_total_result_class(data_total_result_class, t,  'Total Effects (change probability) for class position', data)





## Feed Forward Mediators

In [None]:
# Prepare data for analysis of all sequence positions
direct_effects_all = []
indirect_effects_all = []
direct_effects_class = []
indirect_effects_class = []


# Iterate over all mediators
for mediator in ff_mediators:
    direct_effect_values_all = []
    indirect_effect_values_all = []
    direct_effect_values_class = []
    indirect_effect_values_class = []
    
    # Iterate over all test data
    for i in range(len(data)):
        direct_effect_values_all.append(data_direct_result_all_ff[i][mediator])
        indirect_effect_values_all.append(data_indirect_result_all_ff[i][mediator])
        direct_effect_values_class.append(data_direct_result_class_ff[i][mediator])
        indirect_effect_values_class.append(data_indirect_result_class_ff[i][mediator])
    
    # append the effect values to the respective lists
    direct_effects_all.append(direct_effect_values_all)
    indirect_effects_all.append(indirect_effect_values_all)
    direct_effects_class.append(direct_effect_values_class)
    indirect_effects_class.append(indirect_effect_values_class)

# Plot the direct and indirect effects for all attn mediators over all sequence positions
for m, mediator in enumerate(ff_mediators):
    direct_effect_all = direct_effects_all[m]
    indirect_effect_all = indirect_effects_all[m]

    # Plot the direct effects for attn mediators over all sequence positions
    create_plot_all(direct_effect_all, f'Direct Effects (change probability) caused by Feed Forward Mediators (for all sequence positions): {mediator}', data)

    # Plot the indirect effects for attn mediators over all sequence positions
    create_plot_all(indirect_effect_all, f'Indirect Effects (change probability) caused by Feed Forward Mediators (for all sequence positions): {mediator}', data)


    # direct_effect_class = direct_effects_class[m]
    # indirect_effect_class = indirect_effects_class[m]

    # # Plot the direct effects for attn mediators over class sequence position
    # create_plot_class(direct_effect_class, f'Direct Effects (change probability) of Attention Mediator (for class positions): {mediator}', data)

    # # Plot the indirect effects for attn mediators over class sequence position
    # create_plot_class(indirect_effect_class, f'Indirect Effects (change probability) of Attention Mediator (for class positions): {mediator}', data)



# Plot the direct effects for all attn mediators for the class position at once
create_plot_class_overview(data_direct_result_class_ff, ff_mediators, 'Direct Effects (change probability) caused by Feed Forward Mediators (for class position)', data)

# Plot the indirect effects for all attn mediators for the class position at once
create_plot_class_overview(data_indirect_result_class_ff, ff_mediators, 'Indirect Effects (change probability) caused by Feed Forward Mediators (for class position)', data)

# Plot total effect for all sequence positions
create_plot_total_result_all(data_total_result_all, 'Total Effects (change probability) for all sequence positions', data)

# Plot total effect for all sequence positions
create_plot_total_result_class(data_total_result_class,t, 'Total Effects (change probability) for class position', data)