In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
from pathlib import Path

notebook_path = Path().absolute()
sys.path.append(str(notebook_path.parent))

In [3]:
import torch
from tqdm import tqdm
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
from neural_controllers import NeuralController

In [4]:
from utils import shakespeare_dataset

torch.manual_seed(0)
torch.cuda.manual_seed(0)
np.random.seed(0)

In [5]:
model_type = 'llama'

if model_type=='llama':

    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

    language_model = AutoModelForCausalLM.from_pretrained(
        model_id, device_map="cuda"
    )

    use_fast_tokenizer = "LlamaForCausalLM" not in language_model.config.architectures
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast_tokenizer, padding_side="left", legacy=False)
    tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    model_name='llama_3_8b_it'
    assistant_tag = '<|start_header_id|>assistant<|end_header_id|>'
    
elif model_type=='gemma':

    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")
    language_model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-9b-it",
        device_map="auto",
        torch_dtype=torch.bfloat16,
    )
    model_name='gemma_2_9b_it'


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [6]:
controller = NeuralController(
        language_model,
        tokenizer
    )

n_components: 5
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 10
forward_batch_size   : 2
M_batch_size         : 2048
n_components         : 5
calibrate            : False



In [7]:
concept_types = ['english', 'shakespeare']

data_dir = "../data/languages"

data = shakespeare_dataset(data_dir, concept_types, controller, assistant_tag)

train 200 test 200
train 200 test 200


In [8]:
controllers = {}
for concept_type in tqdm(concept_types):
    
    other_type = [k for k in concept_types if k != concept_type][0]
    
    train_data = data[concept_type]['train']
    test_data = data[concept_type]['test']
        
    language_controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        batch_size=2,
        n_components=5,
        control_method='rfm'
    )
    
    language_controller.compute_directions(train_data['inputs'], train_data['labels'])
    
    controllers[concept_type] = language_controller
    

  0%|          | 0/2 [00:00<?, ?it/s]

n_components: 5
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 2
M_batch_size         : 2048
n_components         : 5
calibrate            : False

