In [1]:
%load_ext autoreload
%autoreload 2
import sys

# Hack to avoid some import problem due to the library being a subfolder
try:
    sys.path.append("third_party/TransformerLens")
    import transformer_lens as lens # Some python problem causes this to throw on the first import
except:
    pass

sys.path.append("third_party/TransformerLens")
import transformer_lens as lens # Import TLens from the local copy shipped with this project - It included various bugfixes as well as implementations for Vision-Language (VL) model hooking
import torch
import torch
import numpy as np
import seaborn as sns
import plotly.express as px
from visualization_utils import imshow, line, scatter, multiple_lines
from general_utils import get_tokens, topk_2d
from analysis_utils import load_model, load_dataset, SUPPORTED_TASKS
from modality_alignment_utils import get_image_positions, get_text_sequence_positions
import plotly.io as pio
from collections import defaultdict, OrderedDict
import numpy as np
import plotly.graph_objects as go
from pprint import pprint
from plotly.subplots import make_subplots

torch.set_grad_enabled(False)
device = "cuda"
COLORBLIND_COLORS = ['#0173b2', '#de8f05', '#029e73','#d55e00', '#cc78bc', '#ca9161', '#fbafe4', '#949494', '#ece133', '#56b4e9']


DEFAULT_METRIC = "LD"
MODEL_NAMES = ["qwen2-7b-vl-instruct", "pixtral-12b", "gemma-3-12b-it"]
MODEL_PATHS = {
    MODEL_NAMES[0]: "/PATH_TO_MODELS/models--Qwen--Qwen2-VL-7B-Instruct/snapshots/a7a06a1cc11b4514ce9edcde0e3ca1d16e5ff2fc",
    MODEL_NAMES[1]: "/PATH_TO_MODELS/models--mistral-community--pixtral-12b/snapshots/c2756cbbb9422eba9f6c5c439a214b0392dfc998/",
    MODEL_NAMES[2]: "/PATH_TO_MODELS/models--google--gemma-3-12b-it/snapshots/96b6f1eccf38110c56df3a15bffe176da04bfd80"
}
VISUALIZED_MODEL_NAMES = {
    MODEL_NAMES[0]: "Qwen2-VL-7B",
    MODEL_NAMES[1]: "Pixtral-12B",
    MODEL_NAMES[2]: "Gemma-3-12B"
}

## Faithfulness graphs

In [None]:
# Show node faithfulness graphs

HIGH_FAITH_THRESHOLD = 0.80

def get_first_over_threshold(faiths):
    over_threshold_indices = (faiths > HIGH_FAITH_THRESHOLD).nonzero().view(-1)
    if len(over_threshold_indices) > 0:
        return over_threshold_indices[0].item()
    else:
        return None

show_l_faiths = True
show_vl_faiths = True
show_cross_modality_faiths = False
show_interchange_faiths = True


l_percentages, vl_percentages = [], []
for model_name in MODEL_NAMES:
    for task_name in SUPPORTED_TASKS:
        try:
            print(model_name, task_name)
            faiths_cf = []
            line_titles_cf = []

            if show_l_faiths:
                results_path = f"./data/{task_name}/results/{model_name}/faithfulness_{DEFAULT_METRIC}_l_node_circuit.pt"     
                percentages, faiths_l_cf, faiths_l_mask = torch.load(results_path, weights_only=True)
                l_percentages.append(percentages[get_first_over_threshold(faiths_l_cf.diag())])
                print(f'L percentage with high faith: {l_percentages[-1] :.3f}')
                faiths_cf.append(faiths_l_cf.diag())
                line_titles_cf.append(f"{DEFAULT_METRIC} L-Discover L-Eval CF")
                
            if show_vl_faiths:
                results_path = f"./data/{task_name}/results/{model_name}/faithfulness_{DEFAULT_METRIC}_vl_node_circuit.pt"     
                percentages, faiths_vl_cf, faiths_vl_mask = torch.load(results_path, weights_only=True)
                vl_percentages.append(percentages[get_first_over_threshold(faiths_vl_cf.diag())])
                print(f'VL percentage with high faith: {vl_percentages[-1] :.3f}')
                faiths_cf.append(faiths_vl_cf.diag())
                line_titles_cf.append(f"{DEFAULT_METRIC} VL-Discover VL-Eval CF")

            if show_cross_modality_faiths:
                results_path = f"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_modal_{DEFAULT_METRIC}.pt"
                percentages_cross, faith_discover_l_eval_vl, faith_discover_vl_eval_l = torch.load(results_path, weights_only=True)
                assert percentages == percentages_cross
                faiths_cf.append(faith_discover_l_eval_vl)
                line_titles_cf.append(f"{DEFAULT_METRIC} L-Discover VL-Eval CF")
                faiths_cf.append(faith_discover_vl_eval_l)
                line_titles_cf.append(f"{DEFAULT_METRIC} VL-Discover L-Eval CF")

            if show_interchange_faiths:
                results_path = f"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{DEFAULT_METRIC}.pt"
                results_dict = torch.load(results_path, weights_only=True)
                print(f"Interchange results: ")
                avg = lambda k1, k2: (results_dict[k1] + results_dict[k2]) / 2
                print(f"D Interchange | Random Baseline | Clean result: {avg('DL_QV_LV', 'DV_QL_LL') :.3f} | {avg('DR_QV_LV', 'DR_QL_LL') :.3f} | {avg('DV_QV_LV', 'DL_QL_LL') :.3f}")
                print(f"Q Interchange | Random Baseline | Clean result: {avg('DV_QL_LV', 'DL_QV_LL') :.3f} | {avg('DV_QR_LV', 'DL_QR_LL') :.3f} | {avg('DV_QV_LV', 'DL_QL_LL') :.3f}")
                print(f"L Interchange | Random Baseline | Clean result: {avg('DV_QV_LL', 'DL_QL_LV') :.3f} | {avg('DV_QV_LR', 'DL_QL_LR') :.3f} | {avg('DV_QV_LV', 'DL_QL_LL') :.3f}")

            fig = multiple_lines(
                x=percentages,
                y=faiths_cf,
                line_titles=line_titles_cf,
                title=f"Faithfulness vs Ablation percentage<br>({task_name}, {model_name})",
                width=500,
                show_fig=False
            )
            fig.update_xaxes(title_text="Nodes included in circuit (percent)")
            fig.update_yaxes(title_text="Faithfulness")

        except Exception as e:
            print(f"\n\nError in {task_name} - {model_name}\n\n")
            print(e)
        
