Skip to content

Commit

Permalink
[WIP] Allow autocast for 1.6 (pytorch#2384)
Browse files Browse the repository at this point in the history
* Fixes Xiao's repro

* Ports nms to use full dispatcher

* Move HIPGuard to nms_cuda

* clang-format

* run models in test_models.py on GPU if available

* Francisco's comment, also disable cuda model tests to see if CPU alone still passes

* cuda tests now pass locally, although still not comparing to saved numerics

* add note for thing to ask francisco

* Allow cuda and cpu tests to share a data file

* ignore suffix if unneeded

* Skip autocast numerics checks for a few models

* Add roi_align test

Co-authored-by: Michael Carilli <mcarilli@nvidia.com>
  • Loading branch information
2 people authored and fmassa committed Jul 9, 2020
1 parent a0e29a4 commit fc7d027
Show file tree
Hide file tree
Showing 12 changed files with 314 additions and 125 deletions.
16 changes: 12 additions & 4 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def is_iterable(obj):
class TestCase(unittest.TestCase):
precision = 1e-5

def assertExpected(self, output, subname=None, prec=None):
def assertExpected(self, output, subname=None, prec=None, strip_suffix=None):
r"""
Test that a python value matches the recorded contents of a file
derived from the name of this test and subname. The value must be
Expand All @@ -96,16 +96,24 @@ def assertExpected(self, output, subname=None, prec=None):
If you call this multiple times in a single function, you must
give a unique subname each time.
strip_suffix allows different tests that expect similar numerics, e.g.
"test_xyz_cuda" and "test_xyz_cpu", to use the same pickled data.
test_xyz_cuda would pass strip_suffix="_cuda", test_xyz_cpu would pass
strip_suffix="_cpu", and they would both use a data file name based on
"test_xyz".
"""
def remove_prefix(text, prefix):
def remove_prefix_suffix(text, prefix, suffix):
if text.startswith(prefix):
return text[len(prefix):]
text = text[len(prefix):]
if suffix is not None and text.endswith(suffix):
text = text[:len(text) - len(suffix)]
return text
# NB: we take __file__ from the module that defined the test
# class, so we place the expect directory where the test script
# lives, NOT where test/common_utils.py lives.
module_id = self.__class__.__module__
munged_id = remove_prefix(self.id(), module_id + ".")
munged_id = remove_prefix_suffix(self.id(), module_id + ".", strip_suffix)
test_file = os.path.realpath(sys.modules[module_id].__file__)
expected_file = os.path.join(os.path.dirname(test_file),
"expect",
Expand Down
206 changes: 138 additions & 68 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,72 +74,114 @@ def get_available_video_models():
}


# The following models exhibit flaky numerics under autocast in _test_*_model harnesses.
# This may be caused by the harness environment (e.g. num classes, input initialization
# via torch.rand), and does not prove autocast is unsuitable when training with real data
# (autocast has been used successfully with real data for some of these models).
# TODO: investigate why autocast numerics are flaky in the harnesses.
#
# For the following models, _test_*_model harnesses skip numerical checks on outputs when
# trying autocast. However, they still try an autocasted forward pass, so they still ensure
# autocast coverage suffices to prevent dtype errors in each model.
autocast_flaky_numerics = (
"fasterrcnn_resnet50_fpn",
"inception_v3",
"keypointrcnn_resnet50_fpn",
"maskrcnn_resnet50_fpn",
"resnet101",
"resnet152",
"wide_resnet101_2",
)


class ModelTester(TestCase):
def checkModule(self, model, name, args):
if name not in script_test_models:
return
unwrapper = script_test_models[name].get('unwrapper', None)
return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=False)

def _test_classification_model(self, name, input_shape):
def _test_classification_model(self, name, input_shape, dev):
set_rng_seed(0)
# passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature
model = models.__dict__[name](num_classes=50)
model.eval()
x = torch.rand(input_shape)
model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertExpected(out, prec=0.1)
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertEqual(out.shape[-1], 50)
self.checkModule(model, name, (x,))

def _test_segmentation_model(self, name):
if dev == "cuda":
with torch.cuda.amp.autocast():
out = model(x)
# See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics:
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
self.assertEqual(out.shape[-1], 50)

def _test_segmentation_model(self, name, dev):
# passing num_class equal to a number other than 1000 helps in making the test
# more enforcing in nature
model = models.segmentation.__dict__[name](num_classes=50, pretrained_backbone=False)
model.eval()
model.eval().to(device=dev)
input_shape = (1, 3, 300, 300)
x = torch.rand(input_shape)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
self.checkModule(model, name, (x,))