Tuning metric: auc
Getting activations from forward passes



  0%|          | 0/80 [00:00<?, ?it/s][A
  1%|▏         | 1/80 [00:00<00:16,  4.69it/s][A
  2%|▎         | 2/80 [00:00<00:13,  5.64it/s][A
  5%|▌         | 4/80 [00:00<00:09,  7.91it/s][A
  8%|▊         | 6/80 [00:00<00:08,  8.91it/s][A
 10%|█         | 8/80 [00:00<00:07,  9.45it/s][A
 12%|█▎        | 10/80 [00:01<00:07,  9.78it/s][A
 15%|█▌        | 12/80 [00:01<00:06, 10.00it/s][A
 18%|█▊        | 14/80 [00:01<00:06, 10.15it/s][A
 20%|██        | 16/80 [00:01<00:06, 10.22it/s][A
 22%|██▎       | 18/80 [00:01<00:06, 10.25it/s][A
 25%|██▌       | 20/80 [00:02<00:05, 10.30it/s][A
 28%|██▊       | 22/80 [00:02<00:05, 10.32it/s][A
 30%|███       | 24/80 [00:02<00:05, 10.34it/s][A
 32%|███▎      | 26/80 [00:02<00:05, 10.33it/s][A
 35%|███▌      | 28/80 [00:02<00:05, 10.36it/s][A
 38%|███▊      | 30/80 [00:03<00:04, 10.36it/s][A
 40%|████      | 32/80 [00:03<00:04, 10.38it/s][A
 42%|████▎     | 34/80 [00:03<00:04, 10.39it/s][A
 45%|████▌     | 36/80 [00:03<00:04, 10.40it

Getting activations from forward passes



  0%|          | 0/20 [00:00<?, ?it/s][A
 10%|█         | 2/20 [00:00<00:01, 10.47it/s][A
 20%|██        | 4/20 [00:00<00:01, 10.35it/s][A
 30%|███       | 6/20 [00:00<00:01, 10.29it/s][A
 40%|████      | 8/20 [00:00<00:01, 10.30it/s][A
 50%|█████     | 10/20 [00:00<00:00, 10.33it/s][A
 60%|██████    | 12/20 [00:01<00:00, 10.31it/s][A
 70%|███████   | 14/20 [00:01<00:00, 10.27it/s][A
 80%|████████  | 16/20 [00:01<00:00, 10.28it/s][A
 90%|█████████ | 18/20 [00:01<00:00, 10.31it/s][A
100%|██████████| 20/20 [00:01<00:00, 10.30it/s][A

  0%|          | 0/31 [00:00<?, ?it/s][A

train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.04992985725402832 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0043735504150390625 seconds
Early stopping at iteration 2
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.006851911544799805 seconds
Optimal M batch size: 160
Time taken for round 1: 0.004742860794067383 seconds
Early stopping at iteration 2
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003595590591430664 seconds
Early stopping at iteration 1
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003072023391723633 seconds
Early 


  3%|▎         | 1/31 [00:00<00:06,  4.63it/s][A

Time taken to compute eigenvectors: 0.048282623291015625 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0039844512939453125 seconds
Optimal M batch size: 160
Time taken for round 1: 0.004174947738647461 seconds
Optimal M batch size: 160
Time taken for round 2: 0.004091739654541016 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0035622119903564453 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003389596939086914 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0035300254821777344 seconds
Optimal M batch size: 160
Time taken for round 6: 0.004000663757324219 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0039691925048828125 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160


  6%|▋         | 2/31 [00:00<00:09,  3.08it/s][A

Time taken for round 0: 0.1859581470489502 seconds
Optimal M batch size: 160
Time taken for round 1: 0.006491661071777344 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003314971923828125 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0032677650451660156 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003195047378540039 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0031774044036865234 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003225564956665039 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033071041107177734 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003387928009033203 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0035257339477539062 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003360748291015625 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0036280155181884766 secon


 10%|▉         | 3/31 [00:00<00:07,  3.63it/s][A

Optimal M batch size: 160
Time taken for round 6: 0.004434823989868164 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033674240112304688 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0033195018768310547 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003516674041748047 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003524303436279297 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034363269805908203 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003376483917236328 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0033822059631347656 seconds
Optimal M batch size: 160
Time taken for round 6: 0.00397038459777832 seconds
Optimal M batch size: 160
Time taken for round 7: 0.00556492805480957 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.


 13%|█▎        | 4/31 [00:01<00:06,  3.95it/s][A

Time taken to compute eigenvectors: 0.010809183120727539 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0037703514099121094 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0034029483795166016 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033349990844726562 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0035572052001953125 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003403186798095703 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034792423248291016 seconds
Optimal M batch size: 160
Time taken for round 6: 0.005643367767333984 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003404378890991211 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 16


 16%|█▌        | 5/31 [00:01<00:06,  4.17it/s][A

Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20253562927246094 seconds
Time taken to compute eigenvectors: 0.010241985321044922 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0034530162811279297 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003408193588256836 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0032825469970703125 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003397226333618164 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003329753875732422 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003317594528198242 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0034723281860351562 seconds
Optimal M batch size: 160
Time taken for 


 19%|█▉        | 6/31 [00:01<00:05,  4.29it/s][A

Optimal M batch size: 160
Time taken for round 3: 0.0052640438079833984 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003334522247314453 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0036094188690185547 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0037627220153808594 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0036628246307373047 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20580077171325684 seconds
Time taken to compute eigenvectors: 0.009146451950073242 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003296375274658203 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0034580230712890625 seconds
Optimal M batch size: 160
Time taken fo


 23%|██▎       | 7/31 [00:01<00:05,  4.59it/s][A

Time taken for round 7: 0.00680851936340332 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.17367219924926758 seconds
Time taken to compute eigenvectors: 0.008723020553588867 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003402233123779297 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0034437179565429688 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0034444332122802734 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003510713577270508 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003458738327026367 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0037817955017089844 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0037934780120849


 26%|██▌       | 8/31 [00:01<00:04,  4.60it/s][A

Optimal M batch size: 160
Time taken for round 3: 0.0049591064453125 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003478527069091797 seconds
Optimal M batch size: 160
Time taken for round 5: 0.005083560943603516 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003941535949707031 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0032384395599365234 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20673537254333496 seconds
Time taken to compute eigenvectors: 0.0055425167083740234 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003105640411376953 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003527402877807617 seconds
Optimal M batch size: 160
Time taken for rou


 29%|██▉       | 9/31 [00:02<00:04,  4.85it/s][A
 32%|███▏      | 10/31 [00:02<00:04,  5.08it/s][A

Time taken to compute eigenvectors: 0.013906717300415039 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003154754638671875 seconds
Early stopping at iteration 1
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.002986907958984375 seconds
Early stopping at iteration 1
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003816843032836914 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0033698081970214844 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0034646987915039062 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0036368370056152344 seconds
Optimal M batch size: 160
Time taken for round


 35%|███▌      | 11/31 [00:02<00:03,  5.14it/s][A

Optimal M batch size: 160
Time taken for round 0: 0.0044329166412353516 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003361940383911133 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003259897232055664 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003297090530395508 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0035657882690429688 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0033292770385742188 seconds
Optimal M batch size: 160
Time taken for round 6: 0.005382061004638672 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003382444381713867 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003304004669189453 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003870725631713867 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0036153793334960938 seconds
Optimal M batch size: 160
Time taken for round 3:


 39%|███▊      | 12/31 [00:02<00:03,  5.20it/s][A

Optimal M batch size: 160
Time taken for round 4: 0.005306243896484375 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034291744232177734 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003305196762084961 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0034520626068115234 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0033111572265625 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003345966339111328 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0032935142517089844 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0038213729858398438 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033261775970458984 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0032987594604492188 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0038535594940185547 seconds
Optimal M batch size: 160
Time taken for round 7


 42%|████▏     | 13/31 [00:02<00:03,  4.91it/s][A

Time taken for round 4: 0.004360675811767578 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003177165985107422 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003286600112915039 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003213644027709961 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0032341480255126953 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0034224987030029297 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033783912658691406 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003347158432006836 seconds
Optimal M batch size: 160
Time taken for round 4: 0.004978179931640625 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0037689208984375 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0032644271850585938 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003211498260498047 second


 45%|████▌     | 14/31 [00:03<00:03,  4.98it/s][A

Optimal M batch size: 160
Time taken for round 2: 0.004702329635620117 seconds
Early stopping at iteration 3
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003161191940307617 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0033402442932128906 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033712387084960938 seconds
Optimal M batch size: 160
Time taken for round 3: 0.004877567291259766 seconds
Optimal M batch size: 160
Time taken for round 4: 0.004259586334228516 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003348827362060547 seconds
Optimal M batch size: 160
Time taken for round 6: 0.00327301025390625 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003315448760986328 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0030150413513183594 seconds
Optimal M batch size: 


 48%|████▊     | 15/31 [00:03<00:03,  4.77it/s][A

Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0033159255981445312 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003476858139038086 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033867359161376953 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0032532215118408203 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003299236297607422 seconds
Optimal M batch size: 160
Time taken for round 5: 0.004470109939575195 seconds
Optimal M batch size: 160
Time taken for round 6: 0.004413604736328125 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033409595489501953 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0030031204223632812 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0034346580505371094 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033211708068847656 se


 52%|█████▏    | 16/31 [00:03<00:03,  4.62it/s][A

Time taken to compute eigenvectors: 0.01989006996154785 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0038030147552490234 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0036792755126953125 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003686189651489258 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0038809776306152344 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003579378128051758 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034461021423339844 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0035402774810791016 seconds
Optimal M batch size: 160
Time taken for round 7: 0.004495143890380859 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160


 55%|█████▍    | 17/31 [00:03<00:03,  4.53it/s][A

Time taken for round 6: 0.0050814151763916016 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033597946166992188 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20655059814453125 seconds
Time taken to compute eigenvectors: 0.01957416534423828 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.004370450973510742 seconds
Optimal M batch size: 160
Time taken for round 1: 0.004621744155883789 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033528804779052734 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033299922943115234 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003263235092163086 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003224372863769


 58%|█████▊    | 18/31 [00:03<00:02,  4.47it/s][A

Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0031714439392089844 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003638029098510742 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0037865638732910156 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033686161041259766 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033659934997558594 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034744739532470703 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0037293434143066406 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003988504409790039 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20595502853393555 seconds
Time taken to compute eigenvectors: 0.021345853805541992 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40,


 61%|██████▏   | 19/31 [00:04<00:02,  4.45it/s][A

Optimal M batch size: 160
Time taken for round 1: 0.0046427249908447266 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0032765865325927734 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003729581832885742 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003398418426513672 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034613609313964844 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0033664703369140625 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003410816192626953 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0034096240997314453 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003195524215698242 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0031795501708984375 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0031304359436035156 seconds
Optimal M batch size: 160
Time taken for round


 65%|██████▍   | 20/31 [00:04<00:02,  4.46it/s][A

Optimal M batch size: 160
Time taken for round 5: 0.004749774932861328 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003363370895385742 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003657102584838867 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0033485889434814453 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0033578872680664062 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033953189849853516 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034270286560058594 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003555774688720703 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034160614013671875 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003470182418823242 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003699064254760742 seconds
Optimal M batch size: 160
Fitting RFM with ntrai


 68%|██████▊   | 21/31 [00:04<00:02,  4.45it/s][A

Optimal M batch size: 160
Time taken for round 7: 0.005259275436401367 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003294706344604492 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0038213729858398438 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0035474300384521484 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003344297409057617 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003313302993774414 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003284454345703125 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0034284591674804688 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033087730407714844 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0031213760375976562 seconds
Optimal M batch size: 160
Time taken for round 1:


 71%|███████   | 22/31 [00:04<00:02,  4.44it/s][A

Optimal M batch size: 160
Time taken for round 1: 0.004453897476196289 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033643245697021484 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003710508346557617 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003336668014526367 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0035076141357421875 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0038313865661621094 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033288002014160156 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003198862075805664 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0033190250396728516 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0038514137268066406 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033969879150390625 seconds
Optimal M batch size: 160
Time taken for round


 74%|███████▍  | 23/31 [00:05<00:01,  4.47it/s][A

Optimal M batch size: 160
Time taken for round 5: 0.004540205001831055 seconds
Optimal M batch size: 160
Time taken for round 6: 0.004233837127685547 seconds
Optimal M batch size: 160
Time taken for round 7: 0.004389286041259766 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0029349327087402344 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0031588077545166016 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0031843185424804688 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0032422542572021484 seconds
Optimal M batch size: 160
Time taken for round 4: 0.00321197509765625 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003185749053955078 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003769397735595703 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0034372806549072266 seconds
Optimal M batch size: 160
Fitting RFM with ntrain


 77%|███████▋  | 24/31 [00:05<00:01,  4.45it/s][A

Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003239870071411133 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003288745880126953 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0032835006713867188 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003409147262573242 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003278970718383789 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003298521041870117 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0050389766693115234 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003886699676513672 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003293752670288086 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0036897659301757812 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003542661666870117 second


 81%|████████  | 25/31 [00:07<00:05,  1.07it/s][A

Time taken to compute eigenvectors: 2.3817763328552246 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003683805465698242 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003255605697631836 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003230571746826172 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003176450729370117 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003236532211303711 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003326416015625 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0036804676055908203 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033998489379882812 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time t


 84%|████████▍ | 26/31 [00:08<00:03,  1.38it/s][A

Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20157909393310547 seconds
Time taken to compute eigenvectors: 0.019203901290893555 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0034978389739990234 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003794431686401367 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003572225570678711 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0032968521118164062 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033309459686279297 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003336668014526367 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0032973289489746094 seconds
Optimal M batch size: 160
Time taken for


 87%|████████▋ | 27/31 [00:08<00:02,  1.74it/s][A

Optimal M batch size: 160
Time taken for round 2: 0.0042798519134521484 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033774375915527344 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0032949447631835938 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003301858901977539 seconds
Optimal M batch size: 160
Time taken for round 6: 0.00532984733581543 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003464937210083008 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20359492301940918 seconds
Time taken to compute eigenvectors: 0.019266128540039062 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003408670425415039 seconds
Optimal M batch size: 160
Time taken for r


 90%|█████████ | 28/31 [00:08<00:01,  2.13it/s][A

Optimal M batch size: 160
Time taken for round 4: 0.00460362434387207 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0033447742462158203 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0037620067596435547 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003534555435180664 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0031228065490722656 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003449678421020508 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0053255558013916016 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034859180450439453 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003220081329345703 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0032129287719726562 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003223419189453125 seconds
Optimal M batch size: 160
Time taken for round 7


 94%|█████████▎| 29/31 [00:08<00:00,  2.51it/s][A

Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003057718276977539 seconds
Optimal M batch size: 160
Time taken for round 1: 0.00334930419921875 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003899812698364258 seconds
Optimal M batch size: 160
Time taken for round 3: 0.01645207405090332 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033006668090820312 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003386974334716797 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0032312870025634766 seconds
Optimal M batch size: 160
Time taken for round 7: 0.004414796829223633 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003020048141479492 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003525972366333008 seconds
Optimal M batch size: 160
Time taken for round 2: 0.00


 97%|█████████▋| 30/31 [00:09<00:00,  2.88it/s][A

Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0032062530517578125 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003437519073486328 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0036361217498779297 seconds
Optimal M batch size: 160
Time taken for round 3: 0.005038738250732422 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0036377906799316406 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003311634063720703 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003283262252807617 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003333568572998047 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0031003952026367188 seconds
Optimal M batch size: 160
Time taken for round 1: 0.004163980484008789 seconds
Optimal M batch size: 160
Time taken for round 2: 


100%|██████████| 31/31 [00:09<00:00,  3.34it/s][A


Time taken for round 1: 0.0046460628509521484 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003938436508178711 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003934144973754883 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003552675247192383 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0033833980560302734 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003531932830810547 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0039904117584228516 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0032088756561279297 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003464937210083008 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003641366958618164 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0037970542907714844 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0036766529083251953 se


100%|██████████| 31/31 [00:00<00:00, 15061.21it/s]
 50%|█████     | 1/2 [00:19<00:19, 19.24s/it]

n_components: 5
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 2
M_batch_size         : 2048
n_components         : 5
calibrate            : False

Tuning metric: auc
Getting activations from forward passes



  0%|          | 0/80 [00:00<?, ?it/s][A
  1%|▏         | 1/80 [00:00<00:09,  8.08it/s][A
  2%|▎         | 2/80 [00:00<00:09,  7.93it/s][A
  4%|▍         | 3/80 [00:00<00:09,  7.87it/s][A
  5%|▌         | 4/80 [00:00<00:09,  7.87it/s][A
  6%|▋         | 5/80 [00:00<00:09,  7.85it/s][A
  8%|▊         | 6/80 [00:00<00:09,  7.84it/s][A
  9%|▉         | 7/80 [00:00<00:09,  7.85it/s][A
 10%|█         | 8/80 [00:01<00:09,  7.84it/s][A
 11%|█▏        | 9/80 [00:01<00:09,  7.83it/s][A
 12%|█▎        | 10/80 [00:01<00:08,  7.83it/s][A
 14%|█▍        | 11/80 [00:01<00:08,  7.84it/s][A
 15%|█▌        | 12/80 [00:01<00:08,  7.85it/s][A
 16%|█▋        | 13/80 [00:01<00:08,  7.83it/s][A
 18%|█▊        | 14/80 [00:01<00:08,  7.82it/s][A
 19%|█▉        | 15/80 [00:01<00:08,  7.84it/s][A
 20%|██        | 16/80 [00:02<00:08,  7.85it/s][A
 21%|██▏       | 17/80 [00:02<00:08,  7.85it/s][A
 22%|██▎       | 18/80 [00:02<00:07,  7.82it/s][A
 24%|██▍       | 19/80 [00:02<00:07,  7.82it/s]

Getting activations from forward passes



  0%|          | 0/20 [00:00<?, ?it/s][A
 10%|█         | 2/20 [00:00<00:01, 10.34it/s][A
 20%|██        | 4/20 [00:00<00:01, 10.26it/s][A
 30%|███       | 6/20 [00:00<00:01, 10.29it/s][A
 40%|████      | 8/20 [00:00<00:01, 10.30it/s][A
 50%|█████     | 10/20 [00:00<00:00, 10.33it/s][A
 60%|██████    | 12/20 [00:01<00:00, 10.33it/s][A
 70%|███████   | 14/20 [00:01<00:00, 10.34it/s][A
 80%|████████  | 16/20 [00:01<00:00, 10.34it/s][A
 90%|█████████ | 18/20 [00:01<00:00, 10.34it/s][A
100%|██████████| 20/20 [00:01<00:00, 10.32it/s][A

  0%|          | 0/31 [00:00<?, ?it/s][A
  3%|▎         | 1/31 [00:00<00:04,  6.61it/s][A

train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.004344463348388672 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0041081905364990234 seconds
Optimal M batch size: 160
Time taken for round 2: 0.005003690719604492 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034172534942626953 seconds
Early stopping at iteration 4
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003628969192504883 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0037872791290283203 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003534078598022461 seconds
Early stopping at iteration 3
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time t


  6%|▋         | 2/31 [00:02<00:45,  1.59s/it][A

Time taken to compute eigenvectors: 2.3756370544433594 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.004022359848022461 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0035851001739501953 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0035016536712646484 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034630298614501953 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033681392669677734 seconds
Optimal M batch size: 160
Time taken for round 5: 0.004008054733276367 seconds
Optimal M batch size: 160
Time taken for round 6: 0.004733085632324219 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003255605697631836 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
T


 10%|▉         | 3/31 [00:05<00:57,  2.04s/it][A

Time taken to compute eigenvectors: 2.3810348510742188 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003952980041503906 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0038437843322753906 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0036003589630126953 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003336668014526367 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0034105777740478516 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003462076187133789 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0034644603729248047 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003367900848388672 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
T


 13%|█▎        | 4/31 [00:05<00:35,  1.33s/it][A

Optimal M batch size: 160
Time taken for round 6: 0.004652500152587891 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003439664840698242 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.2072451114654541 seconds
Time taken to compute eigenvectors: 0.02654242515563965 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003328084945678711 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003648042678833008 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003482341766357422 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0035572052001953125 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0057010650634765625 seconds
Optimal M batch size: 160
Time taken for ro


 16%|█▌        | 5/31 [00:05<00:24,  1.07it/s][A

Time taken for round 6: 0.0045986175537109375 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0034673213958740234 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0029630661010742188 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0032341480255126953 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0031976699829101562 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003226041793823242 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003182649612426758 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003192901611328125 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003239154815673828 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0037865638732910156 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.20383644104003906 seconds
Time take


 19%|█▉        | 6/31 [00:06<00:17,  1.44it/s][A

Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003290891647338867 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003271341323852539 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003254413604736328 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033483505249023438 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003315448760986328 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0032625198364257812 seconds
Optimal M batch size: 160
Time taken for round 6: 0.004651784896850586 seconds
Optimal M batch size: 160
Time taken for round 7: 0.004251956939697266 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0031778812408447266 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0036864280700683594 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0036079883575439453 seco


 23%|██▎       | 7/31 [00:06<00:13,  1.85it/s][A

Optimal M batch size: 160
Time taken for round 2: 0.0060918331146240234 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033690929412841797 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003295421600341797 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003278970718383789 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0033173561096191406 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003240823745727539 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0050661563873291016 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0035860538482666016 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003364086151123047 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003406524658203125 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003609895706176758 seconds
Optimal M batch size: 160
Time taken for round 5


 26%|██▌       | 8/31 [00:06<00:10,  2.26it/s][A

Optimal M batch size: 160
Time taken for round 4: 0.00435948371887207 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003329038619995117 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0032987594604492188 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003204345703125 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003443002700805664 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0035271644592285156 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003488302230834961 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003386974334716797 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0035316944122314453 seconds
Optimal M batch size: 160
Time taken for round 5: 0.005343437194824219 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003321409225463867 seconds
Optimal M batch size: 160
Time taken for round 7: 0.00


 29%|██▉       | 9/31 [00:06<00:08,  2.69it/s][A

Time taken for round 0: 0.004334211349487305 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003504037857055664 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0036003589630126953 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034525394439697266 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003332376480102539 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0033283233642578125 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003265380859375 seconds
Optimal M batch size: 160
Time taken for round 7: 0.004828691482543945 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0030329227447509766 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0035250186920166016 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0032472610473632812 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0031731128692626953 seco


 32%|███▏      | 10/31 [00:06<00:06,  3.22it/s][A

Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003181934356689453 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003464221954345703 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003325939178466797 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033011436462402344 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033082962036132812 seconds
Optimal M batch size: 160
Time taken for round 5: 0.004398345947265625 seconds
Optimal M batch size: 160
Time taken for round 6: 0.004506826400756836 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003299713134765625 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0029795169830322266 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0032613277435302734 seconds
Optimal M batch size: 160
Time taken for round 2: 


 35%|███▌      | 11/31 [00:07<00:05,  3.70it/s][A

Time taken for round 6: 0.005358219146728516 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003383159637451172 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003610372543334961 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0040891170501708984 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003535032272338867 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033216476440429688 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003569364547729492 seconds
Optimal M batch size: 160
Time taken for round 5: 0.004019021987915039 seconds
Optimal M batch size: 160
Time taken for round 6: 0.004031658172607422 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0035746097564697266 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0034303665161132812 secon


 39%|███▊      | 12/31 [00:07<00:04,  4.14it/s][A

Optimal M batch size: 160
Time taken for round 4: 0.004734039306640625 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034799575805664062 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003323078155517578 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003438234329223633 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0031325817108154297 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0031948089599609375 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0031647682189941406 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003237009048461914 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003196239471435547 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0031964778900146484 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003217458724975586 seconds
Optimal M batch size: 160
Time taken for round 7


 42%|████▏     | 13/31 [00:07<00:04,  4.48it/s][A

Optimal M batch size: 160
Time taken for round 3: 0.005311489105224609 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0036249160766601562 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003414630889892578 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0033271312713623047 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003341197967529297 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.15221738815307617 seconds
Time taken to compute eigenvectors: 0.02454996109008789 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003889799118041992 seconds
Early stopping at iteration 1
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M bat


 45%|████▌     | 14/31 [00:07<00:03,  4.65it/s][A

Optimal M batch size: 160
Time taken for round 5: 0.0052547454833984375 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0036973953247070312 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033731460571289062 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 10, center_grads: True
Time taken to train rfm probe: 0.16866159439086914 seconds
Time taken to compute eigenvectors: 0.023534774780273438 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0036156177520751953 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0034165382385253906 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0037217140197753906 seconds
Optimal M batch size: 160
Time taken for round 3: 0.004900455474853516 seconds
Optimal M batch size: 160
Time taken 


 48%|████▊     | 15/31 [00:10<00:14,  1.07it/s][A

Time taken to compute eigenvectors: 2.3749217987060547 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.00529789924621582 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0033998489379882812 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033538341522216797 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003242015838623047 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033066272735595703 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0032989978790283203 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003225088119506836 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0032529830932617188 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
T


 52%|█████▏    | 16/31 [00:10<00:10,  1.39it/s][A

Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20232224464416504 seconds
Time taken to compute eigenvectors: 0.014586448669433594 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.00546717643737793 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003418445587158203 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0032944679260253906 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003574848175048828 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0037233829498291016 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003701448440551758 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003531217575073242 seconds
Optimal M batch size: 160
Time taken for ro


 55%|█████▍    | 17/31 [00:10<00:08,  1.74it/s][A

Optimal M batch size: 160
Time taken for round 1: 0.005792856216430664 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003337860107421875 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0032629966735839844 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003258943557739258 seconds
Optimal M batch size: 160
Time taken for round 5: 0.00323486328125 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003500223159790039 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0032410621643066406 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.21011614799499512 seconds
Time taken to compute eigenvectors: 0.027493715286254883 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round


 58%|█████▊    | 18/31 [00:10<00:06,  2.12it/s][A

Optimal M batch size: 160
Time taken for round 1: 0.00452113151550293 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003353118896484375 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003314971923828125 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033686161041259766 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003655672073364258 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0035872459411621094 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0036101341247558594 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003137350082397461 seconds
Optimal M batch size: 160
Time taken for round 1: 0.004581928253173828 seconds
Optimal M batch size: 160
Time taken for round 2: 0.004245758056640625 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0032596588134765625 seconds
Optimal M batch size: 160
Time taken for round 4: 


 61%|██████▏   | 19/31 [00:11<00:04,  2.52it/s][A

Optimal M batch size: 160
Time taken for round 4: 0.004940509796142578 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034041404724121094 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0035283565521240234 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033884048461914062 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0030257701873779297 seconds
Optimal M batch size: 160
Time taken for round 1: 0.004758119583129883 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0044438838958740234 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034706592559814453 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033369064331054688 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0032672882080078125 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0034253597259521484 seconds
Optimal M batch size: 160
Time taken for rou


 65%|██████▍   | 20/31 [00:11<00:03,  2.88it/s][A

Optimal M batch size: 160
Time taken for round 7: 0.005202531814575195 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0032656192779541016 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003365039825439453 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0034494400024414062 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003865480422973633 seconds
Optimal M batch size: 160
Time taken for round 4: 0.004950761795043945 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0036156177520751953 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0033299922943115234 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0033216476440429688 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0030601024627685547 seconds
Optimal M batch size: 160
Time taken for round 1


 68%|██████▊   | 21/31 [00:11<00:03,  3.21it/s][A

Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0032808780670166016 seconds
Optimal M batch size: 160
Time taken for round 1: 0.004057645797729492 seconds
Optimal M batch size: 160
Time taken for round 2: 0.004170894622802734 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003573894500732422 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003360271453857422 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003557920455932617 seconds
Optimal M batch size: 160
Time taken for round 6: 0.004010677337646484 seconds
Optimal M batch size: 160
Time taken for round 7: 0.004029035568237305 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0031974315643310547 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003974199295043945 seconds
Optimal M batch size: 160
Time taken for round 2: 0.


 71%|███████   | 22/31 [00:11<00:02,  3.49it/s][A

Optimal M batch size: 160
Time taken for round 2: 0.004500865936279297 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034563541412353516 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003360748291015625 seconds
Optimal M batch size: 160
Time taken for round 5: 0.004083156585693359 seconds
Optimal M batch size: 160
Time taken for round 6: 0.004668235778808594 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0032994747161865234 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.002925395965576172 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0032341480255126953 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003240823745727539 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033185482025146484 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003265380859375 seconds
Optimal M batch size: 160
Time taken for round 5: 0.


 74%|███████▍  | 23/31 [00:12<00:02,  3.69it/s][A

Time taken to compute eigenvectors: 0.019321918487548828 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003465890884399414 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0034613609313964844 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003434896469116211 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034863948822021484 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003581523895263672 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003519296646118164 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003585338592529297 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0034804344177246094 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160



 77%|███████▋  | 24/31 [00:12<00:01,  3.86it/s][A

Time taken for round 6: 0.005013227462768555 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0034995079040527344 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20755863189697266 seconds
Time taken to compute eigenvectors: 0.019521236419677734 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003842592239379883 seconds
Optimal M batch size: 160
Time taken for round 1: 0.00516963005065918 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0034449100494384766 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033638477325439453 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033197402954101562 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003303527832031


 81%|████████  | 25/31 [00:12<00:01,  4.05it/s][A

Optimal M batch size: 160
Time taken for round 0: 0.004070758819580078 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0031974315643310547 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003187417984008789 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0032231807708740234 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0031561851501464844 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003743886947631836 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0033502578735351562 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0034775733947753906 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20051050186157227 seconds
Time taken to compute eigenvectors: 0.014869928359985352 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM 


 84%|████████▍ | 26/31 [00:12<00:01,  4.15it/s][A

Optimal M batch size: 160
Time taken for round 4: 0.004725217819213867 seconds
Optimal M batch size: 160
Time taken for round 5: 0.005511045455932617 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0034825801849365234 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003297567367553711 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0035932064056396484 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0035517215728759766 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033223628997802734 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0033881664276123047 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0033311843872070312 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003342866897583008 seconds
Optimal M batch size: 160
Time taken for round 6: 0.003415822982788086 seconds
Optimal M batch size: 160
Time taken for round 


 87%|████████▋ | 27/31 [00:12<00:00,  4.21it/s][A

Optimal M batch size: 160
Time taken for round 7: 0.004698276519775391 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0032122135162353516 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003319978713989258 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003643035888671875 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003736734390258789 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0034923553466796875 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034830570220947266 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0033380985260009766 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003375530242919922 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0033655166625976562 seconds
Optimal M batch size: 160
Time taken for round 1:


 90%|█████████ | 28/31 [00:15<00:02,  1.19it/s][A

Time taken to compute eigenvectors: 2.0419180393218994 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.0035288333892822266 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003889799118041992 seconds
Optimal M batch size: 160
Time taken for round 2: 0.0033037662506103516 seconds
Optimal M batch size: 160
Time taken for round 3: 0.003620147705078125 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0038230419158935547 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003322124481201172 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0034966468811035156 seconds
Optimal M batch size: 160
Time taken for round 7: 0.0037479400634765625 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160



 94%|█████████▎| 29/31 [00:17<00:02,  1.37s/it][A

Time taken to compute eigenvectors: 2.3998405933380127 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003974437713623047 seconds
Optimal M batch size: 160
Time taken for round 1: 0.003599882125854492 seconds
Optimal M batch size: 160
Time taken for round 2: 0.003481149673461914 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0034089088439941406 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0034494400024414062 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0034780502319335938 seconds
Optimal M batch size: 160
Time taken for round 6: 0.005009651184082031 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003706693649291992 seconds
Optimal M batch size: 160
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Ti


 97%|█████████▋| 30/31 [00:18<00:01,  1.03s/it][A

Optimal M batch size: 160
Time taken for round 7: 0.00455164909362793 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20527362823486328 seconds
Time taken to compute eigenvectors: 0.010685443878173828 seconds
train X shape: torch.Size([160, 4096]) train y shape: torch.Size([160, 1]) val X shape: torch.Size([40, 4096]) val y shape: torch.Size([40, 1])
Fitting RFM with ntrain: 160, d: 4096, and nval: 40
Optimal M batch size: 160
Time taken for round 0: 0.003370523452758789 seconds
Optimal M batch size: 160
Time taken for round 1: 0.0034034252166748047 seconds
Optimal M batch size: 160
Time taken for round 2: 0.004860877990722656 seconds
Optimal M batch size: 160
Time taken for round 3: 0.004319429397583008 seconds
Optimal M batch size: 160
Time taken for round 4: 0.003410816192626953 seconds
Optimal M batch size: 160
Time taken for round 5: 0.0033121109008789062 seconds
Optimal M batch size: 160
Time taken for ro


100%|██████████| 31/31 [00:18<00:00,  1.70it/s][A


Optimal M batch size: 160
Time taken for round 2: 0.004784345626831055 seconds
Optimal M batch size: 160
Time taken for round 3: 0.0038933753967285156 seconds
Optimal M batch size: 160
Time taken for round 4: 0.0032815933227539062 seconds
Optimal M batch size: 160
Time taken for round 5: 0.003480672836303711 seconds
Optimal M batch size: 160
Time taken for round 6: 0.0034744739532470703 seconds
Optimal M batch size: 160
Time taken for round 7: 0.003577709197998047 seconds
Optimal M batch size: 160
Best RFM auc: 1.0, reg: 0.001, bw: 1, center_grads: True
Time taken to train rfm probe: 0.20703125 seconds
Time taken to compute eigenvectors: 0.014101505279541016 seconds



100%|██████████| 31/31 [00:00<00:00, 15823.71it/s]
100%|██████████| 2/2 [00:49<00:00, 24.89s/it]


In [9]:
for concept_type in concept_types:
    controller = controllers[concept_type]
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.save(concept=f'{concept_type}_{other_type}', model_name=model_name, path='../directions/')

# Control

In [10]:
concept_types = ['english', 'shakespeare']
# concept_types = ['english', 'german']

controllers = {}

for concept_type in concept_types:
    
    controller = NeuralController(
        language_model,
        tokenizer,
        rfm_iters=8,
        control_method='rfm'
    )
    
    other_type = [k for k in concept_types if k!=concept_type][0]
    
    controller.load(concept=f'{concept_type}_{other_type}', model_name=model_name, path='../directions/')
    
    controllers[concept_type] = controller

n_components: 5
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 2
M_batch_size         : 2048
n_components         : 5
calibrate            : False

Detector found
n_components: 5
Hidden layers: [-1, -2, -3, -4, -5, -6, -7, -8, -9, -10, -11, -12, -13, -14, -15, -16, -17, -18, -19, -20, -21, -22, -23, -24, -25, -26, -27, -28, -29, -30, -31]

Controller hyperparameters:
control_method       : rfm
rfm_iters            : 8
forward_batch_size   : 2
M_batch_size         : 2048
n_components         : 5
calibrate            : False

Detector found


In [14]:
concept_type = "english"
language_controller = controllers[concept_type]

raw_inputs = [
    # "How are you today?",
    # "What can I buy in a grocery store?",
    f"What can I do to treat flu symptoms?",
]
inputs = [language_controller.format_prompt(x) for x in raw_inputs]
num_new_tokens = 150


#llama 
coef=0.6
layers = list(range(-1, -31, -1))

#gemma
# coef=9
# layers = list(range(-1, -41, -1))
assistant_tag = '<|start_header_id|>assistant<|end_header_id|>'
gens=[]
print()
for i in inputs:
    print("Prompt:", i)
    start_idx = len(i)
        
    print("===== No Control =====")
    gen1 = language_controller.generate(i, max_new_tokens=num_new_tokens, do_sample=False)[start_idx:]
    print(gen1)
    print()
    
    print(f"===== + {concept_type} Control =====")
    gen2 = language_controller.generate(i, layers_to_control=layers, control_coef=coef, 
                                max_new_tokens=num_new_tokens, do_sample=False)[start_idx:].replace(assistant_tag,'')
    print(gen2)
    print()
    print()

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.



Prompt: <|begin_of_text|><|start_header_id|>system<|end_header_id|>

Cutting Knowledge Date: December 2023
Today Date: 26 Jul 2024

<|eot_id|><|start_header_id|>user<|end_header_id|>

What can I do to treat flu symptoms?<|eot_id|>
===== No Control =====


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


<|start_header_id|>assistant<|end_header_id|>

Treating flu symptoms involves a combination of self-care, over-the-counter medications, and home remedies. Here are some effective ways to alleviate flu symptoms:

**Self-Care:**

1. **Stay hydrated**: Drink plenty of fluids, such as water, clear broths, and electrolyte-rich beverages like sports drinks.
2. **Rest**: Get plenty of sleep and rest to help your body recover.
3. **Use a humidifier**: Add moisture to the air to relieve congestion and cough.
4. **Practice good hygiene**: Wash your hands frequently, avoid close contact with others, and avoid sharing utensils or personal items.

**Over-the-Counter Medications:**

1. **Pain relievers**: Use acet

===== + english Control =====


Good sir or madam, thou seekest relief from the woes of the flu? Verily, I shall impart upon thee some tried and true remedies to alleviate the symptoms of this most foul affliction.

Firstly, 'tis essential to rest and conserve thy energies, for the flu do