In [1]:
import subprocess

s3_paths = [
    "s3://measurement-noise-scaling-laws/data/larry/100000/1.0/results/Geneformer/checkpoint-63000/model.safetensors",
    "s3://measurement-noise-scaling-laws/data/larry/100000/1.0/results/SCVI/model/model.pt",
    "s3://measurement-noise-scaling-laws/data/merfish/60000/1.0/results/Geneformer/checkpoint-65000/model.safetensors",
    "s3://measurement-noise-scaling-laws/data/merfish/60000/1.0/results/SCVI/model/model.pt",
    "s3://measurement-noise-scaling-laws/data/PBMC/100000/1.0/results/Geneformer/checkpoint-42000/model.safetensors",
    "s3://measurement-noise-scaling-laws/data/PBMC/100000/1.0/results/SCVI/model/model.pt",
    "s3://measurement-noise-scaling-laws/data/shendure/10000000/1.0/results/Geneformer/checkpoint-86000/model.safetensors",
    "s3://measurement-noise-scaling-laws/data/shendure/59948/1.0/results/SCVI/model/model.pt",
]

output_dir = "outputs/2026-01-08_parameter_counts"
import os
os.makedirs(output_dir, exist_ok=True)

for s3_path in s3_paths:
    # Extract the filename from the S3 path
    filename = s3_path.split("/")[-1]
    # Find the dataset folder (immediately after "data/")
    parts = s3_path.split('/')
    if "data" in parts:
        data_index = parts.index("data")
        # assume next part is dataset (case insensitive in output on disk)
        dataset = parts[data_index + 1].lower()
    else:
        dataset = "unknown"
    # Build destination folder
    dataset_dir = os.path.join(output_dir, dataset)
    os.makedirs(dataset_dir, exist_ok=True)
    dest_path = os.path.join(dataset_dir, filename)
    subprocess.run(["aws", "s3", "cp", s3_path, dest_path])

download: s3://measurement-noise-scaling-laws/data/larry/100000/1.0/results/Geneformer/checkpoint-63000/model.safetensors to outputs/2026-01-08_parameter_counts/larry/model.safetensors
download: s3://measurement-noise-scaling-laws/data/larry/100000/1.0/results/SCVI/model/model.pt to outputs/2026-01-08_parameter_counts/larry/model.pt
download: s3://measurement-noise-scaling-laws/data/merfish/60000/1.0/results/Geneformer/checkpoint-65000/model.safetensors to outputs/2026-01-08_parameter_counts/merfish/model.safetensors
download: s3://measurement-noise-scaling-laws/data/merfish/60000/1.0/results/SCVI/model/model.pt to outputs/2026-01-08_parameter_counts/merfish/model.pt
download: s3://measurement-noise-scaling-laws/data/PBMC/100000/1.0/results/Geneformer/checkpoint-42000/model.safetensors to outputs/2026-01-08_parameter_counts/pbmc/model.safetensors
download: s3://measurement-noise-scaling-laws/data/PBMC/100000/1.0/results/SCVI/model/model.pt to outputs/2026-01-08_parameter_counts/pbmc/mo

In [18]:

import os
import csv

def count_params(model):
    # Try common keys for state dicts
    if isinstance(model, dict):
        if 'model_state_dict' in model:
            state_dict = model['model_state_dict']
        elif 'state_dict' in model:
            state_dict = model['state_dict']
        else:
            # Fallback: take first tensor-dict found
            for v in model.values():
                if isinstance(v, dict):
                    for possible_sd in ['model_state_dict', 'state_dict']:
                        if possible_sd in v:
                            state_dict = v[possible_sd]
                            break
                    else:
                        continue
                    break
            else:
                state_dict = None
    elif hasattr(model, 'state_dict'):
        state_dict = model.state_dict()
    else:
        state_dict = None

    if state_dict is None:
        raise ValueError("Could not find model state dict to count parameters.")

    return sum(v.numel() for v in state_dict.values() if hasattr(v, "numel"))

def count_parameters_safetensors(safetensors_path):
    from safetensors.torch import load_file
    state_dict = load_file(safetensors_path)
    total = sum(p.numel() for p in state_dict.values())
    return total

base_dir = os.path.join("outputs", "2026-01-08_parameter_counts")
csv_output_path = os.path.join(base_dir, "parameter_counts.csv")
rows = []

try:
    import safetensors.torch  # noqa: F401
    safetensors_available = True
except ImportError:
    safetensors_available = False

for dataset_folder in sorted(os.listdir(base_dir)):
    dataset_path = os.path.join(base_dir, dataset_folder)
    if not os.path.isdir(dataset_path):
        continue

    pt_path = os.path.join(dataset_path, "model.pt")
    safetensors_path = os.path.join(dataset_path, "model.safetensors")

    # Process svi (model.pt)
    if os.path.exists(pt_path):
        import torch
        try:
            model = torch.load(pt_path, map_location="cpu", weights_only=False)
            pt_params = count_params(model)
            rows.append({
                "dataset": dataset_folder,
                "model": "svi",
                "params": pt_params
            })
        except Exception as e:
            rows.append({
                "dataset": dataset_folder,
                "model": "svi",
                "params": None
            })
    # Process geneformer (model.safetensors)
    if os.path.exists(safetensors_path) and safetensors_available:
        try:
            safetensor_params = count_parameters_safetensors(safetensors_path)
            rows.append({
                "dataset": dataset_folder,
                "model": "geneformer",
                "params": safetensor_params
            })
        except Exception as e:
            rows.append({
                "dataset": dataset_folder,
                "model": "geneformer",
                "params": None
            })
    elif os.path.exists(safetensors_path) and not safetensors_available:
        rows.append({
            "dataset": dataset_folder,
            "model": "geneformer",
            "params": None
        })

# Write to CSV with columns: dataset, model, params
with open(csv_output_path, mode="w", newline="") as f:
    writer = csv.DictWriter(f, fieldnames=["dataset", "model", "params"])
    writer.writeheader()
    for row in rows:
        writer.writerow(row)



In [4]:
import pandas as pd
from wandb import Api

# Download all W&B run history (all columns)
api = Api()
run = api.run('igor-somite-somite/geneformer-scaling-laws/id3ize6p')
df = pd.DataFrame(list(run.scan_history()))

print(df)

     total_parameters  train/learning_rate  eval/steps_per_second  \
0          13480151.0                  NaN                    NaN   
1                 NaN             0.000200                    NaN   
2                 NaN                  NaN                  2.404   
3                 NaN             0.000400                    NaN   
4                 NaN                  NaN                  2.521   
..                ...                  ...                    ...   
171               NaN             0.000948                    NaN   
172               NaN                  NaN                  2.510   
173               NaN             0.000947                    NaN   
174               NaN                  NaN                  4.543   
175               NaN                  NaN                    NaN   

     eval/samples_per_second  _step    _timestamp  train/grad_norm  eval/loss  \
0                        NaN      0  1.764008e+09              NaN        NaN   
1        

### Geneformer params

In [8]:
print("total_params count", df['total_parameters'].loc[0])

total_params count 13480151.0


### SCVI params

In [None]:
api = Api()
run = api.run('igor-somite-somite/scvi-scaling-laws/kse2pfsd')
df_scvi = pd.DataFrame(list(run.scan_history()))
