In [1]:
import tiktoken
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import json

from gpt_model import GPTModel
from clean_gutenberg_collected_work import clean_gutenberg_collected_work
from sparse_auto_encoder import SparseAutoencoder
from train_sae import train_sae

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f"Using {device} device.")

Using cuda device.


In [3]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,
    "context_length": 256,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.2,
    "qkv_bias": True,
    "device": device,
}

In [4]:
model = GPTModel(GPT_CONFIG_124M)
checkpoint = torch.load("model_768_12_12_old_tok.pth", weights_only=True)

model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)
model.eval();

In [5]:
tokenizer = tiktoken.get_encoding("gpt2")

In [6]:
import re

def load_and_clean_text(file_path):
    """
    Loads a text file and splits it into sentences while cleaning the text.
    
    Args:
    - file_path (str): Path to the text file.
    
    Returns:
    - list: A list of cleaned sentences from the book.
    """
    
    text = clean_gutenberg_collected_work(file_path)

    # Split text into sentences (simple heuristic using punctuation)
    sentences = re.split(r"(?<=[.!?])\s+", text)

    # Remove very short or long sentences
    sentences = [s.strip() for s in sentences if 5 < len(s.split()) < 60]

    return sentences

In [7]:
directory="original_texts/"
dataset = []

sentences = load_and_clean_text(os.path.join(directory, 'complete_jane_austen.txt'))
dataset += sentences

In [8]:
def text_to_token_ids(text, tokenizer):
    encoded = tokenizer.encode(text, allowed_special={'<|endoftext|>'})
    encoded_tensor = torch.tensor(encoded).unsqueeze(0) # add batch dimension
    return encoded_tensor

def get_token_embeddings(text, model, tokenizer, layers=[6, 12]):
    """
    Extracts token embeddings from specified transformer layers.

    Args:
    - text (str): Input text.
    - model: Custom GPT model.
    - tokenizer: tiktoken encoding object.
    - layers (list): Transformer layers to extract embeddings from.

    Returns:
    - dict: Layer-wise token embeddings {layer_number: embeddings}
    """

    input_ids = text_to_token_ids(text, tokenizer).to(device)

    with torch.no_grad():
        _, hidden_states = model(input_ids, output_hidden_states=True)

    embeddings = {}
    for layer in layers:
        if layer - 1 < len(hidden_states):
            embeddings[layer] = hidden_states[layer - 1].squeeze(0).cpu().numpy()
        else:
            print(f"‚ö†Ô∏è Warning: Layer {layer} is out of range (max index {len(hidden_states) - 1})")

    return embeddings

In [9]:
layer6_embeddings = []
layer12_embeddings = []

for sentence in dataset:
    embeddings = get_token_embeddings(sentence, model, tokenizer)
    layer6_embeddings.append(embeddings[6])
    layer12_embeddings.append(embeddings[12])

# Convert to NumPy and flatten tokens into dataset
layer6_embeddings = np.vstack(layer6_embeddings)
layer12_embeddings = np.vstack(layer12_embeddings)

os.makedirs("sae_data", exist_ok=True)
np.save("sae_data/layer6_embeddings.npy", layer6_embeddings)
np.save("sae_data/layer12_embeddings.npy", layer12_embeddings)

print("Saved token embeddings:")
print(f"Layer 6: {layer6_embeddings.shape}")
print(f"Layer 12: {layer12_embeddings.shape}")

Saved token embeddings:
Layer 6: (718536, 768)
Layer 12: (718536, 768)


In [10]:
layer6_embeddings = np.load("sae_data/layer6_embeddings.npy")
layer12_embeddings = np.load("sae_data/layer12_embeddings.npy")

In [11]:
import optuna
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, random_split
import torch.optim as optim
import torch.nn as nn
from sparse_auto_encoder import SparseAutoencoder

def objective(trial, device="cpu", embeddings_path="training_embeddings.npy"):
    # Hyperparameter search space
    print(50*"=")
    print(f"Trial number {trial.number + 1}")
    print(50*"=")
    batch_size = trial.suggest_categorical("batch_size", [32, 64, 128])
    lr = trial.suggest_loguniform("lr", 1e-4, 1e-2)
    hidden_dim = trial.suggest_int("hidden_dim", 64, 512, step=64)
    weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)

    # Load your embeddings
    embeddings = np.load(embeddings_path)  # Replace with actual file
    embeddings = torch.tensor(embeddings, dtype=torch.float32).to(device)

    input_dim = embeddings.shape[1]
    sae = SparseAutoencoder(input_dim=input_dim, hidden_dim=hidden_dim).to(device)

    optimizer = optim.AdamW(sae.parameters(), lr=lr, weight_decay=weight_decay)
    criterion = nn.MSELoss()

    dataset = TensorDataset(embeddings)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

    best_val_loss = float("inf")
    patience, early_stop_counter = 10, 0

    for epoch in range(30):
        sae.train()
        train_loss = 0.0
        for batch in train_loader:
            inputs = batch[0].to(device)
            optimizer.zero_grad()
            outputs, encoded = sae(inputs)
            loss = criterion(outputs, inputs)
            sparsity_loss = torch.norm(encoded, p=1) * 1e-4
            total_loss = loss + sparsity_loss
            total_loss.backward()
            optimizer.step()
            train_loss += total_loss.item()

        sae.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                inputs = batch[0].to(device)
                outputs, encoded = sae(inputs)
                loss = criterion(outputs, inputs)
                sparsity_loss = torch.norm(encoded, p=1) * 1e-4
                total_loss = loss + sparsity_loss
                val_loss += total_loss.item()

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        scheduler.step(val_loss)
        print(f"Epoch {epoch}: Train loss {train_loss:.4f}, Val loss {val_loss:.4f}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            early_stop_counter = 0
        else:
            early_stop_counter += 1

        if early_stop_counter >= patience:
            break

    return best_val_loss 

In [12]:
import time

start_time = time.time()

# Run hyperparameter tuning
study = optuna.create_study(direction="minimize")
study.optimize(lambda trial: objective(trial, device, embeddings_path="sae_data/layer6_embeddings.npy"), n_trials=15)

# Print best hyperparameters
print("Best hyperparameters:", study.best_params)

end_time = time.time()
execution_time_minutes = (end_time - start_time) / 60
print(f"Search completed in {execution_time_minutes:.2f} minutes.")

[I 2025-03-07 05:49:50,242] A new study created in memory with name: no-name-debab08a-1ae8-4354-b1d7-2e8a93fcaf53
  lr = trial.suggest_loguniform("lr", 1e-4, 1e-2)
  weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)


Trial number 1
Epoch 0: Train loss 3.5013, Val loss 2.7664
Epoch 1: Train loss 2.6397, Val loss 2.5404
Epoch 2: Train loss 2.5101, Val loss 2.4382
Epoch 3: Train loss 2.4494, Val loss 2.3934
Epoch 4: Train loss 2.4061, Val loss 2.3609
Epoch 5: Train loss 2.3730, Val loss 2.3305
Epoch 6: Train loss 2.3432, Val loss 2.3005
Epoch 7: Train loss 2.3171, Val loss 2.2783
Epoch 8: Train loss 2.2942, Val loss 2.2555
Epoch 9: Train loss 2.2737, Val loss 2.2374
Epoch 10: Train loss 2.2572, Val loss 2.2252
Epoch 11: Train loss 2.2438, Val loss 2.2149
Epoch 12: Train loss 2.2321, Val loss 2.2054
Epoch 13: Train loss 2.2236, Val loss 2.2009
Epoch 14: Train loss 2.2164, Val loss 2.1995
Epoch 15: Train loss 2.2110, Val loss 2.1948
Epoch 16: Train loss 2.2063, Val loss 2.1908
Epoch 17: Train loss 2.2020, Val loss 2.1906
Epoch 18: Train loss 2.1982, Val loss 2.1864
Epoch 19: Train loss 2.1944, Val loss 2.1859
Epoch 20: Train loss 2.1912, Val loss 2.1838
Epoch 21: Train loss 2.1882, Val loss 2.1842
Epoch

[I 2025-03-07 06:02:12,080] Trial 0 finished with value: 2.170551270788366 and parameters: {'batch_size': 128, 'lr': 0.0002688498549528964, 'hidden_dim': 320, 'weight_decay': 0.00024172216807324387}. Best is trial 0 with value: 2.170551270788366.


Epoch 29: Train loss 2.1711, Val loss 2.1706
Trial number 2


  lr = trial.suggest_loguniform("lr", 1e-4, 1e-2)
  weight_decay = trial.suggest_loguniform("weight_decay", 1e-6, 1e-3)


Epoch 0: Train loss 2.4140, Val loss 2.3188
Epoch 1: Train loss 2.3013, Val loss 2.3061
Epoch 2: Train loss 2.2921, Val loss 2.2908
Epoch 3: Train loss 2.2870, Val loss 2.3128
Epoch 4: Train loss 2.2838, Val loss 2.3026
Epoch 5: Train loss 2.2817, Val loss 2.3122
Epoch 6: Train loss 2.2804, Val loss 2.2955
Epoch 7: Train loss 2.2793, Val loss 2.2935
Epoch 8: Train loss 2.2782, Val loss 2.2808
Epoch 9: Train loss 2.2774, Val loss 2.3122
Epoch 10: Train loss 2.2773, Val loss 2.2876
Epoch 11: Train loss 2.2768, Val loss 2.2953
Epoch 12: Train loss 2.2759, Val loss 2.2960
Epoch 13: Train loss 2.2757, Val loss 2.2897
Epoch 14: Train loss 2.2754, Val loss 2.2881
Epoch 15: Train loss 2.2601, Val loss 2.2521
Epoch 16: Train loss 2.2585, Val loss 2.2471
Epoch 17: Train loss 2.2585, Val loss 2.2550
Epoch 18: Train loss 2.2587, Val loss 2.2596
Epoch 19: Train loss 2.2585, Val loss 2.2635
Epoch 20: Train loss 2.2582, Val loss 2.2584
Epoch 21: Train loss 2.2581, Val loss 2.2611
Epoch 22: Train loss

[I 2025-03-07 06:22:54,658] Trial 1 finished with value: 2.224995661976669 and parameters: {'batch_size': 64, 'lr': 0.004508739037375544, 'hidden_dim': 192, 'weight_decay': 0.0001467563404464736}. Best is trial 0 with value: 2.170551270788366.


Epoch 29: Train loss 2.2463, Val loss 2.2263
Trial number 3
Epoch 0: Train loss 3.1771, Val loss 2.8172
Epoch 1: Train loss 2.7800, Val loss 2.6882
Epoch 2: Train loss 2.6947, Val loss 2.6175
Epoch 3: Train loss 2.6386, Val loss 2.5764
Epoch 4: Train loss 2.6049, Val loss 2.5540
Epoch 5: Train loss 2.5865, Val loss 2.5463
Epoch 6: Train loss 2.5802, Val loss 2.5443
Epoch 7: Train loss 2.5765, Val loss 2.5453
Epoch 8: Train loss 2.5741, Val loss 2.5435
Epoch 9: Train loss 2.5720, Val loss 2.5434
Epoch 10: Train loss 2.5707, Val loss 2.5442
Epoch 11: Train loss 2.5698, Val loss 2.5418
Epoch 12: Train loss 2.5687, Val loss 2.5428
Epoch 13: Train loss 2.5678, Val loss 2.5433
Epoch 14: Train loss 2.5671, Val loss 2.5422
Epoch 15: Train loss 2.5665, Val loss 2.5457
Epoch 16: Train loss 2.5662, Val loss 2.5436
Epoch 17: Train loss 2.5657, Val loss 2.5425
Epoch 18: Train loss 2.5633, Val loss 2.5331
Epoch 19: Train loss 2.5630, Val loss 2.5343
Epoch 20: Train loss 2.5628, Val loss 2.5353
Epoch

[I 2025-03-07 06:43:08,899] Trial 2 finished with value: 2.5284160660719817 and parameters: {'batch_size': 64, 'lr': 0.00020725071060021414, 'hidden_dim': 64, 'weight_decay': 1.9243970614633706e-06}. Best is trial 0 with value: 2.170551270788366.


