In [1]:
import os 
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [22]:
from pathlib import Path
from itertools import product
import json
import yaml

import numpy as np
from omegaconf import OmegaConf

from diverse_gen.losses.loss_types import LossType
from diverse_gen.utils.exp_utils import get_conf_dir

In [18]:
OUT_DIR = "output/cc_mix_rate/main"

In [14]:
SEEDS = [1, 2, 3]
MIX_RATES = [0.1, 0.25, 0.5, 0.75, 1.0]
DATASETS = [
    "toy_grid", 
    "fmnist_mnist", 
    "cifar_mnist", 
    "waterbirds", 
    "celebA-0", 
    "multi-nli"
]
METHODS = ["DivDis"]

configs_dir = Path("configs")
methods = OmegaConf.load(configs_dir / "methods.yaml")
datasets = OmegaConf.load(configs_dir / "datasets.yaml")
method_ds = OmegaConf.load(configs_dir / "method_ds.yaml")

# filter configs 
datasets = {k: v for k, v in datasets.items() if k in DATASETS}
methods = {k: v for k, v in methods.items() if k in METHODS}

# topk configs with no schedule
no_sched_topk_configs = {}
for method_name, method_conf in methods.items():
    if method_conf["loss_type"] == LossType.TOPK.name:
        conf = method_conf.copy()
        conf["mix_rate_schedule"] = None
        no_sched_topk_configs[method_name+"_No_Sched"] = conf
methods.update(no_sched_topk_configs)

# generate exp configs
configs = {
    (ds_name, method_name, mix_rate, seed): {**ds, **method, "mix_rate": mix_rate, "seed": seed} 
    for (ds_name, ds), (method_name, method), mix_rate, seed in 
    product(datasets.items(), methods.items(), MIX_RATES, SEEDS)
}

In [30]:
import yaml
import pandas as pd
all_results = []
not_found_configs = []
for (ds_name, method_name, mix_rate, seed), config in configs.items():
        exp_dir = get_conf_dir((ds_name, method_name, mix_rate, seed), OUT_DIR)
        metrics_path = Path(exp_dir, "metrics.json")
        config_path = Path(exp_dir, "config.yaml")
        if not metrics_path.exists():
            # print(f"ds_name: {ds_name}, method_name: {method_name}, i: {i}")
            not_found_configs.append((ds_name, method_name, i))
            continue
        with open(metrics_path, "r") as f:
            metrics = json.load(f)
        with open(config_path, "r") as f:
            config = yaml.safe_load(f)
        # TODO: instead, condition on source acc >= some value, then take min val loss given that mask (and do nan if no valid source acc >= that value)
        # source_acc_mask = np.array(metrics["val_source_acc_0"]) >= (ERM_MEAN_SOURCE_ACCS[ds_name] - ERM_STD_SOURCE_ACCS[ds_name])
        # # if not np.any(source_acc_mask):
        # #     val_loss = np.nan
        # #     source_acc = np.nan 
        # # else:
        # min_val_loss_epoch = np.argmin(np.array(metrics["val_loss"]) * (-np.inf * (1 - source_acc_mask)))
        min_val_loss_epoch = np.argmin(metrics["val_loss"])
        val_loss = metrics["val_loss"][min_val_loss_epoch]
        source_acc = metrics["val_source_acc_0"][min_val_loss_epoch]
        all_results.append({
            "dataset": ds_name,
            "method_name": method_name,
            "mix_rate": mix_rate, 
            "aux_weight": config["aux_weight"],
            "lr": config["lr"],
            "val_loss": val_loss,
            "source_acc": source_acc,
        })

df = pd.DataFrame(all_results, columns=["dataset", "method_name", "mix_rate", "aux_weight", "lr", "val_loss", "source_acc"])

In [33]:
import pandas as pd
import plotly.graph_objects as go
from typing import Literal
def plot_aux_weight_vs_metric(df, 
    dataset=None, 
    method_name=None, 
    metric: Literal["val_loss", "source_acc"]= "val_loss",
    use_x_log_scale: bool=True, 
    use_y_log_scale: bool=True, 
):
    # Filter data based on conditions
    plot_df = df.copy()
    if dataset:
        plot_df = plot_df[plot_df['dataset'] == dataset]
    if method_name:
        plot_df = plot_df[plot_df['method_name'] == method_name]
    
    # Create figure
    fig = go.Figure()
    
    # Group by conditions that weren't filtered
    groupby_cols = []
    if not dataset:
        groupby_cols.append('dataset')
    if not method_name:
        groupby_cols.append('method_name')
    
    if groupby_cols:
        # Plot separate lines for each group
        for name, group in plot_df.groupby(groupby_cols):
            name_str = '_'.join([str(n) for n in name]) if isinstance(name, tuple) else str(name)
            # Sort by aux_weight before plotting
            group = group.sort_values('mix_rate')
            fig.add_trace(go.Scatter(
                x=group['mix_rate'],
                y=group[metric],
                name=name_str,
                mode='markers+lines'
            ))
    else:
        # Single line plot
        # Sort by aux_weight before plotting
        print(len(plot_df))
        plot_df = plot_df.sort_values('mix_rate')
        print(len(plot_df))
        fig.add_trace(go.Scatter(
            x=plot_df['mix_rate'],
            y=plot_df[metric],
            mode='markers+lines'
        ))
    
    # Update layout
    title = f'Mix Rate vs {metric}'
    if dataset:
        title += f' for {dataset}'
    if method_name:
        title += f' with {method_name}'

    # Get y-axis range from data
    y_min = plot_df[metric].min()
    y_max = plot_df[metric].max()
    # Add small padding (5%) to the range
    y_padding = (y_max - y_min) * 0.05
    
    fig.update_layout(
        title=title,
        xaxis_title='Mix Rate',
        yaxis_title=metric,
        width=800,
        height=500, 
        yaxis=dict(
            range=[y_min - y_padding, y_max + y_padding]
        )
    )
    
    if use_x_log_scale:
        fig.update_xaxes(type='log')
    if use_y_log_scale:
        fig.update_yaxes(type='log')
    
    return fig


In [35]:
method_name = "DivDis"
metric = "source_acc"
use_y_log_scale = False 
use_x_log_scale = True

fig = plot_aux_weight_vs_metric(df, method_name=method_name, dataset="toy_grid", metric=metric, use_x_log_scale=use_x_log_scale, use_y_log_scale=use_y_log_scale)
fig.show()

fig = plot_aux_weight_vs_metric(df, method_name=method_name, dataset="cifar_mnist", metric=metric, use_x_log_scale=use_x_log_scale, use_y_log_scale=use_y_log_scale)
fig.show()

fig = plot_aux_weight_vs_metric(df, method_name=method_name, dataset="fmnist_mnist", metric=metric, use_x_log_scale=use_x_log_scale, use_y_log_scale=use_y_log_scale)
fig.show()

fig = plot_aux_weight_vs_metric(df, method_name=method_name, dataset="waterbirds", metric=metric, use_x_log_scale=use_x_log_scale, use_y_log_scale=use_y_log_scale)
fig.show()

fig = plot_aux_weight_vs_metric(df, method_name=method_name, dataset="celebA-0", metric=metric, use_x_log_scale=use_x_log_scale, use_y_log_scale=use_y_log_scale)
fig.show()

fig = plot_aux_weight_vs_metric(df, method_name=method_name, dataset="multi-nli", metric=metric, use_x_log_scale=use_x_log_scale, use_y_log_scale=use_y_log_scale)
fig.show()


15
15


15
15


15
15


15
15


15
15


15
15
