In [107]:
import collections
import functools
import glob
import pickle
import itertools
import json
import os
import random
import sys
import numpy as np
import tqdm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# If domainbed is a custom module, you might need to adjust its imports or setup
from domainbed import datasets
from domainbed import algorithms
from domainbed.lib import misc, reporting
from domainbed import model_selection
from domainbed.lib.query import Q
import warnings


In [177]:
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import itertools
import numpy as np

def get_test_records(records):
    """Given records with a common test env, get the test records (i.e. the
    records with *only* that single test env and no other test envs)"""
    return records.filter(lambda r: len(r['args']['test_envs']) == 1)

class SelectionMethod:
    """Abstract class whose subclasses implement strategies for model
    selection across hparams and timesteps."""

    def __init__(self):
        raise TypeError

    @classmethod
    def run_acc(self, run_records):
        """
        Given records from a run, return a {val_acc, test_acc} dict representing
        the best val-acc and corresponding test-acc for that run.
        """
        raise NotImplementedError

    @classmethod
    def hparams_accs(self, records):
        """
        Given all records from a single (dataset, algorithm, test env) pair,
        return a sorted list of (run_acc, records) tuples.
        """

        return (records.group('args.hparams_seed')
            .map(lambda _, run_records:
                (
                    self.run_acc(run_records),
                    run_records
                )
            ).filter(lambda x: x[0] is not None)
            .sorted(key=lambda x: x[0]['val_acc'])[::-1]
        )

    @classmethod
    def sweep_acc(self, records):
        """
        Given all records from a single (dataset, algorithm, test env) pair,
        return the mean test acc of the k runs with the top val accs.
        """
        _hparams_accs = self.hparams_accs(records)
        a = _hparams_accs
        # for i in range(len(a)):
        #     print(a[i][0]['val_acc'], a[i][0]['test_acc'])
        #     print(f"Hparams {(a[i][1][0]['hparams']['grad_alpha'], a[i][1][0]['hparams']['grad_alpha'],  a[i][1][0]['hparams']['penalty_anneal_iters'])}")
        print(f"Best val acc and test for {a[0][1][0]['args']['dataset']}, env {a[0][1][0]['args']['test_envs']}:", a[0][0]['val_acc'], a[0][0]['test_acc']) 
        print(f"Best hyperparameters for {a[0][1][0]['args']['dataset']}, env {a[0][1][0]['args']['test_envs']}:", a[0][1][0]['hparams'])

        if len(_hparams_accs):
            # breakpoint()
            return _hparams_accs[0][0]['test_acc']
        else:
            return None

class OracleSelectionMethod(SelectionMethod):
    """Like Selection method which picks argmax(test_out_acc) across all hparams
    and checkpoints, but instead of taking the argmax over all
    checkpoints, we pick the last checkpoint, i.e. no early stopping."""
    name = "test-domain validation set (oracle)"

    @classmethod
    def run_acc(self, run_records):
        run_records = run_records.filter(lambda r:
            len(r['args']['test_envs']) == 1)
        if not len(run_records):
            return None
        test_env = run_records[0]['args']['test_envs'][0]
        test_out_acc_key = 'env{}_out_acc'.format(test_env)
        test_in_acc_key = 'env{}_in_acc'.format(test_env)
        chosen_record = run_records.sorted(lambda r: r['step'])[-1]
        return {
            'val_acc':  chosen_record[test_out_acc_key],
            'test_acc': chosen_record[test_in_acc_key]
        }

