In [1]:
import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(42)

net = nn.Sequential(
    nn.Linear(30, 10),
    nn.ReLU(),
    nn.Linear(10, 30),
).cuda()

def latent_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    n_latent_reasoning_steps: int = 3,
    net: nn.Module = net
):
    x_dim = x.shape[-1]
    y_latent_dim = y_latent.shape[-1]
    z_latent_dim = z_latent.shape[-1]
    input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
    for _ in range(n_latent_reasoning_steps):
        output_tensor = net(input_tensor)
        input_tensor = output_tensor + input_tensor
    y = output_tensor[:, x_dim:x_dim+y_latent_dim]
    z = output_tensor[:, x_dim+y_latent_dim:x_dim+y_latent_dim+z_latent_dim]
    return y, z

def deep_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    t_recursion_steps: int = 2,
    net: nn.Module = net
):
    # Don't modify y_latent and z_latent in place within no_grad
    with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        for _ in range(t_recursion_steps - 1):
            with torch.no_grad():
                y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach(), net=net)
            y_latent = y_latent_new
            z_latent = z_latent_new
    y_latent = y_latent.requires_grad_(True)
    z_latent = z_latent.requires_grad_(True)
    with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        y_latent, z_latent = latent_recursion(x, y_latent, z_latent, net=net)
    return y_latent, z_latent

def deep_recursion(x, y_latent, z_latent, t_recursion_steps=2, net=net):
    # burn-in steps: no graph, just values
    for _ in range(t_recursion_steps - 1):
        with torch.no_grad():
            y_latent, z_latent = latent_recursion(x, y_latent, z_latent, net=net)

    # final step: build graph through net
    # (no need to flip requires_grad_ on latents unless you want their .grad)
    y_latent = y_latent.requires_grad_(True)
    z_latent = z_latent.requires_grad_(True)
    with torch.set_grad_enabled(True):
        y_latent, z_latent = latent_recursion(x, y_latent, z_latent, net=net)
    return y_latent, z_latent

x = torch.randn(1,10).cuda()
y_latent = torch.randn(1,10).cuda()
z_latent = torch.randn(1,10).cuda()
y_latent, z_latent = deep_recursion(x, y_latent, z_latent, net=net)
example_class = torch.randint(0, 10, (1,)).cuda()

loss = F.cross_entropy(y_latent, example_class)
loss.backward()
print(net[0].weight.grad)


tensor([[ 0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
        [-8.8206e-02,  1.3895e-01,  2.8990e-03, -7.4096e-02, -5.9568e-03,
          3.8062e-02,  4.7087e-02, -5.5109e-02,  7.3397e-02, -1.1877e-01,
          7.1411e-02,  9.0347e-03,  6.8171e-02,  5.2951e-02,  3.3600e-02,
         -2.3978e-03, -9.0476e-04, -8.5729e-02,  9.9391e-02, -6.9879e-03,
          6.3970e-02,  2.6974e-02, -4.5227e-02, -7.3767e-02, -3.1551e-02,
          8.6635e-02,  1.2341e-01,  4.3791e-02,  5.0437e-02, -3.9473e-02],
        [ 0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  

In [24]:
import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(42)

net = nn.Sequential(
    nn.Linear(30, 10),
    nn.ReLU(),
    nn.Linear(10, 30),
).cuda()

def latent_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    n_latent_reasoning_steps: int = 3,
    net: nn.Module = net
):
    x_dim = x.shape[-1]
    y_latent_dim = y_latent.shape[-1]
    z_latent_dim = z_latent.shape[-1]
    input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
    for _ in range(n_latent_reasoning_steps):
        output_tensor = net(input_tensor)
        input_tensor = output_tensor + input_tensor
    y = output_tensor[:, x_dim:x_dim+y_latent_dim]
    z = output_tensor[:, x_dim+y_latent_dim:x_dim+y_latent_dim+z_latent_dim]
    return y, z

def deep_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    t_recursion_steps: int = 2,
    net: nn.Module = net
):
    # Don't modify y_latent and z_latent in place within no_grad
    for _ in range(t_recursion_steps - 1):
        net.requires_grad_(False)
        y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach(), net=net)
        net.requires_grad_(True)
        y_latent = y_latent_new
        z_latent = z_latent_new
    y_latent, z_latent = latent_recursion(x, y_latent, z_latent)
    return y_latent, z_latent

x = torch.randn(1,10).cuda()
y_latent = torch.randn(1,10).cuda()
z_latent = torch.randn(1,10).cuda()
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
    y_latent, z_latent = deep_recursion(x, y_latent, z_latent, net=net)
    example_class = torch.randint(0, 10, (1,)).cuda()

    loss = F.cross_entropy(y_latent, example_class)
    loss.backward()
    print(net[0].weight.grad)


tensor([[ 0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
        [-8.8379e-02,  1.3867e-01,  3.2959e-03, -7.3730e-02, -6.1646e-03,
          3.8086e-02,  4.6875e-02, -5.4443e-02,  7.2754e-02, -1.1914e-01,
          7.1289e-02,  9.0332e-03,  6.7383e-02,  5.2979e-02,  3.3691e-02,
         -2.3651e-03, -9.3079e-04, -8.5938e-02,  9.9609e-02, -6.9580e-03,
          6.3477e-02,  2.6978e-02, -4.5410e-02, -7.3730e-02, -3.1494e-02,
          8.6426e-02,  1.2402e-01,  4.3701e-02,  5.0293e-02, -3.9307e-02],
        [ 0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  

In [None]:
import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(42)
device = "cuda"

net = nn.Sequential(nn.Linear(30, 10), nn.ReLU(), nn.Linear(10, 30)).to(device)

def latent_recursion(x, y_latent, z_latent, n=3, net=net):
    x_dim, y_dim, z_dim = x.shape[-1], y_latent.shape[-1], z_latent.shape[-1]
    inp = torch.cat([x, y_latent, z_latent], dim=-1)
    for _ in range(n):
        out = net(inp)
        inp = out + inp
    y = out[:, x_dim:x_dim+y_dim]
    z = out[:, x_dim+y_dim:x_dim+y_dim+z_dim]
    return y, z

def deep_recursion(x, y_latent, z_latent, t=2, net=net):
    # warm-up: explicitly no grad
    for _ in range(t - 1):
        with torch.no_grad():
            y_latent, z_latent = latent_recursion(x, y_latent, z_latent, net=net)

    # FINAL pass: explicitly enable grad, and DEBUG inside
    print("inside final pass - grad enabled?", torch.is_grad_enabled())
    y_out, z_out = latent_recursion(x, y_latent, z_latent, net=net)
    print("inside final pass - y_out.requires_grad?", y_out.requires_grad)
    return y_out, z_out

x = torch.randn(1,10, device=device)
y0 = torch.randn(1,10, device=device)
z0 = torch.randn(1,10, device=device)

# Try WITHOUT autocast first
y_out, z_out = deep_recursion(x, y0, z0, net=net)
print("after final call - y_out.requires_grad?", y_out.requires_grad)

target = torch.randint(0, 10, (1,), device=device)
loss = F.cross_entropy(y_out, target)
loss.backward()
print(net[0].weight.grad)
print("net[0].weight.grad is None?", net[0].weight.grad is None)

inside final pass - grad enabled? True
inside final pass - y_out.requires_grad? False
after final call - y_out.requires_grad? False


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn