In [None]:
from matplotlib.patches import Rectangle
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

import torch

In [None]:
def heatmap(ax, tensor, vmin=-5, vmax=0):
    sns.heatmap(
        tensor.numpy(), ax=ax, annot=True, fmt=".2f", square=True, vmin=vmin, vmax=vmax, cbar=False, cmap=sns.color_palette("dark:#5A9", as_cmap=True))


def plot_factorization(results, plot_titles, highlight=None, vmin=-5, vmax=0):
    num_row, num_col = 1, 4
    coords = [(r, c) for r in range(num_row) for c in range(num_col)]
    highlight_squares = {
        (r, c): highlight[i] if highlight is not None else []
        for i, (r,c) in enumerate(coords)
    }

    fig, axs = plt.subplots(num_row, num_col, figsize=(9, 9))

    i = -1
    for r in range(num_row):
        for c in range(num_col):
            i += 1
            ax = axs[c]
            heatmap(ax, results[i], vmin=vmin, vmax=vmax)
            ax.set_title(plot_titles[i])
            for sqr, sqc in highlight_squares[(r, c)]:
                ax.add_patch(Rectangle((sqc, sqr), 1, 1, fill=False, edgecolor='crimson', lw=1.5, clip_on=False))
            ax.set_xticklabels(['$I_0$', '$I_1$', "$I'_0$", "$I'_1$"])
            ax.set_xlabel("Items")

            if i == 0:
                ax.set_ylabel("Objects")
                ax.set_yticklabels(['$O_0$', '$O_1$', "$O'_0$", "$O'_1$"])
            else:
                ax.set_yticklabels([])

    fig.tight_layout()
    return fig


def plot_masking(results, plot_titles, highlight=None, vmin=-5, vmax=0):
    num_row, num_col = 1, 1
    coords = [(r, c) for r in range(num_row) for c in range(num_col)]
    highlight_squares = {
        (r, c): highlight[i] if highlight is not None else []
        for i, (r,c) in enumerate(coords)
    }

    fig, axs = plt.subplots(num_row, num_col, figsize=(2,2))

    i = -1
    for r in range(num_row):
        for c in range(num_col):
            i += 1
            ax = axs
            heatmap(ax, results[i], vmin=vmin, vmax=vmax)
            ax.set_title(plot_titles[i])
            for sqr, sqc in highlight_squares[(r, c)]:
                ax.add_patch(Rectangle((sqc, sqr), 1, 1, fill=False, edgecolor='crimson', lw=1.5, clip_on=False))
            ax.set_xticklabels(['$I_0$', '$I_1$'])
            ax.set_xlabel("Items")

            if i == 0:
                ax.set_ylabel("Objects")
                ax.set_yticklabels(['$O_0$', '$O_1$'])
            else:
                ax.set_yticklabels([])

    fig.tight_layout()
    return fig

# Factorizability

In [None]:
root_temp = 'outputs/factorizability_{}.pth'

no_intervene = [[]]
item_intervenes = [['item1'], ['item2'], ['item1','item2']]

acc, log_probs = [],[]
for intervene in (no_intervene + item_intervenes):
    path = root_temp.format('_'.join(intervene))
    res_dict = torch.load(path)
    acc.append(res_dict['accuracy'])
    lp = res_dict['log_probs'].mean(dim=-2)
    log_probs.append(lp)

plot_names = ['None','Item 0', 'Item 1', 'Items 0,1']
highlight = [[(0,0),(1,1)],
             [(0,2),(1,1)],
             [(0,0),(1,3)],
             [(0,2),(1,3)],
             ]
_ = plot_factorization(acc,plot_names, highlight=highlight, vmin=0, vmax=100)
_ = plot_factorization(log_probs,plot_names, highlight=highlight, vmin=-14, vmax=-8)

In [None]:
root_temp = 'outputs/factorizability_{}.pth'

no_intervene = [[]]
patch_intervenes = [['shape1'], ['shape2'], ['shape1','shape2']]

acc, log_probs = [],[]
for intervene in (no_intervene + patch_intervenes):
    path = root_temp.format('_'.join(intervene))
    res_dict = torch.load(path)
    acc.append(res_dict['accuracy'])
    lp = res_dict['log_probs'].mean(dim=-2)
    log_probs.append(lp)