class IIDAccuracySelectionMethod(SelectionMethod):
    """Picks argmax(mean(env_out_acc for env in train_envs))"""
    name = "training-domain validation set"

    @classmethod
    def _step_acc(self, record):
        """Given a single record, return a {val_acc, test_acc} dict."""
        test_env = record['args']['test_envs'][0]
        val_env_keys = []
        for i in itertools.count():
            if f'env{i}_out_acc' not in record:
                break
            if i != test_env:
                val_env_keys.append(f'env{i}_out_acc')
        test_in_acc_key = 'env{}_in_acc'.format(test_env)
        return {
            'val_acc': np.mean([record[key] for key in val_env_keys]),
            'test_acc': record[test_in_acc_key]
        }

    @classmethod
    def run_acc(self, run_records):
        test_records = get_test_records(run_records)
        if not len(test_records):
            return None

        index_of_max = test_records.map(self._step_acc).map(lambda x: x['val_acc'])._list.index(
            max(test_records.map(self._step_acc).map(lambda x: x['val_acc'])))
        full_record_with_hyperparams = test_records[index_of_max]
        # print(f"Hyperparameters for {full_record_with_hyperparams['args']['dataset']}, env {full_record_with_hyperparams['args']['test_envs']}:", full_record_with_hyperparams['hparams'])
        return test_records.map(self._step_acc).argmax('val_acc')

class LeaveOneOutSelectionMethod(SelectionMethod):
    """Picks (hparams, step) by leave-one-out cross validation."""
    name = "leave-one-domain-out cross-validation"

    @classmethod
    def _step_acc(self, records):
        """Return the {val_acc, test_acc} for a group of records corresponding
        to a single step."""
        test_records = get_test_records(records)
        if len(test_records) != 1:
            return None

        test_env = test_records[0]['args']['test_envs'][0]
        n_envs = 0
        for i in itertools.count():
            if f'env{i}_out_acc' not in records[0]:
                break
            n_envs += 1
        val_accs = np.zeros(n_envs) - 1
        for r in records.filter(lambda r: len(r['args']['test_envs']) == 2):
            val_env = (set(r['args']['test_envs']) - set([test_env])).pop()
            val_accs[val_env] = r['env{}_in_acc'.format(val_env)]
        val_accs = list(val_accs[:test_env]) + list(val_accs[test_env+1:])
        if any([v==-1 for v in val_accs]):
            return None
        val_acc = np.sum(val_accs) / (n_envs-1)
        return {
            'val_acc': val_acc,
            'test_acc': test_records[0]['env{}_in_acc'.format(test_env)]
        }

    @classmethod
    def run_acc(self, records):
        step_accs = records.group('step').map(lambda step, step_records:
            self._step_acc(step_records)
        ).filter_not_none()
        if len(step_accs):
            return step_accs.argmax('val_acc')
        else:
            return None


In [178]:
def count_subdirectories_and_empty_ones(directory):
    
    num_subdirectories = 0
    num_empty_subdirectories = 0

    # Walk through all directories in the given directory
    for root, dirs, files in os.walk(directory):
        # Iterate over each directory in the current root
        for d in dirs:
            num_subdirectories += 1
            subdirectory_path = os.path.join(root, d)
            # Check if the directory is empty
            if not os.listdir(subdirectory_path):
                num_empty_subdirectories += 1

    return num_subdirectories, num_empty_subdirectories


# Specify the directory you want to inspect
directory_path = 'results_resnet_new'

# Get the count of subdirectories and empty subdirectories
subdirectories, empty_subdirectories = count_subdirectories_and_empty_ones(directory_path)
print(f"Total subdirectories: {subdirectories}")
print(f"Empty subdirectories: {empty_subdirectories}")

Total subdirectories: 50
Empty subdirectories: 0


In [179]:
def format_mean(data, latex):
    """Given a list of datapoints, return a string describing their mean and
    standard error"""
    if len(data) == 0:
        return None, None, "X"
    mean = 100 * np.mean(list(data))
    err = 100 * np.std(list(data) / np.sqrt(len(data)))
    if latex:
        return mean, err, "{:.1f} $\\pm$ {:.1f}".format(mean, err)
    else:
        return mean, err, "{:.1f} +/- {:.1f}".format(mean, err)

