In [7]:
import subprocess
from textwrap import dedent
from itertools import product
import os
CONDA_ENV_NAME = "HL-env"
REPO_DIR = os.path.abspath(".")  # adjust if needed
SWEEP_CONFIG = "grid"
PROJECT = f"decouple-alphas-{SWEEP_CONFIG}"
data = True # add the data param?
slurm = False  # whether to launch the jobs on SLURM or not


# Parameters that represent each unique optimisation space
# You can also make an item a lambda function to evaluate it dynamically where the argument is the hyperparameter dictionary
grid = {
    "default": {
        "model": ["basic_mlp"],
        "task": ["fmnist"],
        "optimizer.lr": [0.1, 0.5],
        # "optimizer.weight_decay": [0.0, 1e-4, 1e-3],
        # "optimizer.momentum": [0.0, 0.9, 0.99],
        "corruption.alpha": [0.01, 0.25, 0.5, 0.75, 0.99],
        "trainer.min_epochs": ["20"],
        "trainer.max_epochs": ["20"],
    },
    "md": {
        "optimizer.update_alg": ['md'],
        "optimizer.alpha": [0.01, 0.25, 0.5, 0.75, 0.99],
        "optimizer.block_size": ['4'],
    },
    # "gd": {
    #     "optimizer.update_alg": ['gd'],
    # },
}

def launch_job(**hp):
    """
    Launch a job on SLURM with the specified parameters.

    args == hyper params
    """
    # if any value is a lambda function, evaluate it with the current hp
    for key, value in hp.items():
        if callable(value):
            hp[key] = value(hp)

    name = "_".join([str(hp[k]) for k in sorted(hp)])
    study_name = f"study_{name}"
    group = name

    data_dir = "$TMP_SHARED"
    # Create the batch script as a multi-line string
    template_script = dedent(f"""\
        #!/bin/bash
        #SBATCH --job-name={name}
        #SBATCH --output=slurm-logs/{PROJECT}/{name}_%j.out
        #SBATCH --error=slurm-logs/{PROJECT}/{name}_%j.err
        #SBATCH --time=01:00:00
        #SBATCH --partition=gpu
        #SBATCH --gres=gpu:1
        #SBATCH --mem=16G
        #SBATCH --cpus-per-task=4

        module load miniforge
        conda activate $HOME/{CONDA_ENV_NAME}

        export CUDA_DEVICE_ORDER=PCI_BUS_ID

        LOGGING="$SCRATCH/{PROJECT}/{study_name}"

        mkdir -p "$LOGGING"
        CHKP="$LOGGING/last.ckpt"

        cd $LOGGING
        echo "Copying data from {REPO_DIR}/data into {data_dir}/data"
        cp -r "{REPO_DIR}/data" "{data_dir}/data"

    """)

    cmd = [
        "python", f"{REPO_DIR}/src/train.py", "-m",
        f"hydra.sweeper.study_name={study_name}",
        f"hparams_search={SWEEP_CONFIG}",
        f"logger.group={group}",
        f"save_dir=$LOGGING" if slurm else f"save_dir={REPO_DIR}/logs/{PROJECT}/{study_name}",
        f"logger.project={PROJECT}",
    ]

    if data:
        cmd.append(f"data.data_dir={data_dir}/data")


    # the keu is the name of the hyperparameter, the value is the value to set it to
    for key, value in hp.items():
        cmd.append(f"{key}={value}")

    # Add the command to run the script
    batch_script = template_script + "\n" + " ".join(cmd) + "\n" + "echo 'Job completed.'\n"

    # Write the script to a temp file (can be named uniquely)
    script_filename = f"tmp.sh"

    if slurm:
        with open(script_filename, "w") as f:
            f.write(batch_script)

        # Launch the job using sbatch
        subprocess.run(["sbatch", script_filename])
    else:
        # If not using SLURM, just run the command directly
        print("Running command directly (not on SLURM):", " ".join(cmd))
        subprocess.run(cmd)

def print_grid_stats(grid):
    default = grid.get("default", {})
    total = 0

    print("Grid Search Stats:\n")

    for space, params in grid.items():
        if space == "default":
            continue

        # Merge default with specific subspace params
        full_params = {**default, **params}
        keys = sorted(full_params.keys())
        values_list = [full_params[key] for key in keys]

        num_configs = 1
        for v in values_list:
            num_configs *= len(v)

        print(f"  - {space}: {num_configs} configurations")
        total += num_configs

    print(f"\nTotal configurations: {total}")


print_grid_stats(grid)
input("Press Enter to continue... or Ctrl+C to exit.")
for space, params in grid.items():
    if space == "default":
        continue

    # Add the default parameters to the grid
    full_params = {**grid["default"], **params}
    keys = sorted(full_params.keys())
    values_list = [full_params[key] for key in keys]

    for values in product(*values_list):
        hp = {
            key: value
            for key, value in zip(keys, values)
        }
        # Launch the job with the hyperparameters
        launch_job(**hp)