plot_names = ['None','Object 0', 'Object 1', 'Objects 0,1']
highlight = [[(0,0),(1,1)],
             [(2,0),(1,1)],
             [(0,0),(3,1)],
             [(2,0),(3,1)],
             ]
_ = plot_factorization(acc,plot_names, highlight=highlight, vmin=0, vmax=100)
_ = plot_factorization(log_probs,plot_names, highlight=highlight, vmin=-14, vmax=-8)

In [None]:
root_temp = 'outputs/factorizability_{}.pth'

no_intervene = [[]]
color_intervenes = [['color1'], ['color2'], ['color1','color2']]

acc, log_probs = [],[]
for intervene in (no_intervene + color_intervenes):
    path = root_temp.format('_'.join(intervene))
    res_dict = torch.load(path)
    acc.append(res_dict['accuracy'])
    lp = res_dict['log_probs'].mean(dim=-2)
    log_probs.append(lp)

plot_names = ['None','Color 0', 'Color 1', 'Colors 0,1']
highlight = [[(0,0),(1,1)],
             [(0,0),(1,1)],
             [(0,0),(1,1)],
             [(0,0),(1,1)],
             ]
_ = plot_factorization(acc,plot_names, highlight=highlight, vmin=0, vmax=100)
_ = plot_factorization(log_probs,plot_names, highlight=highlight, vmin=-14, vmax=-8)

# Position Independence

In [None]:
def plot_1d_logits(
    logits,
    tokens,
    color_labels,
    color_category_label,
    style_labels,
    style_category_label,
    ax=None
):
    '''
    logits: [position, color, style]
    '''
    # prompt_position_labels = [
    #     "<WS>" if tok.isspace() else f"{tok}" for i, tok in enumerate( 
    #         vocab.tokenizer.batch_decode(tokens))]
    prompt_position_labels = tokens
    nparr = np.asarray(logits)
    long_data = []
    for (pos, color, style), logit in np.ndenumerate(logits):
        long_data.append({
            color_category_label: color_labels[color],
            style_category_label: style_labels[style],
            'logit': logit,
            'position': pos,
        })
    df = pd.DataFrame(long_data)
    if ax is None:
        fig = plt.figure(figsize=(10, 3))
        ax = fig.add_subplot()
    sns.lineplot(ax=ax, data=df, x='position', y='logit', hue=color_category_label, style=style_category_label)
    ax.set_xlabel('Token Position')
    ax.set_ylabel('Mean Log Prob')
    ax.set_xticks(range(len(prompt_position_labels)), prompt_position_labels, rotation=0, rotation_mode='anchor',ha='right')
    ax.axvline(len(prompt_position_labels)-2, color='grey')
    ax.axvline(1, color='green')
    return ax

In [None]:
pi_item = torch.load('outputs/pos_ind_item.pth')
# logits = torch.stack(list(pi_item['log_probs'].values()))
logits = []
for k in pi_item['log_probs']:
    lp = pi_item['log_probs'][k].mean(dim=-2)
    logits.append(lp)
accuracy = torch.stack(list(pi_item['accuracy'].values()))
tname = lambda x: 'ctrl+({})'.format(str(x))
tokens = list(map(tname, list(pi_item['accuracy'].keys())))

plot_1d_logits(
        logits=logits,
        tokens=tokens,
        color_category_label='Objects',
        color_labels=['$O_0$', '$O_1$'],
        style_category_label='Items',
        style_labels=['$I_0$', '$I_1$'],
        ax=None
    )

plot_1d_logits(
        logits=accuracy,
        tokens=tokens,
        color_category_label='Objects',
        color_labels=['$O_0$', '$O_1$'],
        style_category_label='Items',
        style_labels=['$I_0$', '$I_1$'],
        ax=None
    )

In [None]:
pi_item = torch.load('outputs/pos_ind_color.pth')
# logits = torch.stack(list(pi_item['log_probs'].values()))
logits = []
for k in pi_item['log_probs']:
    lp = pi_item['log_probs'][k].mean(dim=-2)
    logits.append(lp)
accuracy = torch.stack(list(pi_item['accuracy'].values()))
tname = lambda x: 'ctrl+({})'.format(str(x))
tokens = list(map(tname, list(pi_item['accuracy'].keys())))

plot_1d_logits(
        logits=logits,
        tokens=tokens,
        color_category_label='Objects',
        color_labels=['$O_0$', '$O_1$'],
        style_category_label='Items',
        style_labels=['$I_0$', '$I_1$'],
        ax=None
    )

plot_1d_logits(
        logits=accuracy,
        tokens=tokens,
        color_category_label='Objects',
        color_labels=['$O_0$', '$O_1$'],
        style_category_label='Items',
        style_labels=['$I_0$', '$I_1$'],
        ax=None
    )

In [None]:
pi_item = torch.load('outputs/pos_ind_shape.pth')
# logits = torch.stack(list(pi_item['log_probs'].values()))
logits = []
for k in pi_item['log_probs']:
    lp = pi_item['log_probs'][k].mean(dim=-2)
    logits.append(lp)
accuracy = torch.stack(list(pi_item['accuracy'].values()))
tname = lambda x: 'ctrl+({})'.format(str(x))
tokens = list(map(tname, list(pi_item['accuracy'].keys())))

plot_1d_logits(
        logits=logits,
        tokens=tokens,
        color_category_label='Objects',
        color_labels=['$O_0$', '$O_1$'],
        style_category_label='Items',
        style_labels=['$I_0$', '$I_1$'],
        ax=None
    )

plot_1d_logits(
        logits=accuracy,
        tokens=tokens,
        color_category_label='Objects',
        color_labels=['$O_0$', '$O_1$'],
        style_category_label='Items',
        style_labels=['$I_0$', '$I_1$'],
        ax=None
    )

# Mean Ablations

## Mean Difference of binding vectors

In [None]:
res_dict = torch.load('outputs/mean_ab_item.pth')
lp = res_dict['log_probs'].mean(-2)

plot_names = ['All',]
highlight = [[(1,1),(0,0)],
             ]
_ = plot_masking([res_dict['accuracy']],plot_names,highlight=highlight)
_ = plot_masking([lp],plot_names,highlight=highlight)

In [None]:
res_dict = torch.load('outputs/mean_ab_shape.pth')
lp = res_dict['log_probs'].mean(-2)

plot_names = ['All',]
highlight = [[(1,1),(0,0)],
             ]
_ = plot_masking([res_dict['accuracy']],plot_names,highlight=highlight)
_ = plot_masking([lp],plot_names,highlight=highlight)

In [None]:
res_dict = torch.load('outputs/mean_ab_shape_item.pth')
lp = res_dict['log_probs'].mean(-2)

plot_names = ['All',]
highlight = [[(1,1),(0,0)],
             ]
_ = plot_masking([res_dict['accuracy']],plot_names,highlight=highlight)
_ = plot_masking([lp],plot_names,highlight=highlight)

## Random vectors

In [None]:
res_dict = torch.load('outputs/random_mean_ab_item.pth')
lp = res_dict['log_probs'].mean(-2)

plot_names = ['All',]
highlight = [[(1,1),(0,0)],
             ]
_ = plot_masking([res_dict['accuracy']],plot_names,highlight=highlight)
_ = plot_masking([lp],plot_names,highlight=highlight)

In [None]:
res_dict = torch.load('outputs/random_mean_ab_shape.pth')
lp = res_dict['log_probs'].mean(-2)

plot_names = ['All',]
highlight = [[(1,1),(0,0)],
             ]
_ = plot_masking([res_dict['accuracy']],plot_names,highlight=highlight)
_ = plot_masking([lp],plot_names,highlight=highlight)

In [None]:
res_dict = torch.load('outputs/random_mean_ab_shape_item.pth')
lp = res_dict['log_probs'].mean(-2)

plot_names = ['All',]
highlight = [[(1,1),(0,0)],
             ]
_ = plot_masking([res_dict['accuracy']],plot_names,highlight=highlight)
_ = plot_masking([lp],plot_names,highlight=highlight)