def print_table(table, header_text, row_labels, col_labels, colwidth=10,
    latex=True):
    """Pretty-print a 2D array of data, optionally with row/col labels"""
    print("")

    if latex:
        num_cols = len(table[0])
        print("\\begin{center}")
        print("\\adjustbox{max width=\\textwidth}{%")
        print("\\begin{tabular}{l" + "c" * num_cols + "}")
        print("\\toprule")
    else:
        print("--------", header_text)

    for row, label in zip(table, row_labels):
        row.insert(0, label)

    if latex:
        col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}"
            for col_label in col_labels]
    table.insert(0, col_labels)

    for r, row in enumerate(table):
        misc.print_row(row, colwidth=colwidth, latex=latex)
        if latex and r == 0:
            print("\\midrule")
    if latex:
        print("\\bottomrule")
        print("\\end{tabular}}")
        print("\\end{center}")

def print_results_tables(records, selection_method, latex):
    """Given all records, print a results table for each dataset."""
    grouped_records = reporting.get_grouped_records(records).map(lambda group:
        { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) }
    ).filter(lambda g: g["sweep_acc"] is not None)

    # read algorithm names and sort (predefined order)
    alg_names = Q(records).select("args.algorithm").unique()
    alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] +
        [n for n in alg_names if n not in algorithms.ALGORITHMS])

    # read dataset names and sort (lexicographic order)
    dataset_names = Q(records).select("args.dataset").unique().sorted()
    dataset_names = [d for d in datasets.DATASETS if d in dataset_names]

    for dataset in dataset_names:
        if latex:
            print()
            print("\\subsubsection{{{}}}".format(dataset))
        test_envs = range(datasets.num_environments(dataset))
        # breakpoint()
        table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names]
        for i, algorithm in enumerate(alg_names):
            means = []
            for j, test_env in enumerate(test_envs):
                trial_accs = (grouped_records
                    .filter_equals(
                        "dataset, algorithm, test_env",
                        (dataset, algorithm, test_env)
                    ).select("sweep_acc"))
                mean, err, table[i][j] = format_mean(trial_accs, latex)
                means.append(mean)
            if None in means:
                table[i][-1] = "X"
            else:
                table[i][-1] = "{:.1f}".format(sum(means) / len(means))

        col_labels = [
            "Algorithm",
            *datasets.get_dataset_class(dataset).ENVIRONMENTS,
            "Avg"
        ]
        header_text = (f"Dataset: {dataset}, "
            f"model selection method: {selection_method.name}")
        print_table(table, header_text, alg_names, list(col_labels),
            colwidth=20, latex=latex)

    # Print an "averages" table
    if latex:
        print()
        print("\\subsubsection{Averages}")

    table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names]
    for i, algorithm in enumerate(alg_names):
        means = []
        for j, dataset in enumerate(dataset_names):
            trial_averages = (grouped_records
                .filter_equals("algorithm, dataset", (algorithm, dataset))
                .group("trial_seed")
                .map(lambda trial_seed, group:
                    group.select("sweep_acc").mean()
                )
            )
            mean, err, table[i][j] = format_mean(trial_averages, latex)
            means.append(mean)
        if None in means:
            table[i][-1] = "X"
        else:
            table[i][-1] = "{:.1f}".format(sum(means) / len(means))

    col_labels = ["Algorithm", *dataset_names, "Avg"]
    header_text = f"Averages, model selection method: {selection_method.name}"
    print_table(table, header_text, alg_names, col_labels, colwidth=25,
        latex=latex)

In [180]:
input_dir = "./results_resnet_new"
latex = False
results_file = "results.tex" if latex else "results.txt"

In [181]:
records = reporting.load_records(input_dir)
# selection_methods = [
#     model_selection.IIDAccuracySelectionMethod,
#     model_selection.OracleSelectionMethod
# ]
selection_methods = [
    IIDAccuracySelectionMethod,
    OracleSelectionMethod,
]

for selection_method in selection_methods:
    print_results_tables(records, selection_method, latex)


                                                                                

