In [None]:
import yaml
import numpy as np
import os
import sys

import dataclasses
import yaml

from pathlib import Path

import subprocess

import coref.run_manager as rm

from coref import COREF_ROOT

from coref.utils import slugify
from importlib import reload

import torch

import json

In [None]:
expts_root = Path(COREF_ROOT) / 'experiments'
outputs_root = '/data/fjiahai/prop-probes-test' # choose a directory with > 20 gb space

In [None]:
def get_output_path(config_path, main):
    output_path = rm.get_run_dir(
        config_path=config_path,
        runs_root=outputs_root,
        experiments_root=expts_root,
    )
    cfg, meta_kwargs = rm.load_cfg(config_path)
    return cfg, output_path

def run_sbatch(config_path, num_devices, slurm_path):
    slurm_cmd = ['sbatch', f'--gres=gpu:{num_devices}', slurm_path]
    slurm_output = subprocess.run(slurm_cmd, env={**os.environ, 'CONFIG_FILE': config_path}, capture_output=True, check=True)
    return ' '.join(slurm_cmd), slurm_output.stdout, slurm_output.stderr

def get_last_output(cfg_path):
    parent_dir = Path(rm.get_run_dir_parent(cfg_path, outputs_root, expts_root))
    dirs = [d for d in os.listdir(parent_dir)  if os.path.isdir(parent_dir / d)]
    success_dir = [d for d in dirs if 'done.out' in os.listdir(parent_dir / d)]
    max_run = max(int(d) for d in dirs)
    max_success = max(int(d) for d in success_dir)
    if max_run != max_success:
        print(f'Warning: latest run {max_run} of {cfg_path} is not successful. Falling back to {max_success}')
    return parent_dir / str(max_success)
        
    

# Paper Subspace experiments

In [None]:
import scripts.run_hessians

In [None]:
base_cfg = dict(
    num_devices=4,
    is_hf=True,
    hessian_mode='point',
    name_width=1,
    attr_width=1,
    template="NameCountryTemplate",
    swap_dir=False,
)
model_cfgs = {
    'llama': dict(
        model="Llama-2-13b-chat-hf",
        prompt_type="llama_chat"
    ),
    'tulu': dict(
        model="tulu-2-13b",
        prompt_type="tulu_chat"
    )
}

In [None]:
def build_hessian_cfg(
    model,
    uniform_scale,
    interpolating_factor
):
    cfg = base_cfg.copy() # shallow copy
    cfg.update(model_cfgs[model])
    cfg['uniform_scale'] = uniform_scale
    cfg['interpolating_factor'] = interpolating_factor
    return scripts.run_hessians.Cfg(**cfg)

In [None]:
all_cfg_paths = []
for model in model_cfgs.keys():
    for uniform_scale in [False]:
        for interpolating_factor in [0.5]:
            test_path = expts_root / f'point_hessians/paper/{model}_scale_{uniform_scale}_interpolating_{interpolating_factor}.yaml'
            all_cfg_paths.append(str(test_path))
            build_hessian_cfg(
                model=model,
                uniform_scale=uniform_scale,
                interpolating_factor=interpolating_factor
            ).save(test_path, check=True, meta_kwargs={'_output_root': outputs_root})

In [None]:
all_cfg_paths

In [None]:
cmd_logs = []
for cfg_path in all_cfg_paths[-1:]:
    cfg, output_path = get_output_path(cfg_path, scripts.run_hessians.main)
    slurm_cmd, slurm_out, slurm_err = run_sbatch(
        config_path=cfg_path,
        num_devices=cfg['num_devices'],
        slurm_path='slurm/run_hessians.sh'
    )
    cmd_logs.append(f'{cfg_path}\t{output_path}\t{slurm_cmd}\t{slurm_out}\t{slurm_err}')
for cmd in cmd_logs:
    print(cmd)

In [None]:
import scripts.run_eval_form

In [None]:
all_cfg_paths = []

model_cfgs = {
    'llama': dict(
        model="Llama-2-13b-chat-hf",
        chat_style="llama_chat"
    ),
    'tulu': dict(
        model="tulu-2-13b",
        chat_style="tulu_chat"
    )
}
def build_eval_hessian_cfg(model, form_path):
    base_cfg = dict(
        num_devices=2,
        is_hf=True,
        form_path=form_path,
        form_type='hessian_1_1'
    )
    cfg = {**base_cfg, **model_cfgs[model]}
    return scripts.run_eval_form.Cfg(**cfg)

eval_cfgs = []
    