print("Average L circuit percentage: ", np.mean(l_percentages))
print("Average VL circuit percentage: ", np.mean(vl_percentages))

## Intersection results

In [None]:
def unify_l_vl(result_dict, key_suffix):
    return round((result_dict[f"l_{key_suffix}"] + result_dict[f"vl_{key_suffix}"]) / 2, 3)

for task_name in SUPPORTED_TASKS:
    for model_name in MODEL_NAMES:
        try:
            result_dict = torch.load(f"./data/{task_name}/results/{model_name}/intersection_results.pt", weights_only=True)
            pprint(f"Model: {model_name}; Task: {task_name}")

            print("Averaged between L and VL SIDES")
            print(f"Unified by pos (MLP,ATTN): {unify_l_vl(result_dict, 'mlp_iou'), unify_l_vl(result_dict, 'head_iou')}; Random Baseline (MLP,ATTN): {unify_l_vl(result_dict, 'mlp_baseline'), unify_l_vl(result_dict, 'head_baseline')}")
            print(f"D pos: {unify_l_vl(result_dict, 'D_neurons_iou'), unify_l_vl(result_dict, 'D_head_iou')}")
            print(f"Q pos: {unify_l_vl(result_dict, 'Q_neurons_iou'), unify_l_vl(result_dict, 'Q_head_iou')}")
            print(f"L pos: {unify_l_vl(result_dict, 'G_neurons_iou'), unify_l_vl(result_dict, 'G_head_iou')}")

        except Exception as e:
            print(f"Model: {model_name}; Task: {task_name}, No intersection results found")
            # print(e)

## Backpatching results

In [None]:
# Present full back-patching heatmaps

for model_name in MODEL_NAMES:
    for task_name in SUPPORTED_TASKS:
        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f"./data/{task_name}/results/{model_name}/backpatching_results.pt", weights_only=False)
        for cfg in backpatching_results.keys():
            if len(cfg) != 2 or cfg[0] is not True:
                continue

            print(model_name, task_name, cfg)
            results = backpatching_results[cfg][0] - backpatching_results['clean_accs'][0]
            results.clamp_(min=-0.1, max=0.1)
            fig = px.imshow(
                results,
                title=f"Back-patching results<br>({task_name}, {model_name})",
                color_continuous_scale="RdBu",
                color_continuous_midpoint=0,
                width=600
            )   
            fig.update_xaxes(
                tickvals=list(range(backpatching_results[cfg][0].shape[1])),
                ticktext=[dst_layer_range[i] for i in range(backpatching_results[cfg][0].shape[1])],
            )
            fig.update_yaxes(
                tickvals=list(range(backpatching_results[cfg][0].shape[0])),
                ticktext=[src_layer_range[i] for i in range(backpatching_results[cfg][0].shape[0])],
            )
            fig.show()    

In [None]:
# Observe full backpatching results for individual (task, model) pair

task_name = SUPPORTED_TASKS[1]
model_name = MODEL_NAMES[0]

control_results = {}
for model_name in MODEL_NAMES:
    for task_name in SUPPORTED_TASKS:
        print(model_name, task_name)

        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f"./data/{task_name}/results/{model_name}/backpatching_results.pt", weights_only=False)
        print(backpatching_results['clean_accs'])

        # Get the top-10 backpatching results across all settings and src->dst options
        k = 10
        top_results = []
        for cfg in backpatching_results.keys():
            if len(cfg) != 2:
                continue
            if cfg[0] is not True:
                # Ignore setting where data positions are processed post back-patching
                continue
            (top_h_indices, top_w_indices), top_accs = topk_2d(backpatching_results[cfg][0], k)
            for top_h_index, top_w_index, top_acc in zip(top_h_indices, top_w_indices, top_accs):
                top_results.append(cfg + (src_layer_range[top_h_index], dst_layer_range[top_w_index], top_acc))

        top_results = sorted(top_results, key=lambda x: x[-1], reverse=True)[:k]
        pprint([f"Repeat Processing={r[0]}; Layer window size={r[1]}; Layers={r[2]}->{r[3]}; Acc={r[-1].item() :.3f}" for r in top_results])


        # Comparing to control (L->L backpatching; Should (hopefully) lead to a smaller improvement)
        for cfg in backpatching_results.keys():
            if len(cfg) != 2:
                continue
            if cfg[0] is not True:
                # Ignore setting where data positions are processed post back-patching
                continue

            # Remove non-valid settings (i.e. dst >= src) and subtract clean accuracies
            backpatching_diffs = backpatching_results[cfg][0].view(-1)[(backpatching_results[cfg][0].view(-1) > 0)] - backpatching_results["clean_accs"][0]
            control_backpatching_diffs = backpatching_results[cfg][1].view(-1)[(backpatching_results[cfg][1].view(-1) > 0)] - backpatching_results["clean_accs"][1]

            bp_better_than_control_percent = (backpatching_diffs >= control_backpatching_diffs).float().mean()
            control_results[(model_name, task_name, cfg[1])] = bp_better_than_control_percent
            print(f"{cfg}: V Backpatching gets stronger boost in {bp_better_than_control_percent :.3f}% of the cases")
            
            best_backpatching_increase = backpatching_diffs.max()
            best_control_increase = control_backpatching_diffs.max()
            print(task_name, model_name, cfg, best_backpatching_increase, best_control_increase)

print(control_results)
print(f'BP is higher than control (without maxing for best model-task setting) in {np.mean(list(control_results.values())) :.3f} of the cases')

In [None]:
# Present BP best results across models and tasks

