Skip to content

facebookresearch/ICRM

Repository files navigation

Context is Environment

By Sharut Gupta, Stefanie Jegelka, David Lopez-Paz, Kartik Ahuja

[Read the full paper on arXiv]

Two lines of work are taking center stage in AI research. On the one hand, the community is making increasing efforts to build models that discard spurious correlations and generalize better in novel test environments. Unfortunately, a hard lesson so far is that no proposal convincingly outperforms a simple empirical risk minimization baseline. On the other hand, large language models (LLMs) have erupted as algorithms able to learn in-context, generalizing on-the-fly to the eclectic contextual circumstances that users enforce by prompting. We argue that context is environment, and posit that in-context learning holds the key to better domain generalization. Via extensive theory and experiments, we show that paying attention to context---unlabeled examples as they arrive---allows our proposed In-Context Risk Minimization (ICRM) algorithm to zoom-in on the test environment risk minimizer, leading to significant out-of-distribution performance improvements. From all of this, two messages are worth taking home: researchers in domain generalization should consider environment as context, and harness the adaptive power of in-context learning. Researchers in LLMs should consider context as environment, to better structure data towards generalization.

The key contributions of this work include:

  • Establishing a strong parallel between the concept of environment in domain generalization, and the concept of context in next-token prediction
  • Introducing In-Context Risk Minimization (ICRM), a novel algorithm that learns in-context about environmental features by paying attention to unlabeled instances (context) as they arrive
  • Theoretically prove that such in-context learners can amortize context to zoom-in on the empirical risk minimizer of the test environment, achieving competitive out-of-distribution performance
  • Demonstrating that in several settings, ICRM learns invariances in the extended input-context feature space that ERM-based algorithms ignore
  • Empirically demonstrating the efficacy of ICRM and provide extensive ablations that dissect and deepen our understanding of it.

ICRM: In-Context Risk Minimization

Prerequisites

The code has the following package dependencies:

  • Pytorch >= 0.13.0 (preferably 2.0.0)
  • Torchvision >= 0.12.0

To install all dependencies, create the conda environment using requirements.txt file

conda env create -f environment.yml

Datasets

FEMNIST

Tiny ImageNet-C

Rotated MNIST can be downloaded using torch itself, while WILDS Camelyon is downloaded automatically using WILDS repository.

Available algorithms

The currently available algorithms are:

Available datasets

The currently available datasets are:

Quick start

Our code builds on DomainBed and adapts most functionalities like performing sweep across algorithms and datasets, choosing top hyperparameters etc. Download additional datasets:

python -m download --data_dir=./data

Train a model:

python -m main --data_dir=./data/ --algorithm ICRM --dataset FEMNIST

Launch a sweep:

python -m sweep launch --data_dir=./data --output_dir=./out --command_launcher MyLauncher

Here, MyLauncher is your cluster's command launcher, as implemented in command_launchers.py. Currently, it supports single node multiple GPU, multiple nodes and multiple GPUs and a local code launcher. You can also implement a launcher with your specific cluster requirements in command_launchers.py. By default, the entire sweeps trains thousands of models (all algorithms x all datasets x 3 independent trials x 10 random hyper-parameter choices). To customize the sweep, the following command can be used

python -m launch\
       --data_dir=./data\
       --output_dir=./out\
       --command_launcher MyLauncher\
       --algorithms ICRM ERM\
       --datasets FEMNIST WILDSCamelyon\
       --n_hparams 5\
       --n_trials 1

Once all jobs conclude, succeeding or failing, there are two ways to re-launch the script

  • Launch from scratch: Use python -m sweep delete_incomplete to remove data from the failed ones, and re-launch with python -m sweep launch. Ensure the same command-line arguments are used for all sweep calls to maintain consistency with the original jobs.
  • Resume from where the run stopped: Use the exact same command previously used to run sweep. It automatically resumes all incomplete jobs by resuming from the corresponding latest checkpoint.

To view the results of your sweep:

python -m collect_results\
       --input_dir=/my/sweep/output/path\
       --mode=<avg or wo>

Here mode=avg return average performance metrics and mode=wo return worst group performances.

Results

License

This project is licensed under CC-BY-NC as seen in License file

About

Context is Environment

Resources

License

Code of conduct

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages