In [None]:
MODEL_TAG = "tulu-2-13b"
NUM_DEVICES = 4
IS_HF = True

template_tag = 'NameCountryTemplate'
chat_style = 'tulu_chat'

form_path = ''
form_type = 'hessian_1_1'

output_dir = ''

In [None]:
from importlib import reload
import os
from pathlib import Path
import json
from coref import COREF_ROOT
from coref.utils import cache
from functools import partial
import itertools

import torch

import coref.models as models
import coref.experiments.triplet_expts as te
import coref.datasets.templates.triplet as tt

In [None]:
import seaborn as sns

In [None]:
from functools import partial
import numpy as np
import torch
import einops

import coref.datasets.templates.simple as ts
import coref.datasets.templates.common as tc
import coref.datasets.templates.triplet as tt
import coref.datasets.api as ta
import coref.parameters as p
import coref.datascience as ds
import coref.expt_utils as eu
import coref.injection_interventions as ii
import coref.eval_subspace as ess
import coref.form_processors

In [None]:
output_dir = Path(output_dir)

In [None]:
source_context = [
    tc.Statement(0, 0),
    tc.Statement(1, 1),
]

In [None]:

train_template = ta.get_template(template_tag)('llama')
test_template = ta.get_template(template_tag)('llama')


In [None]:

model = models.fetch_model(MODEL_TAG, num_devices=NUM_DEVICES, dtype=torch.bfloat16, hf=IS_HF)

In [None]:
if form_type == 'random':
    form = coref.form_processors.random_form(model.cfg.d_model)
else:
    form = torch.load(form_path)
    form = coref.form_processors.process_form(form, form_type)

In [None]:
eval_subspace = partial(ess.eval_subspace, model=model, chat_style=chat_style, test_template=test_template, verbose=True)

In [None]:
ax = sns.lineplot(x=np.arange(1000), y=form.S[:1000].cpu().numpy())
ax.set(yscale='log')

In [None]:
all_dims = [1, 3, 15, 50, 250, 1000, 5120]

## Name Swaps

In [None]:
all_metrics = []
all_data = []
for dim in all_dims:
    print(f'### Evaluating dimensions {dim}')
    metrics, data = eval_subspace(form.U[..., :dim].to(torch.bfloat16), 'name')
    all_metrics.extend([
        {**m, 'dim':dim}
        for m in metrics
    ])
    all_data.extend([
        {**d, 'dim':dim}
        for d in data
    ])

In [None]:
with open(output_dir / 'name_metrics.json', 'w') as f:
    json.dump(all_metrics, f)
torch.save(all_data, output_dir / 'name_data.pt')

## Attr swaps

In [None]:
all_metrics = []
all_data = []
for dim in all_dims:
    print(f'### Evaluating dimensions {dim}')
    metrics, data = eval_subspace(form.Vh.T[:, :dim].to(torch.bfloat16), 'attr')
    all_metrics.extend([
        {**m, 'dim':dim}
        for m in metrics
    ])
    all_data.extend([
        {**d, 'dim':dim}
        for d in data
    ])

In [None]:
with open(output_dir / 'attr_metrics.json', 'w') as f:
    json.dump(all_metrics, f)
torch.save(all_data, output_dir / 'attr_data.pt')

## Qualitative

In [None]:
import coref.plotting as cplot

In [None]:
cross_attn_ctxt = cplot.prep_plot_cross_attention(
    model=model,
    template=test_template,
    template_content=dict(query_name=0, chat_style=chat_style),
    content_context=source_context,
)

In [None]:
plot_data = cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.form
))

In [None]:
torch.save(plot_data, output_dir / 'serial_full.pt')

In [None]:
plot_data = cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.U[:, :50] @ form.U[:, :50].T,
    plot_style='rocket'
))

In [None]:
plot_data = cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.U[:, :50] * form.S[:50].pow(2) @ form.U[:, :50].T,
    plot_style='rocket'
))

In [None]:
torch.save(plot_data, output_dir / 'serial_u2.pt')

In [None]:
plot_data = cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.Vh.T[:, :50] @ form.Vh.T[:, :50].T,
    plot_style='rocket'
))

In [None]:
plot_data = cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.Vh.T[:, :50] * form.S[:50].pow(2) @ form.Vh.T[:, :50].T,
    plot_style='rocket'
))

In [None]:
plot_data = cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.U[:, :50]* form.S[:50] @ form.Vh.T[:, :50].T
))

In [None]:
torch.save(plot_data, output_dir / 'serial_50.pt')

In [None]:
cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.U[:, :50] @ form.Vh.T[:, :50].T
))

### parallel

In [None]:
cross_attn_ctxt = cplot.prep_plot_cross_attention(
    model=model,
    template=test_template,
    template_content=dict(query_name=0, chat_style=chat_style),
    content_context=[tc.Statement([0,1], [0,1], 'parallel')],
)

In [None]:
plot_data = cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.form
))

In [None]:
torch.save(plot_data, output_dir / 'parallel_full.pt')

In [None]:
plot_data = cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.U[:, :50] * form.S[:50].pow(2) @ form.U[:, :50].T,
    plot_style='rocket'
))

In [None]:
torch.save(plot_data, output_dir / 'parallel_u2.pt')

In [None]:
plot_data = cplot.plot_cross_attention(dict(
    **cross_attn_ctxt,
    form=form.U[:, :50]* form.S[:50] @ form.Vh.T[:, :50].T
))

In [None]:
torch.save(plot_data, output_dir / 'parallel_50.pt')