In [None]:
"""
Stores activations for text samples created by export-jailbreak-generations.ipynb
"""
None

In [None]:
"""
Imports
"""
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.loss.loss_utils import ForCausalLMLoss
from datasets import load_dataset, Dataset
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import gc
import pickle
from termcolor import colored
import importlib 

from utils.memory import check_memory, clear_all_cuda_memory
from utils.loader import load_model_and_tokenizer, load_custom_forward_pass
from utils.store_topk import convert_topk_to_df
from utils.store_outputs import convert_outputs_to_df_fast
from utils.pretrained_models import gptoss

main_device = 'cuda:0'
seed = 123

clear_all_cuda_memory()
check_memory()

ws = '/workspace/deliberative-alignment-jailbreaks'

# Load base model

In [None]:
"""
Load the base tokenizer/model
"""
model_prefix = 'gptoss-20b' # gptoss-20b, gptoss-120b
tokenizer, model, model_architecture, model_n_layers = load_model_and_tokenizer(model_prefix, device = main_device)

check_memory()

In [None]:
"""
Load custom forward pass and verify equality to base model forward pass
"""
run_forward_with_hs = load_custom_forward_pass(model_architecture, model, tokenizer)

## Get dataset

In [None]:
"""
Load dataset
"""
gens_df =\
    pd.read_csv(f'{ws}/experiments/user-injections/base-harmful-responses-classified.csv')

gens_df

In [None]:
"""
Test max tokenization length
"""
inputs_test = tokenizer(
    gens_df['redteam_output_full'].tolist(),
    add_special_tokens = False,
    max_length = 6_000,
    padding = 'max_length',
    truncation = True,
    return_offsets_mapping = True,
    return_tensors = 'pt'
)

max_input_length = inputs_test['attention_mask'].sum(dim = 1).max().item()
print(f"Max input length: {max_input_length}")

In [None]:
""" 
Load dataset into a dataloader. The dataloader returns the original tokens - this is important for BPE tokenizers as otherwise it's difficult to reconstruct the correct string later!
"""
from utils.dataset import ReconstructableTextDataset, stack_collate
from torch.utils.data import DataLoader

# Create and chunk into lists of size 250 each - these will be the export breaks
def chunk_list(input_list, max_length):
    return [input_list[i:i + max_length] for i in range(0, len(input_list), max_length)]

test_dls = [
    DataLoader(
        ReconstructableTextDataset(
            [x['redteam_output_full'] for x in data_chunk],
            tokenizer,
            max_length = max_input_length,
            redteam_prompt_ix = [x['redteam_prompt_ix'] for x in data_chunk]
        ),
        batch_size = 8,
        shuffle = False,
        collate_fn = stack_collate
    )
    for data_chunk in tqdm(chunk_list(gens_df.to_dict('records'), 250))
]

## Get expert selections + export

