# Upload SAEs from Weights & Biases to HuggingFace

This script uploads trained SAEs from Weights and Biases run artifacts to a HuggingFace model repo. Each run will get its own folder containing the weights, a config JSON and a tensor of sparsity values. The models are first downloaded to the local machine in the `scripts/artifacts` folder, then uploaded to HuggingFace, before the artifacts downloaded in this run are then deleted from `scripts/artifacts` (pre-existing artifacts in this folder will be left unchanged).

To run this script uou'll need to:
- Create a HuggingFace model repo to upload to (or use an existing one).
- Get a HuggingFace [write access token](https://huggingface.co/docs/hub/en/security-tokens) for the repo.
- Put these variables in the cell below, along with the W&B project name you want to transfer.

Note: only finished runs will be transferred. You can change this or add extra filtering in the cell beginning with the comment "more filtering on W&B runs can be done here".

Script variables

In [None]:
wandb_project_name = "YOUR-WANDB-PROJECT"
hf_repo_id = "YOUR-HF-REPO"
hf_token = "YOUR-HF-TOKEN"  # do not upload to github!

## W&B downloads

In [None]:
import wandb

### Get runs from W&B

In [None]:
api = wandb.Api()



In [None]:
# more filtering on W&B runs can be done here
runs = api.runs(wandb_project_name)
completed_runs = [run for run in runs if run.state == "finished"]
sorted_by_layer = sorted(completed_runs, key=lambda x: x.config["hook_point_layer"])

### W&B helper functions

In [None]:
def get_model_endstr(run):
    d_sae = int(run.config["d_in"]) * int(run.config["expansion_factor"])
    hook_point = run.config["hook_point"]
    return "_".join([hook_point, str(d_sae)])

In [None]:
def is_model(artifact, model_endstr):
    return artifact.name.split(":")[0].endswith(model_endstr)

In [None]:
def get_sparsity_endstr(run):
    model_endstr = get_model_endstr(run)
    return model_endstr + "_log_feature_sparsity"

In [None]:
def is_sparsity(artifact, sparsity_endstr):
    return artifact.name.split(":")[0].endswith(sparsity_endstr)

In [None]:
def get_model_artifact(run):
    model_endstr = get_model_endstr(run)
    for a in run.logged_artifacts():
        if is_model(a, model_endstr):
            return a
        
        else:
            continue
    

In [None]:
def get_sparsity_artifact(run):
    sparsity_endstr = get_sparsity_endstr(run)
    for a in run.logged_artifacts():
        if is_sparsity(a, sparsity_endstr):
            return a
        
        else:
            continue

In [None]:
def download_model_and_sparsity(run):
    print(f"Downloading model & sparsity for {run.config['hook_point']}")
    model_artifact = get_model_artifact(run)
    model_path = model_artifact.download()
    sparsity_artifact = get_sparsity_artifact(run)
    sparsity_artifact.download(root=model_path)
    return model_path

### Download from W&B (quite slow)

In [None]:
model_paths = [download_model_and_sparsity(run) for run in sorted_by_layer]

## HF uploads

In [None]:
from huggingface_hub import HfApi
api = HfApi()

In [None]:
def upload_to_hf(model_path):
    repo_path = model_path.split('/')[-1]
    api.upload_folder(
    folder_path=model_path,
    repo_id=hf_repo_id,
    path_in_repo=repo_path,
    token=hf_token,
    )

In [None]:
for model_path in model_paths:
    print(f"Uploading {model_path}")
    upload_to_hf(model_path)

## Cleanup

In [None]:
import os

In [None]:
for model_path in model_paths:
    ## first remove directory contents
    for f in os.scandir(model_path):
        os.remove(f)

    ## next remove the (now empty) dir
    os.rmdir(model_path)