In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "" #"1"

# set directory
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [2]:
from pathlib import Path
from itertools import product
import json
import copy

import numpy as np

from losses.loss_types import LossType
from utils.exp_utils import get_executor, get_executor_local, run_experiments

In [3]:
SCRIPT_NAME = "measurement_tampering.py"
EXP_DIR = "output/mtd_hparam_sweep"
n_trials = 16

In [4]:
# diamonds hparams: 
## full model: 2e-5 
## probe: 2e-4 
## ground truth probes: 2e-4 
## num epochs: 5 
## num warmup stesp: 64 

# generated stories hparams: 
## full model: 1e-6
## probes: 5e-4 
## ground truth probes: 5e-3 
## num epochs: 4 
## num warmup stesp: 8

In [5]:
N_EPOCHS = 2
import nevergrad as ng

method_configs = {
    "TopK 0.1": {"loss_type": LossType.TOPK},
    "DivDis": {"loss_type": LossType.DIVDIS},
}

# TODO: set frac warmup steps appropriately for each dataset
dimaond_env_configs = {
    f"diamonds-seed{i}": {
        "dataset": f"diamonds-seed{i}", 
        "model": f"codegen-350M-mono-measurement_pred-diamonds-seed{i}", 
        "bootstrap_eval": False, 
        "lr": 2e-5,
        "frac_warmup": 0.10, # 64 step
    } 
    for i in range(8)
}
dataset_configs = {
   **dimaond_env_configs,
   "generated-stories": {
        "dataset": "generated_stories", 
        "model": "pythia-1_4b-deduped-measurement_pred-generated_stories", 
        "bootstrap_eval": True, 
        "micro_batch_size": 2, 
        "max_length": 1536, 
        "feature_dim": 2048, 
        "frac_warmup": 0.15, # 8 steps
   }
}

dataset_method_hparam_ranges = {
    **{f"diamonds-seed{i}": {
        "TopK 0.1": {"aux_weight": ng.p.Log(lower=1e0, upper=1e1)},
        "DivDis": {"aux_weight": ng.p.Log(lower=1e0, upper=1e1)},
        } for i in range(8)
    },
    "generated-stories": {
        "TopK 0.1": {"lr": ng.p.Log(lower=1e-6, upper=2e-5), "aux_weight": ng.p.Log(lower=1e0, upper=1e1)},
        "DivDis": {"lr": ng.p.Log(lower=1e-6, upper=2e-5), "aux_weight": ng.p.Log(lower=1e0, upper=1e1)},
    }
}


def sample_hparams(dataset_name, method_name):
    method_ranges = dataset_method_hparam_ranges[dataset_name][method_name]
    instrum = ng.p.Instrumentation(**method_ranges)
    sampled = instrum.sample()
    return sampled.kwargs


In [6]:
configs = {(ds_name, method_name): {**ds_config, **method_config}
           for (ds_name, ds_config), (method_name, method_config) in product(dataset_configs.items(), method_configs.items())}


def get_conf_exp_dir(ds_name, method_name, i):
    return Path(EXP_DIR, f"{ds_name}_{method_name}/{i}")

sampled_configs = []
for (ds_name, method_name), conf in configs.items():
    n_trials_ds = n_trials if not ds_name.startswith("diamonds") else n_trials // 8
    for i in range(n_trials_ds):
        sample_conf = copy.deepcopy(conf)
        samples_hparams = sample_hparams(ds_name, method_name)
        sample_conf.update(samples_hparams)
        seed = np.random.randint(0, 1000000)
        sample_conf["seed"] = seed
        sample_conf["num_epochs"] = N_EPOCHS
        sample_conf["exp_dir"] = get_conf_exp_dir(ds_name, method_name, i)
        sampled_configs.append(sample_conf)

In [7]:
len(sampled_configs)

64

# Run Experiments

In [9]:
# local_executor = get_executor_local(f"output/{datetime.now().strftime('%Y-%m-%d')}")
# jobs = run_experiments(local_executor, sampled_configs[:1], SCRIPT_NAME)

In [10]:
non_80gb_nodes = ["ddpg", "dqn", "gail", "gan","ppo", "vae"]
slurm_exclude = ",".join([f"{node}.ist.berkeley.edu" for node in non_80gb_nodes])

executor = get_executor(EXP_DIR, mem_gb=32, slurm_exclude=slurm_exclude)
jobs = run_experiments(executor, sampled_configs, SCRIPT_NAME)

# Plot Results

In [13]:
import yaml
import pandas as pd
all_results = []
not_found_configs = []
for (ds_name, method_name), config in configs.items():
    n_trials_ds = n_trials if not ds_name.startswith("diamonds") else n_trials // 8
    for i in range(n_trials_ds):
        exp_dir = get_conf_exp_dir(ds_name, method_name, i)
        metrics_path = Path(exp_dir, "metrics.json")
        config_path = Path(exp_dir, "config.yaml")
        if not metrics_path.exists():
            # print(f"ds_name: {ds_name}, method_name: {method_name}, i: {i}")
            not_found_configs.append((ds_name, method_name, i))
            continue
        with open(metrics_path, "r") as f:
            metrics = json.load(f)
        with open(config_path, "r") as f:
            config = yaml.safe_load(f)
        min_val_loss = min(metrics["val_loss"])
        all_results.append({
            "dataset": ds_name if not ds_name.startswith("diamonds") else "diamonds",
            "method_name": method_name,
            "aux_weight": config["aux_weight"],
            "lr": config["lr"],
            "loss": min_val_loss,
        })

df = pd.DataFrame(all_results, columns=["dataset", "method_name", "aux_weight", "lr", "loss"])

In [11]:
import pandas as pd
import plotly.graph_objects as go

def plot_aux_weight_vs_loss(df, dataset=None, method_name=None, use_log_scale=True):
    # Filter data based on conditions
    plot_df = df.copy()
    if dataset:
        plot_df = plot_df[plot_df['dataset'] == dataset]
    if method_name:
        plot_df = plot_df[plot_df['method_name'] == method_name]
    
    # Create figure
    fig = go.Figure()
    
    # Group by conditions that weren't filtered
    groupby_cols = []
    if not dataset:
        groupby_cols.append('dataset')
    if not method_name:
        groupby_cols.append('method_name')
    
    if groupby_cols:
        # Plot separate lines for each group
        for name, group in plot_df.groupby(groupby_cols):
            name_str = '_'.join([str(n) for n in name]) if isinstance(name, tuple) else str(name)
            # Sort by aux_weight before plotting
            group = group.sort_values('aux_weight')
            fig.add_trace(go.Scatter(
                x=group['aux_weight'],
                y=group['loss'],
                name=name_str,
                mode='markers+lines'
            ))
    else:
        # Single line plot
        # Sort by aux_weight before plotting
        print(len(plot_df))
        plot_df = plot_df.sort_values('aux_weight')
        print(len(plot_df))
        fig.add_trace(go.Scatter(
            x=plot_df['aux_weight'],
            y=plot_df['loss'],
            mode='markers+lines'
        ))
    
    # Update layout
    title = 'Aux Weight vs Loss'
    if dataset:
        title += f' for {dataset}'
    if method_name:
        title += f' with {method_name}'
    
    fig.update_layout(
        title=title,
        xaxis_title='Aux Weight',
        yaxis_title='Loss',
        width=800,
        height=500
    )
    
    if use_log_scale:
        fig.update_xaxes(type='log')
        fig.update_yaxes(type='log')
    
    return fig

method_name = "TopK 0.1"
fig = plot_aux_weight_vs_loss(df, method_name=method_name, dataset="generated-stories")
fig.show()

16
16


In [19]:
def plot_param_vs_loss(df, x_param='aux_weight', dataset=None, method_name=None, use_log_scale=True):
    """
    Plot either aux_weight or learning rate against loss
    
    Args:
        df: DataFrame with the data
        x_param: Either 'aux_weight' or 'lr' to plot on x-axis
        dataset: Optional dataset to filter by
        method_name: Optional method to filter by
        use_log_scale: Whether to use log scale for axes
    """
    # Validate x_param
    if x_param not in ['aux_weight', 'lr']:
        raise ValueError("x_param must be either 'aux_weight' or 'lr'")
    
    # Filter data based on conditions
    plot_df = df.copy()
    if dataset:
        plot_df = plot_df[plot_df['dataset'] == dataset]
    if method_name:
        plot_df = plot_df[plot_df['method_name'] == method_name]
    
    # Create figure
    fig = go.Figure()
    
    # Group by conditions that weren't filtered
    groupby_cols = []
    if not dataset:
        groupby_cols.append('dataset')
    if not method_name:
        groupby_cols.append('method_name')
    
    if groupby_cols:
        # Plot separate lines for each group
        for name, group in plot_df.groupby(groupby_cols):
            name_str = '_'.join([str(n) for n in name]) if isinstance(name, tuple) else str(name)
            # Sort by x_param before plotting
            group = group.sort_values(x_param)
            fig.add_trace(go.Scatter(
                x=group[x_param],
                y=group['loss'],
                name=name_str,
                mode='markers+lines'
            ))
    else:
        # Single line plot
        plot_df = plot_df.sort_values(x_param)
        fig.add_trace(go.Scatter(
            x=plot_df[x_param],
            y=plot_df['loss'],
            mode='markers+lines'
        ))
    
    # Update layout
    param_name = 'Aux Weight' if x_param == 'aux_weight' else 'Learning Rate'
    title = f'{param_name} vs Loss'
    if dataset:
        title += f' for {dataset}'
    if method_name:
        title += f' with {method_name}'
    
    fig.update_layout(
        title=title,
        xaxis_title=param_name,
        yaxis_title='Loss',
        width=800,
        height=500
    )
    
    if use_log_scale:
        fig.update_xaxes(type='log')
        fig.update_yaxes(type='log')
    
    return fig

# Example usage:
method_name = "TopK 0.1"
dataset = "generated-stories"

# Plot aux_weight vs loss
fig1 = plot_param_vs_loss(df, x_param='aux_weight', method_name=method_name, dataset=dataset)
fig1.show()

# Plot learning rate vs loss
fig2 = plot_param_vs_loss(df, x_param='lr', method_name=method_name, dataset=dataset)
fig2.show()

In [22]:
# Example usage:
method_name = "DivDis"
dataset = "generated-stories"

# Plot aux_weight vs loss
fig1 = plot_param_vs_loss(df, x_param='aux_weight', method_name=method_name, dataset=dataset)
fig1.show()

# Plot learning rate vs loss
fig2 = plot_param_vs_loss(df, x_param='lr', method_name=method_name, dataset=dataset)
fig2.show()

In [23]:
# Example usage:
method_name = "DivDis"
dataset = "diamonds"

# Plot aux_weight vs loss
fig1 = plot_param_vs_loss(df, x_param='aux_weight', method_name=method_name, dataset=dataset)
fig1.show()

# Plot learning rate vs loss
fig2 = plot_param_vs_loss(df, x_param='lr', method_name=method_name, dataset=dataset)
fig2.show()

upshot: 
Topk 0.1: 
diamonds: 3.0
generated stories: 2.5 (2-5), lr=3-6
DivDis
diamonds: 3.5
generated stories: 2.5, lr=3-6