In [None]:
""" 
Run forward passes + export data
"""
@torch.no_grad()
def run_and_export_topk(model, model_prefix: str, dls: list[ReconstructableTextDataset], layers_to_keep_acts: list[int], max_batches: None | int = None):
    """
    Run forward passes on given model and store the intermediate hidden layers as well as topks

    Params:
        @model: The model to run forward passes on. Should return a dict with keys `logits`, `all_topk_experts`, `all_topk_weights`,
          `all_pre_mlp_hidden_states`, and `all_expert_outputs`.
        @model_prefix: The model prefix to save the activations under.
        @dls: A list of dataloaders, each a ReconstructableTextDataset of which returns `input_ids`, `attention_mask`, and `original_tokens`.
        @layers_to_keep_acts: A list of layer indices (0-indexed) for which to filter `all_pre_mlp_hidden_states` (see returned object description).
        @max_batches: The max number of batches to run.

    Returns:
        A dict with keys:
        - `sample_df`: A sample (token)-level dataframe with corresponding input token ID, output token ID, and input token text (removes masked tokens)
        - `topk_df`: A sample (token) x layer_ix x topk_ix level dataframe that gives the expert ID selected at each sample-layer-topk (removes masked_tokens)
        - `all_router_logits`: A tensor of size n_samples x layers_to_keep_acts x n_experts with the pre-softmax router logits.
        - `all_pre_mlp_hidden_states`: A tensor of size n_samples x layers_to_keep_acts x D return the hidden state for each retained layers. Each 
            n_sample corresponds to a row of sample_df.
    """
    cross_dl_batch_ix = 0
    output_dir = f'{ws}/experiments/role-injection-analysis/activations-redteam/{model_prefix}'
    os.makedirs(output_dir, exist_ok = True)
    
    # Save metadata
    with open(f'{output_dir}/metadata.pkl', 'wb') as f:
        pickle.dump(
            {'all_pre_mlp_hidden_states_layers': layers_to_keep_acts},
            f
        )

    # Iterate through dataloaders
    for dl_ix, dl in enumerate(dls):
        print(f"Processing {str(dl_ix)} of {len(dls)}...")   
        dl_dir = f"{output_dir}/{dl_ix:02d}"
        os.makedirs(dl_dir, exist_ok = True)

        all_router_logits = []
        all_pre_mlp_hidden_states = []
        sample_dfs = []
        topk_dfs = []

        for _, batch in tqdm(enumerate(dl), total = len(dl)):

            input_ids = batch['input_ids'].to(main_device)
            attention_mask = batch['attention_mask'].to(main_device)
            original_tokens = batch['original_tokens']
            redteam_prompt_ixs = batch['redteam_prompt_ix']

            output = run_forward_with_hs(model, input_ids, attention_mask, return_hidden_states = True)

            # Check no bugs by validating output/perplexity
            if cross_dl_batch_ix == 0:
                loss = ForCausalLMLoss(output['logits'], torch.where(input_ids == tokenizer.pad_token_id, torch.tensor(-100), input_ids), model.config.vocab_size).detach().cpu().item()
                for i in range(min(20, input_ids.size(0))):
                    decoded_input = tokenizer.decode(input_ids[i, :], skip_special_tokens = False)
                    next_token_id = torch.argmax(output['logits'][i, -1, :]).item()
                    print('---------\n' + decoded_input + colored(tokenizer.decode([next_token_id], skip_special_tokens = False).replace('\n', '<lb>'), 'green'))
                print(f"PPL:", torch.exp(torch.tensor(loss)).item())
            
            original_tokens_df = pd.DataFrame(
                [(seq_i, tok_i, tok) for seq_i, tokens in enumerate(original_tokens) for tok_i, tok in enumerate(tokens)], 
                columns = ['sequence_ix', 'token_ix', 'token']
            )
                    
            prompt_indices_df = pd.DataFrame(
                [(seq_i, seq_source) for seq_i, seq_source in enumerate(redteam_prompt_ixs)], 
                columns = ['sequence_ix', 'redteam_prompt_ix']
            )

            # Create sample (token) level dataframe
            sample_df =\
                convert_outputs_to_df_fast(input_ids, attention_mask, output['logits'])\
                .merge(original_tokens_df, how = 'left', on = ['token_ix', 'sequence_ix'])\
                .merge(prompt_indices_df, how = 'left', on = ['sequence_ix'])\
                .assign(batch_ix = cross_dl_batch_ix)

            # Create topk x layer_ix x sample level dataframe
            topk_df =\
                convert_topk_to_df(input_ids, attention_mask, output['all_topk_experts'], output['all_topk_weights'])\
                .assign(batch_ix = cross_dl_batch_ix, weight = lambda df: df['weight'])\
                .drop(columns = 'token_id')
            
            sample_dfs.append(sample_df)
            topk_dfs.append(topk_df)

            # Store pre-MLP hidden states - the fwd pass as n_layers list as BN x D, collapse to BN x n_layers x D, with BN filtering out masked items
            valid_pos = torch.where(attention_mask.cpu().view(-1) == 1) # Valid (BN, ) positions
            all_router_logits.append(torch.stack(output['all_router_logits'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])
            all_pre_mlp_hidden_states.append(torch.stack(output['all_pre_mlp_hidden_states'], dim = 1)[valid_pos][:, layers_to_keep_acts, :])

            cross_dl_batch_ix += 1
            if max_batches is not None and cross_dl_batch_ix >= max_batches:
                pd.concat(sample_dfs, ignore_index = True).to_pickle(f'{dl_dir}/samples.pkl')
                pd.concat(topk_dfs, ignore_index = True).to_pickle(f'{dl_dir}/topks.pkl')
                torch.save(torch.cat(all_router_logits, dim = 0), f'{dl_dir}/all-router-logits.pt')
                torch.save(torch.cat(all_pre_mlp_hidden_states, dim = 0), f'{dl_dir}/all-pre-mlp-hidden-states.pt')
                return True

        pd.concat(sample_dfs, ignore_index = True).to_pickle(f'{dl_dir}/samples.pkl')
        pd.concat(topk_dfs, ignore_index = True).to_pickle(f'{dl_dir}/topks.pkl')
        torch.save(torch.cat(all_router_logits, dim = 0), f'{dl_dir}/all-router-logits.pt')
        torch.save(torch.cat(all_pre_mlp_hidden_states, dim = 0), f'{dl_dir}/all-pre-mlp-hidden-states.pt')

    return True

res = run_and_export_topk(
    model,
    model_prefix,
    test_dls,
    layers_to_keep_acts = [i for i in list(range(model_n_layers)) if i % 4 == 0],
    max_batches = None
)

In [None]:
"""
Stock an extra copy of the input prompts
"""
gens_df.to_csv(f'{ws}/experiments/role-injection-analysis/activations-redteam/{model_prefix}/base-harmful-responses-classified.csv', index = False)