Trying to use forwrad mode autodiff currently results in
NotImplementedError: You must implement the jvp function for custom autograd.Function to use it with forward mode AD.
Maybe this could be added.
See https://pytorch.org/docs/stable/notes/extending.html#forward-mode-ad and https://pytorch.org/tutorials/intermediate/forward_ad_usage.html#custom-autograd-function