Best val acc and test for VLCS, env [2]: 0.8760349497248341 0.715917745620716
Best hyperparameters for VLCS, env [2]: {'batch_size': 39, 'class_balanced': False, 'data_augmentation': True, 'grad_alpha': 44.41840541006661, 'hess_beta': 18.552574037335336, 'lr': 2.7028930742148706e-05, 'model_type': 'ResNet', 'nonlinear_classifier': False, 'penalty_anneal_iters': 4072, 'resnet18': False, 'resnet_dropout': 0.5, 'weight_decay': 0.00044832883881609976}
Best val acc and test for VLCS, env [0]: 0.8122642250904618 0.9840989399293286
Best hyperparameters for VLCS, env [0]: {'batch_size': 39, 'class_balanced': False, 'data_augmentation': True, 'grad_alpha': 44.41840541006661, 'hess_beta': 18.552574037335336, 'lr': 2.7028930742148706e-05, 'model_type': 'ResNet', 'nonlinear_classifier': False, 'penalty_anneal_iters': 4072, 'resnet18': False, 'resnet_dropout': 0.5, 'weight_decay': 0.00044832883881609976}
Best val acc and test for RotatedMNIST, env [0]: 0.9908272610372911 0.9488965073923291
Best hyp



In [70]:
records[0]

