In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
from pathlib import Path

notebook_path = Path().absolute()
sys.path.append(str(notebook_path.parent))

In [3]:
import torch
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from neural_controllers import NeuralController

In [4]:
from utils import shakespeare_dataset

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [None]:
model_type = 'llama'

if model_type=='llama':

    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    language_model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda"
    )

    use_fast_tokenizer = "LlamaForCausalLM" not in language_model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    model_name='llama_3_8b_it'
    assistant_tag = '<|start_header_id|>assistant<|end_header_id|>'
    
elif model_type=='gemma':

    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
    language_model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-9b-it",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_name='gemma_2_9b_it'


In [None]:
controller = NeuralController(
        language_model,
        tokenizer
    )

In [None]:
concept_types = ['english', 'shakespeare']

data_dir = "../data/languages"

data = shakespeare_dataset(data_dir, concept_types, controller, assistant_tag)

In [None]:
controllers = {}
for concept_type in tqdm(concept_types):
    
    other_type = [k for k in concept_types if k != concept_type][0]
    
    train_data = data[concept_type]['train']
    test_data = data[concept_type]['test']
        
    language_controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        batch_size=2,
        n_components=5,
        control_method='logistic'
    )
    
    language_controller.compute_directions(train_data['inputs'], train_data['labels'])
    
    controllers[concept_type] = language_controller
    

In [9]:
for concept_type in concept_types:
    controller = controllers[concept_type]
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.save(concept=f'{concept_type}_{other_type}', model_name=model_name, path='../directions/')

# Control

In [None]:
concept_types = ['english', 'shakespeare']
# concept_types = ['english', 'german']

controllers = {}

for concept_type in concept_types:
    
    controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        control_method='logistic'
    )
    
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.load(concept=f'{concept_type}_{other_type}', model_name=model_name, path='../directions/')
    
    controllers[concept_type] = controller

In [None]:
concept_type = "english"
language_controller = controllers[concept_type]

raw_inputs = [
    # "How are you today?",
    # "What can I buy in a grocery store?",
    f"What can I do to treat flu symptoms?",
]
inputs = [language_controller.format_prompt(x) for x in raw_inputs]
num_new_tokens = 150


#llama 
coef=0.5
layers = list(range(-1, -31, -1))

#gemma
# coef=9
# layers = list(range(-1, -41, -1))
assistant_tag = '<|start_header_id|>assistant<|end_header_id|>'
gens=[]
print()
for i in inputs:
    print("Prompt:", i)
    start_idx = len(i)
        
    print("===== No Control =====")
    gen1 = language_controller.generate(i, max_new_tokens=num_new_tokens, do_sample=False)[start_idx:]
    print(gen1)
    print()
    
    print(f"===== + {concept_type} Control =====")
    gen2 = language_controller.generate(i, layers_to_control=layers, control_coef=coef, 
                                max_new_tokens=num_new_tokens, do_sample=False)[start_idx:].replace(assistant_tag,'')
    print(gen2)
    print()
    print()