print('Model\t\tTask name\t\tV Acc\t\tL Acc\t\tBackpatching-induced Acc')
relative_diffs = []
for model_name in MODEL_NAMES:
    for task_name in SUPPORTED_TASKS:
        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f"./data/{task_name}/results/{model_name}/backpatching_results.pt", weights_only=False)
        k = 1
        top_results = []
        for cfg in backpatching_results.keys():
            if len(cfg) != 2:
                continue
            (top_h_indices, top_w_indices), top_accs = topk_2d(backpatching_results[cfg][0], k)
            for top_h_index, top_w_index, top_acc in zip(top_h_indices, top_w_indices, top_accs):
                top_results.append(cfg + (src_layer_range[top_h_index], dst_layer_range[top_w_index], top_acc))
        top_results = sorted(top_results, key=lambda x: x[-1], reverse=True)[:k]
        bp_best_acc = top_results[0][-1].item()
        clean_v, clean_l = backpatching_results['clean_accs']
        relative_diff = (bp_best_acc - clean_v) / (clean_l - clean_v)
        if 0 < relative_diff <= 1.0:
            relative_diffs.append(relative_diff)
        print(f"{model_name[:4]}\t\t{task_name[:10]}\t\t{clean_v :.3f}\t\t{clean_l :.3f}\t\t{bp_best_acc :.3f} ({relative_diff :.3f})")

print("Average relative diff: ", np.mean(relative_diffs))

# Figures

#### Circuit Discovery

In [None]:
# Draw the attribution scores per layer per position (summed across components)

fs = 28 # Font size

task_name = SUPPORTED_TASKS[0]
model_name = MODEL_NAMES[0]

for modality in ['l', 'vl']:
    scores = torch.load(f'./data/{task_name}/results/{model_name}/node_scores/nap_ig_{modality}_ig=5_metric=LD.pt', weights_only=True)
    scores = {k: v.abs() for (k, v) in scores.items()}
    n_layers = len([k for k in scores.keys() if 'mlp.hook_post' in k])
    seq_len = scores[list(scores.keys())[0]].shape[0]

    summed_scores_per_layer_per_pos = torch.zeros(n_layers, seq_len)
    for layer in range(n_layers):
        mlp_hook_key = f'blocks.{layer}.mlp.hook_post'
        attn_hook_key = f'blocks.{layer}.attn.hook_z'
        for pos in range(seq_len):
            summed_scores_per_layer_per_pos[layer, pos] = scores[mlp_hook_key][pos].sum() + scores[attn_hook_key][pos].sum()

    if modality == 'l':
        start_of_data = get_text_sequence_positions(model_name, task_name)[0]
        summed_scores_per_layer_per_pos = summed_scores_per_layer_per_pos[:, start_of_data:]
        fig = px.imshow(summed_scores_per_layer_per_pos, color_continuous_scale="Blues")
        
        tickvals = [20 - start_of_data, 30 - start_of_data, seq_len - 1 - start_of_data]
        ticktext = ['Data (Text)', 'Query', 'Generation']
        fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, title=dict(text='Position', font=dict(size=fs)), tickfont=dict(size=fs - 4))
        fig.update_yaxes(tickvals=list(range(5, n_layers, 5)), title=dict(text='Layer', font=dict(size=fs)), tickfont=dict(size=fs - 4))
        fig.update_layout(
            title=dict(text="Textual Task Patching Effects", font=dict(size=fs), x=0.5, y=0.99),
            width=530,
            xaxis_tickangle=0,
            margin=dict(l=0, r=0, t=30, b=0),  # Remove margins
            coloraxis_showscale=False  # Hide the colorbar
        )
    else:
        start_of_data = get_image_positions(model_name, task_name)[0]
        summed_scores_per_layer_per_pos = summed_scores_per_layer_per_pos[:, start_of_data:]
        fig = px.imshow(summed_scores_per_layer_per_pos, color_continuous_scale="Blues")

        tickvals = [70 - start_of_data, 100 - start_of_data, seq_len - 1 - start_of_data]
        ticktext = ['Data (Image)', 'Query', 'Generation']
        fig.update_xaxes(tickvals=tickvals, ticktext=ticktext, title=dict(text='Position', font=dict(size=fs)), tickfont=dict(size=fs - 4))
        fig.update_yaxes(tickvals=list(range(5, n_layers, 5)), title=dict(text='Layer', font=dict(size=fs)), tickfont=dict(size=fs - 4))
        fig.update_layout(
            title=dict(text="Visual Task Patching Effects", font=dict(size=fs), x=0.5, y=0.99),
            width=1200,
            xaxis_tickangle=0,
            coloraxis_colorbar=dict(
                thickness=25,  # Adjust the thickness of the color bar
                len=1.05       # Adjust the length of the color bar
            ),
            margin=dict(l=0, r=0, t=30, b=0),  # Remove margins
            coloraxis_showscale=True
        )
    fig.show()

    # pio.write_image(fig, f"./figures/{model_name}_{task_name}_attr_scores_per_layer_per_pos_{modality}.png")
    # pio.write_image(fig, f"./figures/{model_name}_{task_name}_attr_scores_per_layer_per_pos_{modality}.pdf")

In [None]:
fs = 28  # Font size

task_name = SUPPORTED_TASKS[0]