Epoch 29: Train loss 2.5604, Val loss 2.5284
Trial number 4
Epoch 0: Train loss 2.8168, Val loss 2.6210
Epoch 1: Train loss 2.5877, Val loss 2.6165
Epoch 2: Train loss 2.5764, Val loss 2.5997
Epoch 3: Train loss 2.5724, Val loss 2.5933
Epoch 4: Train loss 2.5704, Val loss 2.6060
Epoch 5: Train loss 2.5694, Val loss 2.6090
Epoch 6: Train loss 2.5685, Val loss 2.5808
Epoch 7: Train loss 2.5677, Val loss 2.5884
Epoch 8: Train loss 2.5670, Val loss 2.5909
Epoch 9: Train loss 2.5670, Val loss 2.5738
Epoch 10: Train loss 2.5666, Val loss 2.5881
Epoch 11: Train loss 2.5661, Val loss 2.5779
Epoch 12: Train loss 2.5656, Val loss 2.5883
Epoch 13: Train loss 2.5655, Val loss 2.5677
Epoch 14: Train loss 2.5651, Val loss 2.5784
Epoch 15: Train loss 2.5650, Val loss 2.5797
Epoch 16: Train loss 2.5648, Val loss 2.5900
Epoch 17: Train loss 2.5645, Val loss 2.5777
Epoch 18: Train loss 2.5645, Val loss 2.5801
Epoch 19: Train loss 2.5645, Val loss 2.5810
Epoch 20: Train loss 2.5585, Val loss 2.5615
Epoch

[I 2025-03-07 06:55:02,432] Trial 3 finished with value: 2.5543929240920327 and parameters: {'batch_size': 128, 'lr': 0.0017955378450738243, 'hidden_dim': 64, 'weight_decay': 1.980034613217017e-05}. Best is trial 0 with value: 2.170551270788366.


Epoch 29: Train loss 2.5572, Val loss 2.5640
Trial number 5
Epoch 0: Train loss 2.7470, Val loss 2.3950
Epoch 1: Train loss 2.3955, Val loss 2.2924
Epoch 2: Train loss 2.3048, Val loss 2.2097
Epoch 3: Train loss 2.2311, Val loss 2.1628
Epoch 4: Train loss 2.1922, Val loss 2.1437
Epoch 5: Train loss 2.1740, Val loss 2.1422
Epoch 6: Train loss 2.1642, Val loss 2.1383
Epoch 7: Train loss 2.1578, Val loss 2.1427
Epoch 8: Train loss 2.1526, Val loss 2.1437
Epoch 9: Train loss 2.1479, Val loss 2.1372
Epoch 10: Train loss 2.1445, Val loss 2.1387
Epoch 11: Train loss 2.1415, Val loss 2.1432
Epoch 12: Train loss 2.1381, Val loss 2.1517
Epoch 13: Train loss 2.1356, Val loss 2.1578
Epoch 14: Train loss 2.1332, Val loss 2.1525
Epoch 15: Train loss 2.1310, Val loss 2.1545
Epoch 16: Train loss 2.1234, Val loss 2.1080
Epoch 17: Train loss 2.1219, Val loss 2.1100
Epoch 18: Train loss 2.1206, Val loss 2.1085
Epoch 19: Train loss 2.1194, Val loss 2.1072
Epoch 20: Train loss 2.1182, Val loss 2.1102
Epoch

[I 2025-03-07 07:16:17,436] Trial 4 finished with value: 2.083881515848338 and parameters: {'batch_size': 64, 'lr': 0.0003881008304162215, 'hidden_dim': 384, 'weight_decay': 2.3923055206292162e-05}. Best is trial 4 with value: 2.083881515848338.


Epoch 29: Train loss 2.1077, Val loss 2.0868
Trial number 6
Epoch 0: Train loss 2.3573, Val loss 2.2031
Epoch 1: Train loss 2.2544, Val loss 2.1934
Epoch 2: Train loss 2.2470, Val loss 2.1944
Epoch 3: Train loss 2.2429, Val loss 2.1903
Epoch 4: Train loss 2.2405, Val loss 2.1887
Epoch 5: Train loss 2.2391, Val loss 2.1896
Epoch 6: Train loss 2.2378, Val loss 2.1853
Epoch 7: Train loss 2.2372, Val loss 2.1904
Epoch 8: Train loss 2.2364, Val loss 2.1852
Epoch 9: Train loss 2.2354, Val loss 2.1833
Epoch 10: Train loss 2.2351, Val loss 2.1861
Epoch 11: Train loss 2.2346, Val loss 2.1805
Epoch 12: Train loss 2.2338, Val loss 2.1858
Epoch 13: Train loss 2.2336, Val loss 2.1823
Epoch 14: Train loss 2.2334, Val loss 2.1866
Epoch 15: Train loss 2.2330, Val loss 2.1816
Epoch 16: Train loss 2.2328, Val loss 2.1794
Epoch 17: Train loss 2.2324, Val loss 2.1818
Epoch 18: Train loss 2.2323, Val loss 2.1867
Epoch 19: Train loss 2.2321, Val loss 2.1766
Epoch 20: Train loss 2.2320, Val loss 2.1891
Epoch

[I 2025-03-07 07:55:47,399] Trial 5 finished with value: 2.1656301614224165 and parameters: {'batch_size': 32, 'lr': 0.0021729419335939316, 'hidden_dim': 256, 'weight_decay': 0.00048608168804415455}. Best is trial 4 with value: 2.083881515848338.


Epoch 29: Train loss 2.2222, Val loss 2.1656
Trial number 7
Epoch 0: Train loss 2.5723, Val loss 2.2847
Epoch 1: Train loss 2.2485, Val loss 2.2205
Epoch 2: Train loss 2.1990, Val loss 2.2098
Epoch 3: Train loss 2.1830, Val loss 2.1955
Epoch 4: Train loss 2.1738, Val loss 2.2090
Epoch 5: Train loss 2.1684, Val loss 2.2012
Epoch 6: Train loss 2.1642, Val loss 2.2004
Epoch 7: Train loss 2.1604, Val loss 2.2010
Epoch 8: Train loss 2.1579, Val loss 2.1832
Epoch 9: Train loss 2.1553, Val loss 2.1774
Epoch 10: Train loss 2.1524, Val loss 2.2113
Epoch 11: Train loss 2.1509, Val loss 2.2076
Epoch 12: Train loss 2.1490, Val loss 2.2049
Epoch 13: Train loss 2.1477, Val loss 2.2049
Epoch 14: Train loss 2.1454, Val loss 2.1735
Epoch 15: Train loss 2.1450, Val loss 2.1901
Epoch 16: Train loss 2.1440, Val loss 2.1863
Epoch 17: Train loss 2.1427, Val loss 2.2090
Epoch 18: Train loss 2.1412, Val loss 2.1864
Epoch 19: Train loss 2.1403, Val loss 2.2033
Epoch 20: Train loss 2.1400, Val loss 2.5091
Epoch

