In [None]:
import os, csv
import time
import argparse
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import sys
from collections import defaultdict
import numpy as np
import pickle  

In [None]:
import wilds
from wilds.common.data_loaders import get_train_loader, get_eval_loader
from wilds.common.grouper import CombinatorialGrouper

In [None]:
from utils import set_seed, Logger, BatchLogger, log_config, ParseKwargs, load, initialize_wandb, log_group_data, parse_bool, get_model_prefix
from train import train, evaluate
from algorithms.initializer import initialize_algorithm
from transforms import initialize_transform
from configs.utils import populate_defaults
import configs.supported as supported

In [None]:
from datetime import datetime
from pathlib import Path
from PIL import Image
import json
from wilds.datasets.wilds_dataset import WILDSDataset
from wilds.common.grouper import CombinatorialGrouper
from wilds.common.metrics.all_metrics import Accuracy, Recall, F1
from typing import Optional
import copy
from tqdm import tqdm
import math
from scipy.special import softmax
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split
from configs.supported import process_outputs_functions, process_pseudolabels_functions
from utils import save_model, save_pred, get_pred_prefix, get_model_prefix, collate_list, detach_and_clone, InfiniteDataIterator
import matplotlib.pyplot as plt

In [None]:
print(torch.cuda.is_available())
print(torch.version.cuda)

# Configuration

In [None]:
root_folder = "/wilds"

In [None]:
parser = argparse.ArgumentParser()

# Required arguments
parser.add_argument('-d', '--dataset', choices=wilds.supported_datasets, required=True)
parser.add_argument('--algorithm', required=True, choices=supported.algorithms)
parser.add_argument('--root_dir', required=True,
                    help='The directory where [dataset]/data can be found (or should be downloaded to, if it does not exist).')

# Dataset
parser.add_argument('--split_scheme', help='Identifies how the train/val/test split is constructed. Choices are dataset-specific.')
parser.add_argument('--dataset_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--download', default=False, type=parse_bool, const=True, nargs='?',
                    help='If true, tries to downloads the dataset if it does not exist in root_dir.')
parser.add_argument('--frac', type=float, default=1,
                    help='Convenience parameter that scales all dataset splits down to the specified fraction, for development purposes. Note that this also scales the test set down, so the reported numbers are not comparable with the full test set.')
parser.add_argument('--version', default=None, type=str)

# Loaders
parser.add_argument('--loader_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--train_loader', choices=['standard', 'group'])
parser.add_argument('--uniform_over_groups', type=parse_bool, const=True, nargs='?')
parser.add_argument('--distinct_groups', type=parse_bool, const=True, nargs='?')
parser.add_argument('--n_groups_per_batch', type=int)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--eval_loader', choices=['standard'], default='standard')

# Model
parser.add_argument('--model', choices=supported.models)
parser.add_argument('--model_kwargs', nargs='*', action=ParseKwargs, default={},
    help='keyword arguments for model initialization passed as key1=value1 key2=value2')

# Transforms
parser.add_argument('--train_transform', choices=supported.transforms)
parser.add_argument('--eval_transform', choices=supported.transforms)
parser.add_argument('--target_resolution', nargs='+', type=int, help='The input resolution that images will be resized to before being passed into the model. For example, use --target_resolution 224 224 for a standard ResNet.')
parser.add_argument('--resize_scale', type=float)
parser.add_argument('--max_token_length', type=int)

# Objective
parser.add_argument('--loss_function', choices = supported.losses)

# Algorithm
parser.add_argument('--groupby_fields', nargs='+')
parser.add_argument('--group_dro_step_size', type=float)
parser.add_argument('--coral_penalty_weight', type=float)
parser.add_argument('--irm_lambda', type=float)
parser.add_argument('--irm_penalty_anneal_iters', type=int)
parser.add_argument('--algo_log_metric')

# Model selection
parser.add_argument('--val_metric')
parser.add_argument('--val_metric_decreasing', type=parse_bool, const=True, nargs='?')

# Optimization
parser.add_argument('--n_epochs', type=int)
parser.add_argument('--optimizer', choices=supported.optimizers)
parser.add_argument('--lr', type=float)
parser.add_argument('--weight_decay', type=float)
parser.add_argument('--max_grad_norm', type=float)
parser.add_argument('--optimizer_kwargs', nargs='*', action=ParseKwargs, default={})

# Scheduler
parser.add_argument('--scheduler', choices=supported.schedulers)
parser.add_argument('--scheduler_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--scheduler_metric_split', choices=['train', 'val'], default='val')
parser.add_argument('--scheduler_metric_name')