for task_name in SUPPORTED_TASKS:
    for model_name in MODEL_NAMES:
        fig = make_subplots(
            rows=1, cols=2,
            column_widths=[0.4, 0.6],
            subplot_titles=["Textual Task Patching Effects", "Visual Task Patching Effects"],
        )

        for col, modality in enumerate(['l', 'vl'], start=1):
            scores = torch.load(f'./data/{task_name}/results/{model_name}/node_scores/nap_ig_{modality}_ig=5_metric=LD.pt', weights_only=True)
            scores = {k: v.abs() for (k, v) in scores.items()}
            n_layers = len([k for k in scores.keys() if 'mlp.hook_post' in k])
            seq_len = scores[list(scores.keys())[0]].shape[0]

            summed_scores_per_layer_per_pos = torch.zeros(n_layers, seq_len)
            for layer in range(n_layers):
                mlp_hook_key = f'blocks.{layer}.mlp.hook_post'
                attn_hook_key = f'blocks.{layer}.attn.hook_z'
                for pos in range(seq_len):
                    summed_scores_per_layer_per_pos[layer, pos] = scores[mlp_hook_key][pos].sum() + scores[attn_hook_key][pos].sum()

            if modality == 'l':
                start_of_data, end_of_data = get_text_sequence_positions(model_name, task_name)
                summed_scores_per_layer_per_pos = summed_scores_per_layer_per_pos[:, start_of_data:]
                tickvals = [end_of_data]
                ticktext = ['Q']
                coloraxis = "coloraxis1"
            else:
                start_of_data, end_of_data = get_image_positions(model_name, task_name)
                summed_scores_per_layer_per_pos = summed_scores_per_layer_per_pos[:, start_of_data:]
                tickvals = [end_of_data]
                ticktext = ['Q']
                coloraxis = "coloraxis2"
            fig.update_annotations(font_size=fs)
            fig.add_trace(
                go.Heatmap(
                    z=summed_scores_per_layer_per_pos.numpy(),
                    coloraxis=coloraxis,
                ),
                row=1, col=col
            )

            fig.update_xaxes(
                tickvals=tickvals,
                ticktext=ticktext,
                tickfont=dict(size=fs - 4),
                row=1, col=col
            )
            fig.update_yaxes(
                tickvals=list(range(5, n_layers, 5)),
                title=dict(text="Layer", font=dict(size=fs)),
                tickfont=dict(size=fs - 4),
                row=1, col=col
            )

        fig.update_layout(
            coloraxis1=dict(colorscale="Blues", showscale=False),
            coloraxis2=dict(colorscale="Blues", colorbar=dict(thickness=25, len=1.05)),
            width=1600,
            height=600,
            margin=dict(l=0, r=0, t=50, b=0),
        )

        print(model_name, task_name)
        fig.show()
        pio.write_image(fig, f"./figures/appendix_heatmaps/{model_name}_{task_name}.png")

#### Faithfulness

In [None]:
fs = 20  # Font size
metric = DEFAULT_METRIC

fig = make_subplots(
    rows=1, cols=len(MODEL_NAMES),
    shared_yaxes=True,
    subplot_titles=[VISUALIZED_MODEL_NAMES[model_name] for model_name in MODEL_NAMES],
)