[I 2025-03-07 08:08:39,194] Trial 6 finished with value: 2.1458648746663873 and parameters: {'batch_size': 128, 'lr': 0.002262804895465889, 'hidden_dim': 512, 'weight_decay': 1.2074671078795074e-05}. Best is trial 4 with value: 2.083881515848338.


Epoch 29: Train loss 2.1242, Val loss 2.1671
Trial number 8
Epoch 0: Train loss 3.6830, Val loss 2.9287
Epoch 1: Train loss 2.6973, Val loss 2.5486
Epoch 2: Train loss 2.5121, Val loss 2.4390
Epoch 3: Train loss 2.4410, Val loss 2.3816
Epoch 4: Train loss 2.3982, Val loss 2.3592
Epoch 5: Train loss 2.3705, Val loss 2.3356
Epoch 6: Train loss 2.3475, Val loss 2.3143
Epoch 7: Train loss 2.3279, Val loss 2.2913
Epoch 8: Train loss 2.3100, Val loss 2.2707
Epoch 9: Train loss 2.2940, Val loss 2.2571
Epoch 10: Train loss 2.2788, Val loss 2.2455
Epoch 11: Train loss 2.2643, Val loss 2.2292
Epoch 12: Train loss 2.2505, Val loss 2.2204
Epoch 13: Train loss 2.2369, Val loss 2.2035
Epoch 14: Train loss 2.2259, Val loss 2.1959
Epoch 15: Train loss 2.2149, Val loss 2.1837
Epoch 16: Train loss 2.2055, Val loss 2.1777
Epoch 17: Train loss 2.1969, Val loss 2.1710
Epoch 18: Train loss 2.1900, Val loss 2.1665
Epoch 19: Train loss 2.1841, Val loss 2.1618
Epoch 20: Train loss 2.1790, Val loss 2.1631
Epoch

[I 2025-03-07 08:21:38,398] Trial 7 finished with value: 2.1394470236518166 and parameters: {'batch_size': 128, 'lr': 0.00019293742668148652, 'hidden_dim': 448, 'weight_decay': 0.00022579186946369568}. Best is trial 4 with value: 2.083881515848338.


Epoch 29: Train loss 2.1504, Val loss 2.1417
Trial number 9
Epoch 0: Train loss 2.4441, Val loss 2.1521
Epoch 1: Train loss 2.1927, Val loss 2.1148
Epoch 2: Train loss 2.1680, Val loss 2.1177
Epoch 3: Train loss 2.1613, Val loss 2.1051
Epoch 4: Train loss 2.1571, Val loss 2.1005
Epoch 5: Train loss 2.1540, Val loss 2.1026
Epoch 6: Train loss 2.1515, Val loss 2.0959
Epoch 7: Train loss 2.1492, Val loss 2.0967
Epoch 8: Train loss 2.1474, Val loss 2.0888
Epoch 9: Train loss 2.1458, Val loss 2.0864
Epoch 10: Train loss 2.1449, Val loss 2.0961
Epoch 11: Train loss 2.1438, Val loss 2.0909
Epoch 12: Train loss 2.1426, Val loss 2.0885
Epoch 13: Train loss 2.1418, Val loss 2.0868
Epoch 14: Train loss 2.1407, Val loss 2.0861
Epoch 15: Train loss 2.1403, Val loss 2.0882
Epoch 16: Train loss 2.1397, Val loss 2.0880
Epoch 17: Train loss 2.1392, Val loss 2.0777
Epoch 18: Train loss 2.1386, Val loss 2.0790
Epoch 19: Train loss 2.1383, Val loss 2.0814
Epoch 20: Train loss 2.1380, Val loss 2.0770
Epoch

[I 2025-03-07 09:00:44,986] Trial 8 finished with value: 2.07491034575172 and parameters: {'batch_size': 32, 'lr': 0.000816547713285386, 'hidden_dim': 512, 'weight_decay': 2.9090240560596477e-06}. Best is trial 8 with value: 2.07491034575172.


Epoch 29: Train loss 2.1341, Val loss 2.0782
Trial number 10
Epoch 0: Train loss 3.6428, Val loss 2.9041
Epoch 1: Train loss 2.6809, Val loss 2.5343
Epoch 2: Train loss 2.5082, Val loss 2.4249
Epoch 3: Train loss 2.4365, Val loss 2.3805
Epoch 4: Train loss 2.3950, Val loss 2.3383
Epoch 5: Train loss 2.3643, Val loss 2.3227
Epoch 6: Train loss 2.3415, Val loss 2.3036
Epoch 7: Train loss 2.3214, Val loss 2.2759
Epoch 8: Train loss 2.3022, Val loss 2.2607
Epoch 9: Train loss 2.2851, Val loss 2.2427
Epoch 10: Train loss 2.2683, Val loss 2.2313
Epoch 11: Train loss 2.2530, Val loss 2.2129
Epoch 12: Train loss 2.2384, Val loss 2.2007
Epoch 13: Train loss 2.2255, Val loss 2.1905
Epoch 14: Train loss 2.2141, Val loss 2.1795
Epoch 15: Train loss 2.2038, Val loss 2.1727
Epoch 16: Train loss 2.1950, Val loss 2.1652
Epoch 17: Train loss 2.1873, Val loss 2.1620
Epoch 18: Train loss 2.1811, Val loss 2.1559
Epoch 19: Train loss 2.1750, Val loss 2.1516
Epoch 20: Train loss 2.1699, Val loss 2.1459
Epoc

[I 2025-03-07 09:13:30,328] Trial 9 finished with value: 2.133140113136985 and parameters: {'batch_size': 128, 'lr': 0.00020385491002032069, 'hidden_dim': 448, 'weight_decay': 0.0009293870195375112}. Best is trial 8 with value: 2.07491034575172.


