# Train bilingual control probes

## Imports

In [2]:
import os
import sys

talk_tuner_path = "/root/sandbox/sandbox/bilingual_user_models/talk_tuner"
sys.path.append(talk_tuner_path)

import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F
from src.losses import edl_mse_loss

from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm.auto import tqdm


import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset

from src.probes import ProbeClassification, ProbeClassificationMixScaler
from src.train_test_utils import train, test
import torch.nn as nn

import time

tic, toc = (time.time, time.time)


## Models

In [3]:
device = "cuda" if torch.cuda.is_available() else "mps"
access_token = "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"
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-13b-chat-hf", use_auth_token=access_token
)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-13b-chat-hf", use_auth_token=access_token
)
model.half().to(device)
model.eval()




Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 5120)
    (layers): ModuleList(
      (0-39): 40 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (k_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (v_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (o_proj): Linear(in_features=5120, out_features=5120, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (up_proj): Linear(in_features=5120, out_features=13824, bias=False)
          (down_proj): Linear(in_features=13824, out_features=5120, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((5120,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((5120,), eps=1e-05)
      )
    )
    (no

## Control Probe [age]

In [4]:
from src.probes import LinearProbeClassification
import sklearn.model_selection
import pickle

### Training config

In [5]:

class TrainerConfig:
    # optimization parameters
    learning_rate = 1e-3
    betas = (0.9, 0.95)
    weight_decay = 0.1  # only applied on matmul weights
    # learning rate decay params: linear warmup followed by cosine decay to 10% of original
    # checkpoint settings

    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            setattr(self, k, v)


In [15]:

from pydantic import BaseModel

uncertainty: bool = False
logistic: bool = True

new_format: bool = True
residual_stream: bool = True
if_augmented: bool = False
remove_last_ai_response: bool = True
include_inst: bool = True
one_hot: bool = True





### Training Utils 

Need to instantiate a couple of things: 
1. Map from label to [0, 1]
2. Directories to load into the dataset 


In [16]:
label_to_id_age = {"child": 0,
                   "adolescent": 1,
                   "adult": 2,
                   "older adult": 3,
                  }

### Dataset 

Load in the TextDataset()

In [17]:
import torch
from src.dataset import TextDataset

#### [TODO] Mix with spanish dataset

In [18]:
directory = "../dataset/llama_age_1/"
additional_dataset = ["../dataset/llama_age_2/", "../dataset/openai_age_1/","../dataset/openai_age_2/"]

In [19]:
dataset = TextDataset(directory, tokenizer, model, label_idf="_age_", label_to_id=label_to_id_age,
                        convert_to_llama2_format=True, additional_datas=additional_dataset, 
                        new_format=new_format, control_probe=True,
                        residual_stream=residual_stream, if_augmented=if_augmented, 
                        remove_last_ai_response=remove_last_ai_response, include_inst=include_inst, k=1,
                        one_hot=False, last_tok_pos=-1)

  0%|          | 0/4000 [00:00<?, ?it/s]

Corrupted file at ../dataset/openai_age_2/conversation_650_age_adolescent.txt
Corrupted file at ../dataset/openai_age_2/conversation_735_age_older adult.txt
Corrupted file at ../dataset/openai_age_2/conversation_740_age_child.txt


### Training Loop

In [20]:
label_to_id = label_to_id_age

In [21]:
# SPlit 

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_idx, val_idx = sklearn.model_selection.train_test_split(list(range(len(dataset))), 
                                                              test_size=test_size,
                                                              train_size=train_size,
                                                              random_state=12345,
                                                              shuffle=True,
                                                              stratify=dataset.labels,
                                                            )

train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, val_idx)

sampler = None
train_loader = DataLoader(train_dataset, shuffle=True, sampler=sampler, pin_memory=True, batch_size=200, num_workers=1)
test_loader = DataLoader(test_dataset, shuffle=False, pin_memory=True, batch_size=400, num_workers=1)

if uncertainty:
    loss_func = edl_mse_loss
else:
    loss_func = nn.BCELoss()

torch_device = "cuda"

accuracy_dict = {}

dict_name = "age"

# seeds = seeds[:9]
accuracy_dict[dict_name] = []
accuracy_dict[dict_name + "_final"] = []
accuracy_dict[dict_name + "_train"] = []

In [1]:
accs = []
final_accs = []
train_accs = []

os.environ["TOKENIZERS_PARALLELISM"] = "false"


for i in tqdm(range(0, 41)):
    trainer_config = TrainerConfig()
    probe = LinearProbeClassification(probe_class=len(label_to_id.keys()), device="cuda", input_dim=5120,
                                        logistic=logistic)
    optimizer, scheduler = probe.configure_optimizers(trainer_config)
    best_acc = 0
    max_epoch = 50
    verbosity = False
    layer_num = i
    print("-" * 40 + f"Layer {layer_num}" + "-" * 40)
    for epoch in range(1, max_epoch + 1):
        if epoch == max_epoch:
            verbosity = True
        # Get the train results from training of each epoch
        if uncertainty:
            train_results = train(probe, torch_device, train_loader, optimizer, 
                                  epoch, loss_func=loss_func, verbose_interval=None,
                                    verbose=verbosity, layer_num=layer_num, 
                                    return_raw_outputs=True, epoch_num=epoch, num_classes=len(label_to_id.keys()))
            test_results = test(probe, torch_device, test_loader, loss_func=loss_func, 
                                return_raw_outputs=True, verbose=verbosity, layer_num=layer_num,
                                scheduler=scheduler, epoch_num=epoch, num_classes=len(label_to_id.keys()))
        # TODO: just remove this else case
        else:
            train_results = train(probe, torch_device, train_loader, optimizer, 
                                    epoch, loss_func=loss_func, verbose_interval=None,
                                    verbose=verbosity, layer_num=layer_num,
                                    return_raw_outputs=True,
                                    one_hot=args.one_hot, num_classes=len(label_to_id.keys()))
            test_results = test(probe, torch_device, test_loader, loss_func=loss_func, 
                                return_raw_outputs=True, verbose=verbosity, layer_num=layer_num,
                                scheduler=scheduler,
                                one_hot=args.one_hot, num_classes=len(label_to_id.keys()))

        if test_results[1] > best_acc:
            best_acc = test_results[1]
            torch.save(probe.state_dict(), f"../probe_checkpoints/controlling_probe/{dict_name}_probe_at_layer_{layer_num}.pth")
    torch.save(probe.state_dict(), f"../probe_checkpoints/controlling_probe/{dict_name}_probe_at_layer_{layer_num}_final.pth")
    
    accs.append(best_acc)
    final_accs.append(test_results[1])
    train_accs.append(train_results[1])
    cm = confusion_matrix(test_results[3], test_results[2])
    cm_display = ConfusionMatrixDisplay(cm, display_labels=label_to_id.keys()).plot()
    plt.show()

    accuracy_dict[dict_name].append(accs)
    accuracy_dict[dict_name + "_final"].append(final_accs)
    accuracy_dict[dict_name + "_train"].append(train_accs)
    
    with open("../probe_checkpoints/controlling_probe_experiment.pkl", "wb") as outfile:
        pickle.dump(accuracy_dict, outfile)
del dataset, train_dataset, test_dataset, train_loader, test_loader
torch.cuda.empty_cache()

NameError: name 'os' is not defined