for col, model_name in enumerate(MODEL_NAMES, start=1):
    faiths_cf = []
    line_titles_cf = []

    for task_name in SUPPORTED_TASKS:
        results_path = f"./data/{task_name}/results/{model_name}/faithfulness_{metric}_l_node_circuit.pt"
        percentages, faiths_l_cf, faiths_l_mask = torch.load(results_path, weights_only=True)
        faiths_cf.append(faiths_l_cf.diag()[1:])
        line_titles_cf.append(f"{task_name.replace('_', ' ').capitalize()}")

        results_path = f"./data/{task_name}/results/{model_name}/faithfulness_{metric}_vl_node_circuit.pt"
        percentages, faiths_vl_cf, faiths_vl_mask = torch.load(results_path, weights_only=True)
        faiths_cf.append(faiths_vl_cf.diag()[1:])
        line_titles_cf.append("")  # Empty title for VL faithfulness

    for i, (faith, title) in enumerate(zip(faiths_cf, line_titles_cf)):
        line_style = 'solid' if title != '' else 'dot'
        fig.add_trace(
            go.Scatter(
                x=percentages[1:], y=faith.numpy(),
                mode='lines', 
                line=dict(dash=line_style, color=COLORBLIND_COLORS[i // 2])
            ),
            row=1, col=col
        )

fig.update_layout(
    width=1100,
    height=300,
    showlegend=False,
    legend=dict(
        orientation="h",
        yanchor="bottom",
        y=-0.9,
        xanchor="center",
        x=0.5
    ),
    margin=dict(l=10, r=10, t=30, b=30),  # Adjust margins
    xaxis=dict(domain=[0.0, 0.32]),  # Reduce spacing between subplots
    xaxis2=dict(domain=[0.34, 0.65]),
    xaxis3=dict(domain=[0.67, 1.0]),
    annotations=[
        dict(
            font=dict(size=fs),  # Increase font size for subplot titles
            showarrow=False,
            text=annotation['text'],
            x=annotation['x'],
            xanchor='center',
            xref=annotation['xref'],
            y=annotation['y'],
            yanchor=annotation['yanchor'],
            yref=annotation['yref']
        )
        if 'text' in annotation else annotation
        for annotation in fig['layout']['annotations']
    ],
)

for col in range(1, len(MODEL_NAMES) + 1):
    fig.update_xaxes(title=dict(text="Circuit size (% Components)", font=dict(size=fs)), type="log", row=1, col=col, tickvals=[0.01, 0.1, 1], tickfont=dict(size=fs - 4))
    if col == 1:
        fig.update_yaxes(title=dict(text="Faithfulness", font=dict(size=fs)), row=1, col=col, tickvals=[0, 0.2, 0.4, 0.6, 0.8, 1.0], tickfont=dict(size=fs - 4))
        fig.add_hline(y=0.8, line_color="black", line_dash="dot", annotation_font=dict(size=fs-7), annotation_text="High faithfulness<br>threshold", annotation_position="top left", row=1, col=col)


    else:
        fig.update_yaxes(showticklabels=False, row=1, col=col)
        fig.add_hline(y=0.8, line_color="black", line_dash="dot", annotation_text="", annotation_position="top left", row=1, col=col)

fig.show()

pio.write_image(fig, f"./figures/faithfulness-all-models-and-tasks.pdf")

#### Intersections

In [None]:
def unify_l_vl(result_dict, key_suffix):
    return round((result_dict[f"l_{key_suffix}"] + result_dict[f"vl_{key_suffix}"]) / 2, 3)


SPLIT_DQG = True
NO_POS_IN_DATA = True
nps = "_no_pos" if NO_POS_IN_DATA else "" # nps == no pos suffix

W = 0.4
x_vals_all = [[] for _ in MODEL_NAMES]  # to collect x-values per subplot


fig = make_subplots(rows=1, cols=len(MODEL_NAMES), subplot_titles=[m.capitalize() for m in MODEL_NAMES])

for model_idx, model_name in enumerate(MODEL_NAMES):
    task_intersections, baseline_intersections = [], []
    for task_name in SUPPORTED_TASKS:
        # Simulate loading data - replace with your actual loading
        intersections_dict = torch.load(f"./data/{task_name}/results/{model_name}/intersection_results.pt", weights_only=True)
        D_iou = (unify_l_vl(intersections_dict, f'D_neurons_iou{nps}') + unify_l_vl(intersections_dict, f'D_head_iou{nps}')) / 2
        Q_iou = (unify_l_vl(intersections_dict, 'Q_neurons_iou') + unify_l_vl(intersections_dict, 'Q_head_iou')) / 2
        G_iou = (unify_l_vl(intersections_dict, 'G_neurons_iou') + unify_l_vl(intersections_dict, 'G_head_iou')) / 2

        if SPLIT_DQG:
            task_intersections.append((D_iou, Q_iou, G_iou))
        else:
            avg_iou = (D_iou + Q_iou + G_iou) / 3
            task_intersections.append(avg_iou)

        D_baseline = (unify_l_vl(intersections_dict, f'D_neurons_baseline{nps}') + unify_l_vl(intersections_dict, f'D_head_baseline{nps}')) / 2
        Q_baseline = (unify_l_vl(intersections_dict, 'Q_neurons_baseline') + unify_l_vl(intersections_dict, 'Q_head_baseline')) / 2
        G_baseline = (unify_l_vl(intersections_dict, 'G_neurons_baseline') + unify_l_vl(intersections_dict, 'G_head_baseline')) / 2
        if SPLIT_DQG:
            baseline_intersections.append((D_baseline, Q_baseline, G_baseline))
        else:
            avg_baseline = (D_baseline + Q_baseline + G_baseline) / 3
            baseline_intersections.append(avg_baseline)

    for i, task_name in enumerate(SUPPORTED_TASKS):
        presented_task_name = task_name.replace('_', '<br>').capitalize()
        if SPLIT_DQG:
            for j, dqg in enumerate(['D', 'Q', 'G']):
                fig.add_trace(go.Bar(
                    x=[f"{presented_task_name} {dqg} I"],
                    y=[task_intersections[i][j]],
                    marker_color='green',
                    name=f"{task_name} {dqg} Intersection",
                    showlegend=False,
                    width=W,
                ), row=1, col=model_idx + 1)

                fig.add_trace(go.Bar(
                    x=[f"{presented_task_name} {dqg} B"],
                    y=[baseline_intersections[i][j]],
                    marker_color='red',
                    marker_pattern_shape='x',
                    name=f"{task_name} {dqg} Baseline",
                    showlegend=False,
                    width=W,
                ), row=1, col=model_idx + 1)
        else:
            x_inter = f"{presented_task_name}"
            x_base = f"{presented_task_name} Base"
            x_vals_all[model_idx].extend([x_inter, x_base])

            fig.add_trace(go.Bar(
                x=[x_inter],
                y=[task_intersections[i]],
                marker_color='green',
                name=f"{task_name} Intersection",
                showlegend=False,
                width=W,
            ), row=1, col=model_idx + 1)
            fig.add_trace(go.Bar(
                x=[x_base],
                y=[baseline_intersections[i]],
                marker_color='red',
                marker_pattern_shape="x",
                name=f"{task_name} Baseline",
                showlegend=False,
                width=W,
            ), row=1, col=model_idx + 1)

# Update y-axis visibility
for i in range(1, len(MODEL_NAMES) + 1):
    if i > 1:
        fig.update_yaxes(showticklabels=False, range=[0, 1.05], row=1, col=i)
    else:
        fig.update_yaxes(title_text="IoU", tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], range=[0, 1.05], row=1, col=i)


# Update x-axes: hide "Base" bar labels
for i in range(1, len(MODEL_NAMES) + 1):
    tickvals = x_vals_all[i - 1]
    ticktext = [label if "Base" not in label else "" for label in tickvals]
    fig.update_xaxes(
        tickvals=tickvals,
        ticktext=ticktext,
        row=1,
        col=i,
    )

fig.add_hline(y=1.0, line_dash="dot", line_color="black", row=1, col='all')

# Update bar width
fig.update_traces(width=0.8)

# Update layout
fig.update_layout(
    barmode='group',
    width=1200,
    height=400,
)
fig.update_layout(
    margin=dict(l=10, r=0, t=20, b=0),
    width=1000,
    height=400,
    xaxis=dict(domain=[0.0, 0.32]),
    xaxis2=dict(domain=[0.34, 0.65]),
    xaxis3=dict(domain=[0.67, 0.99])
)
fig.show()

pio.write_image(fig, f"./figures/intersections_DQG={SPLIT_DQG}.pdf")

In [None]:
def unify_l_vl(result_dict, key_suffix):
    return round((result_dict[f"l_{key_suffix}"] + result_dict[f"vl_{key_suffix}"]) / 2, 3)


SPLIT_DQG = True
NO_POS_IN_DATA = True
nps = "_no_pos" if NO_POS_IN_DATA else "" # nps == no pos suffix

W = 0.4
x_vals_all = [[] for _ in MODEL_NAMES]  # to collect x-values per subplot


fig = make_subplots(rows=1, cols=len(MODEL_NAMES), subplot_titles=[m.capitalize() for m in MODEL_NAMES])

for model_idx, model_name in enumerate(MODEL_NAMES):
    task_intersections, baseline_intersections = [], []
    for task_name in SUPPORTED_TASKS:
        # Simulate loading data - replace with your actual loading
        intersections_dict = torch.load(f"./data/{task_name}/results/{model_name}/intersection_results.pt", weights_only=True)
        D_iou = (unify_l_vl(intersections_dict, f'D_neurons_iou{nps}') + unify_l_vl(intersections_dict, f'D_head_iou{nps}')) / 2
        Q_iou = (unify_l_vl(intersections_dict, 'Q_neurons_iou') + unify_l_vl(intersections_dict, 'Q_head_iou')) / 2
        G_iou = (unify_l_vl(intersections_dict, 'G_neurons_iou') + unify_l_vl(intersections_dict, 'G_head_iou')) / 2

        if SPLIT_DQG:
            task_intersections.append((D_iou, Q_iou, G_iou))
        else:
            avg_iou = (D_iou + Q_iou + G_iou) / 3
            task_intersections.append(avg_iou)

        D_baseline = (unify_l_vl(intersections_dict, f'D_neurons_baseline{nps}') + unify_l_vl(intersections_dict, f'D_head_baseline{nps}')) / 2
        Q_baseline = (unify_l_vl(intersections_dict, 'Q_neurons_baseline') + unify_l_vl(intersections_dict, 'Q_head_baseline')) / 2
        G_baseline = (unify_l_vl(intersections_dict, 'G_neurons_baseline') + unify_l_vl(intersections_dict, 'G_head_baseline')) / 2
        if SPLIT_DQG:
            baseline_intersections.append((D_baseline, Q_baseline, G_baseline))
        else:
            avg_baseline = (D_baseline + Q_baseline + G_baseline) / 3
            baseline_intersections.append(avg_baseline)

    normalized_intersections = (torch.tensor(task_intersections) - torch.tensor(baseline_intersections) / 1.0 - torch.tensor(baseline_intersections)).clamp(0.01, 1)
    print('Mean across tasks: ', normalized_intersections.mean(dim=0))
    print('Mean across positions: ', normalized_intersections.mean(dim=1))
    for i, task_name in enumerate(SUPPORTED_TASKS):
        presented_task_name = task_name.replace('_', '<br>').capitalize()
        if SPLIT_DQG:
            x_inter = f"{presented_task_name}"
            x_vals_all[model_idx].extend([x_inter])
            for j, dqg in enumerate(['D', 'Q', 'G']):
                normalized_intersection = normalized_intersections[i][j].item()
                fig.add_trace(go.Bar(
                    x=[x_inter + dqg],
                    y=[normalized_intersection],
                    marker_color=COLORBLIND_COLORS[j],
                    name=f"{task_name} {dqg} Intersection",
                    showlegend=False,
                    width=W,
                ), row=1, col=model_idx + 1)
        else:
            x_inter = f"{presented_task_name}"
            x_vals_all[model_idx].extend([x_inter])
            normalized_intersection = normalized_intersections[i].item()
            fig.add_trace(go.Bar(
                x=[x_inter],
                y=[normalized_intersection],
                marker_color='green',
                name=f"{task_name} Intersection",
                showlegend=False,
                width=W,
            ), row=1, col=model_idx + 1)

# Update y-axis visibility
for i in range(1, len(MODEL_NAMES) + 1):
    fig.update_yaxes(range=[0, 1.05], row=1, col=i)
    if i > 1:
        fig.update_yaxes(showticklabels=False, tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], row=1, col=i)
    else:
        fig.update_yaxes(title_text="Normalized IoU", tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], row=1, col=i)