Epoch 29: Train loss 2.1461, Val loss 2.1341
Trial number 11
Epoch 0: Train loss 2.4641, Val loss 2.1610
Epoch 1: Train loss 2.1981, Val loss 2.1177
Epoch 2: Train loss 2.1680, Val loss 2.1113
Epoch 3: Train loss 2.1603, Val loss 2.1073
Epoch 4: Train loss 2.1552, Val loss 2.1021
Epoch 5: Train loss 2.1526, Val loss 2.1009
Epoch 6: Train loss 2.1501, Val loss 2.1042
Epoch 7: Train loss 2.1482, Val loss 2.0891
Epoch 8: Train loss 2.1468, Val loss 2.0926
Epoch 9: Train loss 2.1453, Val loss 2.0885
Epoch 10: Train loss 2.1440, Val loss 2.0865
Epoch 11: Train loss 2.1432, Val loss 2.0876
Epoch 12: Train loss 2.1423, Val loss 2.0805
Epoch 13: Train loss 2.1417, Val loss 2.0836
Epoch 14: Train loss 2.1412, Val loss 2.0824
Epoch 15: Train loss 2.1406, Val loss 2.0840
Epoch 16: Train loss 2.1400, Val loss 2.0843
Epoch 17: Train loss 2.1391, Val loss 2.0755
Epoch 18: Train loss 2.1386, Val loss 2.0769
Epoch 19: Train loss 2.1380, Val loss 2.0733
Epoch 20: Train loss 2.1372, Val loss 2.0794
Epoc

[I 2025-03-07 09:52:38,080] Trial 10 finished with value: 2.070989271797911 and parameters: {'batch_size': 32, 'lr': 0.0007246691209707045, 'hidden_dim': 512, 'weight_decay': 1.0552397856143292e-06}. Best is trial 10 with value: 2.070989271797911.


Epoch 29: Train loss 2.1352, Val loss 2.0710
Trial number 12
Epoch 0: Train loss 2.4702, Val loss 2.1798
Epoch 1: Train loss 2.2019, Val loss 2.1129
Epoch 2: Train loss 2.1672, Val loss 2.1201
Epoch 3: Train loss 2.1594, Val loss 2.1079
Epoch 4: Train loss 2.1543, Val loss 2.1025
Epoch 5: Train loss 2.1510, Val loss 2.1049
Epoch 6: Train loss 2.1492, Val loss 2.0979
Epoch 7: Train loss 2.1475, Val loss 2.0987
Epoch 8: Train loss 2.1465, Val loss 2.0943
Epoch 9: Train loss 2.1456, Val loss 2.0965
Epoch 10: Train loss 2.1447, Val loss 2.0882
Epoch 11: Train loss 2.1438, Val loss 2.0918
Epoch 12: Train loss 2.1429, Val loss 2.0872
Epoch 13: Train loss 2.1420, Val loss 2.0878
Epoch 14: Train loss 2.1413, Val loss 2.0828
Epoch 15: Train loss 2.1404, Val loss 2.0836
Epoch 16: Train loss 2.1395, Val loss 2.0862
Epoch 17: Train loss 2.1389, Val loss 2.0900
Epoch 18: Train loss 2.1385, Val loss 2.0852
Epoch 19: Train loss 2.1379, Val loss 2.0827
Epoch 20: Train loss 2.1371, Val loss 2.0838
Epoc

[I 2025-03-07 10:31:43,896] Trial 11 finished with value: 2.0660949138034574 and parameters: {'batch_size': 32, 'lr': 0.0007016224567866245, 'hidden_dim': 512, 'weight_decay': 1.1004487564277892e-06}. Best is trial 11 with value: 2.0660949138034574.


Epoch 29: Train loss 2.1256, Val loss 2.0701
Trial number 13
Epoch 0: Train loss 2.4893, Val loss 2.1768
Epoch 1: Train loss 2.2109, Val loss 2.1531
Epoch 2: Train loss 2.1821, Val loss 2.1659
Epoch 3: Train loss 2.1730, Val loss 2.1861
Epoch 4: Train loss 2.1683, Val loss 2.1886
Epoch 5: Train loss 2.1654, Val loss 2.2205
Epoch 6: Train loss 2.1636, Val loss 2.2249
Epoch 7: Train loss 2.1614, Val loss 2.2632
Epoch 8: Train loss 2.1478, Val loss 2.1724
Epoch 9: Train loss 2.1455, Val loss 2.1609
Epoch 10: Train loss 2.1442, Val loss 2.1709


[I 2025-03-07 10:47:19,730] Trial 12 finished with value: 2.153145467405852 and parameters: {'batch_size': 32, 'lr': 0.0006288969424953504, 'hidden_dim': 384, 'weight_decay': 1.0034603972640884e-06}. Best is trial 11 with value: 2.0660949138034574.


Epoch 11: Train loss 2.1431, Val loss 2.2010
Trial number 14
Epoch 0: Train loss 2.5095, Val loss 2.2279
Epoch 1: Train loss 2.2277, Val loss 2.1206
Epoch 2: Train loss 2.1678, Val loss 2.1205
Epoch 3: Train loss 2.1543, Val loss 2.1307
Epoch 4: Train loss 2.1478, Val loss 2.1580
Epoch 5: Train loss 2.1434, Val loss 2.1555
Epoch 6: Train loss 2.1400, Val loss 2.1728
Epoch 7: Train loss 2.1374, Val loss 2.1752
Epoch 8: Train loss 2.1244, Val loss 2.0986
Epoch 9: Train loss 2.1213, Val loss 2.1100
Epoch 10: Train loss 2.1196, Val loss 2.1081
Epoch 11: Train loss 2.1185, Val loss 2.1205
Epoch 12: Train loss 2.1173, Val loss 2.1234
Epoch 13: Train loss 2.1164, Val loss 2.1237
Epoch 14: Train loss 2.1156, Val loss 2.1255
Epoch 15: Train loss 2.1085, Val loss 2.0756
Epoch 16: Train loss 2.1080, Val loss 2.0798
Epoch 17: Train loss 2.1071, Val loss 2.0785
Epoch 18: Train loss 2.1066, Val loss 2.0757
Epoch 19: Train loss 2.1058, Val loss 2.0778
Epoch 20: Train loss 2.1053, Val loss 2.0760
Epoc

[I 2025-03-07 11:26:55,833] Trial 13 finished with value: 2.05181546670681 and parameters: {'batch_size': 32, 'lr': 0.0005587614855267997, 'hidden_dim': 512, 'weight_decay': 6.056437420027853e-06}. Best is trial 13 with value: 2.05181546670681.


