To use: Move this notebook into `dictionary_learning/dictionaries` and run from there. `dictionaries` is a folder in a submodule so we don't track it currently.

In [None]:
import os
import shutil

def copy_folder_structure(src, dst, exclude_files):
    for root, dirs, files in os.walk(src):
        # Create corresponding directory in destination
        rel_path = os.path.relpath(root, src)
        dst_dir = os.path.join(dst, rel_path)
        os.makedirs(dst_dir, exist_ok=True)
        
        # Copy files, excluding those in exclude_files
        for file in files:
            if file not in exclude_files:
                src_file = os.path.join(root, file)
                dst_file = os.path.join(dst_dir, file)
                shutil.copy2(src_file, dst_file)

# Example usage
src_folder = "../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712"
dst_folder = f"{src_folder}_results"
files_to_exclude = ["ae.pt"]

copy_folder_structure(src_folder, dst_folder, files_to_exclude)

print(f"Folder structure copied from {src_folder} to {dst_folder}, excluding {files_to_exclude}")

In [3]:
import os
import shutil

## This copies all files from src to dst, excluding those in exclude_files
# Used to copy results files from a data file to the autoencoder folder

def copy_files(src: str, dst: str, exclude_files: list[str]):
    dst_files = []
    for root, dirs, files in os.walk(src):
        for file in files:
            if file not in exclude_files:
                src_file = os.path.join(root, file)
                dst_file = os.path.join(dst, os.path.relpath(src_file, src))
                os.makedirs(os.path.dirname(dst_file), exist_ok=True)
                shutil.copy2(src_file, dst_file)
                print(f"Copied: {dst_file}")
                dst_files.append(dst_file)
    if len(dst_files) == 0:
        raise ValueError(f"No files copied from {src} to {dst}")

# Paths

ae_path = "pythia70m_sweep_gated_ctx128_0730"
ae_path = "pythia70m_sweep_topk_ctx128_0730"
ae_path = "pythia70m_sweep_standard_ctx128_0712"


src_folder = f"../dictionary_learning/dictionaries/all_730_results/{ae_path}_results"
dst_folder = f"../dictionary_learning/dictionaries/{ae_path}"

# Files to always exclude
files_to_exclude = ['config.json', 'ae.pt']

# Run the copy function
copy_files(src_folder, dst_folder, files_to_exclude)

print(f"Files copied from {src_folder} to {dst_folder}")
print(f"Excluded files: {', '.join(files_to_exclude)}")

Copied: ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_18/node_effects.pkl
Copied: ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_18/class_accuracies.pkl
Copied: ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_18/eval_results.json
Copied: ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_16/node_effects.pkl
Copied: ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_16/class_accuracies.pkl
Copied: ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_16/eval_results.json
Copied: ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_11/node_effects.pkl
Copied: ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_la

In [30]:
import os
import torch

def process_ae_files(src_folder):
    for root, dirs, files in os.walk(src_folder):
        for file in files:
            if file == 'ae.pt':
                file_path = os.path.join(root, file)
                print(f"Found ae.pt file: {file_path}")
                
                # Load the state dict
                state_dict = torch.load(file_path)

                decoder = state_dict["decoder"]
                state_dict["decoder.weight"] = decoder.T

                del state_dict["decoder"]

                torch.save(state_dict, file_path)
                
                print(f"Processed: {file_path}")

# Example usage
src_folder = "../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712"


src_folder = "../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730"

process_ae_files(src_folder)

Found ae.pt file: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_0/ae.pt
Processed: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_0/ae.pt
Found ae.pt file: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_1/ae.pt
Processed: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_1/ae.pt
Found ae.pt file: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_10/ae.pt
Processed: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_10/ae.pt
Found ae.pt file: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_11/ae.pt


Processed: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_11/ae.pt
Found ae.pt file: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_12/ae.pt
Processed: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_12/ae.pt
Found ae.pt file: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_13/ae.pt
Processed: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_13/ae.pt
Found ae.pt file: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_14/ae.pt
Processed: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/resid_post_layer_3/trainer_14/ae.pt
Found ae.pt file: ../dictionary_learning/dictionaries/pythia70m_sweep_transpose_topk_ctx128_0730/

In [22]:
import os
import torch

def process_ae_files(src_folder):
    for root, dirs, files in os.walk(src_folder):
        for file in files:
            if file == 'ae.pt':
                file_path = os.path.join(root, file)
                print(f"Found ae.pt file: {file_path}")
                
                # Load the state dict
                state_dict = torch.load(file_path)

                print(state_dict.keys())

                print(state_dict["decoder.weight"].shape)
                raise ValueError("Stop here")
                
                # TODO: Modify the state dict here
                
                # TODO: Save the modified state dict
                
                print(f"Processed: {file_path}")

# Example usage
src_folder = "../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712"


process_ae_files(src_folder)

Found ae.pt file: ../dictionary_learning/dictionaries/pythia70m_sweep_standard_ctx128_0712/resid_post_layer_3/trainer_0/ae.pt
odict_keys(['bias', 'encoder.weight', 'encoder.bias', 'decoder.weight'])
torch.Size([512, 4096])


ValueError: Stop here