In [None]:
"""
Train probes
"""
None

In [None]:
"""
Imports
"""
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import pickle
import importlib
import cupy
import yaml

from utils.memory import check_memory, clear_all_cuda_memory
from utils.loader import load_model_and_tokenizer, load_custom_forward_pass
from utils.probes import run_and_export_states, run_projections, check_max_seq_len

main_device = 'cuda:0'
seed = 123

ws = '/workspace/deliberative-alignment-jailbreaks'

clear_all_cuda_memory()
check_memory()

# Load model

In [None]:
"""
Load the base tokenizer/model
"""
model_prefix = 'gptoss-20b'
tokenizer, model, model_architecture, model_n_layers = load_model_and_tokenizer(model_prefix, device = main_device)

check_memory()

In [None]:
"""
Load custom forward pass and verify equality to base model forward pass
"""
run_forward_with_hs = load_custom_forward_pass(model_architecture, model, tokenizer)

# Load probes

In [None]:
TEST_MODEL = 'gptoss-20b'
TEST_LAYER_IX = 12
TEST_ROLE_SPACE = ['system', 'user', 'cot', 'assistant']

with open(f'{ws}/experiments/role-analysis/outputs/probes/{TEST_MODEL}.pkl', 'rb') as f:
    probes = pickle.load(f)

probe = [p for p in probes if p['layer_ix'] == TEST_LAYER_IX and p['role_space'] == TEST_ROLE_SPACE][0]
probe

# Create dataset and collect states

In [None]:
"""
Define a lexicon
"""
MAX_ROWS = 25

test_prefix = yaml.safe_load(open('./../role-analysis/config/probe.yaml'))[model_prefix]['test_prefix']
raw_convs = pd.read_csv(f'{ws}/experiments/role-analysis/data/conversations/{TEST_MODEL}.csv')

lexicon_df = pd.DataFrame([
    {'role': role, 'content': test_prefix + ' ' + ' '.join([row[col]] * 20)}
    for row in raw_convs.head(MAX_ROWS).to_dict('records')
    for role, col in [('user', 'user_query'), ('cot', 'cot'), ('assistant', 'assistant')]
])

lexicon_df

In [None]:
"""
Prep
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

MAX_INPUT_LEN = 768

input_dl = DataLoader(
    ReconstructableTextDataset(lexicon_df['content'].tolist(), tokenizer, max_length = MAX_INPUT_LEN, prompt_ix = list(range(0, len(lexicon_df)))),
    batch_size = 8,
    shuffle = False,
    collate_fn = stack_collate
)

run_result = run_and_export_states(model, tokenizer, run_model_return_states = run_forward_with_hs, dl = input_dl, layers_to_keep_acts = [TEST_LAYER_IX])
hs = run_result['all_hs'].to(torch.float16)
hs = {layer_ix: hs[:, save_ix, :] for save_ix, layer_ix in enumerate([TEST_LAYER_IX])}
sample_df = run_result['sample_df'].assign(sample_ix = lambda df: range(0, len(df)))
gc.collect()

In [None]:
"""
Role space projections
"""
def run_projections(valid_sample_df, layer_hs, probe):
    """
    Run probe-level projections
    
    Params:
        @valid_sample_df: A sample-level df with columns `sample_ix` (1... T - 1), `sample_ix`.
            Can be shorter than full T - 1 due to pre-filters, as long as sample_ix corresponds to the full length.
        @layer_hs: A tensor of size T x D for the layer to project.
        @probe: The probe dict with keys `probe` (the trained model) and `role_space` (the roles list)
    
    Returns:
        A df at (sample_ix, target_role) level with cols `sample_ix`, `target_role`, `prob`
    """
    x_cp = cupy.asarray(layer_hs[valid_sample_df['sample_ix'].tolist(), :])
    y_cp = probe['probe'].predict_proba(x_cp).round(12)

    proj_results = pd.DataFrame(cupy.asnumpy(y_cp), columns = probe['role_space'])
    if len(proj_results) != len(valid_sample_df):
        raise Exception("Error!")

    role_df =\
        pd.concat([
            proj_results.reset_index(drop = True),
            valid_sample_df[['sample_ix']].reset_index(drop = True)
        ], axis = 1)\
        .melt(id_vars = ['sample_ix'], var_name = 'target_role', value_name = 'prob')\
        .reset_index(drop = True)\
        .assign(prob = lambda df: df['prob'].round(4))

    return role_df

output_projections =\
    run_projections(valid_sample_df = sample_df, layer_hs = hs[TEST_LAYER_IX], probe = probe)\
    .merge(
        sample_df[['prompt_ix', 'sample_ix', 'token']].assign(token_in_prompt_ix = lambda df: df.groupby(['prompt_ix']).cumcount()), 
        how = 'inner',
        on = ['sample_ix']
    )\
    .merge(lexicon_df.assign(prompt_ix = lambda df: range(0, len(df))), how = 'inner', on = 'prompt_ix')\
    .assign(rollprob = lambda df: df.groupby(['prompt_ix', 'target_role'])['prob'].transform(lambda x: x.ewm(alpha = 0.25).mean()))

output_projections


In [None]:
target_cols = output_projections['target_role'].unique().tolist()

output_agg =\
    output_projections\
    .groupby(['token_in_prompt_ix', 'target_role', 'role'], as_index = False)\
    .agg(mean_rollprob = ('rollprob', lambda x: x.mean()))
    
output_agg\
    .pivot(index = ['role', 'token_in_prompt_ix'], columns = 'target_role', values = 'mean_rollprob')\
    .reset_index()\
    .pipe(lambda df: df[df['token_in_prompt_ix'] >= MAX_INPUT_LEN - 10])

# .to_csv('lexical_role_projections.csv', index = False)

In [None]:
import plotly.express as px

fig = px.line(output_agg, x = 'token_in_prompt_ix', y = 'mean_rollprob', title = 'role', color = 'target_role', facet_col = 'role')

fig.show()