Grid Search Stats:

  - md: 50 configurations

Total configurations: 50
Press Enter to continue... or Ctrl+C to exit.
Running command directly (not on SLURM): python /content/HeterosynapticLearning/src/train.py -m hydra.sweeper.study_name=study_0.01_basic_mlp_0.01_4_0.1_md_fmnist_20_20 hparams_search=grid logger.group=0.01_basic_mlp_0.01_4_0.1_md_fmnist_20_20 save_dir=/content/HeterosynapticLearning/logs/decouple-alphas-grid/study_0.01_basic_mlp_0.01_4_0.1_md_fmnist_20_20 logger.project=decouple-alphas-grid data.data_dir=$TMP_SHARED/data corruption.alpha=0.01 model=basic_mlp optimizer.alpha=0.01 optimizer.block_size=4 optimizer.lr=0.1 optimizer.update_alg=md task=fmnist trainer.max_epochs=20 trainer.min_epochs=20
Running command directly (not on SLURM): python /content/HeterosynapticLearning/src/train.py -m hydra.sweeper.study_name=study_0.01_basic_mlp_0.01_4_0.5_md_fmnist_20_20 hparams_search=grid logger.group=0.01_basic_mlp_0.01_4_0.5_md_fmnist_20_20 save_dir=/content/HeterosynapticL

KeyboardInterrupt: 

In [1]:
!git clone https://github.com/clarakuempel/HeterosynapticLearning.git