{'args': {'algorithm': 'HessianAlignment',
  'checkpoint_freq': None,
  'data_dir': './domainbed/data/',
  'dataset': 'VLCS',
  'device': 0,
  'holdout_fraction': 0.2,
  'hparams': '{"model_type":"ResNet"}',
  'hparams_seed': 1,
  'output_dir': './domainbed/results_resnet_new/be7d8e28947326fedd849898854fc595',
  'save_model_every_checkpoint': False,
  'seed': 728992854,
  'skip_model_save': False,
  'steps': None,
  'task': 'domain_generalization',
  'test_envs': [2],
  'trial_seed': 0,
  'uda_holdout_fraction': 0},
 'env0_in_acc': 0.6554770318021201,
 'env0_out_acc': 0.6784452296819788,
 'env1_in_acc': 0.49317647058823527,
 'env1_out_acc': 0.4858757062146893,
 'env2_in_acc': 0.49124143183549124,
 'env2_out_acc': 0.4817073170731707,
 'env3_in_acc': 0.4683450573861533,
 'env3_out_acc': 0.4725925925925926,
 'epoch': 0.0,
 'erm_loss': 1.5806922912597656,
 'grad_pen': 0.0,
 'hess_pen': 0.0,
 'hparams': {'batch_size': 39,
  'class_balanced': False,
  'data_augmentation': True,
  'grad_alpha

In [84]:
grouped_records2 = reporting.get_grouped_records(records).map(lambda group:
    { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) }
).filter(lambda g: g["sweep_acc"] is not None)

In [85]:
len(grouped_records2)

10

In [86]:
grouped_records = reporting.get_grouped_records(records)

In [87]:
len(grouped_records)

10

In [92]:
grouped_records[0]['records'][0]

{'args': {'algorithm': 'HessianAlignment',
  'checkpoint_freq': None,
  'data_dir': './domainbed/data/',
  'dataset': 'VLCS',
  'device': 0,
  'holdout_fraction': 0.2,
  'hparams': '{"model_type":"ResNet"}',
  'hparams_seed': 1,
  'output_dir': './domainbed/results_resnet_new/be7d8e28947326fedd849898854fc595',
  'save_model_every_checkpoint': False,
  'seed': 728992854,
  'skip_model_save': False,
  'steps': None,
  'task': 'domain_generalization',
  'test_envs': [2],
  'trial_seed': 0,
  'uda_holdout_fraction': 0},
 'env0_in_acc': 0.6554770318021201,
 'env0_out_acc': 0.6784452296819788,
 'env1_in_acc': 0.49317647058823527,
 'env1_out_acc': 0.4858757062146893,
 'env2_in_acc': 0.49124143183549124,
 'env2_out_acc': 0.4817073170731707,
 'env3_in_acc': 0.4683450573861533,
 'env3_out_acc': 0.4725925925925926,
 'epoch': 0.0,
 'erm_loss': 1.5806922912597656,
 'grad_pen': 0.0,
 'hess_pen': 0.0,
 'hparams': {'batch_size': 39,
  'class_balanced': False,
  'data_augmentation': True,
  'grad_alpha

In [115]:
a = records.group('args.hparams_seed').map(lambda _, run_records:
                (
                    selection_method.run_acc(run_records),
                    run_records
                )
            ).filter(lambda x: x[0] is not None).sorted(key=lambda x: x[0]['val_acc'])[::-1]

In [120]:
b = records.group('args.hparams_seed').map(lambda _, run_records:
                (
                    selection_method.run_acc(run_records),
                    run_records
                )
            ).filter(lambda x: x[0] is not None)

In [121]:
len(b)

5

In [128]:
for i in range(len(a)):
    print(a[i][0]['val_acc'], b[i][0]['val_acc'])

0.9901414487783969 0.9901414487783969
0.9828546935276468 0.9819974282040291
0.9819974282040291 0.9828546935276468
0.979854264894985 0.979854264894985
0.7820121951219512 0.7820121951219512


In [132]:
len(a[0][1])

378

In [141]:
for i in range(len(a[0][1])):
    print(a[0][1][i]['args']['output_dir'])

./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a17063f6ffc3d197dbdfa139c
./domainbed/results_resnet_new/33b8159a1

In [161]:
for i in range(len(a)):
    print(a[i][0]['val_acc'], a[i][0]['test_acc'])
    print(f"Hparams {(a[i][1][0]['hparams']['grad_alpha'], a[i][1][0]['hparams']['grad_alpha'],  a[i][1][0]['hparams']['penalty_anneal_iters'])}")


0.9901414487783969 1.0
Hparams (100, 100, 0)
0.9828546935276468 1.0
Hparams (115.67607214735708, 115.67607214735708, 1112)
0.9819974282040291 0.9837154488965074
Hparams (44.41840541006661, 44.41840541006661, 4072)
0.979854264894985 0.9934647525176773
Hparams (437.8858725808355, 437.8858725808355, 1319)
0.7820121951219512 0.9900990099009901
Hparams (11.274368753947169, 11.274368753947169, 2498)


In [155]:
a[1]

({'val_acc': 0.9828546935276468, 'test_acc': 1.0},
 [{'args': {'algorithm': 'HessianAlignment', 'checkpoint_freq': None, 'data_dir': './domainbed/data/', 'dataset': 'VLCS', 'device': 0, 'holdout_fraction': 0.2, 'hparams': '{"model_type":"ResNet"}', 'hparams_seed': 2, 'output_dir': './domainbed/results_resnet_new/2006fbf922f428c83d1352311a271a56', 'save_model_every_checkpoint': False, 'seed': 875388191, 'skip_model_save': False, 'steps': None, 'task': 'domain_generalization', 'test_envs': [0], 'trial_seed': 0, 'uda_holdout_fraction': 0}, 'env0_in_acc': 0.6819787985865724, 'env0_out_acc': 0.6819787985865724, 'env1_in_acc': 0.5327058823529411, 'env1_out_acc': 0.5178907721280602, 'env2_in_acc': 0.43107387661843105, 'env2_out_acc': 0.4298780487804878, 'env3_in_acc': 0.46945575712699, 'env3_out_acc': 0.46370370370370373, 'epoch': 0.0, 'erm_loss': 1.900767207145691, 'grad_pen': 0.0, 'hess_pen': 0.0, 'hparams': {'batch_size': 25, 'class_balanced': False, 'data_augmentation': True, 'grad_alpha'

In [150]:
a[0][1][0]['args']['dataset']

'RotatedMNIST'

In [151]:
a[0][1][i]['args']['test_envs']

[4]