In [1]:
# set cuda visible devices
def is_notebook() -> bool:
    try:
        shell = get_ipython().__class__.__name__
        if shell == 'ZMQInteractiveShell':
            return True   # Jupyter notebook or qtconsole
        elif shell == 'TerminalInteractiveShell':
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False      # Probably standard Python interpreter

import os
if is_notebook():
    os.environ["CUDA_VISIBLE_DEVICES"] = "" #"1"
    # os.environ['CUDA_LAUNCH_BLOCKING']="1"
    # os.environ['TORCH_USE_CUDA_DSA'] = "1"

import matplotlib 
if not is_notebook():
    matplotlib.use('Agg')

# set directory
os.chdir("/nas/ucb/oliveradk/diverse-gen/")

In [2]:
# for now incomplete spurious correlation: 
#   waterbirds

# other notebooks:
# multi-class classification: 
#   multi-nli-cc

# known group labels
# waterbirds (normal)
# mulit-nli cc 


In [3]:
from tqdm import tqdm
from pathlib import Path
from datetime import datetime
from itertools import product
import json
import copy

import numpy as np

from losses.loss_types import LossType
from utils.exp_utils import get_executor, run_experiments

In [4]:
SCRIPT_NAME = "spur_corr_exp.py"
EXP_DIR = "output/real_data_aux_weight_sweep"
n_trials = 32

In [5]:
method_configs = {
    "DivDis": {"loss_type": LossType.DIVDIS},
    "TopK 0.1": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.1},
    "TopK 0.5": {"loss_type": LossType.TOPK, "mix_rate_lower_bound": 0.5},
    "DBAT": {"loss_type": LossType.DBAT, "shared_backbone": False, "freeze_heads": True, "binary": True, "batch_size": 16, "target_batch_size": 32},
}

dataset_configs = {
    "waterbirds": {"dataset": "waterbirds", "model": "Resnet50", "epochs": 5, "source_cc": False},
}

method_ranges = {
    "DivDis": {"aux_weight": [0, 1]},
    "TopK 0.1": {"aux_weight": [0, 1]},
    "TopK 0.5": {"aux_weight": [-1, 1]},
    "DBAT": {"aux_weight": [-1, 1]},
}

configs = {
    (ds_name, method_name): {**ds_config, **method_config} 
    for (ds_name, ds_config), (method_name, method_config) in product(dataset_configs.items(), method_configs.items())
}

# TODO: should have done by seed - by randomly sampled aux weight is really silly
def get_conf_exp_dir(ds_name, method_name, aux_weight):
    return Path(EXP_DIR, f"{ds_name}_{method_name}/{aux_weight}")



sampled_configs = []
for (ds_name, method_name), conf in configs.items():
    for i in range(n_trials):
        sample_conf = copy.deepcopy(conf)
        sample_range = method_ranges[method_name]["aux_weight"]
        aux_weight = 10**(np.random.uniform(sample_range[0], sample_range[1]))
        seed = np.random.randint(0, 1000000)
        sample_conf["aux_weight"] = aux_weight
        sample_conf["seed"] = seed
        sample_conf["exp_dir"] = get_conf_exp_dir(ds_name, method_name, aux_weight)
        sample_conf["plot_activations"] = False
        sampled_configs.append(sample_conf)

# Run Experiments

In [12]:
executor = get_executor(EXP_DIR, mem_gb=16)
jobs = run_experiments(executor, sampled_configs, SCRIPT_NAME)

# Plot Results

In [9]:
import pandas as pd
import yaml

def get_method_name(conf):
    if conf["loss_type"] == LossType.DIVDIS:
        return "DivDis"
    elif conf["loss_type"] == LossType.TOPK:
        return f"TopK {conf['mix_rate_lower_bound']}"
    elif conf["loss_type"] == LossType.DBAT:
        return "DBAT"
    else:
        raise ValueError(f"Unknown loss type: {conf['loss_type']}")


