In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

notebook_path = Path().absolute()
sys.path.append(str(notebook_path.parent))

In [3]:
import torch
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from neural_controllers import NeuralController

In [4]:
from utils import programming_language_dataset, pca_programming_language_dataset

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [None]:
model_type = 'llama'

if model_type=='llama':
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    language_model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda"
    )

    use_fast_tokenizer = "LlamaForCausalLM" not in language_model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    model_name='llama_3_8b_it'
    
elif model_type=='gemma':

    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
    language_model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-9b-it",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_name='gemma_2_9b_it'

In [None]:
# concept_types = ['python', 'c++']
concept_types = ['python', 'javascript']
data_dir = "../data/programming"

dataset = programming_language_dataset(concept_types, tokenizer)
# dataset = pca_programming_language_dataset(concept_types, tokenizer)

In [None]:
controllers = {}
for concept_type in tqdm(concept_types):
    
    other_type = [k for k in concept_types if k != concept_type][0]
    
    train_data = dataset[concept_type]['train']
    test_data = dataset[concept_type]['test']
        
    language_controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        batch_size=2,
    )
    
    language_controller.compute_directions(train_data['inputs'], train_data['labels'])
    
    controllers[concept_type] = language_controller
    

In [None]:
for concept_type in concept_types:
    try:
        controller = controllers[concept_type]
        other_type = [k for k in concept_types if k!=concept_type][0]
        controller.save(concept=f'{concept_type}_{other_type}', model_name=model_name, path='../directions/')
    except:
        print(f'{concept_type} not found')
    

# Control

In [7]:
from datasets import load_dataset
huggingface_dataset = load_dataset("greengerong/leetcode")
python_dataset = huggingface_dataset["train"]['python']
js_dataset = huggingface_dataset["train"]['javascript']


def extract_code(c):
    items = c.split("```")
    code = items[1]
    return code

In [None]:
concept_types = ['python', 'javascript']
controllers = {}

for concept_type in concept_types:
    
    controller = NeuralController(
        language_model,
        tokenizer,
        control_method='rfm',
        n_components=1
    )
    
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    try:
        controller.load(
                        concept=f'{concept_type}_{other_type}', 
                        model_name=model_name, 
                        path='../directions/')
        controllers[concept_type] = controller
    except:
        print(f'{concept_type} not found')
    

In [None]:
concept_type = "javascript"
# concept_type = "python"

idx=0
js_task = extract_code(js_dataset[idx])
python_task = extract_code(python_dataset[idx])
prompt = f"Re-state the following program. "
# prompt = f"Give a single, different re-writing of this program with the same function. "
# prompt += f"The output will be judged by an expert in all programming languages. "
# prompt += f"Do not include an explanation.\n\n```{python_task}```"
prompt += f"Do not include an explanation.\n\n```{js_task}```"

# prompt = f"Re-state the following program. Do not include an explanation. {python_task}."
# prompt = f"Give a single, different re-writing of this program with the same function. "
# prompt += f"The output will be judged by an expert in all programming languages. "
# # prompt += f"Do not include an explanation.\n\n```{python_task}```"



layer_id = list(range(-1, -31, -1))
# layer_id = list(range(-1, -41, -1))
language_controller = controllers[concept_type]
num_new_tokens = 150

inputs = language_controller.format_prompt(prompt)


# rfm
# coeff=9 # for javascript, gemma
coeff=0.7 # for javascript, llama

print(inputs)
print("===== No Control =====")
gen1 = language_controller.generate(inputs, max_new_tokens=num_new_tokens, do_sample=False)
print(gen1[len(inputs):])
print()
print(f"===== + {concept_type} Control =====")
gen2 = language_controller.generate(inputs, layers_to_control=layer_id, control_coef=coeff, 
                            max_new_tokens=num_new_tokens, do_sample=False)
print(gen2[len(inputs):])
print()