for i in range(1, len(MODEL_NAMES) + 1):
    tickvals = x_vals_all[i - 1]
    ticktext = [label for label in tickvals]
    fig.update_xaxes(
        tickvals=tickvals,
        ticktext=ticktext,
        row=1,
        col=i,
    )

fig.add_hline(y=1.0, line_dash="dot", line_color="black", row=1, col='all')

# Update bar width
fig.update_traces(width=0.8)

# Update layout
fig.update_layout(
    barmode='group',
    margin=dict(l=10, r=0, t=20, b=5),
    width=1000,
    height=200,
    xaxis=dict(domain=[0.0, 0.32]),
    xaxis2=dict(domain=[0.34, 0.65]),
    xaxis3=dict(domain=[0.67, 0.99]),
)
fig.show()

pio.write_image(fig, f"./figures/intersections_normalized_DQG={SPLIT_DQG}.pdf")

#### Interchange

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np


metric = DEFAULT_METRIC
model_name = MODEL_NAMES[0]
BASELINE_KEY, VALUE_KEY, SKYLINE_KEY = 'Random Components (Baseline)', 'Modality Switch Faithfulness', 'Clean Circuit Faithfulness'
avg = lambda k1, k2: ((results_dict[k1] + results_dict[k2]) / 2).item()
show_splits = "DQG"

tasks = []
values = []
types = []
if "D" in show_splits:
    for task_name in SUPPORTED_TASKS:
        tasks += [task_name.replace('_', ' ').capitalize() + " D"] * 3
        types += [BASELINE_KEY, VALUE_KEY, SKYLINE_KEY]
        results_dict = torch.load(f"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{metric}.pt", weights_only=True)
        values += [avg('DR_QV_LV', 'DR_QL_LL'), avg('DL_QV_LV', 'DV_QL_LL'), avg('DV_QV_LV', 'DL_QL_LL')]

if "Q" in show_splits:
    for task_name in SUPPORTED_TASKS:
        tasks += [task_name.replace('_', ' ').capitalize() + " Q"] * 3 # 3 for skyline, values, baseline
        types += [BASELINE_KEY, VALUE_KEY, SKYLINE_KEY]
        results_dict = torch.load(f"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{metric}.pt", weights_only=True)
        values += [avg('DV_QR_LV', 'DL_QR_LL'), avg('DV_QL_LV', 'DL_QV_LL'), avg('DV_QV_LV', 'DL_QL_LL')]

if "G" in show_splits:
    for task_name in SUPPORTED_TASKS:
        tasks += [task_name.replace('_', ' ').capitalize() + " G"] * 3 # 3 for skyline, values, baseline
        types += [BASELINE_KEY, VALUE_KEY, SKYLINE_KEY]
        results_dict = torch.load(f"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{metric}.pt", weights_only=True)
        values += [avg('DV_QV_LR', 'DL_QL_LR'), avg('DV_QV_LL', 'DL_QL_LV'), avg('DV_QV_LV', 'DL_QL_LL')]


