# Launch Experiment Sweets

This script grid searches over all listed experiment configurations and runs them sequentially. Optionally, runs that already exist in Weights and Biases can be skipped. 

In [1]:
import os
import ast
import numpy as np

from tqdm import tqdm
from itertools import combinations

## Build Run Grid

In [2]:
models = ["mistralai/Mistral-7B-Instruct-v0.3"]

# Interventions
editing_interventions = ["memit"]
unlearning_interventions = ["rmu"]
pruning_interventions = ["wanda"]
quantization_interventions = ["awq"]
all_interventions = editing_interventions + unlearning_interventions + pruning_interventions + quantization_interventions
print(all_interventions)

# Intervention Settings
pruning_levels = [0.25, 0.35, 0.45, 0.55, 0.65, 0.75]
quant_levels = [2, 3, 4, 5, 6, 8]
rmu_setting_overrides = {
    "rmu_alpha": [1000, 1000],
    "rmu_layer_id": 6,
    "rmu_max_num_batches": 450,
}

['memit', 'rmu', 'wanda', 'awq']


In [3]:
def get_all_combinations(lst):
    all_combs = []
    for r in range(1, len(lst) + 1):
        all_combs.extend(combinations(lst, r))

    return all_combs


all_intervention_combinations = get_all_combinations(all_interventions)
all_intervention_combinations

[('memit',),
 ('rmu',),
 ('wanda',),
 ('awq',),
 ('memit', 'rmu'),
 ('memit', 'wanda'),
 ('memit', 'awq'),
 ('rmu', 'wanda'),
 ('rmu', 'awq'),
 ('wanda', 'awq'),
 ('memit', 'rmu', 'wanda'),
 ('memit', 'rmu', 'awq'),
 ('memit', 'wanda', 'awq'),
 ('rmu', 'wanda', 'awq'),
 ('memit', 'rmu', 'wanda', 'awq')]

In [4]:
max_num_interventions = 2
run_configurations = []
for intervention_combination in all_intervention_combinations:
    if len(intervention_combination) > max_num_interventions:
        continue

    count_interventions = len(intervention_combination)
    all_compression_interventions = pruning_interventions + quantization_interventions
    is_double_compression = sum([technique in all_compression_interventions for technique in intervention_combination]) == 2
    if is_double_compression:
        continue

    intervention_orderings = {intervention_combination, tuple(reversed(intervention_combination))}
    for model_name in models:
        for ordering in intervention_orderings:
            if contains_pruning := any([technique in pruning_interventions for technique in ordering]):
                for sparsity_ratio in pruning_levels:
                    run_configurations.append({"interventions": ordering, "sparsity_ratio": sparsity_ratio, "wbits": 16, "model_name": model_name})
            elif contains_quantization := any([technique in quantization_interventions for technique in ordering]):
                for wbits in quant_levels:
                    run_configurations.append({"interventions": ordering, "sparsity_ratio": 0, "wbits": wbits, "model_name": model_name})

            default_sparsity_ratio = 0
            default_wbits = 16


# add one without interventions, just the model
run_configurations.append({"interventions": [], "sparsity_ratio": 0, "wbits": 16, "model_name": model_name})

run_configurations