for model in model_cfgs.keys():
    for uniform_scale in [False, True]:
        for interpolating_factor in [0., 0.5, 1.]:
            hessian_cfg = expts_root / f'point_hessians/paper/{model}_scale_{uniform_scale}_interpolating_{interpolating_factor}.yaml'
            eval_hessian_cfg = expts_root / f'point_hessians/paper/eval_{model}_scale_{uniform_scale}_interpolating_{interpolating_factor}.yaml'
            try:
                hessian_dir = (get_last_output(hessian_cfg))
            except:
                print(f'Failed to get last output for {hessian_cfg}')
                hessian_dir = None
            if hessian_dir is not None:
                form_path = str(hessian_dir / 'hessian.pt')
                build_eval_hessian_cfg(model, form_path).save(eval_hessian_cfg, check=True, meta_kwargs={'_output_dir': str(hessian_dir / 'eval')})
                eval_cfgs.append(eval_hessian_cfg)

In [None]:
eval_cfgs

In [None]:
# random baseline

model_cfgs = {
    'llama': dict(
        model="Llama-2-13b-chat-hf",
        chat_style="llama_chat"
    ),
    'tulu': dict(
        model="tulu-2-13b",
        chat_style="tulu_chat"
    )
}
def build_random_eval_hessian_cfg(model):
    base_cfg = dict(
        num_devices=4,
        is_hf=True,
        form_path='',
        form_type='random'
    )
    cfg = {**base_cfg, **model_cfgs[model]}
    return scripts.run_eval_form.Cfg(**cfg)
random_cfgs = []
for model in ['llama', 'tulu']:
    cfg_path = expts_root / f'point_hessians/paper/random_{model}.yaml'
    build_random_eval_hessian_cfg(model).save(cfg_path, check=True, meta_kwargs={'_output_root': outputs_root})
    random_cfgs.append(cfg_path)


In [None]:
cmd_logs = []
for cfg_path in random_cfgs  + eval_cfgs:
    cfg, output_path = get_output_path(cfg_path, scripts.run_eval_form.main)
    slurm_cmd, slurm_out, slurm_err = run_sbatch(
        config_path=cfg_path,
        num_devices=cfg['num_devices'],
        slurm_path='slurm/run_eval_form.sh'
    )
    cmd_logs.append(f'{cfg_path}\t{output_path}\t{slurm_cmd}\t{slurm_out}\t{slurm_err}')
for cmd in cmd_logs:
    print(cmd)

In [None]:
# vector baseline
import coref.vector_subspace_baseline

model_cfgs = {
    'llama': dict(
        model="Llama-2-13b-chat-hf",
        chat_style="llama_chat"
    ),
    'tulu': dict(
        model="tulu-2-13b",
        chat_style="tulu_chat"
    )
}
def build_vector_baseline_cfg(model):
    base_cfg = dict(
        num_devices=2,
        is_hf=False,
    )
    cfg = {**base_cfg, **model_cfgs[model]}
    return coref.vector_subspace_baseline.Cfg(**cfg)
baseline_cfgs = []
for model in ['llama', 'tulu']:
    cfg_path = expts_root / f'point_hessians/paper/baseline_{model}.yaml'
    build_vector_baseline_cfg(model).save(cfg_path, check=True, meta_kwargs={'_output_root': outputs_root})
    baseline_cfgs.append(cfg_path)
baseline_cfgs

In [None]:
cmd_logs = []
for cfg_path in baseline_cfgs:
    cfg, output_path = get_output_path(cfg_path, coref.vector_subspace_baseline.main)
    slurm_cmd, slurm_out, slurm_err = run_sbatch(
        config_path=cfg_path,
        num_devices=cfg['num_devices'],
        slurm_path='slurm/run_vector_subspace_baseline.sh'
    )
    cmd_logs.append(f'{cfg_path}\t{output_path}\t{slurm_cmd}\t{slurm_out}\t{slurm_err}')
for cmd in cmd_logs:
    print(cmd)

## DAS

In [None]:
import coref.train_das
model_cfgs = {
    'llama': dict(
        model="Llama-2-13b-chat-hf",
        chat_style="llama_chat"
    ),
    'tulu': dict(
        model="tulu-2-13b",
        chat_style="tulu_chat"
    )
}

def build_das_cfg(model, d_subspace):
    base_cfg = dict(
        num_devices=4,
        is_hf=False,
        d_subspace=d_subspace
    )
    cfg = {**base_cfg, **model_cfgs[model]}
    return coref.train_das.Cfg(**cfg)

In [None]:
all_dims = [1, 3, 15, 50, 250, 1000, 5120]