def _test_detection_model(self, name):
if dev == "cuda":
with torch.cuda.amp.autocast():
out = model(x)
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))

def _test_detection_model(self, name, dev):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
model.eval()
model.eval().to(device=dev)
input_shape = (3, 300, 300)
x = torch.rand(input_shape)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)

def subsample_tensor(tensor):
num_elems = tensor.numel()
num_samples = 20
if num_elems <= num_samples:
return tensor

flat_tensor = tensor.flatten()
ith_index = num_elems // num_samples
return flat_tensor[ith_index - 1::ith_index]

def compute_mean_std(tensor):
# can't compute mean of integral tensor
tensor = tensor.to(torch.double)
mean = torch.mean(tensor)
std = torch.std(tensor)
return {"mean": mean, "std": std}

# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
# compare results with mean and std
if name == "maskrcnn_resnet50_fpn":
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
# mean values are small, use large prec
self.assertExpected(test_value, prec=.01)
else:
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor), prec=0.01)
def check_out(out):
self.assertEqual(len(out), 1)

def subsample_tensor(tensor):
num_elems = tensor.numel()
num_samples = 20
if num_elems <= num_samples:
return tensor

flat_tensor = tensor.flatten()
ith_index = num_elems // num_samples
return flat_tensor[ith_index - 1::ith_index]

def compute_mean_std(tensor):
# can't compute mean of integral tensor
tensor = tensor.to(torch.double)
mean = torch.mean(tensor)
std = torch.std(tensor)
return {"mean": mean, "std": std}

# maskrcnn_resnet_50_fpn numerically unstable across platforms, so for now
# compare results with mean and std
if name == "maskrcnn_resnet50_fpn":
test_value = map_nested_tensor_object(out, tensor_map_fn=compute_mean_std)
# mean values are small, use large prec
self.assertExpected(test_value, prec=.01, strip_suffix="_" + dev)
else:
self.assertExpected(map_nested_tensor_object(out, tensor_map_fn=subsample_tensor),
prec=0.01,
strip_suffix="_" + dev)

check_out(out)

scripted_model = torch.jit.script(model)
scripted_model.eval()
Expand All @@ -156,6 +198,13 @@ def compute_mean_std(tensor):
# self.check_script(model, name)
self.checkModule(model, name, ([x],))

if dev == "cuda":
with torch.cuda.amp.autocast():
out = model(model_input)
# See autocast_flaky_numerics comment at top of file.
if name not in autocast_flaky_numerics:
check_out(out)

def _test_detection_model_validation(self, name):
set_rng_seed(0)
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
Expand All @@ -179,18 +228,24 @@ def _test_detection_model_validation(self, name):
targets = [{'boxes': boxes}]
self.assertRaises(ValueError, model, x, targets=targets)

def _test_video_model(self, name):
def _test_video_model(self, name, dev):
# the default input shape is
# bs * num_channels * clip_len * h *w
input_shape = (1, 3, 4, 112, 112)
# test both basicblock and Bottleneck
model = models.video.__dict__[name](num_classes=50)
model.eval()
x = torch.rand(input_shape)
model.eval().to(device=dev)
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
x = torch.rand(input_shape).to(device=dev)
out = model(x)
self.checkModule(model, name, (x,))
self.assertEqual(out.shape[-1], 50)

if dev == "cuda":
with torch.cuda.amp.autocast():
out = model(x)
self.assertEqual(out.shape[-1], 50)

def _make_sliced_model(self, model, stop_layer):
layers = OrderedDict()
for name, layer in model.named_children():
Expand Down Expand Up @@ -272,6 +327,12 @@ def test_googlenet_eval(self):

@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
def test_fasterrcnn_switch_devices(self):
def checkOut(out):
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])

model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
model.cuda()
model.eval()
Expand All @@ -280,17 +341,20 @@ def test_fasterrcnn_switch_devices(self):
model_input = [x]
out = model(model_input)
self.assertIs(model_input[0], x)
self.assertEqual(len(out), 1)
self.assertTrue("boxes" in out[0])
self.assertTrue("scores" in out[0])
self.assertTrue("labels" in out[0])

checkOut(out)

with torch.cuda.amp.autocast():
out = model(model_input)

checkOut(out)

# now switch to cpu and make sure it works
model.cpu()
x = x.cpu()
out_cpu = model([x])
self.assertTrue("boxes" in out_cpu[0])
self.assertTrue("scores" in out_cpu[0])
self.assertTrue("labels" in out_cpu[0])

