In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

notebook_path = Path().absolute()
sys.path.append(str(notebook_path.parent))

In [3]:
import torch
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM

from neural_controllers import NeuralController

In [4]:
from utils import concept_dataset, multi_concept_dataset

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [5]:
model_type = 'llama'

if model_type=='llama':

    # model_id = "meta-llama/Meta-Llama-3.1-8B"
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    language_model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda"
    )

    use_fast_tokenizer = "LlamaForCausalLM" not in language_model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Examples: [128000, 15339] !


In [6]:
controller = NeuralController(
        language_model,
        tokenizer
    )


Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 8
M_batch_size         : 2048
n_components         : 5



In [8]:
concept_types = ['Biology', 'Classical Mechanics']#, 'Geology', 'Chemistry']
data_dir = "../data/science_subjects"

dataset = concept_dataset(data_dir, concept_types, controller)

train 300 test 122
train 300 test 122


In [10]:
controllers = {}
for concept_type in tqdm(concept_types):
    
    other_type = [k for k in concept_types if k != concept_type][0]
    
    train_data = dataset[concept_type]['train']
    test_data = dataset[concept_type]['test']
        
    controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        batch_size=4,
        control_method='pca'
    )
    
    controller.compute_directions(train_data['inputs'], train_data['labels']) 
                                           # log_path=f'../agop_spectra/{concept_type}', log_spectrum=True)
    
    controllers[concept_type] = controller
    

  0%|          | 0/2 [00:00<?, ?it/s]


Controller hyperparameters:
control_method       : pca
rfm_iters            : 8
forward_batch_size   : 4
M_batch_size         : 2048
n_components         : 5