# Evaluation
parser.add_argument('--process_outputs_function', choices = supported.process_outputs_functions)
parser.add_argument('--evaluate_all_splits', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--eval_splits', nargs='+', default=[])
parser.add_argument('--eval_only', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--eval_epoch', default=None, type=int, help='If eval_only is set, then eval_epoch allows you to specify evaluating at a particular epoch. By default, it evaluates the best epoch by validation performance.')

# Misc
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--log_dir', default='./logs')
parser.add_argument('--log_every', default=50, type=int)
parser.add_argument('--save_step', type=int)
parser.add_argument('--save_best', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--save_last', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--save_pred', type=parse_bool, const=True, nargs='?', default=True)
parser.add_argument('--no_group_logging', type=parse_bool, const=True, nargs='?')
parser.add_argument('--use_wandb', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--progress_bar', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--resume', type=parse_bool, const=True, nargs='?', default=False)
parser.add_argument('--use_unlabeled_y', default=False, type=parse_bool, const=True, nargs='?', 
                    help='If true, unlabeled loaders will also the true labels for the unlabeled data. This is only available for some datasets. Used for "fully-labeled ERM experiments" in the paper. Correct functionality relies on CrossEntropyLoss using ignore_index=-100.')
parser.add_argument('--additional_train_transform', choices=supported.additional_transforms, help='Optional data augmentations to layer on top of the default transforms.')
parser.add_argument('--load_featurizer_only', default=False, type=parse_bool, const=True, nargs='?', help='If true, only loads the featurizer weights and not the classifier weights.')
parser.add_argument('--unlabeled_loader_kwargs', nargs='*', action=ParseKwargs, default={})
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, help='Number of batches to process before stepping optimizer and schedulers. If > 1, we simulate having a larger effective batch size (though batchnorm behaves differently).')
parser.add_argument('--pretrained_model_path', default=None, type=str, help='Specify a path to pretrained model weights')

In [None]:
string_parsed = [
    '--dataset',
    'iwildcam',
    '--algorithm',
    'ERM',
    '--root_dir',
    os.path.join(root_folder,'data'),
    '--train_transform',
    "image_base",
    '--eval_transform',
    "image_base",    
    '--resume',
    '--eval_only',
    "--save_last",
    "False",
    '--n_epochs',
    '5'
]

config = parser.parse_args(string_parsed)
config = populate_defaults(config)
config.loader_kwargs['num_workers'] = 4

In [None]:
if torch.cuda.is_available():
    config.use_data_parallel = False
    config.device = torch.device("cuda:" + str(config.device))
else:
    config.use_data_parallel = False
    config.device = torch.device("cpu")

config.device

In [None]:
# Initialize logs
if os.path.exists(config.log_dir) and config.resume:
    resume=True
    mode='a'
elif os.path.exists(config.log_dir) and config.eval_only:
    resume=False
    mode='a'
else:
    resume=False
    mode='w'

if not os.path.exists(config.log_dir):
    os.makedirs(config.log_dir)
logger = Logger(os.path.join(config.log_dir, 'log.txt'), mode)

# Record config
log_config(config, logger)

# Set random seed
set_seed(config.seed)

# Helper Functions

In [None]:
def split_train_test(input_df, train_frac=0.8):
    train_msk = np.random.choice(input_df.index, size = int(train_frac*(len(input_df))), replace=False)
    input_df.loc[train_msk, 'split'] = "train"
    input_df.loc[input_df.index.difference(train_msk), 'split'] = "val"
    
def q_plus(arr, alpha):
    arr = list(arr)
    n = len(arr)
    idx = np.ceil((1-alpha)*(n+1)).astype("int")-1
    if idx >= n:
        return float("inf")
    else:
        return np.partition(arr, idx)[idx]

def q_minus(arr, alpha):
    arr = list(arr)
    n = len(arr)
    idx = np.floor(alpha*(n+1))-1
    if idx < 0:
        return -float("inf")
    else:
        return np.partition(arr, idx)[idx]
    
# Compute tau
def split_conformal_compute_tau(total_res, alpha, delta):
    S_list = []
    for res in total_res:
        S_list.append(q_plus(res, alpha))
    return q_plus(np.array(S_list), delta)

In [None]:
alpha_delta_list = []
for alpha in np.linspace(0.02, 0.88, 50):
    for delta in np.linspace(0.02, 0.88, 50):
        alpha_delta_list.append((alpha, delta))

# Filter & Split Environments

In [None]:
# construct LOO datasets
df_filename = config.root_dir + "/iwildcam_v2.0/metadata.csv"
df = pd.read_csv(df_filename)

# remove domains with <= 100 samples
sample_threshold = 100
sample_count_ser = df.groupby(['location_remapped'])['split'].count()
filtered_ser = sample_count_ser[sample_count_ser > sample_threshold]
df = df[df['location_remapped'].isin(filtered_ser.index)]

# remove categories that appear in <= 5% domains
domain_threshold = int(0.05 * len(df['location_remapped'].unique()))
print(f"domain_threshold: {domain_threshold}")
domain_count_ser = df.groupby("category_id")["location_remapped"].nunique()
filtered_category_ser = domain_count_ser[domain_count_ser > domain_threshold]
df = df[df["category_id"].isin(filtered_category_ser.index)]

unique_domains = df['location_remapped'].unique()
num_domains = len(unique_domains)
print("Number of unique domains: ", num_domains)
print("Number of unique categories: ", df.category_id.nunique())

In [None]:
# remap category ids
old_cates = df.category_id.unique()
cate_map = {cate:idx for idx, cate in enumerate(old_cates)}
df['category_id'] = df['category_id'].map(cate_map)
df['y'] = df['category_id']
num_cate_total = df['category_id'].nunique()

# Read Data

In [None]:
class IWildCamDatasetModified(WILDSDataset):
    _dataset_name = 'iwildcam'
    _versions_dict = {
        '2.0': {
            'download_url': 'https://worksheets.codalab.org/rest/bundles/0x6313da2b204647e79a14b468131fcd64/contents/blob/',
            'compressed_size': 11_957_420_032}}


    def __init__(self, version=None, root_dir='data', download=False, split_scheme='official', metadata_filename = 'metadata.csv', num_classes = None):
        self._version = version
        self._split_scheme = split_scheme
        if self._split_scheme != 'official':
            raise ValueError(f'Split scheme {self._split_scheme} not recognized')

        # path
        self._data_dir = Path(self.initialize_data_dir(root_dir, download))

        # Load splits
        df = pd.read_csv(self._data_dir / metadata_filename)

        # Splits
        self._split_dict = {'train': 0, 'val': 1, 'test': 2, 'id_val': 3, 'id_test': 4}
        self._split_names = {'train': 'Train', 'val': 'Validation (OOD/Trans)',
                                'test': 'Test (OOD/Trans)', 'id_val': 'Validation (ID/Cis)',
                                'id_test': 'Test (ID/Cis)'}

        df['split_id'] = df['split'].apply(lambda x: self._split_dict[x])
        self._split_array = df['split_id'].values

        # Filenames
        self._input_array = df['filename'].values

        # Labels
        self._y_array = torch.tensor(df['y'].values)
        if num_classes is None:
            self._n_classes = max(df['y']) + 1
        else:
            self._n_classes = num_classes
        self._y_size = 1
        # assert len(np.unique(df['y'])) == self._n_classes

        # Location/group info
        n_groups = max(df['location_remapped']) + 1
        self._n_groups = n_groups
        # assert len(np.unique(df['location_remapped'])) == self._n_groups

        # Sequence info
        n_sequences = max(df['sequence_remapped']) + 1
        self._n_sequences = n_sequences
        # assert len(np.unique(df['sequence_remapped'])) == self._n_sequences

        # Extract datetime subcomponents and include in metadata
        df['datetime_obj'] = df['datetime'].apply(lambda x: datetime.strptime(x, '%Y-%m-%d %H:%M:%S.%f'))
        df['year'] = df['datetime_obj'].apply(lambda x: int(x.year))
        df['month'] = df['datetime_obj'].apply(lambda x: int(x.month))
        df['day'] = df['datetime_obj'].apply(lambda x: int(x.day))
        df['hour'] = df['datetime_obj'].apply(lambda x: int(x.hour))
        df['minute'] = df['datetime_obj'].apply(lambda x: int(x.minute))
        df['second'] = df['datetime_obj'].apply(lambda x: int(x.second))

        self._metadata_array = torch.tensor(np.stack([df['location_remapped'].values,
                            df['sequence_remapped'].values,
                            df['year'].values, df['month'].values, df['day'].values,
                            df['hour'].values, df['minute'].values, df['second'].values,
                            self.y_array], axis=1))
        self._metadata_fields = ['location', 'sequence', 'year', 'month', 'day', 'hour', 'minute', 'second', 'y']

        # eval grouper
        self._eval_grouper = CombinatorialGrouper(
            dataset=self,
            groupby_fields=None)

        super().__init__(root_dir, download, split_scheme)

    def eval(self, y_pred, y_true, metadata, prediction_fn=None):
        """
        Computes all evaluation metrics.
        Args:
            - y_pred (Tensor): Predictions from a model. By default, they are predicted labels (LongTensor).
                               But they can also be other model outputs such that prediction_fn(y_pred)
                               are predicted labels.
            - y_true (LongTensor): Ground-truth labels
            - metadata (Tensor): Metadata
            - prediction_fn (function): A function that turns y_pred into predicted labels
        Output:
            - results (dictionary): Dictionary of evaluation metrics
            - results_str (str): String summarizing the evaluation metrics
        """
        metrics = [
            Accuracy(prediction_fn=prediction_fn),
            Recall(prediction_fn=prediction_fn, average='macro'),
            F1(prediction_fn=prediction_fn, average='macro'),
        ]

        results = {}
        for i in range(len(metrics)):
            results.update({
                **metrics[i].compute(y_pred, y_true),
                        })
        results_str = (
            f"Average acc: {results[metrics[0].agg_metric_field]:.3f}\n"
            f"Recall macro: {results[metrics[1].agg_metric_field]:.3f}\n"
            f"F1 macro: {results[metrics[2].agg_metric_field]:.3f}\n"
        )
        return results, results_str

    def get_input(self, idx):
        """
        Args:
            - idx (int): Index of a data point
        Output:
            - x (Tensor): Input features of the idx-th data point
        """

        # All images are in the train folder
        img_path = self.data_dir / 'train' / self._input_array[idx]
        img = Image.open(img_path)

        return img

In [None]:
def get_modified_dataset(dataset: str, version: Optional[str] = None, unlabeled: bool = False, metadata_filename = 'pretrain_metadata.csv', num_classes = None, **dataset_kwargs):
    """
    Returns the appropriate WILDS dataset class.
    Input:
        dataset (str): Name of the dataset
        version (Union[str, None]): Dataset version number, e.g., '1.0'.
                                    Defaults to the latest version.
        unlabeled (bool): If true, use the unlabeled version of the dataset.
        dataset_kwargs: Other keyword arguments to pass to the dataset constructors.
    Output:
        The specified WILDSDataset class.
    """
    if version is not None:
        version = str(version)

    if dataset not in wilds.supported_datasets:
        raise ValueError(f'The dataset {dataset} is not recognized. Must be one of {wilds.supported_datasets}.')

    if unlabeled and dataset not in wilds.unlabeled_datasets:
        raise ValueError(f'Unlabeled data is not available for {dataset}. Must be one of {wilds.unlabeled_datasets}.')

    if dataset == 'iwildcam':
        if unlabeled:
            print("unlabeled")
            from wilds.datasets.unlabeled.iwildcam_unlabeled_dataset import IWildCamUnlabeledDataset
            return IWildCamUnlabeledDataset(version=version, **dataset_kwargs)
        else:
            if version == '1.0':
                from wilds.datasets.archive.iwildcam_v1_0_dataset import IWildCamDataset
            else:
                # from wilds.datasets.iwildcam_dataset import IWildCamDataset # type:ignore
                return IWildCamDatasetModified(version=version,\
                       metadata_filename = metadata_filename, \
                       num_classes = num_classes, **dataset_kwargs)

In [None]:
def load_datasets(data_filename):
    full_dataset = get_modified_dataset(
        dataset=config.dataset,
        version=config.version,
        root_dir=config.root_dir,
        download=config.download,
        split_scheme=config.split_scheme,
        metadata_filename = data_filename,
        num_classes = num_cate_total,
        **config.dataset_kwargs
    )
    
    train_transform = initialize_transform(
        transform_name=config.train_transform,
        config=config,
        dataset=full_dataset,
        is_training = True)
    eval_transform = initialize_transform(
        transform_name=config.eval_transform,
        config=config,
        dataset=full_dataset,
        is_training = False)

    train_grouper = CombinatorialGrouper(
        dataset=full_dataset,
        groupby_fields=None)

    
    datasets = defaultdict(dict)
    for split in full_dataset.split_dict.keys():
        if split=='train':
            transform = train_transform
            verbose = True
        elif split == 'val':
            transform = eval_transform
            verbose = True
        else:
            transform = eval_transform
            verbose = False
        # Get subset
        datasets[split]['dataset'] = full_dataset.get_subset(
            split,
            frac = config.frac,
            transform=transform)

        if split == 'train':
            datasets[split]['loader'] = get_train_loader(
                loader=config.train_loader,
                dataset=datasets[split]['dataset'],
                batch_size=config.batch_size,
                uniform_over_groups=config.uniform_over_groups,
                grouper=train_grouper,
                distinct_groups=config.distinct_groups,
                n_groups_per_batch=config.n_groups_per_batch,
                **config.loader_kwargs)
        else:
            datasets[split]['loader'] = get_eval_loader(
                loader=config.eval_loader,
                dataset=datasets[split]['dataset'],
                grouper=train_grouper,
                batch_size=config.batch_size,
                **config.loader_kwargs)

        # Set fields
        datasets[split]['split'] = split
        datasets[split]['name'] = full_dataset.split_names[split]
        datasets[split]['verbose'] = verbose

        # Loggers
        datasets[split]['eval_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_eval.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))
        datasets[split]['algo_logger'] = BatchLogger(
            os.path.join(config.log_dir, f'{split}_algo.csv'), mode=mode, use_wandb=(config.use_wandb and verbose))

        if config.use_wandb:
            initialize_wandb(config)

    # Logging dataset info
    # Show class breakdown if feasible
    if config.no_group_logging and full_dataset.is_classification and full_dataset.y_size==1 and full_dataset.n_classes <= 10:
        log_grouper = CombinatorialGrouper(
            dataset=full_dataset,
            groupby_fields=None)
    elif config.no_group_logging:
        log_grouper = None
    else:
        log_grouper = train_grouper
    log_group_data(datasets, log_grouper, logger)
    
    return datasets, train_grouper

# Train Models

In [None]:
def convert_to_one_hot(y_true, num_labels = num_cate_total):
    n = y_true.shape[0]
    res = np.zeros([n, num_labels])
    res[np.arange(n), y_true] = 1
    return res

In [None]:
def run_epoch(algorithm, dataset, general_logger, epoch, config, train, unlabeled_dataset=None):
    if dataset['verbose']:
        general_logger.write(f"\n{dataset['name']}:\n")

    if train:
        algorithm.train()
        torch.set_grad_enabled(True)
    else:
        algorithm.eval()
        torch.set_grad_enabled(False)

    # Not preallocating memory is slower
    # but makes it easier to handle different types of data loaders
    # (which might not return exactly the same number of examples per epoch)
    epoch_y_true = []
    epoch_y_pred = []
    epoch_metadata = []

    # Assert that data loaders are defined for the datasets
    assert 'loader' in dataset, "A data loader must be defined for the dataset."
    if unlabeled_dataset:
        assert 'loader' in unlabeled_dataset, "A data loader must be defined for the dataset."

    batches = dataset['loader']
    if config.progress_bar:
        batches = tqdm(batches)
    last_batch_idx = len(batches)-1
    
    if unlabeled_dataset:
        unlabeled_data_iterator = InfiniteDataIterator(unlabeled_dataset['loader'])

    # Using enumerate(iterator) can sometimes leak memory in some environments (!)
    # so we manually increment batch_idx
    batch_idx = 0
    for labeled_batch in batches:
        if train:
            if unlabeled_dataset:
                unlabeled_batch = next(unlabeled_data_iterator)
                batch_results = algorithm.update(labeled_batch, unlabeled_batch, is_epoch_end=(batch_idx==last_batch_idx))
            else:
                batch_results = algorithm.update(labeled_batch, is_epoch_end=(batch_idx==last_batch_idx))
        else:
            batch_results = algorithm.evaluate(labeled_batch)

        # These tensors are already detached, but we need to clone them again
        # Otherwise they don't get garbage collected properly in some versions
        # The extra detach is just for safety
        # (they should already be detached in batch_results)
        epoch_y_true.append(detach_and_clone(batch_results['y_true']))
        y_pred = detach_and_clone(batch_results['y_pred'])
        if config.process_outputs_function is not None:
            y_pred = process_outputs_functions[config.process_outputs_function](y_pred)
        epoch_y_pred.append(y_pred)
        epoch_metadata.append(detach_and_clone(batch_results['metadata']))

        if train: 
            effective_batch_idx = (batch_idx + 1) / config.gradient_accumulation_steps
        else: 
            effective_batch_idx = batch_idx + 1

        if train and effective_batch_idx % config.log_every==0:
            log_results(algorithm, dataset, general_logger, epoch, math.ceil(effective_batch_idx))

        batch_idx += 1

    epoch_y_pred = collate_list(epoch_y_pred)
    epoch_y_true = collate_list(epoch_y_true)
    epoch_metadata = collate_list(epoch_metadata)

    results, results_str = dataset['dataset'].eval(
        epoch_y_pred,
        epoch_y_true,
        epoch_metadata)

    if config.scheduler_metric_split==dataset['split']:
        algorithm.step_schedulers(
            is_epoch=True,
            metrics=results,
            log_access=(not train))

    # log after updating the scheduler in case it needs to access the internal logs
    log_results(algorithm, dataset, general_logger, epoch, math.ceil(effective_batch_idx))

    results['epoch'] = epoch
    dataset['eval_logger'].log(results)
    if dataset['verbose']:
        general_logger.write('Epoch eval:\n')
        general_logger.write(results_str)

    return results, epoch_y_pred

In [None]:
def train_model(algorithm, datasets, general_logger, config, epoch_offset, best_val_metric, domain_num, unlabeled_dataset=None, validation_threshold=0.90):
    """
    Train loop that, each epoch:
        - Steps an algorithm on the datasets['train'] split and the unlabeled split
        - Evaluates the algorithm on the datasets['val'] split
        - Saves models / preds with frequency according to the configs
        - Evaluates on any other specified splits in the configs
    Assumes that the datasets dict contains labeled data.
    """
    
    early_stop_count = 0
    for epoch in range(epoch_offset, config.n_epochs):
        general_logger.write('\nEpoch [%d]:\n' % epoch)

        # First run training
        run_epoch(algorithm, datasets['train'], general_logger, epoch, config, train=True, unlabeled_dataset=unlabeled_dataset)

        # Then run val
        val_results, y_pred = run_epoch(algorithm, datasets['val'], general_logger, epoch, config, train=False)
        curr_val_metric = val_results[config.val_metric]
        # general_logger.write(f'Validation {config.val_metric}: {curr_val_metric:.3f}\n')

        if best_val_metric is None:
            is_best = True
        else:
            if config.val_metric_decreasing:
                is_best = curr_val_metric < best_val_metric
            else:
                is_best = curr_val_metric > best_val_metric
        if is_best:
            best_val_metric = curr_val_metric
            # general_logger.write(f'Epoch {epoch} has the best validation performance so far.\n')

        save_model_if_needed(algorithm, datasets['val'], epoch, config, is_best, best_val_metric, domain_num)
        save_pred_if_needed(y_pred, datasets['val'], epoch, config, is_best)

        general_logger.write('\n')
        
        if val_results["acc_avg"] > validation_threshold:
            early_stop_count += 1
            print(f"early_stop_count: {early_stop_count}")
            if early_stop_count >= 10:
                break

In [None]:
def evaluate(algorithm, datasets, epoch, general_logger, config, is_best):
    algorithm.eval()
    torch.set_grad_enabled(False)
    for split, dataset in datasets.items():
        if (not config.evaluate_all_splits) and (split not in config.eval_splits):
            continue
        epoch_y_true = []
        epoch_y_pred = []
        epoch_metadata = []
        iterator = tqdm(dataset['loader']) if config.progress_bar else dataset['loader']
        for batch in iterator:
            batch_results = algorithm.evaluate(batch)
            epoch_y_true.append(detach_and_clone(batch_results['y_true']))
            y_pred = detach_and_clone(batch_results['y_pred'])
            if config.process_outputs_function is not None:
                y_pred = process_outputs_functions[config.process_outputs_function](y_pred)
            epoch_y_pred.append(y_pred)
            epoch_metadata.append(detach_and_clone(batch_results['metadata']))

        epoch_y_pred = collate_list(epoch_y_pred)
        epoch_y_true = collate_list(epoch_y_true)
        epoch_metadata = collate_list(epoch_metadata)
        ear, results_str = dataset['dataset'].eval(
            epoch_y_pred,
            epoch_y_true,
            epoch_metadata)

        results['epoch'] = epoch
        dataset['eval_logger'].log(results)
        general_logger.write(f'Eval split {split} at epoch {epoch}:\n')
        general_logger.write(results_str)

        # Skip saving train preds, since the train loader generally shuffles the data
        if split != 'train':
            save_pred_if_needed(epoch_y_pred, dataset, epoch, config, is_best, force_save=True)

In [None]:
def get_residuals(algorithm, batches):
    res = []
    epoch_y_true = []
    epoch_y_pred = []
    epoch_metadata = []

    algorithm.eval()
    for labeled_batch in batches:
        batch_results = algorithm.evaluate(labeled_batch)
        y_pred = detach_and_clone(batch_results['y_pred'])

        epoch_y_true.append(detach_and_clone(batch_results['y_true']))
        epoch_y_pred.append(y_pred)

    epoch_y_pred = collate_list(epoch_y_pred)
    epoch_y_true = collate_list(epoch_y_true)
    n, p = epoch_y_pred.shape
    loss = -np.log(softmax(epoch_y_pred, axis=1))[np.arange(n), epoch_y_true]

    return loss

In [None]:
def save_model_if_needed(algorithm, dataset, epoch, config, is_best, best_val_metric, domain_num):
    prefix = get_model_prefix(dataset, config) + str(domain_num) + "_"
    if config.save_step is not None and (epoch + 1) % config.save_step == 0:
        save_model(algorithm, epoch, best_val_metric, prefix + f'epoch:{epoch}_model.pth')
    if config.save_last:
        save_model(algorithm, epoch, best_val_metric, prefix + 'epoch:last_model.pth')
    if config.save_best and is_best:
        save_model(algorithm, epoch, best_val_metric, prefix + 'epoch:best_model.pth')

In [None]:
def log_results(algorithm, dataset, general_logger, epoch, effective_batch_idx):
    if algorithm.has_log:
        log = algorithm.get_log()
        log['epoch'] = epoch
        log['batch'] = effective_batch_idx
        dataset['algo_logger'].log(log)
        if dataset['verbose']:
            general_logger.write(algorithm.get_pretty_log_str())
        algorithm.reset_log()

# Inference

In [None]:
def infer_predictions(model, loader, config):
    """
    Simple inference loop that performs inference using a model (not algorithm) and returns model outputs.
    Compatible with both labeled and unlabeled WILDS datasets.
    """
    model.eval()
    y_pred = []
    iterator = tqdm(loader) if config.progress_bar else loader
    for batch in iterator:
        x = batch[0]
        x = x.to(config.device)
        with torch.no_grad(): 
            output = model(x)
            if not config.soft_pseudolabels and config.process_pseudolabels_function is not None:
                _, output, _, _ = process_pseudolabels_functions[config.process_pseudolabels_function](
                    output,
                    confidence_threshold=config.self_training_threshold if config.dataset == 'globalwheat' else 0
                )
            elif config.soft_pseudolabels:
                output = torch.nn.functional.softmax(output, dim=1)
        if isinstance(output, list):
            y_pred.extend(detach_and_clone(output))
        else:
            y_pred.append(detach_and_clone(output))

    return torch.cat(y_pred, 0) if torch.is_tensor(y_pred[0]) else y_pred

In [None]:
def save_pred_if_needed(y_pred, dataset, epoch, config, is_best, force_save=False):
    if config.save_pred:
        prefix = get_pred_prefix(dataset, config)
        if force_save or (config.save_step is not None and (epoch + 1) % config.save_step == 0):
            save_pred(y_pred, prefix + f'epoch:{epoch}_pred')
        if (not force_save) and config.save_last:
            save_pred(y_pred, prefix + f'epoch:last_pred')
        if config.save_best and is_best:
            save_pred(y_pred, prefix + f'epoch:best_pred')

# Conformal Algorithms

In [None]:
# use part of the training data to train a model
def train_D1():
    datasets, train_grouper = load_datasets("split_data/D1_metadata.csv")

    ## Initialize algorithm
    algorithm = initialize_algorithm(
        config=config,
        datasets=datasets,
        train_grouper=train_grouper)
    
    resume_success = False
    if resume_success == False:
        epoch_offset=0
        best_val_metric=None

    train_model(algorithm=algorithm,
          datasets=datasets,
          general_logger=logger,
          config=config,
          epoch_offset=epoch_offset,
          best_val_metric=best_val_metric,
          domain_num = "D1")

In [None]:
num_experiments = 50
num_ID = 50
num_OOD = num_domains - num_ID
model_save_path = f"/wilds/examples/logs/iwildcam_seed:0_D1_epoch:best_model.pth"

In [None]:
np.random.seed()
emp_alpha_res_list, emp_delta_res_list, emp_card_res_list = [], [], []

for _ in range(num_experiments):
    # randomly split data into two datasets
    permutation_idx = np.random.permutation(num_domains)
    ID_idx = unique_domains[permutation_idx[:num_ID]]
    OOD_idx = unique_domains[permutation_idx[num_ID:]]

    for i in ID_idx:
        domain_df = df[df['location_remapped'] == i]
        split_train_test(domain_df, train_frac=1)
        domain_df.to_csv(config.root_dir + f"/iwildcam_v2.0/split_data/ID_{i}_metadata.csv")

    ID_df = df[df['location_remapped'].isin(ID_idx)]
    split_train_test(ID_df, train_frac=0.7)
    ID_df.to_csv(config.root_dir + f"/iwildcam_v2.0/split_data/total_ID_metadata.csv")

    OOD_df = df[df['location_remapped'].isin(OOD_idx)]
    split_train_test(OOD_df, train_frac=1)
    OOD_df.to_csv(config.root_dir + f"/iwildcam_v2.0/split_data/total_OOD_metadata.csv")

    total_id_df = pd.read_csv(config.root_dir + "/iwildcam_v2.0/split_data/total_ID_metadata.csv")
    ID_idx = total_id_df.location_remapped.unique()

    # Split conformal algo
    id_filename = config.root_dir + f"/iwildcam_v2.0/split_data/total_ID_metadata.csv"
    ID_df = pd.read_csv(id_filename)
    ID_idx = ID_df['location_remapped'].unique()
    train_train_idx, train_val_idx = train_test_split(ID_idx, test_size=0.5)

    # Train on D1
    D1_df = df[df['location_remapped'].isin(train_train_idx)]
    split_train_test(D1_df)
    D1_df.to_csv(config.root_dir + f"/iwildcam_v2.0/split_data/D1_metadata.csv")
    train_D1()
    
    # Evaluate on D2
    datasets, train_grouper = load_datasets("split_data/D1_metadata.csv")
    algorithm = initialize_algorithm(
        config=config,
        datasets=datasets,
        train_grouper=train_grouper)

    load(algorithm, model_save_path, device=config.device) 
    algorithm.eval()

    total_train_res = []
    for domain_num in train_val_idx:
        LOO_datasets, LOO_train_grouper = load_datasets(f'split_data/ID_{domain_num}_metadata.csv')
        batches = LOO_datasets['train']['loader']
        train_res = get_residuals(algorithm, batches)
        total_train_res.append(train_res)
    total_train_res = np.array(total_train_res)
    
    # Test on holdout data
    total_count_list, total_cover_list, total_card_list =\
    defaultdict(list), defaultdict(list), defaultdict(list)

    lee_total_count_list, lee_total_cover_list, lee_total_card_list =\
    defaultdict(list), defaultdict(list), defaultdict(list)

    for domain in OOD_df['location_remapped'].unique():
        print(f"Processing domain {domain}")
        domain_df = OOD_df[OOD_df['location_remapped'] == domain]
        split_train_test(domain_df, train_frac=1)
        domain_df.to_csv(config.root_dir + f"/iwildcam_v2.0/split_data/OOD_{domain}_metadata.csv")

        datasets, train_grouper = load_datasets(config.root_dir + f"/iwildcam_v2.0/split_data/OOD_{domain}_metadata.csv")
        batches = datasets['train']['loader']
        curr_train_res = get_residuals(algorithm, batches)

        datasets, train_grouper = load_datasets(config.root_dir + f"/iwildcam_v2.0/split_data/OOD_{domain}_metadata.csv")
        batches = datasets['train']['loader']

        # for each sample, evaluate
        total_count = defaultdict(lambda: 0)
        total_cover = defaultdict(lambda: 0)
        total_card = defaultdict(lambda: 0)
        
        # read stored residual quantiles
        taus_dict = {}
        for (alpha, delta) in alpha_delta_list:
            taus_dict[(alpha, delta)] = split_conformal_compute_tau(total_train_res, alpha, delta)

        for labeled_batch in batches:
            x,y,_ = labeled_batch
            x = x.to(algorithm.device)

            y_pred = algorithm.model(x)
            sm = torch.nn.Softmax()
            loss_val = -torch.log(sm(y_pred)).cpu().detach().numpy()

            for (alpha, delta) in alpha_delta_list:
                tau = taus_dict[(alpha, delta)]
                total_count[(alpha, delta)] += y_pred.shape[0]
                total_card[(alpha, delta)] += np.sum(loss_val<= tau)
                total_cover[(alpha, delta)] += np.sum(loss_val[np.arange(len(loss_val)), np.array(y)] <= tau)
        
        for (alpha, delta) in alpha_delta_list:
            total_count_list[(alpha, delta)].append(total_count[(alpha, delta)])
            total_cover_list[(alpha, delta)].append(total_cover[(alpha, delta)])
            total_card_list[(alpha, delta)].append(total_card[(alpha, delta)])

    # analyze methods' performance
    emp_alpha, emp_delta, emp_card = defaultdict(list), defaultdict(list), defaultdict(list)
    for (alpha, delta) in alpha_delta_list:
        curr_total_cover_list = np.array(total_cover_list[(alpha, delta)])
        curr_total_card_list = np.array(total_card_list[(alpha, delta)])
        curr_total_count_list = np.array(total_count_list[(alpha, delta)])
        
        emp_card[(alpha, delta)] = np.mean(curr_total_card_list / curr_total_count_list)
        emp_delta[(alpha, delta)] = np.mean(curr_total_cover_list >= (1-alpha)*(1+curr_total_count_list))
        is_valid_coverage = (curr_total_cover_list >= (1-alpha)*(1+curr_total_count_list))
        emp_alpha_all = curr_total_cover_list / curr_total_count_list
        emp_alpha[(alpha, delta)] = np.mean(emp_alpha_all[is_valid_coverage])
    
    # stores the empirical alpha values
    emp_alpha_res_list.append(emp_alpha)
    # stores the empirical delta values
    emp_delta_res_list.append(emp_delta)
    # stores the average set lengths
    emp_card_res_list.append(emp_card)

    res_dict = {"emp_alpha_res_list": emp_alpha_res_list,\
               "emp_delta_res_list": emp_delta_res_list,\
                "emp_card_res_list": emp_card_res_list
               }
    with open(f'split_conformal_iwilds.pickle', 'wb') as handle:
        pickle.dump(res_dict, handle)

In [None]:
# remove name
# change frac