Epoch 29: Train loss 2.0975, Val loss 2.0518
Trial number 15
Epoch 0: Train loss 2.8744, Val loss 2.4504
Epoch 1: Train loss 2.4838, Val loss 2.3448
Epoch 2: Train loss 2.4184, Val loss 2.2957
Epoch 3: Train loss 2.3694, Val loss 2.2535
Epoch 4: Train loss 2.3175, Val loss 2.1969
Epoch 5: Train loss 2.2606, Val loss 2.1474
Epoch 6: Train loss 2.2106, Val loss 2.1180
Epoch 7: Train loss 2.1856, Val loss 2.1069
Epoch 8: Train loss 2.1742, Val loss 2.1009
Epoch 9: Train loss 2.1669, Val loss 2.0953
Epoch 10: Train loss 2.1617, Val loss 2.0941
Epoch 11: Train loss 2.1573, Val loss 2.0918
Epoch 12: Train loss 2.1538, Val loss 2.0895
Epoch 13: Train loss 2.1505, Val loss 2.0892
Epoch 14: Train loss 2.1486, Val loss 2.0874
Epoch 15: Train loss 2.1460, Val loss 2.0858
Epoch 16: Train loss 2.1435, Val loss 2.0855
Epoch 17: Train loss 2.1417, Val loss 2.0863
Epoch 18: Train loss 2.1397, Val loss 2.0826
Epoch 19: Train loss 2.1383, Val loss 2.0815
Epoch 20: Train loss 2.1371, Val loss 2.0823
Epoc

[I 2025-03-07 12:07:36,156] Trial 14 finished with value: 2.0745204591968585 and parameters: {'batch_size': 32, 'lr': 0.00010999744095048953, 'hidden_dim': 384, 'weight_decay': 5.383174846292139e-06}. Best is trial 13 with value: 2.05181546670681.


Epoch 29: Train loss 2.1285, Val loss 2.0749
Best hyperparameters: {'batch_size': 32, 'lr': 0.0005587614855267997, 'hidden_dim': 512, 'weight_decay': 6.056437420027853e-06}
Search completed in 377.77 minutes.


In [11]:
train_losses = []
val_losses = []
global train_losses, val_losses

sae = SparseAutoencoder(input_dim=768, hidden_dim=3072, top_k=50)
train_sae(layer6_embeddings, sae, model_prefix="sae_model_6_3072", epochs=500,
          batch_size=64, lr=5e-4, weight_decay=1e-6,
          train_losses=[],val_losses=val_losses, device=device, patience=10)

Epoch [1/500] | Train Loss: 0.564636 | Val Loss: 0.443621
Epoch [2/500] | Train Loss: 0.429605 | Val Loss: 0.415526
Epoch [3/500] | Train Loss: 0.404336 | Val Loss: 0.391972
Epoch [4/500] | Train Loss: 0.380943 | Val Loss: 0.371695
Epoch [5/500] | Train Loss: 0.363363 | Val Loss: 0.355106
Epoch [6/500] | Train Loss: 0.348986 | Val Loss: 0.341697
Epoch [7/500] | Train Loss: 0.337461 | Val Loss: 0.333137
Epoch [8/500] | Train Loss: 0.328246 | Val Loss: 0.325167
Epoch [9/500] | Train Loss: 0.321236 | Val Loss: 0.319153
Epoch [10/500] | Train Loss: 0.315880 | Val Loss: 0.314109
Epoch [11/500] | Train Loss: 0.311361 | Val Loss: 0.310885
Epoch [12/500] | Train Loss: 0.307849 | Val Loss: 0.307511
Epoch [13/500] | Train Loss: 0.305241 | Val Loss: 0.305991
Epoch [14/500] | Train Loss: 0.303034 | Val Loss: 0.303367
Epoch [15/500] | Train Loss: 0.300390 | Val Loss: 0.301121
Epoch [16/500] | Train Loss: 0.298111 | Val Loss: 0.298208
Epoch [17/500] | Train Loss: 0.296032 | Val Loss: 0.296372
Epoch 

