In [27]:
import collections
import functools
import glob
import pickle
import itertools
import json
import os
import random
import sys
import numpy as np
import tqdm
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# If domainbed is a custom module, you might need to adjust its imports or setup
from domainbed import datasets
from domainbed import algorithms
from domainbed.lib import misc, reporting
from domainbed import model_selection
from domainbed.lib.query import Q
import warnings


In [37]:
def format_mean(data, latex):
    """Given a list of datapoints, return a string describing their mean and
    standard error"""
    if len(data) == 0:
        return None, None, "X"
    mean = 100 * np.mean(list(data))
    err = 100 * np.std(list(data) / np.sqrt(len(data)))
    if latex:
        return mean, err, "{:.1f} $\\pm$ {:.1f}".format(mean, err)
    else:
        return mean, err, "{:.1f} +/- {:.1f}".format(mean, err)

def print_table(table, header_text, row_labels, col_labels, colwidth=10,
    latex=True):
    """Pretty-print a 2D array of data, optionally with row/col labels"""
    print("")

    if latex:
        num_cols = len(table[0])
        print("\\begin{center}")
        print("\\adjustbox{max width=\\textwidth}{%")
        print("\\begin{tabular}{l" + "c" * num_cols + "}")
        print("\\toprule")
    else:
        print("--------", header_text)

    for row, label in zip(table, row_labels):
        row.insert(0, label)

    if latex:
        col_labels = ["\\textbf{" + str(col_label).replace("%", "\\%") + "}"
            for col_label in col_labels]
    table.insert(0, col_labels)

    for r, row in enumerate(table):
        misc.print_row(row, colwidth=colwidth, latex=latex)
        if latex and r == 0:
            print("\\midrule")
    if latex:
        print("\\bottomrule")
        print("\\end{tabular}}")
        print("\\end{center}")

def print_results_tables(records, selection_method, latex):
    """Given all records, print a results table for each dataset."""
    grouped_records = reporting.get_grouped_records(records).map(lambda group:
        { **group, "sweep_acc": selection_method.sweep_acc(group["records"]) }
    ).filter(lambda g: g["sweep_acc"] is not None)

    # read algorithm names and sort (predefined order)
    alg_names = Q(records).select("args.algorithm").unique()
    alg_names = ([n for n in algorithms.ALGORITHMS if n in alg_names] +
        [n for n in alg_names if n not in algorithms.ALGORITHMS])

    # read dataset names and sort (lexicographic order)
    dataset_names = Q(records).select("args.dataset").unique().sorted()
    dataset_names = [d for d in datasets.DATASETS if d in dataset_names]

    for dataset in dataset_names:
        if latex:
            print()
            print("\\subsubsection{{{}}}".format(dataset))
        test_envs = range(datasets.num_environments(dataset))
        # breakpoint()
        table = [[None for _ in [*test_envs, "Avg"]] for _ in alg_names]
        for i, algorithm in enumerate(alg_names):
            means = []
            for j, test_env in enumerate(test_envs):
                trial_accs = (grouped_records
                    .filter_equals(
                        "dataset, algorithm, test_env",
                        (dataset, algorithm, test_env)
                    ).select("sweep_acc"))
                mean, err, table[i][j] = format_mean(trial_accs, latex)
                means.append(mean)
            if None in means:
                table[i][-1] = "X"
            else:
                table[i][-1] = "{:.1f}".format(sum(means) / len(means))

        col_labels = [
            "Algorithm",
            *datasets.get_dataset_class(dataset).ENVIRONMENTS,
            "Avg"
        ]
        header_text = (f"Dataset: {dataset}, "
            f"model selection method: {selection_method.name}")
        print_table(table, header_text, alg_names, list(col_labels),
            colwidth=20, latex=latex)

    # Print an "averages" table
    if latex:
        print()
        print("\\subsubsection{Averages}")

    table = [[None for _ in [*dataset_names, "Avg"]] for _ in alg_names]
    for i, algorithm in enumerate(alg_names):
        means = []
        for j, dataset in enumerate(dataset_names):
            trial_averages = (grouped_records
                .filter_equals("algorithm, dataset", (algorithm, dataset))
                .group("trial_seed")
                .map(lambda trial_seed, group:
                    group.select("sweep_acc").mean()
                )
            )
            mean, err, table[i][j] = format_mean(trial_averages, latex)
            means.append(mean)
        if None in means:
            table[i][-1] = "X"
        else:
            table[i][-1] = "{:.1f}".format(sum(means) / len(means))

    col_labels = ["Algorithm", *dataset_names, "Avg"]
    header_text = f"Averages, model selection method: {selection_method.name}"
    print_table(table, header_text, alg_names, col_labels, colwidth=25,
        latex=latex)

In [38]:
input_dir = "./results_resnet"
latex = False
results_file = "results.tex" if latex else "results.txt"

In [39]:
records = reporting.load_records(input_dir)
selection_methods = [
    model_selection.IIDAccuracySelectionMethod,
    model_selection.OracleSelectionMethod
]

for selection_method in selection_methods:
    print_results_tables(records, selection_method, latex)


                                                                                


-------- Dataset: VLCS, model selection method: training-domain validation set
Algorithm             C                     L                     S                     V                     Avg                  
HessianAlignment      83.3 +/- 4.6          64.2 +/- 1.2          59.7 +/- 1.2          76.4 +/- 0.6          70.9                 

-------- Dataset: PACS, model selection method: training-domain validation set
Algorithm             A                     C                     P                     S                     Avg                  
HessianAlignment      84.1 +/- 1.2          69.3 +/- 2.3          96.1 +/- 0.5          78.2 +/- 0.6          81.9                 

-------- Dataset: OfficeHome, model selection method: training-domain validation set
Algorithm             A                     C                     P                     R                     Avg                  
HessianAlignment      51.8 +/- 1.4          50.2 +/- 0.5          63.2 +/- 1.0          76.8 +

AttributeError: 'Q' object has no attribute 'shape'

In [26]:
records[0: 2]

[{'args': {'algorithm': 'HessianAlignment',
   'checkpoint_freq': None,
   'data_dir': './domainbed/data/',
   'dataset': 'PACS',
   'device': 0,
   'holdout_fraction': 0.2,
   'hparams': '{"model_type":"ResNet"}',
   'hparams_seed': 3,
   'output_dir': './domainbed/results_resnet/1d5173d1b760a99669e81f677ef5721b',
   'save_model_every_checkpoint': False,
   'seed': 735808364,
   'skip_model_save': False,
   'steps': None,
   'task': 'domain_generalization',
   'test_envs': [0],
   'trial_seed': 0,
   'uda_holdout_fraction': 0},
  'env0_in_acc': 0.225137278828554,
  'env0_out_acc': 0.20293398533007334,
  'env1_in_acc': 0.17857142857142858,
  'env1_out_acc': 0.17307692307692307,
  'env2_in_acc': 0.30988023952095806,
  'env2_out_acc': 0.2874251497005988,
  'env3_in_acc': 0.044529262086514,
  'env3_out_acc': 0.04331210191082802,
  'epoch': 0.0,
  'hparams': {'batch_size': 34,
   'class_balanced': False,
   'data_augmentation': True,
   'grad_alpha': 5.440033278684992e-06,
   'hess_beta': 

In [12]:
input_dir

'./results_resnet'