Skip to content

Commit

Permalink
Bugfixes, RSC, cleanup datasets.py
Browse files Browse the repository at this point in the history
  • Loading branch information
David Lopez-Paz committed Oct 1, 2020
1 parent 6ff150b commit 7df6f06
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 214 deletions.
187 changes: 47 additions & 140 deletions domainbed/algorithms.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import copy

import numpy as np
import torch
import torch.autograd as autograd
import torch.nn as nn
import torch.nn.functional as F
import random
import torch.autograd as autograd

import copy
import numpy as np

from domainbed import networks
from domainbed.lib.misc import random_pairs_of_minibatches
Expand Down Expand Up @@ -762,149 +761,57 @@ def update(self, minibatches):
def predict(self, x):
return self.network_c(self.network_f(x))


class RSC(ERM):
"""
Representation Self-Challenging (RSC)
from: https://arxiv.org/pdf/2007.02454.pdf
"""
def __init__(self, input_shape, num_classes, num_domains, hparams):
super(RSC, self).__init__(input_shape, num_classes, num_domains,
hparams)

self.f_drop_factor = hparams['rsc_f_drop_factor']
self.b_drop_factor = hparams['rsc_b_drop_factor']
hparams)
self.drop_f = (1 - hparams['rsc_f_drop_factor']) * 100
self.drop_b = (1 - hparams['rsc_b_drop_factor']) * 100
self.num_classes = num_classes
self.flatten = nn.Flatten()
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.featurizer_name = self.featurizer.__class__.__name__

"""
For ResNet (original paper):
Delete the Average Pooling since this method operates on features after conv layer 4
Dropout is disabled in hyperparameters both for default and random
Flatten from original ResNet is temporarily reverted for computations during update since it is not accessible as a layer
For MNIST_CNN (custom):
Delete the Average Pooling and use features after bn3
SqueezeLastTwo is accessible as a Layer so we can delete it and replace it through Identity, serves as our flatten operation later
Disclaimer: Apparently this algorithm doesn't work very well for the MNIST datasets with the current architecture, probably would require more fine-tuning
"""

if self.featurizer_name == "ResNet":
del self.featurizer.network.avgpool
self.featurizer.network.avgpool = networks.Identity()
elif self.featurizer_name == "MNIST_CNN":
del self.featurizer.avgpool
self.featurizer.avgpool = networks.Identity()
del self.featurizer.squeezeLastTwo
self.featurizer.squeezeLastTwo = networks.Identity()
self.flatten = networks.SqueezeLastTwo()

self.network = nn.Sequential(self.featurizer, self.avgpool, self.flatten, self.classifier)
self.optimizer = torch.optim.Adam(
self.network.parameters(),
lr=self.hparams["lr"],
weight_decay=self.hparams['weight_decay']
)

def create_onehots(self, num_samples, targets):
one_hots = torch.zeros(num_samples, self.num_classes) # By default this sets requires_grad = False
one_hots[range(one_hots.shape[0]), targets] = 1
return one_hots.cuda()

def construct_mask(self, mean, num_samples, size):
num_to_drop = int(size * self.f_drop_factor)
mask_values = torch.sort(mean, dim=1, descending=True)[0][:, num_to_drop]
mask_values = mask_values.view(num_samples, 1).expand(num_samples, size)
mask = torch.where(mean >= mask_values, torch.zeros(mean.shape).cuda(), torch.ones(mean.shape).cuda())
return mask

def update(self, minibatches):
self.network.eval()
features = torch.cat([self.featurizer(xi) for xi, _ in minibatches])

# In ResNet we don't have access to the Flatten as a Layer as it is used inplace in the forward pass, hence revert it here
if self.featurizer_name == "ResNet":
features = features.view(features.shape[0], self.featurizer.n_outputs, 7, 7)

targets = torch.cat([yi for _,yi in minibatches])

# detach features to disable gradient flow into the featurizer
features_detached = features.clone().detach()
features_detached.requires_grad = True

# Predict
output = self.latent_predict(features_detached)
num_samples = features_detached.shape[0]
num_channels = features_detached.shape[1]
H = features_detached.shape[2]
HW = features_detached.shape[2] * features_detached.shape[3]