([0.5646357463587044,
  0.4296048907582445,
  0.40433580629575494,
  0.38094315248394534,
  0.3633628173991632,
  0.34898595275677174,
  0.33746110880982455,
  0.32824598680859096,
  0.3212363291898616,
  0.31587958567836155,
  0.31136095165351074,
  0.3078488080558182,
  0.30524123651535423,
  0.30303444220310743,
  0.30038968253878895,
  0.29811122687307695,
  0.2960324708658065,
  0.2936735098563109,
  0.2912557582802728,
  0.2893906635408765,
  0.287565917916548,
  0.2862232499049478,
  0.2848637790549574,
  0.2838606225624349,
  0.28275578153982783,
  0.28214090536986763,
  0.28166227829243984,
  0.28111695539001663,
  0.2802673163875856,
  0.2796682889982947,
  0.2791917120979777,
  0.278557852059529,
  0.27818241567967966,
  0.2777476436412907,
  0.27731287658273557,
  0.2767877736550519,
  0.2759289249549226,
  0.27511014672035866,
  0.2744906287435728,
  0.27377619896119093,
  0.2733028698474803,
  0.2727361684090899,
  0.2723010436872749,
  0.27195265582001843,
  0.2713508229

In [17]:
train_losses = []
val_losses = []
global train_losses, val_losses

sae = SparseAutoencoder(input_dim=768, hidden_dim=3072, top_k=60)
train_sae(layer12_embeddings, sae, model_prefix="sae_model_12_3072", epochs=500,
          batch_size=64, lr=5e-3, weight_decay=1e-6, 
          train_losses=[],val_losses=val_losses, device=device, patience=10)

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.06 GiB. GPU 0 has a total capacity of 23.78 GiB of which 1.29 GiB is free. Process 1364620 has 609.79 MiB memory in use. Process 1098709 has 5.87 GiB memory in use. Process 2597354 has 2.18 GiB memory in use. Including non-PyTorch memory, this process has 11.44 GiB memory in use. Of the allocated memory 10.94 GiB is allocated by PyTorch, and 188.45 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [11]:
train_losses = []
val_losses = []
global train_losses, val_losses

train_sae(layer6_embeddings, model_name="sae_layer6_512.pth", epochs=500,
          batch_size=32, lr=5e-5, hidden_dim=512, weight_decay=1e-03, 
          train_losses=[],val_losses=val_losses, device=device, patience=30)

Epoch [1/500], Train Loss: 4.2220, Val Loss: 3.5712
Epoch [2/500], Train Loss: 3.7283, Val Loss: 3.4367
Epoch [3/500], Train Loss: 3.5819, Val Loss: 3.3105
Epoch [4/500], Train Loss: 3.4556, Val Loss: 3.3198
Epoch [5/500], Train Loss: 3.4234, Val Loss: 3.4495
Epoch [6/500], Train Loss: 3.4090, Val Loss: 4.0589
Epoch [7/500], Train Loss: 3.4005, Val Loss: 5.4380
Epoch [8/500], Train Loss: 3.3950, Val Loss: 6.2446
Epoch [9/500], Train Loss: 3.3905, Val Loss: 10.7519
Epoch [10/500], Train Loss: 3.3870, Val Loss: 13.5389
Epoch [11/500], Train Loss: 3.3847, Val Loss: 21.7800


KeyboardInterrupt: 

In [19]:
train_losses = []
val_losses = []
global train_losses, val_losses

train_sae(layer12_embeddings, "sae_layer12.pth", epochs=500,
          batch_size=32, lr=1e-5, hidden_dim=512, weight_decay=1e-02, 
          train_losses=[],val_losses=val_losses, device=device)

Epoch [1/500], Train Loss: 9.3784, Val Loss: 8.6603
Epoch [2/500], Train Loss: 8.5349, Val Loss: 8.2008
Epoch [3/500], Train Loss: 8.1273, Val Loss: 7.8981
Epoch [4/500], Train Loss: 7.8221, Val Loss: 7.6962
Epoch [5/500], Train Loss: 7.5788, Val Loss: 7.4449
Epoch [6/500], Train Loss: 7.3745, Val Loss: 7.2593
Epoch [7/500], Train Loss: 7.1991, Val Loss: 7.0678
Epoch [8/500], Train Loss: 7.0452, Val Loss: 6.8777
Epoch [9/500], Train Loss: 6.9088, Val Loss: 6.7717
Epoch [10/500], Train Loss: 6.7874, Val Loss: 6.6290
Epoch [11/500], Train Loss: 6.6799, Val Loss: 6.5192
Epoch [12/500], Train Loss: 6.5856, Val Loss: 6.3898
Epoch [13/500], Train Loss: 6.5007, Val Loss: 6.3215
Epoch [14/500], Train Loss: 6.4268, Val Loss: 6.2082
Epoch [15/500], Train Loss: 6.3615, Val Loss: 6.1749
Epoch [16/500], Train Loss: 6.3047, Val Loss: 6.1079
Epoch [17/500], Train Loss: 6.2549, Val Loss: 6.0426
Epoch [18/500], Train Loss: 6.2121, Val Loss: 6.0207
Epoch [19/500], Train Loss: 6.1741, Val Loss: 5.9502
Ep

KeyboardInterrupt: 

In [21]:
train_losses = []
val_losses = []
global train_losses, val_losses

train_sae(layer12_embeddings, "sae_layer12_expanded_1700.pth", epochs=500,
          batch_size=32, lr=1e-5, hidden_dim=1700, weight_decay=1e-04, 
          train_losses=[],val_losses=val_losses, device=device)

Epoch [1/500], Train Loss: 10.2528, Val Loss: 9.4597
Epoch [2/500], Train Loss: 9.1150, Val Loss: 8.7378
Epoch [3/500], Train Loss: 8.4605, Val Loss: 8.2728
Epoch [4/500], Train Loss: 7.9646, Val Loss: 7.8654
Epoch [5/500], Train Loss: 7.5986, Val Loss: 7.6130
Epoch [6/500], Train Loss: 7.3334, Val Loss: 7.4305
Epoch [7/500], Train Loss: 7.1272, Val Loss: 7.0853
Epoch [8/500], Train Loss: 6.9541, Val Loss: 7.0251
Epoch [9/500], Train Loss: 6.8049, Val Loss: 6.7875
Epoch [10/500], Train Loss: 6.6743, Val Loss: 6.6703
Epoch [11/500], Train Loss: 6.5586, Val Loss: 6.5076
Epoch [12/500], Train Loss: 6.4574, Val Loss: 6.4108
Epoch [13/500], Train Loss: 6.3698, Val Loss: 6.2961
Epoch [14/500], Train Loss: 6.2927, Val Loss: 6.1800
Epoch [15/500], Train Loss: 6.2243, Val Loss: 6.0854
Epoch [16/500], Train Loss: 6.1671, Val Loss: 6.0580
Epoch [17/500], Train Loss: 6.1168, Val Loss: 5.9766
Epoch [18/500], Train Loss: 6.0739, Val Loss: 5.9582
Epoch [19/500], Train Loss: 6.0375, Val Loss: 5.8765
E

Epoch [155/500], Train Loss: 5.4472, Val Loss: 5.2968
Epoch [156/500], Train Loss: 5.4463, Val Loss: 5.2974
Epoch [157/500], Train Loss: 5.4433, Val Loss: 5.2900
Epoch [158/500], Train Loss: 5.4432, Val Loss: 5.3019
Epoch [159/500], Train Loss: 5.4421, Val Loss: 5.2970
Epoch [160/500], Train Loss: 5.4419, Val Loss: 5.2957
Epoch [161/500], Train Loss: 5.4415, Val Loss: 5.2974
Epoch [162/500], Train Loss: 5.4405, Val Loss: 5.2977
Epoch [163/500], Train Loss: 5.4403, Val Loss: 5.3023
Epoch [164/500], Train Loss: 5.4392, Val Loss: 5.2934
Epoch [165/500], Train Loss: 5.4390, Val Loss: 5.3098
Epoch [166/500], Train Loss: 5.4385, Val Loss: 5.3026
Epoch [167/500], Train Loss: 5.4382, Val Loss: 5.2926
‚è≥ Early stopping at epoch 167. No improvement for 10 epochs.
‚úÖ SAE training completed. Best model saved as sae_layer12_expanded.pth.


In [11]:
with open("losses_sae.json", "r") as f:
    data = json.load(f)
train_losses = data["train_losses"]
val_losses = data["val_losses"]

In [14]:
from sparse_auto_encoder import SparseAutoencoder

sae_layer12_expanded = SparseAutoencoder(input_dim=768, hidden_dim=1700)
sae_layer12_expanded.load_state_dict(torch.load("sae_layer12_expanded.pth"))

<All keys matched successfully>

In [15]:
train_sae(layer12_embeddings, model_name="sae_layer12_expanded_1.pth", epochs=500,
          patience=30, batch_size=32, lr=1e-5, hidden_dim=2304, weight_decay=1e-03, 
          train_losses=[],val_losses=[], device=device)

Epoch [1/500], Train Loss: 10.7967, Val Loss: 9.8300
Epoch [2/500], Train Loss: 9.4837, Val Loss: 9.0222
Epoch [3/500], Train Loss: 8.7017, Val Loss: 8.3775
Epoch [4/500], Train Loss: 8.1056, Val Loss: 8.0037
Epoch [5/500], Train Loss: 7.6678, Val Loss: 7.7563
Epoch [6/500], Train Loss: 7.3639, Val Loss: 7.4141
Epoch [7/500], Train Loss: 7.1397, Val Loss: 7.2778
Epoch [8/500], Train Loss: 6.9595, Val Loss: 7.0649
Epoch [9/500], Train Loss: 6.8049, Val Loss: 6.8241
Epoch [10/500], Train Loss: 6.6685, Val Loss: 6.6474
Epoch [11/500], Train Loss: 6.5516, Val Loss: 6.5122
Epoch [12/500], Train Loss: 6.4494, Val Loss: 6.3105
Epoch [13/500], Train Loss: 6.3583, Val Loss: 6.2694
Epoch [14/500], Train Loss: 6.2799, Val Loss: 6.1230
Epoch [15/500], Train Loss: 6.2110, Val Loss: 6.0314
Epoch [16/500], Train Loss: 6.1518, Val Loss: 5.9883
Epoch [17/500], Train Loss: 6.1017, Val Loss: 5.9333
Epoch [18/500], Train Loss: 6.0584, Val Loss: 5.8741
Epoch [19/500], Train Loss: 6.0212, Val Loss: 5.8591
E

KeyboardInterrupt: 

In [12]:
layer12_embeddings

array([[ 6.419478  ,  1.097434  , -5.971122  , ..., -2.1528153 ,
         4.6942825 ,  0.84316725],
       [-0.55278456, -2.5868275 , -4.9650073 , ..., -4.4642773 ,
         8.33209   , -6.3761005 ],
       [11.52194   , -7.388887  , -7.3803024 , ..., -2.946519  ,
        -0.92312884, -1.3777622 ],
       ...,
       [ 5.6442814 , -3.999446  , -2.193171  , ..., -2.052844  ,
         0.05622917,  0.82508975],
       [ 1.7008734 ,  1.0291717 ,  1.7188739 , ...,  0.36227274,
         3.265392  , -0.6340358 ],
       [ 4.86724   , -3.615003  , -0.6352348 , ...,  3.7851796 ,
         1.1924815 , -1.6266764 ]], dtype=float32)

In [10]:
train_sae(layer12_embeddings, model_name="sae_layer12_expanded_2.pth", epochs=500,
          patience=30, batch_size=32, lr=1e-5, hidden_dim=2304, weight_decay=1e-02, 
          train_losses=[],val_losses=[], device=device)



Epoch [1/500], Train Loss: 10.7911, Val Loss: 9.8234
üîπ New Learning Rate: 1e-05
Epoch [2/500], Train Loss: 9.4748, Val Loss: 8.9847
üîπ New Learning Rate: 1e-05
Epoch [3/500], Train Loss: 8.6924, Val Loss: 8.4460
üîπ New Learning Rate: 1e-05
Epoch [4/500], Train Loss: 8.0968, Val Loss: 7.9623
üîπ New Learning Rate: 1e-05
Epoch [5/500], Train Loss: 7.6660, Val Loss: 7.7041
üîπ New Learning Rate: 1e-05
Epoch [6/500], Train Loss: 7.3656, Val Loss: 7.4455
üîπ New Learning Rate: 1e-05
Epoch [7/500], Train Loss: 7.1416, Val Loss: 7.1965
üîπ New Learning Rate: 1e-05
Epoch [8/500], Train Loss: 6.9612, Val Loss: 7.0236
üîπ New Learning Rate: 1e-05
Epoch [9/500], Train Loss: 6.8082, Val Loss: 6.9582
üîπ New Learning Rate: 1e-05
Epoch [10/500], Train Loss: 6.6735, Val Loss: 6.6553
üîπ New Learning Rate: 1e-05
Epoch [11/500], Train Loss: 6.5566, Val Loss: 6.5508
üîπ New Learning Rate: 1e-05
Epoch [12/500], Train Loss: 6.4546, Val Loss: 6.4480
üîπ New Learning Rate: 1e-05
Epoch [13/50

KeyboardInterrupt: 

In [10]:
train_sae(layer6_embeddings, model_name="sae_layer6_expanded.pth", epochs=500,
          patience=30, batch_size=16, lr=1e-6, hidden_dim=3072, weight_decay=0.01, 
          train_losses=[],val_losses=[], device=device)

Epoch [1/500], Train Loss: 5.4032, Val Loss: 4.8838
Epoch [2/500], Train Loss: 4.8314, Val Loss: 4.5803
Epoch [3/500], Train Loss: 4.6227, Val Loss: 4.3705
Epoch [4/500], Train Loss: 4.4661, Val Loss: 4.2232
Epoch [5/500], Train Loss: 4.3308, Val Loss: 4.0835
Epoch [6/500], Train Loss: 4.2070, Val Loss: 3.9584
Epoch [7/500], Train Loss: 4.0939, Val Loss: 3.8387
Epoch [8/500], Train Loss: 3.9889, Val Loss: 3.7355
Epoch [9/500], Train Loss: 3.8890, Val Loss: 3.6505
Epoch [10/500], Train Loss: 3.7948, Val Loss: 3.5522
Epoch [11/500], Train Loss: 3.7048, Val Loss: 3.4585
Epoch [12/500], Train Loss: 3.6188, Val Loss: 3.3752
Epoch [13/500], Train Loss: 3.5362, Val Loss: 3.3095
Epoch [14/500], Train Loss: 3.4574, Val Loss: 3.2342
Epoch [15/500], Train Loss: 3.3810, Val Loss: 3.1437
Epoch [16/500], Train Loss: 3.3091, Val Loss: 3.0838
Epoch [17/500], Train Loss: 3.2393, Val Loss: 3.0454
Epoch [18/500], Train Loss: 3.1739, Val Loss: 2.9571
Epoch [19/500], Train Loss: 3.1122, Val Loss: 2.9221
Ep

KeyboardInterrupt: 