In [1]:
import torch
import transformer_lens.utils as utils

import plotly.express as px
import tqdm
from functools import partial
import einops
import plotly.graph_objects as go

update_layout_set = {
    "xaxis_range", "yaxis_range", "hovermode", "xaxis_title", "yaxis_title", "colorbar", "colorscale", "coloraxis",
     "title_x", "bargap", "bargroupgap", "xaxis_tickformat", "yaxis_tickformat", "title_y", "legend_title_text", "xaxis_showgrid",
     "xaxis_gridwidth", "xaxis_gridcolor", "yaxis_showgrid", "yaxis_gridwidth"
}

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    if isinstance(tensor, list):
        tensor = torch.stack(tensor)
    kwargs_post = {k: v for k, v in kwargs.items() if k in update_layout_set}
    kwargs_pre = {k: v for k, v in kwargs.items() if k not in update_layout_set}
    if "facet_labels" in kwargs_pre:
        facet_labels = kwargs_pre.pop("facet_labels")
    else:
        facet_labels = None
    if "color_continuous_scale" not in kwargs_pre:
        kwargs_pre["color_continuous_scale"] = "RdBu"
    fig = px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0,labels={"x":xaxis, "y":yaxis}, **kwargs_pre).update_layout(**kwargs_post)
    if facet_labels:
        for i, label in enumerate(facet_labels):
            fig.layout.annotations[i]['text'] = label

    fig.show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, return_fig=False, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    fig = px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)
    if return_fig:
        return fig
    fig.show(renderer)

from typing import List
def show_avg_logit_diffs(x_axis: List[str], per_prompt_logit_diffs: List[torch.tensor]):


    y_data = [per_prompt_logit_diff.mean().item() for per_prompt_logit_diff in per_prompt_logit_diffs]
    error_y_data = [per_prompt_logit_diff.std().item() for per_prompt_logit_diff in per_prompt_logit_diffs] 

    fig = go.Figure(data=[go.Bar(
        x=x_axis,
        y=y_data,
        error_y=dict(
            type='data',  # specifies that the actual values are given
            array=error_y_data,  # the magnitudes of the errors
            visible=True  # make error bars visible
        ),
    )])

    # Customize layout
    fig.update_layout(title_text=f'Logit Diff after Interventions',
                    xaxis_title_text='Intervention',
                    yaxis_title_text='Logit diff',
                    plot_bgcolor='white')

    # Show the figure
    fig.show()