In [None]:
from ActivationStoreParallel import ActivationsStore
from sparse_transcoder import SparseTranscoder
from transcoder_training_parallel import train_transcoder_on_language_model_parallel
from transcoder_runner_parallel import language_model_transcoder_runner_parallel
from dataclasses import dataclass
import transformer_lens
import torch
import wandb

In [None]:
@dataclass
class Config1():

    # Data Generating Function (Model + Training Distibuion)
    model_name = "roneneldan/TinyStories-1Layer-21M"
    hook_transcoder_in = "blocks.0.hook_resid_pre"
    hook_point = "blocks.0.hook_resid_pre"
    hook_transcoder_out = "blocks.0.attn.hook_q"
    target = "blocks.0.attn.hook_q"
    hook_point_layer = 0
    ln = 'blocks.0.ln1.hook_scale'
    d_in = 1024
    d_out = 1024
    n_head = 16
    d_head = 64
    dataset_path = "roneneldan/TinyStories"
    is_dataset_tokenized=False
    layer = 0
    use_ghost_grads = False
    training = True
    attn_scores_normed = False
    
    
    # SAE Parameters
    expansion_factor = 12 # determines the dimension of the SAE.
    d_hidden = 1024
    b_dec_init_method = "mean"
    
    # Training Parameters
    lr = 4e-4
    reg_coefficient = 4e-6
    lr_scheduler_name=None
    train_batch_size = 4096
    context_size = 256
    lr_warm_up_steps=5000
    
    # Activation Store Parameters
    n_batches_in_buffer = 128
    total_training_tokens = 20_000 * 10_000 # 200M tokens seems doable overnight.
    store_batch_size = 32
    use_cached_activations = False
    
    # Resampling protocol
    feature_sampling_method = 'none'
    feature_sampling_window = 1000
    feature_reinit_scale = 0.2
    resample_batches=1028
    dead_feature_window=50000
    dead_feature_threshold = 1e-6
    
    # WANDB
    log_to_wandb = True
    wandb_project= "sparsification"
    wandb_entity = None
    wandb_log_frequency=1000
    entity = "kwyn390"
    
    # Misc
    device = "cuda"
    eps = 1e-7
    seed = 42
    reshape_from_heads = True
    n_checkpoints = 10
    checkpoint_path = "checkpoints"
    dtype = torch.float32
    run_name = "qk_parallel"
    type = "resid_to_queries"

cfg1 = Config1()

cfg1.run_name = str(cfg1.d_hidden) + "_" + str(cfg1.reg_coefficient) + "_" + str(cfg1.lr)

In [None]:
@dataclass
class Config2():

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2"
    hook_transcoder_in = "blocks.10.hook_resid_pre"
    hook_point = "blocks.10.hook_resid_pre"
    hook_transcoder_out = "blocks.10.attn.hook_k"
    target = "blocks.10.attn.hook_k"
    hook_point_layer = 2
    layer = 10
    d_in = 768
    d_out = 768
    dataset_path = "Skylion007/openwebtext"
    is_dataset_tokenized=False
    use_ghost_grads = False
    training = True
    d_head = 64
    n_head = 12

    
    # SAE Parameters
    expansion_factor = 12 # determines the dimension of the SAE.
    d_hidden = 1024
    b_dec_init_method = "mean"
    
    # Training Parameters
    lr = 0.0002
    reg_coefficient = 4e-6
    lr_scheduler_name=None
    train_batch_size = 4096
    context_size = 128
    lr_warm_up_steps=5000
    
    # Activation Store Parameters
    n_batches_in_buffer = 128
    total_training_tokens = 20_000 * 20_000 # 200M tokens seems doable overnight.
    store_batch_size = 32
    use_cached_activations = False
    
    # Resampling protocol
    feature_sampling_method = 'none'
    feature_sampling_window = 1000
    feature_reinit_scale = 0.2
    resample_batches=1028
    dead_feature_window=50000
    dead_feature_threshold = 1e-6
    
    # WANDB
    log_to_wandb = True
    wandb_project= "transcoder_training_gpt2_L10"
    wandb_entity = None
    wandb_log_frequency=1000
    
    # Misc
    device = "cuda"
    eps = 1e-7
    seed = 42
    reshape_from_heads = True
    n_checkpoints = 10
    checkpoint_path = "checkpoints"
    dtype = torch.float32
    run_name = "parallel_test"
    type = "resid_to_keys"

cfg2 = Config2()

In [None]:
sparse_transcoder_Q, sparse_transcoder_K = language_model_transcoder_runner_parallel(cfg1, cfg2)

In [1]:
import torch

In [2]:
kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target = True)


In [54]:
a = torch.rand(12, 256, 256)
b = torch.rand(12, 256, 256)


a_patt = a.softmax(-1)
b_patt = b.softmax(-1)

print(kl_loss(a_patt, b_patt).mean())

tensor(0.0819)


In [None]:
kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target = True)

kl_loss(a_patt, b_patt)