diff --git a/domainbed/algorithms.py b/domainbed/algorithms.py index 35b0e721..a19adc71 100644 --- a/domainbed/algorithms.py +++ b/domainbed/algorithms.py @@ -1,13 +1,12 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import copy - -import numpy as np import torch -import torch.autograd as autograd import torch.nn as nn import torch.nn.functional as F -import random +import torch.autograd as autograd + +import copy +import numpy as np from domainbed import networks from domainbed.lib.misc import random_pairs_of_minibatches @@ -762,149 +761,57 @@ def update(self, minibatches): def predict(self, x): return self.network_c(self.network_f(x)) + class RSC(ERM): - """ - Representation Self-Challenging (RSC) - from: https://arxiv.org/pdf/2007.02454.pdf - """ def __init__(self, input_shape, num_classes, num_domains, hparams): super(RSC, self).__init__(input_shape, num_classes, num_domains, - hparams) - - self.f_drop_factor = hparams['rsc_f_drop_factor'] - self.b_drop_factor = hparams['rsc_b_drop_factor'] + hparams) + self.drop_f = (1 - hparams['rsc_f_drop_factor']) * 100 + self.drop_b = (1 - hparams['rsc_b_drop_factor']) * 100 self.num_classes = num_classes - self.flatten = nn.Flatten() - self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) - self.featurizer_name = self.featurizer.__class__.__name__ - - """ - For ResNet (original paper): - Delete the Average Pooling since this method operates on features after conv layer 4 - Dropout is disabled in hyperparameters both for default and random - Flatten from original ResNet is temporarily reverted for computations during update since it is not accessible as a layer - For MNIST_CNN (custom): - Delete the Average Pooling and use features after bn3 - SqueezeLastTwo is accessible as a Layer so we can delete it and replace it through Identity, serves as our flatten operation later - - Disclaimer: Apparently this algorithm doesn't work very well for the MNIST datasets with the current architecture, probably would require more fine-tuning - """ - - if self.featurizer_name == "ResNet": - del self.featurizer.network.avgpool - self.featurizer.network.avgpool = networks.Identity() - elif self.featurizer_name == "MNIST_CNN": - del self.featurizer.avgpool - self.featurizer.avgpool = networks.Identity() - del self.featurizer.squeezeLastTwo - self.featurizer.squeezeLastTwo = networks.Identity() - self.flatten = networks.SqueezeLastTwo() - - self.network = nn.Sequential(self.featurizer, self.avgpool, self.flatten, self.classifier) - self.optimizer = torch.optim.Adam( - self.network.parameters(), - lr=self.hparams["lr"], - weight_decay=self.hparams['weight_decay'] - ) - - def create_onehots(self, num_samples, targets): - one_hots = torch.zeros(num_samples, self.num_classes) # By default this sets requires_grad = False - one_hots[range(one_hots.shape[0]), targets] = 1 - return one_hots.cuda() - - def construct_mask(self, mean, num_samples, size): - num_to_drop = int(size * self.f_drop_factor) - mask_values = torch.sort(mean, dim=1, descending=True)[0][:, num_to_drop] - mask_values = mask_values.view(num_samples, 1).expand(num_samples, size) - mask = torch.where(mean >= mask_values, torch.zeros(mean.shape).cuda(), torch.ones(mean.shape).cuda()) - return mask def update(self, minibatches): - self.network.eval() - features = torch.cat([self.featurizer(xi) for xi, _ in minibatches]) - - # In ResNet we don't have access to the Flatten as a Layer as it is used inplace in the forward pass, hence revert it here - if self.featurizer_name == "ResNet": - features = features.view(features.shape[0], self.featurizer.n_outputs, 7, 7) - - targets = torch.cat([yi for _,yi in minibatches]) - - # detach features to disable gradient flow into the featurizer - features_detached = features.clone().detach() - features_detached.requires_grad = True - - # Predict - output = self.latent_predict(features_detached) - num_samples = features_detached.shape[0] - num_channels = features_detached.shape[1] - H = features_detached.shape[2] - HW = features_detached.shape[2] * features_detached.shape[3] - - # Create onehot targets for batch - one_hots = self.create_onehots(num_samples, targets) - - # calculate dot product elem_wise between output and one-hot target - elem_wise = torch.sum(output * one_hots) - - # calculate gradient of elem_wise wrt feature_detached - self.optimizer.zero_grad() - elem_wise.backward() - - # Channel & Spatial means - grad_val = features_detached.grad.clone().detach() - channel_mean = torch.mean(grad_val.view(num_samples, num_channels, -1), dim=2) - spatial_mean = torch.mean(grad_val, dim=1).view(num_samples, HW) - self.optimizer.zero_grad() - - # Choose either spatial or channel. Spatial doesn't make sense for architectures which do avgpooling afterwards, but it would make sense for e.g. AlexNet - if self.featurizer_name in ["ResNet", "MNIST_CNN"]: - choose_spatial = False - else: - choose_spatial = random.choice([True, False]) - - if choose_spatial: - mask = self.construct_mask(spatial_mean, num_samples, HW) - mask = mask.view(num_samples, 1, H, H) - - else: - channel_vector = self.construct_mask(channel_mean, num_samples, num_channels) - mask = channel_vector.view(num_samples, num_channels, 1, 1) - - # Drop the mask for certain samples inside each batch - class_prob_before = F.softmax(output, dim=1) - features_detached_masked = features_detached * mask - output_masked = self.latent_predict(features_detached_masked) - class_prob_after = F.softmax(output_masked, dim=1) - - before_vector = torch.sum(one_hots * class_prob_before, dim=1) - after_vector = torch.sum(one_hots * class_prob_after, dim=1) - change = before_vector - after_vector - change = torch.where(change > 0, change, torch.zeros(change.shape).cuda()) - threshold_value = torch.sort(change, dim=0, descending=True)[0][int(num_samples * self.b_drop_factor)] - drop_indeces = change.gt(threshold_value) - mask[~drop_indeces.long(), :] = 1 - - # Apply the mask on the non-detached features - self.train() - mask.requires_grad = True - features = features * mask - - # calculate loss between prediction and target - output_final = self.latent_predict(features) - loss = F.cross_entropy(output_final, targets) + # inputs + all_x = torch.cat([x for x, y in minibatches]) + # labels + all_y = torch.cat([y for _, y in minibatches]) + # one-hot labels + all_o = torch.nn.functional.one_hot(all_y, self.num_classes) + # features + all_f = self.featurizer(all_x) + # predictions + all_p = self.classifier(all_f) + + # Equation (1): compute gradients with respect to representation + all_g = autograd.grad((all_p * all_o).sum(), all_f)[0] + + # Equation (2): compute top-gradient-percentile mask + percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1) + percentiles = torch.Tensor(percentiles) + percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1)) + mask_f = all_g.lt(percentiles.cuda()).float() + + # Equation (3): mute top-gradient-percentile activations + all_f_muted = all_f * mask_f + + # Equation (4): compute muted predictions + all_p_muted = self.classifier(all_f_muted) + + # Section 3.3: Batch Percentage + all_s = F.softmax(all_p, dim=1) + all_s_muted = F.softmax(all_p_muted, dim=1) + changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1) + percentile = np.percentile(changes.detach().cpu(), self.drop_b) + mask_b = changes.lt(percentile).float().view(-1, 1) + mask = torch.logical_or(mask_f, mask_b).float() + + # Equations (3) and (4) again, this time mutting over examples + all_p_muted_again = self.classifier(all_f * mask) + + # Equation (5): update + loss = F.cross_entropy(all_p_muted_again, all_y) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return {'loss': loss.item()} - - def latent_predict(self, x): - x = self.avgpool(x) - x = self.flatten(x) - return self.classifier(x) - - def predict(self, x): - features = self.featurizer(x) - if self.featurizer_name == "ResNet": - features = features.view(features.shape[0], self.featurizer.n_outputs, 7, 7) - return self.latent_predict(features) diff --git a/domainbed/datasets.py b/domainbed/datasets.py index cc810a6f..ce38170d 100644 --- a/domainbed/datasets.py +++ b/domainbed/datasets.py @@ -1,19 +1,13 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import os -from collections import defaultdict - -import PIL import torch from PIL import Image, ImageFile -from torch.utils.data import TensorDataset from torchvision import transforms import torchvision.datasets.folder -from torchvision.datasets import CIFAR100, MNIST, ImageFolder +from torch.utils.data import TensorDataset +from torchvision.datasets import MNIST, ImageFolder from torchvision.transforms.functional import rotate -import tqdm -import io -import functools ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -22,8 +16,8 @@ "Debug28", "Debug224", # Small images - "RotatedMNIST", "ColoredMNIST", + "RotatedMNIST", # Big images "VLCS", "PACS", @@ -32,21 +26,6 @@ "DomainNet", ] -NUM_ENVIRONMENTS = { - # Debug - "Debug28": 3, - "Debug224": 3, - # Small images - "RotatedMNIST": 6, - "ColoredMNIST": 3, - # Big images - "VLCS": 4, - "PACS": 4, - "OfficeHome": 4, - "TerraIncognita": 4, - "DomainNet": 6, -} - def get_dataset_class(dataset_name): """Return the dataset class with the given name.""" @@ -54,40 +33,46 @@ def get_dataset_class(dataset_name): raise NotImplementedError("Dataset not found: {}".format(dataset_name)) return globals()[dataset_name] + +def num_environments(dataset_name): + return len(get_dataset_class(dataset_name).ENVIRONMENTS) + + class MultipleDomainDataset: - N_STEPS = 5001 - CHECKPOINT_FREQ = 100 - N_WORKERS = 8 + N_STEPS = 5001 # Default, subclasses may override + CHECKPOINT_FREQ = 100 # Default, subclasses may override + N_WORKERS = 8 # Default, subclasses may override + ENVIRONMENTS = None # Subclasses should override + INPUT_SHAPE = None # Subclasses should override + + def __getitem__(self, index): + return self.datasets[index] + + def __len__(self): + return len(self.datasets) + class Debug(MultipleDomainDataset): - DATASET_SIZE = 16 - INPUT_SHAPE = None # Subclasses should override def __init__(self, root, test_envs, hparams): super().__init__() self.input_shape = self.INPUT_SHAPE self.num_classes = 2 - self.environments = [0, 1, 2] self.datasets = [] - for _ in range(len(self.environments)): + for _ in [0, 1, 2]: self.datasets.append( TensorDataset( - torch.randn(self.DATASET_SIZE, *self.INPUT_SHAPE), - torch.randint(0, self.num_classes, (self.DATASET_SIZE,)) + torch.randn(16, *self.INPUT_SHAPE), + torch.randint(0, self.num_classes, (16,)) ) ) - def __getitem__(self, index): - return self.datasets[index] - - def __len__(self): - return len(self.datasets) class Debug28(Debug): INPUT_SHAPE = (3, 28, 28) - ENVIRONMENT_NAMES = ['0', '1', '2'] + ENVIRONMENTS = ['0', '1', '2'] class Debug224(Debug): INPUT_SHAPE = (3, 224, 224) - ENVIRONMENT_NAMES = ['0', '1', '2'] + ENVIRONMENTS = ['0', '1', '2'] class MultipleEnvironmentMNIST(MultipleDomainDataset): @@ -112,25 +97,18 @@ def __init__(self, root, environments, dataset_transform, input_shape, original_labels = original_labels[shuffle] self.datasets = [] - self.environments = environments - for i in range(len(self.environments)): - images = original_images[i::len(self.environments)] - labels = original_labels[i::len(self.environments)] + for i in range(len(environments)): + images = original_images[i::len(environments)] + labels = original_labels[i::len(environments)] self.datasets.append(dataset_transform(images, labels, environments[i])) self.input_shape = input_shape self.num_classes = num_classes - def __getitem__(self, index): - return self.datasets[index] - - def __len__(self): - return len(self.datasets) - class ColoredMNIST(MultipleEnvironmentMNIST): - ENVIRONMENT_NAMES = ['+90%', '+80%', '-90%'] + ENVIRONMENTS = ['+90%', '+80%', '-90%'] def __init__(self, root, test_envs, hparams): super(ColoredMNIST, self).__init__(root, [0.1, 0.2, 0.9], @@ -170,7 +148,7 @@ def torch_xor_(self, a, b): class RotatedMNIST(MultipleEnvironmentMNIST): - ENVIRONMENT_NAMES = ['0', '15', '30', '45', '60', '75'] + ENVIRONMENTS = ['0', '15', '30', '45', '60', '75'] def __init__(self, root, test_envs, hparams): super(RotatedMNIST, self).__init__(root, [0, 15, 30, 45, 60, 75], @@ -180,7 +158,7 @@ def rotate_dataset(self, images, labels, angle): rotation = transforms.Compose([ transforms.ToPILImage(), transforms.Lambda(lambda x: rotate(x, angle, fill=(0,), - resample=PIL.Image.BICUBIC)), + resample=Image.BICUBIC)), transforms.ToTensor()]) x = torch.zeros(len(images), 1, 28, 28) @@ -194,8 +172,8 @@ def rotate_dataset(self, images, labels, angle): class MultipleEnvironmentImageFolder(MultipleDomainDataset): def __init__(self, root, test_envs, augment, hparams): super().__init__() - self.environments = [f.name for f in os.scandir(root) if f.is_dir()] - self.environments = sorted(self.environments) + environments = [f.name for f in os.scandir(root) if f.is_dir()] + environments = sorted(environments) transform = transforms.Compose([ transforms.Resize((224,224)), @@ -216,7 +194,7 @@ def __init__(self, root, test_envs, augment, hparams): ]) self.datasets = [] - for i, environment in enumerate(self.environments): + for i, environment in enumerate(environments): if augment and (i not in test_envs): env_transform = augment_transform @@ -232,43 +210,37 @@ def __init__(self, root, test_envs, augment, hparams): self.input_shape = (3, 224, 224,) self.num_classes = len(self.datasets[-1].classes) - def __getitem__(self, index): - return self.datasets[index] - - def __len__(self): - return len(self.datasets) - class VLCS(MultipleEnvironmentImageFolder): CHECKPOINT_FREQ = 300 - ENVIRONMENT_NAMES = ["C", "L", "S", "V"] + ENVIRONMENTS = ["C", "L", "S", "V"] def __init__(self, root, test_envs, hparams): self.dir = os.path.join(root, "VLCS/") super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) class PACS(MultipleEnvironmentImageFolder): CHECKPOINT_FREQ = 300 - ENVIRONMENT_NAMES = ["A", "C", "P", "S"] + ENVIRONMENTS = ["A", "C", "P", "S"] def __init__(self, root, test_envs, hparams): self.dir = os.path.join(root, "PACS/") super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) class DomainNet(MultipleEnvironmentImageFolder): CHECKPOINT_FREQ = 1000 - ENVIRONMENT_NAMES = ["clip", "info", "paint", "quick", "real", "sketch"] + ENVIRONMENTS = ["clip", "info", "paint", "quick", "real", "sketch"] def __init__(self, root, test_envs, hparams): self.dir = os.path.join(root, "domain_net/") super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) class OfficeHome(MultipleEnvironmentImageFolder): CHECKPOINT_FREQ = 300 - ENVIRONMENT_NAMES = ["A", "C", "P", "R"] + ENVIRONMENTS = ["A", "C", "P", "R"] def __init__(self, root, test_envs, hparams): self.dir = os.path.join(root, "office_home/") super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) class TerraIncognita(MultipleEnvironmentImageFolder): CHECKPOINT_FREQ = 300 - ENVIRONMENT_NAMES = ["L100", "L38", "L43", "L46"] + ENVIRONMENTS = ["L100", "L38", "L43", "L46"] def __init__(self, root, test_envs, hparams): self.dir = os.path.join(root, "terra_incognita/") super().__init__(self.dir, test_envs, hparams['data_augmentation'], hparams) diff --git a/domainbed/hparams_registry.py b/domainbed/hparams_registry.py index 4cd01f04..59ace4eb 100644 --- a/domainbed/hparams_registry.py +++ b/domainbed/hparams_registry.py @@ -55,8 +55,8 @@ def _hparams(algorithm, dataset, random_state): hparams['mlp_depth'] = (3, int(random_state.choice([3, 4, 5]))) hparams['mlp_dropout'] = (0., random_state.choice([0., 0.1, 0.5])) elif algorithm == "RSC": - hparams['rsc_f_drop_factor'] = (1/3, random_state.uniform(0,0.5)) # Feature drop factor - hparams['rsc_b_drop_factor'] = (1/3, random_state.uniform(0, 0.5)) # Batch drop factor + hparams['rsc_f_drop_factor'] = (1/3, random_state.uniform(0, 0.5)) + hparams['rsc_b_drop_factor'] = (1/3, random_state.uniform(0, 0.5)) elif algorithm == "SagNet": hparams['sag_w_adv'] = (0.1, 10**random_state.uniform(-2, 1)) elif algorithm == "IRM": diff --git a/domainbed/scripts/save_images.py b/domainbed/scripts/save_images.py index 3a69cb80..2097173b 100644 --- a/domainbed/scripts/save_images.py +++ b/domainbed/scripts/save_images.py @@ -25,7 +25,7 @@ hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)( args.data_dir, - list(range(datasets.NUM_ENVIRONMENTS[dataset_name])), + list(range(datasets.num_environments(dataset_name))), hparams) for env_idx, env in enumerate(tqdm(dataset)): for i in tqdm(range(50)): @@ -46,5 +46,5 @@ x = x.numpy().astype('uint8').transpose(1,2,0) imageio.imwrite( os.path.join(args.output_dir, - f'{dataset_name}_env{env_idx}{dataset.environments[env_idx]}_{i}_idx{idx}_class{y}.png'), + f'{dataset_name}_env{env_idx}{dataset.ENVIRONMENTS[env_idx]}_{i}_idx{idx}_class{y}.png'), x) diff --git a/domainbed/scripts/sweep.py b/domainbed/scripts/sweep.py index cedc827f..ebf980d9 100644 --- a/domainbed/scripts/sweep.py +++ b/domainbed/scripts/sweep.py @@ -100,7 +100,7 @@ def make_args_list(n_trials, dataset_names, algorithms, n_hparams, steps, for dataset in dataset_names: for algorithm in algorithms: all_test_envs = all_test_env_combinations( - datasets.NUM_ENVIRONMENTS[dataset]) + datasets.num_environments(dataset)) for test_envs in all_test_envs: for hparams_seed in range(n_hparams): train_args = {} diff --git a/domainbed/scripts/train.py b/domainbed/scripts/train.py index 9ef9954e..6cc09738 100644 --- a/domainbed/scripts/train.py +++ b/domainbed/scripts/train.py @@ -10,7 +10,9 @@ import uuid import numpy as np +import PIL import torch +import torchvision import torch.utils.data from domainbed import datasets @@ -52,6 +54,15 @@ sys.stdout = misc.Tee(os.path.join(args.output_dir, 'out.txt')) sys.stderr = misc.Tee(os.path.join(args.output_dir, 'err.txt')) + print("Environment:") + print("\tPython: {}".format(sys.version.split(" ")[0])) + print("\tPyTorch: {}".format(torch.__version__)) + print("\tTorchvision: {}".format(torchvision.__version__)) + print("\tCUDA: {}".format(torch.version.cuda)) + print("\tCUDNN: {}".format(torch.backends.cudnn.version())) + print("\tNumPy: {}".format(np.__version__)) + print("\tPIL: {}".format(PIL.__version__)) + print('Args:') for k, v in sorted(vars(args).items()): print('\t{}: {}'.format(k, v)) diff --git a/domainbed/test/test_datasets.py b/domainbed/test/test_datasets.py index 8198d340..8ed8a164 100644 --- a/domainbed/test/test_datasets.py +++ b/domainbed/test/test_datasets.py @@ -32,13 +32,13 @@ def test_dataset_erm(self, dataset_name): """ Test that ERM can complete one step on a given dataset without raising an error. - Also test that NUM_ENVIRONMENTS[dataset] is set correctly. + Also test that num_environments() works correctly. """ batch_size = 8 hparams = hparams_registry.default_hparams('ERM', dataset_name) dataset = datasets.get_dataset_class(dataset_name)( os.environ['DATA_DIR'], [], hparams) - self.assertEqual(datasets.NUM_ENVIRONMENTS[dataset_name], + self.assertEqual(datasets.num_environments(dataset_name), len(dataset)) algorithm = algorithms.get_algorithm_class('ERM')( dataset.input_shape,