In [None]:
das_cfgs = []
for model in ['llama', 'tulu']:
    for d_subspace in all_dims:
        cfg_path = expts_root / f'das/{model}_{d_subspace}.yaml'
        build_das_cfg(model, d_subspace).save(cfg_path, check=True, meta_kwargs={'_output_root': outputs_root})
        das_cfgs.append(cfg_path)
das_cfgs

In [None]:
cmd_logs = []
for cfg_path in das_cfgs:
    cfg, output_path = get_output_path(cfg_path, coref.train_das.main)
    slurm_cmd, slurm_out, slurm_err = run_sbatch(
        config_path=cfg_path,
        num_devices=cfg['num_devices'],
        slurm_path='slurm/run_das.sh'
    )
    cmd_logs.append(f'{cfg_path}\t{output_path}\t{slurm_cmd}\t{slurm_out}\t{slurm_err}')
for cmd in cmd_logs:
    print(cmd)

In [None]:
# eval DAS
import scripts.eval_das
model_cfgs = {
    'llama': dict(
        model="Llama-2-13b-chat-hf",
        chat_style="llama_chat"
    ),
    'tulu': dict(
        model="tulu-2-13b",
        chat_style="tulu_chat"
    )
}

def build_eval_das_cfg(model, d_subspace, das_path):
    base_cfg = dict(
        num_devices=2,
        is_hf=False,
        das_path=das_path
    )
    cfg = {**base_cfg, **model_cfgs[model]}
    return scripts.eval_das.Cfg(**cfg)

eval_das_cfgs = []
for model in ['llama', 'tulu']:
    for d_subspace in all_dims:
        das_cfg_path = expts_root / f'das/{model}_{d_subspace}.yaml'
        cfg_path = expts_root / f'das/eval_{model}_{d_subspace}.yaml'
        das_path = get_last_output(das_cfg_path)
        print(das_path)
        build_eval_das_cfg(model, d_subspace, str(das_path)).save(cfg_path, check=True, meta_kwargs={'_output_dir': str(os.path.join(das_path, 'eval'))})
        eval_das_cfgs.append(cfg_path)
eval_das_cfgs

In [None]:
cmd_logs = []
for cfg_path in eval_das_cfgs:
    cfg, output_path = get_output_path(cfg_path, scripts.eval_das.main)
    slurm_cmd, slurm_out, slurm_err = run_sbatch(
        config_path=cfg_path,
        num_devices=cfg['num_devices'],
        slurm_path='slurm/run_eval_das.sh'
    )
    cmd_logs.append(f'{cfg_path}\t{output_path}\t{slurm_cmd}\t{slurm_out}\t{slurm_err}')
for cmd in cmd_logs:
    print(cmd)

## Hessians

In [None]:
import scripts.run_hessians

In [None]:
hessian_paths = [
    expts_root / 'point_hessians' / 'llama_13b_chat_widths_1_1.yaml',
]

In [None]:
cmd_logs = []
for cfg_path in hessian_paths:
    cfg, output_path = get_output_path(cfg_path, scripts.run_hessians.main)
    slurm_cmd, slurm_out, slurm_err = run_sbatch(
        config_path=cfg_path,
        num_devices=cfg['num_devices'],
        slurm_path='slurm/run_hessians.sh'
    )
    cmd_logs.append(f'{cfg_path}\t{output_path}\t{slurm_cmd}\t{slurm_out}\t{slurm_err}')
for cmd in cmd_logs:
    print(cmd)

In [None]:
import scripts.run_eval_form

In [None]:
eval_form_cfgs = []
for cfg_path in hessian_paths:
    form_type = 'hessian_1_1'
    output_path = get_last_output(cfg_path)
    fn = str(expts_root / 'eval_form' / f'{form_type}_{slugify(rm.get_family_name(cfg_path, outputs_root, expts_root))}.yaml')
    scripts.run_eval_form.Cfg(
        model="Llama-2-13b-chat-hf",
        num_devices=4,
        is_hf=False,
        template='NameCountryTemplate',
        prompt_type='llama_chat',
        form_path=str(output_path / "hessian.pt"),
        form_type=form_type
    ).save(fn)
    eval_form_cfgs.append(fn)
    

In [None]:
eval_form_cfgs

In [None]:
cmd_logs = []
for cfg_path in eval_form_cfgs:
    cfg, output_path = get_output_path(cfg_path, scripts.run_eval_form.main)
    slurm_cmd, slurm_out, slurm_err = run_sbatch(
        config_path=cfg_path,
        num_devices=cfg['num_devices'],
        slurm_path='slurm/run_eval_form.sh'
    )
    
    cmd_logs.append(f'{cfg_path}\t{output_path}\t{slurm_cmd}\t{slurm_out}\t{slurm_err}')
for cmd in cmd_logs:
    print(cmd)