In [20]:
import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(42)

net = nn.Sequential(
    nn.Linear(30, 10),
    nn.ReLU(),
    nn.Linear(10, 30),
).cuda()

def latent_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    n_latent_reasoning_steps: int = 3,
    net: nn.Module = net
):
    x_dim = x.shape[-1]
    y_latent_dim = y_latent.shape[-1]
    z_latent_dim = z_latent.shape[-1]
    input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
    for _ in range(n_latent_reasoning_steps):
        output_tensor = net(input_tensor)
        input_tensor = output_tensor + input_tensor
    y = output_tensor[:, x_dim:x_dim+y_latent_dim]
    z = output_tensor[:, x_dim+y_latent_dim:x_dim+y_latent_dim+z_latent_dim]
    return y, z

def deep_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    t_recursion_steps: int = 2,
    net: nn.Module = net
):
    # Don't modify y_latent and z_latent in place within no_grad
    for _ in range(t_recursion_steps - 1):
        with torch.no_grad():
            y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach(), net=net)
        y_latent = y_latent_new
        z_latent = z_latent_new
    y_latent = y_latent.requires_grad_(True)
    z_latent = z_latent.requires_grad_(True)
    y_latent, z_latent = latent_recursion(x, y_latent, z_latent)
    return y_latent, z_latent

x = torch.randn(1,10).cuda()
y_latent = torch.randn(1,10).cuda()
z_latent = torch.randn(1,10).cuda()
# with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
y_latent, z_latent = deep_recursion(x, y_latent, z_latent, net=net)
example_class = torch.randint(0, 10, (1,)).cuda()

loss = F.cross_entropy(y_latent, example_class)
loss.backward()
print(net[0].weight.grad)


tensor([[ 0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
        [-8.8206e-02,  1.3895e-01,  2.8990e-03, -7.4096e-02, -5.9568e-03,
          3.8062e-02,  4.7087e-02, -5.5109e-02,  7.3397e-02, -1.1877e-01,
          7.1411e-02,  9.0347e-03,  6.8171e-02,  5.2951e-02,  3.3600e-02,
         -2.3978e-03, -9.0476e-04, -8.5729e-02,  9.9391e-02, -6.9879e-03,
          6.3970e-02,  2.6974e-02, -4.5227e-02, -7.3767e-02, -3.1551e-02,
          8.6635e-02,  1.2341e-01,  4.3791e-02,  5.0437e-02, -3.9473e-02],
        [ 0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  

In [38]:
import torch
from torch import nn
import torch.nn.functional as F

# --- Network and functions are unchanged ---
net = nn.Sequential(
    nn.Linear(30, 10),
    nn.ReLU(),
    nn.Linear(10, 30),
).cuda()

def latent_recursion(
    x: torch.Tensor,
    y_latent: torch.Tensor,
    z_latent: torch.Tensor,
    n_latent_reasoning_steps: int = 3,
    net: nn.Module = net
):
    x_dim = x.shape[-1]
    y_latent_dim = y_latent.shape[-1]
    input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
    for _ in range(n_latent_reasoning_steps):
        output_tensor = net(input_tensor)
        input_tensor = output_tensor + input_tensor
    y = output_tensor[:, x_dim:x_dim+y_latent_dim]
    z = output_tensor[:, x_dim+y_latent_dim:] # Simplified slicing
    return y, z

def deep_recursion(
    x: torch.Tensor,
    y_latent: torch.Tensor,
    z_latent: torch.Tensor,
    t_recursion_steps: int = 2,
    net: nn.Module = net
):
    # This loop runs in default precision (float32) and without gradients
    for _ in range(t_recursion_steps - 1):
        with torch.no_grad():
            y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach())
        y_latent = y_latent_new
        z_latent = z_latent_new

    # Make the float32 tensors require gradients
    y_latent = y_latent.requires_grad_(True)
    z_latent = z_latent.requires_grad_(True)

    # **THE FIX: Apply autocast ONLY to the final differentiable step**
    with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
        y_latent_final, z_latent_final = latent_recursion(x, y_latent, z_latent)

    return y_latent_final, z_latent_final

# --- Training loop ---
x = torch.randn(1,10).cuda()
y_latent_start = torch.randn(1,10).cuda()
z_latent_start = torch.randn(1,10).cuda()
scaler = torch.amp.GradScaler()

# No top-level autocast here
y_latent_out, z_latent_out = deep_recursion(x, y_latent_start, z_latent_start)
example_class = torch.randint(0, 10, (1,)).cuda()

# The loss calculation can also be inside the autocast block if preferred,
# but it works here too because y_latent_out is now correctly connected.
# For consistency, we put the autocast around the operation that needs it.
loss = F.cross_entropy(y_latent_out, example_class)

scaler.scale(loss).backward()
print(net[0].weight.grad)

tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan],
        [-0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0., -0., 0., 0., 0.,
         0., -0., 0., -0., -0., 0.],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan],
        [-0., 0., 0., -0., 0., 0., -0., 0., -0., 0., 0., 0., -0., 0., -0., 0., -0., 0., 0., -0., -0., 0., 0., 0.,
         0., -0., 0., -0., -0., 0.],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan,
         nan, nan, nan, nan, nan, nan],
        [nan, nan, nan