Skip to content

Commit

Permalink
Add lookaside for torch.autograd.function.Function.apply (Lightning…
Browse files Browse the repository at this point in the history
  • Loading branch information
crcrpar committed Jul 4, 2024
1 parent 1d7a01d commit ab514fc
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
14 changes: 14 additions & 0 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,6 +891,20 @@ def _general_jit_named_buffers_lookaside(obj: Any, *args, **kwargs):
)


@general_jit_lookaside(torch.autograd.function.Function.apply.__func__)
def _general_jit_torch_autograd_function_apply_lookaside(obj: Any, *args, **kwargs):

custom_autograd_function_cls = unwrap(obj)
custom_forward = custom_autograd_function_cls.forward
args_, kwargs_ = tree_map(unwrap, (args, kwargs))
ctx = torch.autograd.function.FunctionCtx()

pr = ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[wrap_const(custom_forward).provenance])
wrapped_ctx = wrap(ctx, provenance=pr)
args_, kwargs_ = tree_map(lambda a: wrap(a, provenance=pr), (args_, kwargs_))
return _interpret_call(custom_forward, wrapped_ctx, *args_, **kwargs_)


# Adds proxy methods
# NOTE These methods map to themselves, which prevents the interpreter from looking into them
# This is OK because these methods are written in a tracing-safe manner, and trying to
Expand Down
26 changes: 26 additions & 0 deletions thunder/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2854,6 +2854,32 @@ def forward(self, x) -> torch.Tensor:
with pytest.raises(GradcheckError):
gradcheck(model, (x,))

class MyLinear(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
ctx.save_for_backward(x)
ctx.pretty_attr = 100
return torch.matmul(x, weight.t())

@staticmethod
def backward(ctx, grad_output):
(x,) = ctx.saved_tensors
return torch.matmul(grad_output, weight), torch.matmul(grad_output.t(), x)

class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(2, 2, bias=False)

def forward(self, x):
return MyLinear.apply(x, self.l1.weight)

x = torch.randn((2, 2), dtype=torch.float64, requires_grad=True)
model = Model().to(dtype=torch.float64)
jitted = thunder.jit(model, skip_inplace_functionalization=True)

gradcheck(jitted, (x,))


def test_proxy_repr():
# Verify that we can call `__repr__` on different proxy subclasses.
Expand Down

0 comments on commit ab514fc

Please sign in to comment.