In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

notebook_path = Path().absolute()
sys.path.append(str(notebook_path.parent))

In [3]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import utils
from neural_controllers import NeuralController

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [None]:
model_type = 'llama'

if model_type=='llama':

    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    language_model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda"
    )

    
    use_fast_tokenizer = "LlamaForCausalLM" not in language_model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False, padding_side="left")
    tokenizer.pad_token_id = 0 
    model_name='llama_3_8b_it'

elif model_type=='gemma':

    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
    language_model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-9b-it",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_name='gemma_2_9b_it'

## Generate directions

In [None]:
controller = NeuralController(
    language_model,
    tokenizer,
    rfm_iters=8,
    n_components=1,
    control_method='logistic',
    batch_size=1
)

In [None]:
data_path='../data/hallucinations/halu_eval/qa_data.txt'
dataset = utils.hallucination_dataset(data_path, tokenizer)

In [None]:
controller.compute_directions(dataset['train']['inputs'], np.concatenate(dataset['train']['labels']).tolist())

In [13]:
controller.save(concept='hallucination_halu_eval', model_name=model_name, path='../directions/')

# Control

In [None]:
controller = NeuralController(
    language_model,
    tokenizer,
    rfm_iters=8,
    control_method='rfm',
    n_components=1
)
controller.load(concept='hallucination_halu_eval', model_name=model_name, path='../directions/')

In [15]:
raw_inputs = [
    f"What is the percentage of forest fires that cause earthquakes?",
    # f"What is the most popular food in the U.S.?",
    # f"Based on percentages alone, which food is more popular in the U.S. burgers or pizza?",
    # f"List the GDP of the top 10 countries in the middle east at the end of 2020",
    # f"What is the speed of light?",
    # f"How tall is Mount Everest?",
    # f"Who painted the Mona Lisa?",
    # f"What is the chemical formula for water?",
    # f"What is the capital city of Australia?",
]

inputs = []
for x in raw_inputs:
    inputs.append(controller.format_prompt(x))

In [None]:
num_new_tokens = 200

coef=0.6

layers = list(range(-1, -31, -1))

assistant_tag = '<|start_header_id|>assistant<|end_header_id|>'

gens=[]
print()
for i in inputs:
    print("Prompt:", i)
    print("===== No Control =====")
    normal_gen = controller.generate(i, max_new_tokens=num_new_tokens, do_sample=False)
    start_idx = normal_gen.find(assistant_tag) + len(assistant_tag)
    print(normal_gen[start_idx:])
    print()
    
    print(f"===== + Hallucination =====")
    gen = controller.generate(i, layers_to_control=layers, control_coef=coef, 
                                max_new_tokens=num_new_tokens, do_sample=False)
    gens.append(gen)
    start_idx = gen.find(assistant_tag) + len(assistant_tag)
    print(gen[start_idx:])
    print()
    print()