data = pd.DataFrame({
    'Task': tasks,
    'Value': values,
    'Type': types
})

# Define colors
colors = {BASELINE_KEY: 'lightgray', VALUE_KEY: 'steelblue', SKYLINE_KEY: 'lightblue'}

# Define patterns for baselines
patterns = {BASELINE_KEY: '//', SKYLINE_KEY: '\\\\'}

# Create the barplot
fig, ax = plt.subplots(figsize=(8, 6))

sns.barplot(x='Task', y='Value', hue='Type', data=data, palette=colors, dodge=True, ax=ax)


# Set y-axis limits
ax.set_ylim(0, 1)

# Remove duplicate x-axis labels
handles, labels = ax.get_legend_handles_labels()
new_handles = []
new_labels = []
seen_labels = set()
for handle, label in zip(handles, labels):
    if label not in seen_labels:
        new_handles.append(handle)
        new_labels.append(label)
        seen_labels.add(label)
ax.legend(new_handles, new_labels, loc='lower center', bbox_to_anchor=(0.5, -0.12), ncol=3, frameon=False, fontsize='medium')

ax.set_xlabel('', fontsize='large')
ax.set_ylabel('Circuit Faithfulness', fontsize='large')
ax.tick_params(axis='x', labelrotation=45)
plt.title(f'Faithfulness when patching {" / ".join(show_splits)} sub-circuits', fontsize='large')
# plt.tight_layout()
plt.show()  

# Save the plot as a PDF
fig.savefig(f"./figures/{MODEL_NAMES[0]}_interchange_faithfulness.pdf", bbox_inches='tight')

In [None]:
metric = DEFAULT_METRIC
model_name = MODEL_NAMES[0]
avg = lambda k1, k2: ((results_dict[k1] + results_dict[k2]) / 2).item()

W = 0.4
x_vals_all = [[] for _ in MODEL_NAMES]  # to collect x-values per subplot
fig = make_subplots(rows=1, cols=len(MODEL_NAMES), subplot_titles=[VISUALIZED_MODEL_NAMES[model_name] for model_name in MODEL_NAMES])

for model_idx, model_name in enumerate(MODEL_NAMES):
    task_intersections, baseline_intersections = [], []

    scores = []
    for task_name in SUPPORTED_TASKS:
        try:
            results_dict = torch.load(f"./data/{task_name}/results/{model_name}/faithfulness_nodes_cross_interchanges_{metric}.pt", weights_only=True)
            d_baseline, d_skyline, d_value = avg('DR_QV_LV', 'DR_QL_LL'), avg('DV_QV_LV', 'DL_QL_LL'), avg('DL_QV_LV', 'DV_QL_LL')
            q_baseline, q_skyline, q_value = avg('DV_QR_LV', 'DL_QR_LL'),           avg('DV_QV_LV', 'DL_QL_LL'), avg('DV_QL_LV', 'DL_QV_LL'),
            g_baseline, g_skyline, g_value = avg('DV_QV_LR', 'DL_QL_LR'),           avg('DV_QV_LV', 'DL_QL_LL'), avg('DV_QV_LL', 'DL_QL_LV')
            scores.append([(d_value - d_baseline) / (d_skyline - d_baseline), (q_value - q_baseline) / (q_skyline - q_baseline), (g_value - g_baseline) / (g_skyline - g_baseline)])
        except Exception as e:
            print(e)
            scores.append(0)
    print(scores)
    scores = torch.tensor(scores).clamp(min=0.0)

    for i, task_name in enumerate(SUPPORTED_TASKS):
        presented_task_name = task_name.replace('_', '<br>').capitalize()
        x_inter = f"{presented_task_name}"
        x_vals_all[model_idx].extend([x_inter])
        for j, dqg in enumerate(['D', 'Q', 'G']):
            fig.add_trace(go.Bar(
                x=[x_inter + dqg],
                y=[scores[i][j].item()],
                marker_color=COLORBLIND_COLORS[j],
                name=f"{task_name} {dqg}",
                showlegend=False,
                width=W,
            ), row=1, col=model_idx + 1)
        

# Update y-axis visibility
for i in range(1, len(MODEL_NAMES) + 1):
    fig.update_yaxes(range=[0, 1.05], row=1, col=i)
    if i > 1:
        fig.update_yaxes(showticklabels=False, tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], row=1, col=i)
    else:
        fig.update_yaxes(title_text="Interchange Faithfulness", tickvals=[0.2, 0.4, 0.6, 0.8, 1.0], row=1, col=i)


for i in range(1, len(MODEL_NAMES) + 1):
    tickvals = x_vals_all[i - 1]
    ticktext = [label for label in tickvals]
    fig.update_xaxes(
        tickvals=tickvals,
        ticktext=ticktext,
        row=1,
        col=i,
    )

fig.add_hline(y=1.0, line_dash="dot", line_color="black", row=1, col='all')

# Update bar width
fig.update_traces(width=0.8)

# Update layout
fig.update_layout(
    barmode='group',
    margin=dict(l=10, r=0, t=20, b=5),
    width=1000,
    height=200,
    xaxis=dict(domain=[0.0, 0.32]),
    xaxis2=dict(domain=[0.34, 0.65]),
    xaxis3=dict(domain=[0.67, 0.99]),
)
fig.show()

pio.write_image(fig, f"./figures/interchange_faithfulness_normalized.pdf")

#### Backpatching

In [None]:
# Show pre-backpatching similarities of visual image patches to text tokens from the parallel textual sequences

all_model_similarities = {(model_name, task_name): torch.load(f"./data/{task_name}/results/{model_name}/similarities_of_vl_activations_to_text_seq_tokens_k=0.05_use_unembed=True.pt", weights_only=True).cpu() for model_name in MODEL_NAMES for task_name in SUPPORTED_TASKS}

