Skip to content

[BUG]: Activation Checkpointing Failed Test with PyTorch 1.9 #718

@FrankLeeeee

Description

@FrankLeeeee

🐛 Describe the bug

When running unit testing with PyTorch 1.9, the following exception occurs for tests/test_utils/test_activation_checkpointing.py.

Screenshot

Lark20220411-162602

Reasons

This bug occurs because of the invalid use of ctx.save_for_backward. In torch autograd function, we need to provide either inputs or intermediate values to ctx.save_for_backward. However, when activation_offload=True, the current implementation will copy the tensors to cpu, creating new tensors which are neither inputs or intermediate activations (i.e. does not have gradient accumulator objects associated with them).

for i, arg in enumerate(args):
      if torch.is_tensor(arg):
          if activation_offload:
              tensor_inputs.append(copy_to_device(arg, 'cpu'))
          else:
              tensor_inputs.append(arg)
          ctx.tensor_indices.append(i)
          ctx.inputs.append(None)
      else:
          ctx.inputs.append(arg)

  ctx.save_for_backward(*tensor_inputs)

This will trigger the assertion statement as shown below.

Screenshot 2022-04-11 at 4 27 50 PM

Solution

Use cxt.tensor_inputs = tensor_inputs instead of ctx.save_for_backward(*tensor_inputs) if activation_offload=True.

Environment

CUDA: 11.1
PyTorch: 1.9.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions