Skip to content

ConvGrad CUDA Kernel Bugfix#7273

Merged
Lafi7e merged 2 commits intomasterfrom
weicwang/conv_grad_bugfix
Apr 8, 2021
Merged

ConvGrad CUDA Kernel Bugfix#7273
Lafi7e merged 2 commits intomasterfrom
weicwang/conv_grad_bugfix

Conversation

@Lafi7e
Copy link
Copy Markdown
Contributor

@Lafi7e Lafi7e commented Apr 7, 2021

ConvGrad CUDA kernel bugfix.

The original code will get segmentation fault when set input's requires_grad flag to False. It's possible that dX and dW is nullptr, so we can pass X and W to PrepareArgs as inside that only the shape info is needed, and they have same shapes as dX and dW.

@Lafi7e Lafi7e added the training issues related to ONNX Runtime training; typically submitted using template label Apr 7, 2021
@Lafi7e Lafi7e requested a review from SherlockNoMad April 7, 2021 09:42
@Lafi7e Lafi7e requested a review from a team as a code owner April 7, 2021 09:42
Copy link
Copy Markdown
Contributor

@mrry mrry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!


@pytest.mark.parametrize("use_fp16", [False, True])
def test_gradient_correctness_conv1d(use_fp16):
@pytest.mark.parametrize("use_fp16, input_requires_grad", [
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pytest.mark.parametrize("use_fp16", [...])
@pytest.mark.parametrize("input_requires_grad", [...])

I saw somewhere this would also work.

Comment thread orttraining/orttraining/training_ops/cuda/nn/conv_grad.cc
@thiagocrepaldi
Copy link
Copy Markdown
Contributor

I am getting a failure when running this UT locally. Would it be related?

__________________________________________________________________________________________ test_gradient_correctness_conv1d[True] __________________________________________________________________________________________

use_fp16 = True

    @pytest.mark.parametrize("use_fp16", [False, True])
    def test_gradient_correctness_conv1d(use_fp16):
        class NeuralNetConv1D(torch.nn.Module):
            def __init__(self, in_channels, out_channels, kernel_size, padding=0, groups=1):
                super(NeuralNetConv1D, self).__init__()
                self.conv1 = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=groups)
                self.conv2 = torch.nn.Conv1d(in_channels, out_channels, kernel_size, padding=padding, groups=groups)
    
            def forward(self, input):
                out = self.conv1(input.permute(0, 2, 1).contiguous())
                out = self.conv2(out).permute(0, 2, 1).contiguous()
                return out
    
        device = 'cuda'
        N, seq_len, C_in, C_out, kernel_size = 32, 128, 1536, 1536, 3
        pt_model = NeuralNetConv1D(C_in, C_out, kernel_size, padding=1).to(device)
        ort_model = ORTModule(copy.deepcopy(pt_model))
    
        def run_step(model, x):
            with amp.autocast(use_fp16):
                prediction = model(x)
                loss = prediction.sum()
            loss.backward()
            return prediction
    
        for step in range(10):
            x = torch.randn(N, seq_len, C_in, device=device, requires_grad=True)
            pt_prediction = run_step(pt_model, x)
            ort_prediction = run_step(ort_model, x)
    
            if use_fp16:
                assert torch.allclose(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3)
>               _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=1e-2, atol=1e-2)

../../../orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py:545: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

ort_model = ORTModule(
  (_original_module): NeuralNetConv1D(
    (conv1): Conv1d(1536, 1536, kernel_size=(3,), stride=(1,), paddi..., stride=(1,), padding=(1,))
      (conv2): Conv1d(1536, 1536, kernel_size=(3,), stride=(1,), padding=(1,))
    )
  )
)
pt_model = NeuralNetConv1D(
  (conv1): Conv1d(1536, 1536, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(1536, 1536, kernel_size=(3,), stride=(1,), padding=(1,))
), none_pt_params = []
reset_gradient = True, rtol = 0.01, atol = 0.01

    def assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-05, atol=1e-06):
        ort_named_params = list(ort_model.named_parameters())
        pt_named_params = list(pt_model.named_parameters())
        assert len(ort_named_params) == len(pt_named_params)
    
        for ort_named_param, pt_named_param in zip(ort_named_params, pt_named_params):
            ort_name, ort_param = ort_named_param
            pt_name, pt_param = pt_named_param
    
            assert pt_name in ort_name
            if pt_name in none_pt_params:
                print(f'{pt_name} in {none_pt_params}')
                assert pt_param.grad is None
                assert not torch.is_nonzero(torch.count_nonzero(ort_param.grad))
            else:
                print(f'{pt_name} ***NOT*** in {none_pt_params}')
>               assert torch.allclose(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol)
E               AssertionError

../../../orttraining/orttraining/test/python/_test_helpers.py:167: AssertionError
```

@Lafi7e
Copy link
Copy Markdown
Contributor Author

Lafi7e commented Apr 7, 2021

@thiagocrepaldi It's not related. The issue will cause segmentation fault, instead of assertion.

@Lafi7e Lafi7e merged commit beb299e into master Apr 8, 2021
@Lafi7e Lafi7e deleted the weicwang/conv_grad_bugfix branch April 8, 2021 00:22
@SherlockNoMad
Copy link
Copy Markdown
Contributor

Hi @thiagocrepaldi, your local GPU is probably not V100
I haven't test ConvGrad on any non-V100 device yet... so I had the other PR to restrict using ConvGrad only on CUDA_ARCH >=700

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

training issues related to ONNX Runtime training; typically submitted using template

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants