In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

notebook_path = Path().absolute()
sys.path.append(str(notebook_path.parent))

In [6]:
import torch
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from neural_controllers import NeuralController
import utils

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [None]:
model_type = 'llama'

if model_type=='llama':

    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    language_model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda"
    )

    use_fast_tokenizer = "LlamaForCausalLM" not in language_model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    model_name='llama_3_8b_it'

In [None]:
controller = NeuralController(
        language_model,
        tokenizer,
    )

In [None]:
concept_types = ['River', 'Bank']
data_dir = "../data/disambiguation"

dataset = utils.concept_dataset(data_dir, concept_types, controller)

In [None]:
controllers = {}
for concept_type in tqdm(concept_types):
    
    other_type = [k for k in concept_types if k != concept_type][0]
    
    train_data = dataset[concept_type]['train']
    test_data = dataset[concept_type]['test']
        
    controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        batch_size=4,
        control_method='rfm'
    )
    
    controller.compute_directions(train_data['inputs'], train_data['labels'])
    
    controllers[concept_type] = controller

In [18]:
for concept_type in concept_types:
    controller = controllers[concept_type]    
    controller.save(concept=f'{concept_type}', model_name='llama_3_8b_it', path='../directions/')

# Control

In [None]:
concept_types = ['Bank', 'River']
controllers = {}

for concept_type in concept_types:
    
    controller = NeuralController(
        language_model,
        tokenizer,
        control_method='rfm'
    )
    
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.load(concept=f'{concept_type}', model_name=model_name, path='../directions/')
    
    controllers[concept_type] = controller
    

In [None]:
# concept_type = "River"
concept_type = "Bank"
controller = controllers[concept_type]

raw_inputs = [
    # f"Consider all options. What kinds of things might you find at a bank?",
    f"The fisherman went to the bank by the river. Explain the items that he sees.",
    # f"Give the most likely answer to the following question. The teller went to the bank. What kind of bank is it?"
]
inputs = [controller.format_prompt(x) for x in raw_inputs]

coef=0.4
num_new_tokens=120

layers = list(range(-1, -31, -1))
# layers = list(range(-1, -41, -1))

gens=[]
print()
for i in inputs:
    print("Prompt:", i)
    print("===== No Control =====")
    print(controller.generate(i, max_new_tokens=num_new_tokens, do_sample=False).replace(i, ""))
    print()
    
    print(f"===== + {concept_type} Control =====")
    gen = controller.generate(i, layers_to_control=layers, control_coef=coef, 
                                max_new_tokens=num_new_tokens, do_sample=False).replace(i, "")
    gens.append(gen)
    print(gen)
    print()
    print()