checkOut(out_cpu)

def test_generalizedrcnn_transform_repr(self):

Expand All @@ -312,34 +376,40 @@ def test_generalizedrcnn_transform_repr(self):
self.assertEqual(t.__repr__(), expected_string)


_devs = ["cpu", "cuda"] if torch.cuda.is_available() else ["cpu"]


for model_name in get_available_classification_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape)
for dev in _devs:
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name, dev=dev):
input_shape = (1, 3, 224, 224)
if model_name in ['inception_v3']:
input_shape = (1, 3, 299, 299)
self._test_classification_model(model_name, input_shape, dev)

setattr(ModelTester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)


for model_name in get_available_segmentation_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
self._test_segmentation_model(model_name)
for dev in _devs:
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name, dev=dev):
self._test_segmentation_model(model_name, dev)

setattr(ModelTester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)


for model_name in get_available_detection_models():
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name):
self._test_detection_model(model_name)
for dev in _devs:
# for-loop bodies don't define scopes, so we have to save the variables
# we want to close over in some way
def do_test(self, model_name=model_name, dev=dev):
self._test_detection_model(model_name, dev)

setattr(ModelTester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)

def do_validation_test(self, model_name=model_name):
self._test_detection_model_validation(model_name)
Expand All @@ -348,11 +418,11 @@ def do_validation_test(self, model_name=model_name):


for model_name in get_available_video_models():
for dev in _devs:
def do_test(self, model_name=model_name, dev=dev):
self._test_video_model(model_name, dev)

def do_test(self, model_name=model_name):
self._test_video_model(model_name)

setattr(ModelTester, "test_" + model_name, do_test)
setattr(ModelTester, "test_" + model_name + "_" + dev, do_test)

if __name__ == '__main__':
unittest.main()
20 changes: 16 additions & 4 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,25 +52,30 @@ def _test_backward(self, device, contiguous):


class RoIOpTester(OpTester):
def _test_forward(self, device, contiguous):
def _test_forward(self, device, contiguous, x_dtype=None, rois_dtype=None):
x_dtype = self.dtype if x_dtype is None else x_dtype
rois_dtype = self.dtype if rois_dtype is None else rois_dtype
pool_size = 5
# n_channels % (pool_size ** 2) == 0 required for PS opeartions.
n_channels = 2 * (pool_size ** 2)
x = torch.rand(2, n_channels, 10, 10, dtype=self.dtype, device=device)
x = torch.rand(2, n_channels, 10, 10, dtype=x_dtype, device=device)
if not contiguous:
x = x.permute(0, 1, 3, 2)
rois = torch.tensor([[0, 0, 0, 9, 9], # format is (xyxy)
[0, 0, 5, 4, 9],
[0, 5, 5, 9, 9],
[1, 0, 0, 9, 9]],
dtype=self.dtype, device=device)
dtype=rois_dtype, device=device)

pool_h, pool_w = pool_size, pool_size
y = self.fn(x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1)
# the following should be true whether we're running an autocast test or not.
self.assertTrue(y.dtype == x.dtype)
gt_y = self.expected_fn(x, rois, pool_h, pool_w, spatial_scale=1,
sampling_ratio=-1, device=device, dtype=self.dtype)

self.assertTrue(torch.allclose(gt_y, y))
tol = 1e-3 if (x_dtype is torch.half or rois_dtype is torch.half) else 1e-5
self.assertTrue(torch.allclose(gt_y.to(y.dtype), y, rtol=tol, atol=tol))

def _test_backward(self, device, contiguous):
pool_size = 2
Expand Down Expand Up @@ -290,6 +295,13 @@ def expected_fn(self, in_data, rois, pool_h, pool_w, spatial_scale=1, sampling_r
def _test_boxes_shape(self):
self._helper_boxes_shape(ops.roi_align)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_roi_align_autocast(self):
for x_dtype in (torch.float, torch.half):
for rois_dtype in (torch.float, torch.half):
with torch.cuda.amp.autocast():
self._test_forward(torch.device("cuda"), contiguous=False, x_dtype=x_dtype, rois_dtype=rois_dtype)


class PSRoIAlignTester(RoIOpTester, unittest.TestCase):
def fn(self, x, rois, pool_h, pool_w, spatial_scale=1, sampling_ratio=-1, **kwargs):
Expand Down

0 comments on commit fc7d027

Please sign in to comment.