[{'interventions': ('wanda',),
  'sparsity_ratio': 0.25,
  'wbits': 16,
  'model_name': 'mistralai/Mistral-7B-Instruct-v0.3'},
 {'interventions': ('wanda',),
  'sparsity_ratio': 0.35,
  'wbits': 16,
  'model_name': 'mistralai/Mistral-7B-Instruct-v0.3'},
 {'interventions': ('wanda',),
  'sparsity_ratio': 0.45,
  'wbits': 16,
  'model_name': 'mistralai/Mistral-7B-Instruct-v0.3'},
 {'interventions': ('wanda',),
  'sparsity_ratio': 0.55,
  'wbits': 16,
  'model_name': 'mistralai/Mistral-7B-Instruct-v0.3'},
 {'interventions': ('wanda',),
  'sparsity_ratio': 0.65,
  'wbits': 16,
  'model_name': 'mistralai/Mistral-7B-Instruct-v0.3'},
 {'interventions': ('wanda',),
  'sparsity_ratio': 0.75,
  'wbits': 16,
  'model_name': 'mistralai/Mistral-7B-Instruct-v0.3'},
 {'interventions': ('awq',),
  'sparsity_ratio': 0,
  'wbits': 2,
  'model_name': 'mistralai/Mistral-7B-Instruct-v0.3'},
 {'interventions': ('awq',),
  'sparsity_ratio': 0,
  'wbits': 3,
  'model_name': 'mistralai/Mistral-7B-Instruct-v0.3

## Pull Historical Run Data

In [5]:
skip_previous_runs = False
previous_runs = None
# TODO

## Invoke Runs

In [6]:
def set_tag(experiment_row):
    if experiment_row["interventions"] in [None, np.nan]:
        return "NONE"

    intervention_categories = None
    if isinstance(experiment_row["interventions"], str):
        intervention_categories = ast.literal_eval(experiment_row["interventions"])
    else:
        intervention_categories = experiment_row["interventions"]

    interventions = []
    for category in intervention_categories:
        intervention = category.upper()
        if intervention in ["AWQ", "GPTQ"]:
            intervention += str(int(experiment_row["wbits"])) + "bit"
        if intervention in ["WANDA", "SPARSEGPT"]:
            intervention += str(int(experiment_row["sparsity_ratio"] * 100)) + "%"

        interventions.append(intervention)

    if len(interventions) == 0:
        interventions.append("NONE")

    return "_".join(interventions)


commands = []
use_wandb = True
is_slurm = False
python_path = "~/miniconda3/envs/lm-compose/bin/python"

run_commands = []
for run_config in run_configurations:
    intervention_type_map = {"memit": "edit", "rmu": "unlearning", "wanda": "compression", "awq": "compression"}
    interventions_arg = [intervention_type_map[intervention] for intervention in run_config["interventions"]]
    interventions_arg_str = f"interventions={interventions_arg}".replace("'", "").replace(" ", "")
    category_args = [f"{intervention_type_map[intervention]}={intervention}" for intervention in run_config["interventions"]]
    category_args_str = " ".join(category_args)

    compress_args = []
    for intervention in run_config["interventions"]:
        if intervention in pruning_interventions:
            compress_args.append(f"sparsity_ratio={run_config['sparsity_ratio']}")
        elif intervention in quantization_interventions:
            compress_args.append(f"wbits={run_config['wbits']}")

    command_prefix = "sbatch run_exp.sh" if is_slurm else f"{python_path} -m lm_compose"
    command = command_prefix + f" model_name={run_config['model_name']} {interventions_arg_str} {category_args_str} {' '.join(compress_args)}"

    involves_rmu = any([intervention in run_config["interventions"] for intervention in unlearning_interventions])
    if involves_rmu:
        for key, value in rmu_setting_overrides.items():
            command += f" {key}={value}"

    tag = set_tag(run_config)
    command += f" tag={tag}"

    if use_wandb:
        command += " wandb=online"

    print(command)
    run_commands.append(command)

print(f"Total number of runs: {len(run_configurations)}")

~/miniconda3/envs/lm-compose/bin/python -m lm_compose model_name=mistralai/Mistral-7B-Instruct-v0.3 interventions=[compression] compression=wanda sparsity_ratio=0.25 tag=WANDA25% wandb=online
~/miniconda3/envs/lm-compose/bin/python -m lm_compose model_name=mistralai/Mistral-7B-Instruct-v0.3 interventions=[compression] compression=wanda sparsity_ratio=0.35 tag=WANDA35% wandb=online
~/miniconda3/envs/lm-compose/bin/python -m lm_compose model_name=mistralai/Mistral-7B-Instruct-v0.3 interventions=[compression] compression=wanda sparsity_ratio=0.45 tag=WANDA45% wandb=online
~/miniconda3/envs/lm-compose/bin/python -m lm_compose model_name=mistralai/Mistral-7B-Instruct-v0.3 interventions=[compression] compression=wanda sparsity_ratio=0.55 tag=WANDA55% wandb=online
~/miniconda3/envs/lm-compose/bin/python -m lm_compose model_name=mistralai/Mistral-7B-Instruct-v0.3 interventions=[compression] compression=wanda sparsity_ratio=0.65 tag=WANDA65% wandb=online
~/miniconda3/envs/lm-compose/bin/python 

In [7]:
for command in tqdm(run_commands, desc="Running Experiments"):
    print(command)
    # os.system(command)
    # raise exception if command fails

    code = os.system(command)
    if code != 0:
        raise Exception(f"Command failed with code: {code}")

Running Experiments:   0%|          | 0/61 [00:00<?, ?it/s]

~/miniconda3/envs/lm-compose/bin/python -m lm_compose model_name=mistralai/Mistral-7B-Instruct-v0.3 interventions=[compression] compression=wanda sparsity_ratio=0.25 tag=WANDA25% wandb=online
CUDA extension not installed.


wandb: Currently logged in as: kyledevinobrien1 (dri-ice). Use `wandb login --relogin` to force relogin
wandb: wandb version 0.18.1 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.17.0
wandb: Run data is saved locally in /sfs/weka/scratch/hua2bv/unlearning/composable-interventions/wandb/run-20240918_181749-o9ecfun7
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run WANDA25%-20240918_181748
wandb: ⭐️ View project at https://wandb.ai/dri-ice/Composable_Interventions
wandb: 🚀 View run at https://wandb.ai/dri-ice/Composable_Interventions/runs/o9ecfun7


[2024-09-18 18:17:53,075][accelerate.utils.modeling][INFO] - We will use 90% of the memory on device 0 for storing the model, and 10% for the buffer to avoid OOM. You can set `max_memory` in to a higher value to use more memory (at your own risk).


Loading checkpoint shards: 100%|██████████| 3/3 [00:02<00:00,  1.01it/s]
2024-09-18 18:17:56,225 - lm_compose.easyeditor.editors.editor - INFO - Instantiating model


[2024-09-18 18:17:56,225][lm_compose.easyeditor.editors.editor][INFO] - Instantiating model
<class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>
[2024-09-18 18:17:56,549][lm_compose.easyeditor.editors.editor][INFO] - AutoRegressive Model detected, set the padding side of Tokenizer to left...


2024-09-18 18:17:56,549 - lm_compose.easyeditor.editors.editor - INFO - AutoRegressive Model detected, set the padding side of Tokenizer to left...


prune
############# Begin intervention: compression #############
prune
use device  cuda:0
pruning starts
loading calibdation data
Using c4