# Show all tasks of a specific model on one plot
model_sims = {model_name: torch.stack([all_model_similarities[(model_name, task_name)] for task_name in SUPPORTED_TASKS]) for model_name in MODEL_NAMES}
for model_name in MODEL_NAMES:
    fig = go.Figure()
    for i, task_name in enumerate(SUPPORTED_TASKS):
        fig.add_scatter(x=list(range(model_sims[model_name].shape[1])), y=model_sims[model_name][i].numpy(), mode='lines', name=task_name.replace('_', ' ').capitalize(), line=dict(color=COLORBLIND_COLORS[i]))
    fig.update_layout(
        title=f'Similarity of image patch activations with text token unembeddings<br>({model_name})',
        xaxis_title='Layer',
        yaxis_title='Similarity',
        width=600,
        legend=dict(
            orientation="h",
            yanchor="bottom",
            y=-0.3,
            xanchor="center",
            x=0.5
        )
    )
    fig.update_yaxes(range=[0, 0.3])
    fig.show()

    pio.write_image(fig, f"./figures/{model_name}_similarities_to_text_seq_tokens.pdf")

In [None]:
# Print looped results
fs = 24

looped_results = defaultdict(lambda: [0] * 10)
for model_name in MODEL_NAMES:
    for task_name in SUPPORTED_TASKS:
        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f"./data/{task_name}/results/{model_name}/backpatching_results.pt", weights_only=False)
        for cfg in backpatching_results.keys():
            if len(cfg) != 3:
                # Take only looped results (have 3 keys)
                continue
            looped_results[(model_name, task_name)][cfg[2]] = backpatching_results[cfg]

looped_results = sorted(looped_results.items())

fig = make_subplots(
    rows=1, cols=len(MODEL_NAMES),
    subplot_titles=[VISUALIZED_MODEL_NAMES[model_name] for model_name in MODEL_NAMES],
    shared_yaxes=True
)

task_colors = {task_name: COLORBLIND_COLORS[i] for i, task_name in enumerate(SUPPORTED_TASKS)}

for col, model_name in enumerate(MODEL_NAMES, start=1):
    for idx, ((looped_model_name, task_name), results) in enumerate(looped_results):
        if looped_model_name != model_name:
            continue
        fig.add_trace(
            go.Scatter(
                x=list(range(len(results))),
                y=results,
                mode='lines',
                name=f"{task_name}",
                line=dict(color=task_colors[task_name])
            ),
            row=1, col=col
        )
fig.update_layout(
    margin=dict(l=40, r=40, t=40, b=40),
    width=1000,
    height=300,
    showlegend=False
)

for col, model_name in enumerate(MODEL_NAMES, start=1):
    fig.update_xaxes(title=dict(text="Back-patching iteration", font=dict(size=fs)), row=1, col=col, tickfont=dict(size=fs-4))
    fig.update_yaxes(title=dict(text="Accuracy", font=dict(size=fs)), row=1, col=col, tickfont=dict(size=fs-4))
    fig.update_annotations(font_size=fs)

fig.show()

pio.write_image(fig, f"./figures/iterative-backpatching-results.pdf")

In [None]:
# Control vs Normal backpatching results

ordered_keys = []
control_results = OrderedDict()
for model_name in MODEL_NAMES:
    for task_name in SUPPORTED_TASKS:
        backpatching_results, src_layer_range, dst_layer_range  = torch.load(f"./data/{task_name}/results/{model_name}/backpatching_results.pt", weights_only=False)
        # Comparing to control (L->L backpatching; Should (hopefully) lead to a smaller improvement)
        for cfg in backpatching_results.keys():
            if len(cfg) != 2:
                continue
            if cfg[0] is not True:
                continue

            # Remove non-valid settings (i.e. dst >= src) and subtract clean accuracies
            backpatching_diffs = backpatching_results[cfg][0].view(-1)[(backpatching_results[cfg][0].view(-1) > 0)] - backpatching_results["clean_accs"][0]
            control_backpatching_diffs = backpatching_results[cfg][1].view(-1)[(backpatching_results[cfg][1].view(-1) > 0)] - backpatching_results["clean_accs"][1]

            bp_better_than_control_percent = (backpatching_diffs >= control_backpatching_diffs).float().mean()
            control_results[(model_name, task_name, cfg[1])] = bp_better_than_control_percent

# Sort the keys in control_results

# Extract x_labels and y_values
x_labels = [f"{key[0]}-{key[1]}-{key[2]}" for key in control_results.keys()]
y_values = [control_results[key].item() * 100 for key in control_results.keys()]
# Assign colors based on task_name
colors = [COLORBLIND_COLORS[SUPPORTED_TASKS.index(key[1])] for key in control_results.keys()]
# Create the scatter plot
fig = go.Figure()
# Add scatter points with repeating shapes
shapes = ['circle', 'square', 'star']
fig.add_trace(go.Scatter(
    x=x_labels,
    y=y_values,
    mode='markers',
    marker=dict(color=colors, size=10, symbol=[shapes[i % len(shapes)] for i in range(len(x_labels))]),
    name="Control Results"
))
# Add dotted lines
fig.add_vline(x=14.5, line_color="black")
fig.add_vline(x=29.5, line_color="black")
fig.add_hline(y=50, line_dash="dot", line_color="gray", annotation_text="Random", annotation_position="top left", annotation_font=dict(size=16))
fig.add_hline(y=0, line_dash="dot", line_color="black", annotation_text="Control Advantage", annotation_position="top left", annotation_font=dict(size=16))
fig.add_hline(y=100, line_dash="dot", line_color="green", annotation_text="BP Advantage", annotation_position="top left", annotation_font=dict(size=16))
# Update layout

fig.update_xaxes(tickvals=[], ticktext=[])
fig.update_layout(
    yaxis=dict(title="Back-patching<br>beats control<br>percentage", range=[-10, 110]),
    width=1000,
    height=200,
    showlegend=False,
    margin=dict(l=0, r=5, t=5, b=5),
)

fig.show()

pio.write_image(fig, f"./figures/control-vs-backpatching-results.pdf")