train_y torch.Size([240, 1]) val_y torch.Size([60, 1])
Getting activations from forward passes



  0%|          | 0/75 [00:00<?, ?it/s][A
  1%|▏         | 1/75 [00:00<00:37,  1.98it/s][A
  3%|▎         | 2/75 [00:01<00:37,  1.94it/s][A
  4%|▍         | 3/75 [00:01<00:37,  1.93it/s][A
  5%|▌         | 4/75 [00:02<00:36,  1.93it/s][A
  7%|▋         | 5/75 [00:02<00:36,  1.93it/s][A
  8%|▊         | 6/75 [00:03<00:35,  1.92it/s][A
  9%|▉         | 7/75 [00:03<00:35,  1.92it/s][A
 11%|█         | 8/75 [00:04<00:34,  1.91it/s][A
 12%|█▏        | 9/75 [00:04<00:34,  1.92it/s][A
 13%|█▎        | 10/75 [00:05<00:33,  1.92it/s][A
 15%|█▍        | 11/75 [00:05<00:33,  1.92it/s][A
 16%|█▌        | 12/75 [00:06<00:32,  1.92it/s][A
 17%|█▋        | 13/75 [00:06<00:32,  1.92it/s][A
 19%|█▊        | 14/75 [00:07<00:31,  1.92it/s][A
 20%|██        | 15/75 [00:07<00:31,  1.93it/s][A
 21%|██▏       | 16/75 [00:08<00:30,  1.92it/s][A
 23%|██▎       | 17/75 [00:08<00:30,  1.92it/s][A
 24%|██▍       | 18/75 [00:09<00:29,  1.92it/s][A
 25%|██▌       | 19/75 [00:09<00:29,  1.93it/s]

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



  3%|▎         | 1/31 [00:00<00:08,  3.48it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



  6%|▋         | 2/31 [00:00<00:07,  3.73it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 10%|▉         | 3/31 [00:00<00:07,  3.82it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 13%|█▎        | 4/31 [00:01<00:06,  3.86it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 16%|█▌        | 5/31 [00:01<00:06,  3.88it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 19%|█▉        | 6/31 [00:01<00:06,  3.90it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 23%|██▎       | 7/31 [00:01<00:06,  3.90it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 26%|██▌       | 8/31 [00:02<00:05,  3.91it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 29%|██▉       | 9/31 [00:02<00:05,  3.91it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 32%|███▏      | 10/31 [00:02<00:05,  3.91it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 35%|███▌      | 11/31 [00:02<00:05,  3.91it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 39%|███▊      | 12/31 [00:03<00:04,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 42%|████▏     | 13/31 [00:03<00:04,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 45%|████▌     | 14/31 [00:03<00:04,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 48%|████▊     | 15/31 [00:03<00:04,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 52%|█████▏    | 16/31 [00:04<00:03,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 55%|█████▍    | 17/31 [00:04<00:03,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 58%|█████▊    | 18/31 [00:04<00:03,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 61%|██████▏   | 19/31 [00:04<00:03,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 65%|██████▍   | 20/31 [00:05<00:02,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 68%|██████▊   | 21/31 [00:05<00:02,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 71%|███████   | 22/31 [00:05<00:02,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 74%|███████▍  | 23/31 [00:05<00:02,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 77%|███████▋  | 24/31 [00:06<00:01,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 81%|████████  | 25/31 [00:06<00:01,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 84%|████████▍ | 26/31 [00:06<00:01,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 87%|████████▋ | 27/31 [00:06<00:01,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 90%|█████████ | 28/31 [00:07<00:00,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 94%|█████████▎| 29/31 [00:07<00:00,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 97%|█████████▋| 30/31 [00:07<00:00,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



100%|██████████| 31/31 [00:07<00:00,  3.90it/s][A


Computing signs
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidd


100%|██████████| 31/31 [00:00<00:00, 68146.45it/s]
 50%|█████     | 1/2 [00:47<00:47, 47.30s/it]


Controller hyperparameters:
control_method       : pca
rfm_iters            : 8
forward_batch_size   : 4
M_batch_size         : 2048
n_components         : 5

train_y torch.Size([240, 1]) val_y torch.Size([60, 1])
Getting activations from forward passes



  0%|          | 0/75 [00:00<?, ?it/s][A
  1%|▏         | 1/75 [00:00<00:38,  1.94it/s][A
  3%|▎         | 2/75 [00:01<00:38,  1.91it/s][A
  4%|▍         | 3/75 [00:01<00:37,  1.91it/s][A
  5%|▌         | 4/75 [00:02<00:37,  1.90it/s][A
  7%|▋         | 5/75 [00:02<00:36,  1.90it/s][A
  8%|▊         | 6/75 [00:03<00:36,  1.90it/s][A
  9%|▉         | 7/75 [00:03<00:35,  1.90it/s][A
 11%|█         | 8/75 [00:04<00:35,  1.90it/s][A
 12%|█▏        | 9/75 [00:04<00:34,  1.90it/s][A
 13%|█▎        | 10/75 [00:05<00:34,  1.90it/s][A
 15%|█▍        | 11/75 [00:05<00:33,  1.90it/s][A
 16%|█▌        | 12/75 [00:06<00:33,  1.90it/s][A
 17%|█▋        | 13/75 [00:06<00:32,  1.89it/s][A
 19%|█▊        | 14/75 [00:07<00:32,  1.89it/s][A
 20%|██        | 15/75 [00:07<00:31,  1.89it/s][A
 21%|██▏       | 16/75 [00:08<00:31,  1.89it/s][A
 23%|██▎       | 17/75 [00:08<00:30,  1.89it/s][A
 24%|██▍       | 18/75 [00:09<00:30,  1.89it/s][A
 25%|██▌       | 19/75 [00:10<00:29,  1.89it/s]

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



  3%|▎         | 1/31 [00:00<00:07,  3.98it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



  6%|▋         | 2/31 [00:00<00:07,  3.94it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 10%|▉         | 3/31 [00:00<00:07,  3.93it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 13%|█▎        | 4/31 [00:01<00:06,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 16%|█▌        | 5/31 [00:01<00:06,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 19%|█▉        | 6/31 [00:01<00:06,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 23%|██▎       | 7/31 [00:01<00:06,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 26%|██▌       | 8/31 [00:02<00:05,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 29%|██▉       | 9/31 [00:02<00:05,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 32%|███▏      | 10/31 [00:02<00:05,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 35%|███▌      | 11/31 [00:02<00:05,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 39%|███▊      | 12/31 [00:03<00:04,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 42%|████▏     | 13/31 [00:03<00:04,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 45%|████▌     | 14/31 [00:03<00:04,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 48%|████▊     | 15/31 [00:03<00:04,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 52%|█████▏    | 16/31 [00:04<00:03,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 55%|█████▍    | 17/31 [00:04<00:03,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 58%|█████▊    | 18/31 [00:04<00:03,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 61%|██████▏   | 19/31 [00:04<00:03,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 65%|██████▍   | 20/31 [00:05<00:02,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 68%|██████▊   | 21/31 [00:05<00:02,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 71%|███████   | 22/31 [00:05<00:02,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 74%|███████▍  | 23/31 [00:05<00:02,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 77%|███████▋  | 24/31 [00:06<00:01,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 81%|████████  | 25/31 [00:06<00:01,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 84%|████████▍ | 26/31 [00:06<00:01,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 87%|████████▋ | 27/31 [00:06<00:01,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 90%|█████████ | 28/31 [00:07<00:00,  3.91it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 94%|█████████▎| 29/31 [00:07<00:00,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



 97%|█████████▋| 30/31 [00:07<00:00,  3.92it/s][A

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 1])
Training PCA model



100%|██████████| 31/31 [00:07<00:00,  3.92it/s][A


Computing signs
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidden_state_projections torch.Size([300]) all_y torch.Size([300, 1])
hidd


100%|██████████| 31/31 [00:00<00:00, 68505.49it/s]
100%|██████████| 2/2 [01:35<00:00, 47.52s/it]


In [11]:
for concept_type in concept_types:
    controller = controllers[concept_type]
    # other_type = [k for k in concept_types if k!=concept_type][0]
    print("concept_type", concept_type)
    # print("Other type", other_type)
    
    controller.save(concept=f"{concept_type.replace(' ', '-')}", model_name='llama_3_8b_it', path='../directions/')

concept_type Biology
concept_type Classical Mechanics


## Multi-concept

In [30]:
dataset = multi_concept_dataset(data_dir, concept_types, user_tag=user_tag, assistant_tag=assistant_tag)#, n_train=128)

train_data = dataset['train']
test_data = dataset['test']

language_controller = NeuralController(
    language_model,
    tokenizer,
    rfm_iters=8,
    batch_size=4,
    control_method='rfm'
)

language_controller.compute_directions(train_data['inputs'], train_data['labels'],
                                       log_path=f'../agop_spectra/multi_science', log_spectrum=True)


Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 4
M_batch_size         : 2048
n_components         : 1

Getting activations from forward passes


100%|██████████| 75/75 [00:33<00:00,  2.27it/s]
  0%|          | 0/31 [00:00<?, ?it/s]

train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.0020981610286980867, R2: None, reg: 0.001, bw: 1000, acc: 100.0


  3%|▎         | 1/31 [00:02<01:25,  2.86s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-1.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.002401650184765458, R2: None, reg: 0.001, bw: 1000, acc: 100.0


  6%|▋         | 2/31 [00:05<01:20,  2.76s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-2.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.002959304256364703, R2: None, reg: 0.001, bw: 1000, acc: 100.0


 10%|▉         | 3/31 [00:08<01:19,  2.83s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-3.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.0029589904006570578, R2: None, reg: 0.001, bw: 1000, acc: 100.0


 13%|█▎        | 4/31 [00:11<01:14,  2.77s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-4.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.003636735724285245, R2: None, reg: 0.001, bw: 100, acc: 100.0


 16%|█▌        | 5/31 [00:14<01:15,  2.89s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-5.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.004714768845587969, R2: None, reg: 0.001, bw: 100, acc: 100.0


 19%|█▉        | 6/31 [00:17<01:12,  2.90s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-6.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.005513761192560196, R2: None, reg: 0.001, bw: 1000, acc: 100.0


 23%|██▎       | 7/31 [00:20<01:11,  2.97s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-7.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.005704257171601057, R2: None, reg: 0.001, bw: 100, acc: 100.0


 26%|██▌       | 8/31 [00:23<01:07,  2.94s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-8.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.005211944691836834, R2: None, reg: 0.001, bw: 1000, acc: 100.0


 29%|██▉       | 9/31 [00:26<01:05,  2.99s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-9.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.006502476986497641, R2: None, reg: 0.001, bw: 100, acc: 100.0


 32%|███▏      | 10/31 [00:29<01:01,  2.95s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-10.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.006231639068573713, R2: None, reg: 0.001, bw: 1000, acc: 100.0


 35%|███▌      | 11/31 [00:32<00:58,  2.93s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-11.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.00794919766485691, R2: None, reg: 0.001, bw: 100, acc: 98.33334350585938


 39%|███▊      | 12/31 [00:34<00:54,  2.86s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-12.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.008004785515367985, R2: None, reg: 0.01, bw: 100, acc: 100.0


 42%|████▏     | 13/31 [00:37<00:51,  2.88s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-13.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.0074312458746135235, R2: None, reg: 0.001, bw: 100, acc: 100.0


 45%|████▌     | 14/31 [00:40<00:48,  2.83s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-14.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.008539503440260887, R2: None, reg: 0.001, bw: 100, acc: 100.0


 48%|████▊     | 15/31 [00:43<00:45,  2.85s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-15.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.008295046165585518, R2: None, reg: 0.001, bw: 100, acc: 98.33334350585938


 52%|█████▏    | 16/31 [00:46<00:46,  3.08s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-16.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.010946803726255894, R2: None, reg: 0.001, bw: 100, acc: 98.33334350585938


 55%|█████▍    | 17/31 [00:49<00:42,  3.02s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-17.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.014955594204366207, R2: None, reg: 0.001, bw: 100, acc: 98.33334350585938


 58%|█████▊    | 18/31 [00:52<00:39,  3.02s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-18.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.012817835435271263, R2: None, reg: 0.001, bw: 100, acc: 96.66667175292969


 61%|██████▏   | 19/31 [00:55<00:36,  3.07s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-19.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.01776823401451111, R2: None, reg: 0.001, bw: 100, acc: 96.66667175292969


 65%|██████▍   | 20/31 [00:58<00:33,  3.06s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-20.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.018211757764220238, R2: None, reg: 0.001, bw: 100, acc: 96.66667175292969


 68%|██████▊   | 21/31 [01:02<00:31,  3.19s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-21.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.011796312406659126, R2: None, reg: 0.001, bw: 5, acc: 98.33334350585938


 71%|███████   | 22/31 [01:05<00:27,  3.10s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-22.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.005552544258534908, R2: None, reg: 0.001, bw: 5, acc: 98.33334350585938


 74%|███████▍  | 23/31 [01:08<00:24,  3.11s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-23.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.005343331489712, R2: None, reg: 0.01, bw: 5, acc: 98.33334350585938


 77%|███████▋  | 24/31 [01:11<00:20,  2.98s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-24.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.004615167621523142, R2: None, reg: 0.01, bw: 5, acc: 98.33334350585938


 81%|████████  | 25/31 [01:14<00:17,  2.96s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-25.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.0063772243447601795, R2: None, reg: 0.001, bw: 5, acc: 98.33334350585938


 84%|████████▍ | 26/31 [01:16<00:14,  2.89s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-26.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.0036815761122852564, R2: None, reg: 0.001, bw: 10, acc: 100.0


 87%|████████▋ | 27/31 [01:19<00:11,  2.89s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-27.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.003232859540730715, R2: None, reg: 0.001, bw: 10, acc: 100.0


 90%|█████████ | 28/31 [01:22<00:08,  2.84s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-28.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.00330732180736959, R2: None, reg: 0.001, bw: 10, acc: 98.33334350585938


 94%|█████████▎| 29/31 [01:25<00:05,  2.86s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-29.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.01613503508269787, R2: None, reg: 0.01, bw: 0.2, acc: 95.00000762939453


 97%|█████████▋| 30/31 [01:28<00:02,  2.81s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-30.pt
train X shape: torch.Size([240, 4096]) train y shape: torch.Size([240, 4]) val X shape: torch.Size([60, 4096]) val y shape: torch.Size([60, 4])
Best RFM loss: 0.01888142339885235, R2: None, reg: 0.001, bw: 5, acc: 96.66667175292969


100%|██████████| 31/31 [01:30<00:00,  2.93s/it]

spectrum_filename ../agop_spectra/multi_science_layer_-31.pt
Computing signs





UnboundLocalError: cannot access local variable 'signs' where it is not associated with a value

## Control

In [None]:
def combine_directions(poetry_dirs, harmful_dirs, a=0.5, b=0.5):
    return {
       k: a * poetry_dirs[k] + b * harmful_dirs[k]
       for k in poetry_dirs.keys()
    }

In [75]:
concept_type='english_shakespeare'
poetry_controller = NeuralController(
    language_model,
    tokenizer,
    control_method='rfm'
)
poetry_controller.load(concept=f'{concept_type}', model_name='llama_3_8b_it', path='../directions/')

Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 8
M_batch_size         : 2048
n_components         : 5

Detector found


In [82]:
concept_types = ['Biology', 'Classical Mechanics']
controllers = {}

combine=False

for concept_type in concept_types:
    
    controller = NeuralController(
        language_model,
        tokenizer,
        control_method='rfm'
    )
    
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.load(concept=f"{concept_type.replace(' ', '-')}", model_name='llama_3_8b_it', path='../directions/')
    
    if combine:
        controller.directions = combine_directions(poetry_controller.directions, controller.directions, a=0.9, b=1.0)
    controllers[concept_type] = controller

Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 8
M_batch_size         : 2048
n_components         : 5

Detector found
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 8
M_batch_size         : 2048
n_components         : 5

Detector found


In [84]:
# concept = "Biology"
concept = "Classical Mechanics"

raw_inputs = [
    # f"What might someone who studies biology be interested in?",
    f"What is interesting about bird flight?",
    # f"What is interesting about oxygen?",
    # f"What governs the movement of air in the respiratory system?",
]
inputs = [controller.format_prompt(x) for x in raw_inputs]


num_new_tokens = 180
controller = controllers[concept]

coef=0.4 #llama 
# coef=9

layers = list(range(-1, -31, -1))
# layers = list(range(-1, -41, -1))

gens=[]
print()
for i in inputs:
    print("Prompt:", i)
    # print("===== No Control =====")
    # print(controller.generate(i, max_new_tokens=num_new_tokens, do_sample=False).replace(i, ""))
    # print()
    
    print(f"===== + {concept} Control =====")
    gen = controller.generate(i, layers_to_control=layers, control_coef=coef, 
                                max_new_tokens=num_new_tokens, do_sample=False).replace(i, "")
    gens.append(gen)
    print(gen)
    print()
    print()

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Prompt: "<|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

What is interesting about bird flight?<|eot_id|>
===== + Classical Mechanics Control =====
<|start_header_id|>assistant<|end_header_id|>

Bird flight is a most wondrous and subtle thing, full of mysteries and paradoxes. Here are some of the most interesting aspects of bird flight:

1. **The Lifting Line**: The concept of the lifting line, introduced by Sir George Taylor, shows that a wing can produce lift without any net force on the wing itself. This is a most curious property, for it implies that the wing can produce lift without any net force on the wing, and yet, the wing must be subject to a net force to produce lift.

2. **The Lifting Line Paradox**: The lifting line paradox, also known as the "lifting line conundrum," arises when one considers the motion of a wing in a fluid, such as air. The paradox states that,