Skip to content

Commit

Permalink
allow tf32 cudnn
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreSlavescu committed Apr 25, 2023
1 parent ad35af0 commit cb2fc6b
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 18 deletions.
6 changes: 3 additions & 3 deletions tests/frontends/torch/test_torch_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

@pytest.mark.parametrize('in_shape,w_shape,stride,padding,groups', [[[1, 3, 224], [42, 3, 7], 2, 1, 1]])
@pytest.mark.parametrize('dtype', [torch.float32])
@pytest.mark.parametrize('cudnn_allow_tf32', [True, False])
def test_conv1d(in_shape, w_shape, stride, padding, groups, dtype, cudnn_allow_tf32):
def test_conv1d(in_shape, w_shape, stride, padding, groups, dtype):
cudnn.allow_tf32 = False
check_module(
model=torch.nn.Conv1d(
in_channels=in_shape[1],
Expand All @@ -30,7 +30,7 @@ def test_conv1d(in_shape, w_shape, stride, padding, groups, dtype, cudnn_allow_t
),
args=[torch.randn(in_shape, dtype=dtype)],
)
cudnn.allow_tf32 = cudnn_allow_tf32
cudnn.allow_tf32 = True


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions tests/frontends/torch/test_torch_conv1d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
@pytest.mark.parametrize('output_padding', [3])
@pytest.mark.parametrize('groups', [1])
@pytest.mark.parametrize('dtype', [torch.float32])
@pytest.mark.parametrize('cudnn_allow_tf32', [True, False])
def test_conv1d_transpose(in_shape, w_shape, stride, padding, output_padding, groups, dtype, cudnn_allow_tf32):
def test_conv1d_transpose(in_shape, w_shape, stride, padding, output_padding, groups, dtype):
cudnn.allow_tf32 = False
check_module(
model=torch.nn.ConvTranspose1d(
in_channels=in_shape[1],
Expand All @@ -37,7 +37,7 @@ def test_conv1d_transpose(in_shape, w_shape, stride, padding, output_padding, gr
args=[torch.randn([1, 3, 224], dtype=dtype)],
atol=2e-4,
)
cudnn.allow_tf32 = cudnn_allow_tf32
cudnn.allow_tf32 = True


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions tests/frontends/torch/test_torch_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
)
@pytest.mark.parametrize('groups', [1, 3])
@pytest.mark.parametrize('dtype', [torch.float32])
@pytest.mark.parametrize('cudnn_allow_tf32', [True, False])
def test_conv2d(in_shape, w_shape, stride, padding, groups, dtype, cudnn_allow_tf32):
def test_conv2d(in_shape, w_shape, stride, padding, groups, dtype):
cudnn.allow_tf32 = False
check_module(
model=torch.nn.Conv2d(
in_channels=in_shape[1],
Expand All @@ -34,7 +34,7 @@ def test_conv2d(in_shape, w_shape, stride, padding, groups, dtype, cudnn_allow_t
),
args=[torch.randn(in_shape, dtype=dtype)],
)
cudnn.allow_tf32 = cudnn_allow_tf32
cudnn.allow_tf32 = True


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions tests/frontends/torch/test_torch_conv2d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
@pytest.mark.parametrize('output_padding', [3])
@pytest.mark.parametrize('groups', [1])
@pytest.mark.parametrize('dtype', [torch.float32])
@pytest.mark.parametrize('cudnn_allow_tf32', [True, False])
def test_conv2d_transpose(in_shape, w_shape, stride, padding, output_padding, groups, dtype, cudnn_allow_tf32):
def test_conv2d_transpose(in_shape, w_shape, stride, padding, output_padding, groups, dtype):
cudnn.allow_tf32 = False
check_module(
model=torch.nn.ConvTranspose2d(
in_channels=in_shape[1],
Expand All @@ -37,7 +37,7 @@ def test_conv2d_transpose(in_shape, w_shape, stride, padding, output_padding, gr
args=[torch.randn(in_shape, dtype=dtype)],
atol=2e-4,
)
cudnn.allow_tf32 = cudnn_allow_tf32
cudnn.allow_tf32 = True


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions tests/frontends/torch/test_torch_conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
)
@pytest.mark.parametrize('groups', [1, 3])
@pytest.mark.parametrize('dtype', [torch.float32])
@pytest.mark.parametrize('cudnn_allow_tf32', [True, False])
def test_conv3d(in_shape, w_shape, stride, padding, groups, dtype, cudnn_allow_tf32):
def test_conv3d(in_shape, w_shape, stride, padding, groups, dtype):
cudnn.allow_tf32 = False
check_module(
model=torch.nn.Conv3d(
in_channels=in_shape[1],
Expand All @@ -34,7 +34,7 @@ def test_conv3d(in_shape, w_shape, stride, padding, groups, dtype, cudnn_allow_t
),
args=[torch.randn(in_shape, dtype=dtype)],
)
cudnn.allow_tf32 = cudnn_allow_tf32
cudnn.allow_tf32 = True


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions tests/frontends/torch/test_torch_conv3d_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
@pytest.mark.parametrize('output_padding', [3])
@pytest.mark.parametrize('groups', [1])
@pytest.mark.parametrize('dtype', [torch.float32])
@pytest.mark.parametrize('cudnn_allow_tf32', [True, False])
def test_conv3d_transpose(in_shape, w_shape, stride, padding, output_padding, groups, dtype, cudnn_allow_tf32):
def test_conv3d_transpose(in_shape, w_shape, stride, padding, output_padding, groups, dtype):
cudnn.allow_tf32 = False
check_module(
model=torch.nn.ConvTranspose3d(
in_channels=in_shape[1],
Expand All @@ -37,7 +37,7 @@ def test_conv3d_transpose(in_shape, w_shape, stride, padding, output_padding, gr
args=[torch.randn(in_shape, dtype=dtype)],
atol=2e-4,
)
cudnn.allow_tf32 = cudnn_allow_tf32
cudnn.allow_tf32 = True


if __name__ == '__main__':
Expand Down

0 comments on commit cb2fc6b

Please sign in to comment.