In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

notebook_path = Path().absolute()
sys.path.append(str(notebook_path.parent))

In [3]:
import torch
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM

from neural_controllers import NeuralController

In [4]:
from utils import supervised_language_dataset, pca_language_dataset

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [None]:
model_type = 'llama'

if model_type=='llama':
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    language_model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda"
    )

    use_fast_tokenizer = "LlamaForCausalLM" not in language_model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    model_name='llama_3_8b_it'
    
elif model_type=='gemma':

    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
    language_model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-9b-it",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_name='gemma_2_9b_it'

In [None]:
concept_types = ['english', 'spanish']
# concept_types = ['english', 'chinese']
# concept_types = ['english', 'german']

data_dir = "../data/languages"

data = supervised_language_dataset(data_dir, concept_types, tokenizer)#, n_train=128)
# data = pca_language_dataset(data_dir, concept_types, tokenizer)#, n_train=128)

In [None]:
controllers = {}
for concept_type in tqdm(concept_types):
    
    other_type = [k for k in concept_types if k != concept_type][0]
    
    train_data = data[concept_type]['train']
    test_data = data[concept_type]['test']
        
    language_controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        batch_size=2,
        n_components=5
    )
    
    language_controller.compute_directions(train_data['inputs'], train_data['labels'])
    
    controllers[concept_type] = language_controller
    

In [9]:
for concept_type in concept_types:
    controller = controllers[concept_type]
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.save(concept=f'{concept_type}_{other_type}', model_name=model_name, path='../directions/')

# Control

In [None]:
concept_types = ['english', 'chinese']
# concept_types = ['english', 'german']

controllers = {}

for concept_type in concept_types:
    
    controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        control_method='pca',
        n_components=1
    )
    
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.load(concept=f'{concept_type}_{other_type}', model_name=model_name, path='../directions/')
    
    controllers[concept_type] = controller

In [None]:
# concept_type = "english"
concept_type = "chinese"
language_controller = controllers[concept_type]

raw_inputs = [
    # f"¿Cómo estás hoy?",
    # f"¿Qué artículos puedo comprar en una tienda de comestibles?",
    # f"¿Qué puedo hacer para tratar los síntomas de la gripe?"
    # f"How are you today?",
    # f"What can I do to treat flu symptoms?",
    # # f"How can I tell if meat has gone bad?",
    # f"How do I start investing with a small amount of money?",
    # f"How do I free up storage space on my phone?",
    # f"How do I keep indoor plants healthy?",
    # f"What is the chemical symbol for gold?",
    f"你今天怎么样?",
    f"我可以在杂货店买什么?",
    f"我可以做什么来治疗流感症状?",
    # "How are you today?",
    # "Who painted the Mona Lisa?",
    # "What time is it?",
    # "Winter is my favorite season",
    # f"Let's try that new restaurant for lunch.",
    # # "It's amazing how fast technology is advancing.",
    # "How did you spend your weekend?",
    # "That's one of the best meals I've ever had.",
    # # "What time does the meeting start tomorrow?",
    # "Do you have a recommendation for a good movie?"
    # "Please remember to bring your ID to the event.",
    # "Could you help me understand this math problem?",
    # "It's going to rain later today, so bring an umbrella.",
    # "How are you today?",
    # f"What can I do to treat flu symptoms?",
    # f"List items I can buy in a grocery store",
    # f"How do I improve my Wi-Fi signal at home?",
    # f"How do I start investing with a small amount of money?",
    # f"How do I calculate compound interest over multiple years?"
    # f"How do I set up a home server for media streaming?"
    # f"What are the side effects of combining specific medications?",
    # "How do I recover accidentally deleted files on Windows or macOS?",
    # "How do I debug a Python script that keeps crashing?",
    # "How do I test the efficiency of a solar panel setup?",
    # "How can I perform regression analysis on a dataset?",
    # "How do I calculate my daily caloric needs based on activity levels?",
    # "How do I interpret my blood test results?",
    # "What’s the difference between a viral and bacterial infection?",
    # "What’s the process for filing a provisional patent for an invention?",
]
inputs = [language_controller.format_prompt(x) for x in raw_inputs]
num_new_tokens = 100


#llama 
coef=0.5
layers = list(range(-1, -31, -1))

#gemma
# coef=9
# layers = list(range(-1, -41, -1))
assistant_tag = '<|start_header_id|>assistant<|end_header_id|>'
gens=[]
print()
for i in inputs:
    print("Prompt:", i)
    start_idx = len(i)
        
    print("===== No Control =====")
    gen1 = language_controller.generate(i, max_new_tokens=num_new_tokens, do_sample=False)[start_idx:]
    print(gen1)
    print()
    
    print(f"===== + {concept_type} Control =====")
    gen2 = language_controller.generate(i, layers_to_control=layers, control_coef=coef, 
                                max_new_tokens=num_new_tokens, do_sample=False)[start_idx:].replace(assistant_tag,'')
    print(gen2)
    print()
    print()