# Create onehot targets for batch
one_hots = self.create_onehots(num_samples, targets)

# calculate dot product elem_wise between output and one-hot target
elem_wise = torch.sum(output * one_hots)

# calculate gradient of elem_wise wrt feature_detached
self.optimizer.zero_grad()
elem_wise.backward()

# Channel & Spatial means
grad_val = features_detached.grad.clone().detach()
channel_mean = torch.mean(grad_val.view(num_samples, num_channels, -1), dim=2)
spatial_mean = torch.mean(grad_val, dim=1).view(num_samples, HW)
self.optimizer.zero_grad()

# Choose either spatial or channel. Spatial doesn't make sense for architectures which do avgpooling afterwards, but it would make sense for e.g. AlexNet
if self.featurizer_name in ["ResNet", "MNIST_CNN"]:
choose_spatial = False
else:
choose_spatial = random.choice([True, False])

if choose_spatial:
mask = self.construct_mask(spatial_mean, num_samples, HW)
mask = mask.view(num_samples, 1, H, H)

else:
channel_vector = self.construct_mask(channel_mean, num_samples, num_channels)
mask = channel_vector.view(num_samples, num_channels, 1, 1)

# Drop the mask for certain samples inside each batch
class_prob_before = F.softmax(output, dim=1)
features_detached_masked = features_detached * mask
output_masked = self.latent_predict(features_detached_masked)
class_prob_after = F.softmax(output_masked, dim=1)

before_vector = torch.sum(one_hots * class_prob_before, dim=1)
after_vector = torch.sum(one_hots * class_prob_after, dim=1)
change = before_vector - after_vector
change = torch.where(change > 0, change, torch.zeros(change.shape).cuda())
threshold_value = torch.sort(change, dim=0, descending=True)[0][int(num_samples * self.b_drop_factor)]
drop_indeces = change.gt(threshold_value)
mask[~drop_indeces.long(), :] = 1

# Apply the mask on the non-detached features
self.train()
mask.requires_grad = True
features = features * mask

# calculate loss between prediction and target
output_final = self.latent_predict(features)
loss = F.cross_entropy(output_final, targets)
# inputs
all_x = torch.cat([x for x, y in minibatches])
# labels
all_y = torch.cat([y for _, y in minibatches])
# one-hot labels
all_o = torch.nn.functional.one_hot(all_y, self.num_classes)
# features
all_f = self.featurizer(all_x)
# predictions
all_p = self.classifier(all_f)

# Equation (1): compute gradients with respect to representation
all_g = autograd.grad((all_p * all_o).sum(), all_f)[0]

# Equation (2): compute top-gradient-percentile mask
percentiles = np.percentile(all_g.cpu(), self.drop_f, axis=1)
percentiles = torch.Tensor(percentiles)
percentiles = percentiles.unsqueeze(1).repeat(1, all_g.size(1))
mask_f = all_g.lt(percentiles.cuda()).float()

# Equation (3): mute top-gradient-percentile activations
all_f_muted = all_f * mask_f

# Equation (4): compute muted predictions
all_p_muted = self.classifier(all_f_muted)

# Section 3.3: Batch Percentage
all_s = F.softmax(all_p, dim=1)
all_s_muted = F.softmax(all_p_muted, dim=1)
changes = (all_s * all_o).sum(1) - (all_s_muted * all_o).sum(1)
percentile = np.percentile(changes.detach().cpu(), self.drop_b)
mask_b = changes.lt(percentile).float().view(-1, 1)
mask = torch.logical_or(mask_f, mask_b).float()

# Equations (3) and (4) again, this time mutting over examples
all_p_muted_again = self.classifier(all_f * mask)

# Equation (5): update
loss = F.cross_entropy(all_p_muted_again, all_y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

return {'loss': loss.item()}

def latent_predict(self, x):
x = self.avgpool(x)
x = self.flatten(x)
return self.classifier(x)

def predict(self, x):
features = self.featurizer(x)
if self.featurizer_name == "ResNet":
features = features.view(features.shape[0], self.featurizer.n_outputs, 7, 7)
return self.latent_predict(features)
Loading

0 comments on commit 7df6f06

Please sign in to comment.