all_results = []
for (ds_name, method_name), conf in configs.items():
    dataset = conf["dataset"]
    parent_dir = Path(EXP_DIR, f"{ds_name}_{method_name}")
    # iterate through all directories in 
    for aux_weight_dir in parent_dir.iterdir():
        if not aux_weight_dir.is_dir():
            continue
        # get aux weight
        exp_config_path = Path(aux_weight_dir, "config.yaml")
        with open(exp_config_path, "r") as f:
            exp_config = yaml.safe_load(f)
        aux_weight = exp_config["aux_weight"]

        # get min val loss
        metrics_path = Path(aux_weight_dir, "metrics.json")
        with open(metrics_path, "r") as f:
            metrics = json.load(f)
        min_val_loss = min(metrics["val_loss"])
        all_results.append({
            "dataset": dataset,
            "method_name": method_name,
            "aux_weight": aux_weight,
            "loss": min_val_loss
        })

# create dataframe b
df = pd.DataFrame(all_results, columns=["dataset", "method_name", "aux_weight", "loss"])

In [10]:
df

Unnamed: 0,dataset,method_name,aux_weight,loss
0,waterbirds,DivDis,3.865080,0.273404
1,waterbirds,DivDis,6.844978,0.277668
2,waterbirds,DivDis,7.508777,0.264617
3,waterbirds,DivDis,1.742226,0.270747
4,waterbirds,DivDis,2.009554,0.265594
...,...,...,...,...
123,waterbirds,DBAT,8.831156,1.506094
124,waterbirds,DBAT,0.308765,1.485741
125,waterbirds,DBAT,0.276387,1.471796
126,waterbirds,DBAT,0.487482,1.481482


In [18]:
import pandas as pd
import plotly.graph_objects as go

def plot_aux_weight_vs_loss(df, dataset=None, method_name=None, use_log_scale=True):
    # Filter data based on conditions
    plot_df = df.copy()
    if dataset:
        plot_df = plot_df[plot_df['dataset'] == dataset]
    if method_name:
        plot_df = plot_df[plot_df['method_name'] == method_name]
    
    # Create figure
    fig = go.Figure()
    
    # Group by conditions that weren't filtered
    groupby_cols = []
    if not dataset:
        groupby_cols.append('dataset')
    if not method_name:
        groupby_cols.append('method_name')
    
    if groupby_cols:
        # Plot separate lines for each group
        for name, group in plot_df.groupby(groupby_cols):
            name_str = '_'.join([str(n) for n in name]) if isinstance(name, tuple) else str(name)
            # Sort by aux_weight before plotting
            group = group.sort_values('aux_weight')
            fig.add_trace(go.Scatter(
                x=group['aux_weight'],
                y=group['loss'],
                name=name_str,
                mode='markers+lines'
            ))
    else:
        # Single line plot
        # Sort by aux_weight before plotting
        print(len(plot_df))
        plot_df = plot_df.sort_values('aux_weight')
        print(len(plot_df))
        fig.add_trace(go.Scatter(
            x=plot_df['aux_weight'],
            y=plot_df['loss'],
            mode='markers+lines'
        ))
    
    # Update layout
    title = 'Aux Weight vs Loss'
    if dataset:
        title += f' for {dataset}'
    if method_name:
        title += f' with {method_name}'
    
    fig.update_layout(
        title=title,
        xaxis_title='Aux Weight',
        yaxis_title='Loss',
        width=800,
        height=500
    )
    
    if use_log_scale:
        fig.update_xaxes(type='log')
        fig.update_yaxes(type='log')
    
    return fig

# Example usage:
# Plot all data

# TopK 0.1: 1 
# TopK 0.5: 3 
# DivDis: 3
# DBAT: 0.5

# Plot for specific dataset
method_name = "TopK 0.5"
fig = plot_aux_weight_vs_loss(df, method_name=method_name)
fig.show()