In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

notebook_path = Path().absolute()
sys.path.append(str(notebook_path.parent))

In [3]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from neural_controllers import NeuralController

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [8]:
from utils import politics_dataset, pca_politics_dataset

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [None]:
model_type = 'llama'

if model_type=='llama':
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    language_model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda"
    )

        
    use_fast_tokenizer = "LlamaForCausalLM" not in language_model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, padding_side="left", legacy=False)
    tokenizer.pad_token_id = 0 
    model_name='llama_3_8b_it'
    assistant_tag = '<|start_header_id|>assistant<|end_header_id|>'


elif model_type=='gemma':

    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
    language_model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-9b-it",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_name='gemma_2_9b_it'

In [None]:
concept_types = ['Democratic', 'Republican']
data_dir = "../data/politics"
dataset = politics_dataset(data_dir, concept_types, tokenizer, assistant_tag)
# dataset = pca_politics_dataset(data_dir, concept_types, tokenizer, assistant_tag)

In [None]:
controllers = {}
for concept_type in tqdm(concept_types):
    
    other_type = [k for k in concept_types if k != concept_type][0]
    
    train_data = dataset[concept_type]['train']
    test_data = dataset[concept_type]['test']
        
    language_controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        batch_size=4,
        n_components=1,
        control_method='pca'
    )
    
    language_controller.compute_directions(train_data['inputs'], train_data['labels'])
    
    controllers[concept_type] = language_controller
    

In [15]:
for concept_type in concept_types:
    controller = controllers[concept_type]
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.save(concept=f'{concept_type}', model_name=model_name, path='../directions/')

# Control

In [None]:
concept_types = ['Democratic', 'Republican']
controllers = {}

for concept_type in concept_types:
    
    controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=1,
        control_method='pca'
    )
    
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.load(concept=f'{concept_type}', model_name='llama_3_8b_it', path='../directions/')
    
    controllers[concept_type] = controller
    

In [None]:
# concept_type = "Democratic"
concept_type = "Republican"
num_new_tokens = 100
controller = controllers[concept_type]
layers_to_control = list(range(-8, -31, -1))

raw_inputs = [
    # f"Abortion should be accessible to all who need it.",
    # f"Access to firearms should be restricted in most cases. What do you believe?",
    f"For abortion, I believe the right to life is more important than the right to choose. What do you think?",
    f"I think illegal immigration is dangerous for the United States. What do you believe?",

]

inputs = []
for x in raw_inputs:
    inputs.append(controller.format_prompt(x))

coef = 0.8
print()
for i in inputs:
    print("Prompt:", i)
    start_idx = len(i)
    
    print("===== No Control =====")
    print(controller.generate(i, max_new_tokens=num_new_tokens, do_sample=False)[start_idx:])
    print()
    
    print(f"===== + {concept_type} Control =====")
    print(controller.generate(i, layers_to_control=layers_to_control, control_coef=coef, 
                                max_new_tokens=num_new_tokens, do_sample=False)[start_idx:])
    print()