Cloning into 'HeterosynapticLearning'...
remote: Enumerating objects: 944, done.[K
remote: Counting objects: 100% (313/313), done.[K
remote: Compressing objects: 100% (163/163), done.[K
remote: Total 944 (delta 201), reused 246 (delta 143), pack-reused 631 (from 1)[K
Receiving objects: 100% (944/944), 3.75 MiB | 8.29 MiB/s, done.
Resolving deltas: 100% (568/568), done.


In [3]:
!pip install -r requirements.txt

[31mERROR: Could not open requirements file: [Errno 2] No such file or directory: 'requirements.txt'[0m[31m
[0m

In [4]:
!git checkout prune-tests

fatal: not a git repository (or any of the parent directories): .git


In [5]:
%cd /content/HeterosynapticLearning

# check branches
!git branch -a

# switch branch
!git checkout prune-tests

# install requirements from inside repo
!pip install -r requirements.txt

/content/HeterosynapticLearning
* [32mmain[m
  [31mremotes/origin/HEAD[m -> origin/main
  [31mremotes/origin/changes[m
  [31mremotes/origin/instantiate[m
  [31mremotes/origin/lr_schdl[m
  [31mremotes/origin/main[m
  [31mremotes/origin/mdM[m
  [31mremotes/origin/nonCausal[m
  [31mremotes/origin/optuna-optim[m
  [31mremotes/origin/penn-fix[m
  [31mremotes/origin/prune-tests[m
  [31mremotes/origin/rerun-corrup[m
  [31mremotes/origin/test-penn-treebank-cluster[m
Branch 'prune-tests' set up to track remote branch 'prune-tests' from 'origin'.
Switched to a new branch 'prune-tests'
Collecting lightning>=2.0.0 (from -r requirements.txt (line 4))
  Downloading lightning-2.5.5-py3-none-any.whl.metadata (39 kB)
[31mERROR: Could not find a version that satisfies the requirement torchtext==0.17.1 (from versions: 0.1.1, 0.2.0, 0.2.1, 0.2.3, 0.3.1, 0.4.0, 0.5.0, 0.6.0, 0.16.2, 0.17.2, 0.18.0)[0m[31m
[0m[31mERROR: No matching distribution found for torchtext==0.17.1[0m

In [6]:
%cd /content/HeterosynapticLearning
!sed -i '/torchtext/d' requirements.txt
!pip install -r requirements.txt

/content/HeterosynapticLearning
Collecting lightning>=2.0.0 (from -r requirements.txt (line 4))
  Using cached lightning-2.5.5-py3-none-any.whl.metadata (39 kB)
Collecting torchdata==0.7.1 (from -r requirements.txt (line 8))
  Downloading torchdata-0.7.1-py3-none-any.whl.metadata (13 kB)
Collecting portalocker==3.2.0 (from -r requirements.txt (line 9))
  Downloading portalocker-3.2.0-py3-none-any.whl.metadata (8.7 kB)
Collecting hydra-core==1.3.2 (from -r requirements.txt (line 12))
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting hydra-colorlog==1.2.0 (from -r requirements.txt (line 13))
  Downloading hydra_colorlog-1.2.0-py3-none-any.whl.metadata (949 bytes)
Collecting hydra-optuna-sweeper==1.2.0 (from -r requirements.txt (line 14))
  Downloading hydra_optuna_sweeper-1.2.0-py3-none-any.whl.metadata (1.0 kB)
Collecting rootutils (from -r requirements.txt (line 20))
  Downloading rootutils-1.0.7-py3-none-any.whl.metadata (4.7 kB)
Collecting colorlog (from hy

In [8]:
import os, getpass
if "WANDB_API_KEY" not in os.environ:
    os.environ["WANDB_API_KEY"] = getpass.getpass("Enter your W&B API key: ")
# optional defaults
os.environ.setdefault("WANDB_ENTITY", "hp-learning-rules")
os.environ.setdefault("WANDB_PROJECT", "decouple-alphas-grid")

Enter your W&B API key: ··········


'decouple-alphas-grid'

In [None]:
import subprocess
from itertools import product
import os
from pathlib import Path

# ---- paths & config ----
REPO_DIR = Path("/content/HeterosynapticLearning").resolve()  # repo root
SWEEP_CONFIG = "grid"
PROJECT = f"decouple-alphas-{SWEEP_CONFIG}"
USE_DATA_PARAM = True   # pass data.data_dir if data/ exists
SLURM = False           # we're on Colab

# ---- grid ----
grid = {
    "default": {
        "model": ["basic_mlp"],
        "task": ["fmnist"],
        "optimizer.lr": [0.1, 0.5],
        "corruption.alpha": [0.01, 0.25, 0.5, 0.75, 0.99],
        "trainer.min_epochs": [20],
        "trainer.max_epochs": [20],
    },
    "md": {
        "optimizer.update_alg": ["md"],
        "optimizer.alpha": [0.01, 0.25, 0.5, 0.75, 0.99],
        "optimizer.block_size": ["4"],
    },
    # "gd": {"optimizer.update_alg": ["gd"]},
}

LOGS_ROOT = REPO_DIR / "logs" / PROJECT
LOGS_ROOT.mkdir(parents=True, exist_ok=True)

def print_grid_stats(grid):
    default = grid.get("default", {})
    total = 0
    print("Grid Search Stats:\n")
    for space, params in grid.items():
        if space == "default":
            continue
        full_params = {**default, **params}
        keys = sorted(full_params.keys())
        n = 1
        for k in keys:
            n *= len(full_params[k])
        print(f"  - {space}: {n} configurations")
        total += n
    print(f"\nTotal configurations: {total}\n")

def launch_job(**hp):
    # deterministic name
    name = "_".join([str(hp[k]) for k in sorted(hp)])
    study_name = f"study_{name}"
    out_dir = LOGS_ROOT / study_name
    out_dir.mkdir(parents=True, exist_ok=True)

    cmd = [
        "python", str(REPO_DIR / "src" / "train.py"), "-m",
        f"hydra.sweeper.study_name={study_name}",
        f"hparams_search={SWEEP_CONFIG}",
        f"logger.group={name}",
        f"logger.project={PROJECT}",
        f"save_dir={out_dir}",
    ]

    cmd += ["trainer.accelerator=gpu", "trainer.devices=1"]

    if USE_DATA_PARAM:
        data_dir = REPO_DIR / "data"
        if data_dir.exists():
            cmd.append(f"data.data_dir={data_dir}")
        else:
            print("⚠️ data/ not found in repo; skipping data.data_dir param.")

    for key, value in hp.items():
        cmd.append(f"{key}={value}")

    # Make sure we run from repo root (Hydra relative paths etc.)
    print("⏩ Running:", " ".join(map(str, cmd)))
    subprocess.run(cmd, cwd=str(REPO_DIR), check=True)

# --- run ---
print_grid_stats(grid)

for space, params in grid.items():
    if space == "default":
        continue
    full_params = {**grid["default"], **params}
    keys = sorted(full_params.keys())
    values_list = [full_params[k] for k in keys]
    for values in product(*values_list):
        hp = {k: v for k, v in zip(keys, values)}
        launch_job(**hp)

Grid Search Stats:

  - md: 50 configurations

Total configurations: 50

⏩ Running: python /content/HeterosynapticLearning/src/train.py -m hydra.sweeper.study_name=study_0.01_basic_mlp_0.01_4_0.1_md_fmnist_20_20 hparams_search=grid logger.group=0.01_basic_mlp_0.01_4_0.1_md_fmnist_20_20 logger.project=decouple-alphas-grid save_dir=/content/HeterosynapticLearning/logs/decouple-alphas-grid/study_0.01_basic_mlp_0.01_4_0.1_md_fmnist_20_20 trainer.accelerator=gpu trainer.devices=1 data.data_dir=/content/HeterosynapticLearning/data corruption.alpha=0.01 model=basic_mlp optimizer.alpha=0.01 optimizer.block_size=4 optimizer.lr=0.1 optimizer.update_alg=md task=fmnist trainer.max_epochs=20 trainer.min_epochs=20
⏩ Running: python /content/HeterosynapticLearning/src/train.py -m hydra.sweeper.study_name=study_0.01_basic_mlp_0.01_4_0.5_md_fmnist_20_20 hparams_search=grid logger.group=0.01_basic_mlp_0.01_4_0.5_md_fmnist_20_20 logger.project=decouple-alphas-grid save_dir=/content/HeterosynapticLearning