In [23]:
import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(42)

net = nn.Sequential(
    nn.Linear(30, 10),
    nn.ReLU(),
    nn.Linear(10, 30),
).cuda()

def latent_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    n_latent_reasoning_steps: int = 3,
    net: nn.Module = net
):
    x_dim = x.shape[-1]
    y_latent_dim = y_latent.shape[-1]
    z_latent_dim = z_latent.shape[-1]
    input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
    for _ in range(n_latent_reasoning_steps):
        output_tensor = net(input_tensor)
        input_tensor = output_tensor + input_tensor
    y = output_tensor[:, x_dim:x_dim+y_latent_dim]
    z = output_tensor[:, x_dim+y_latent_dim:x_dim+y_latent_dim+z_latent_dim]
    return y, z

def deep_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    t_recursion_steps: int = 2,
    net: nn.Module = net
):
    # Don't modify y_latent and z_latent in place within no_grad
    for _ in range(t_recursion_steps - 1):
        with torch.no_grad():
            y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach(), net=net)
        y_latent = y_latent_new
        z_latent = z_latent_new
    y_latent = y_latent.requires_grad_(True)
    z_latent = z_latent.requires_grad_(True)
    y_latent, z_latent = latent_recursion(x, y_latent, z_latent)
    return y_latent, z_latent

x = torch.randn(1,10).cuda()
y_latent = torch.randn(1,10).cuda()
z_latent = torch.randn(1,10).cuda()
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
    y_latent, z_latent = deep_recursion(x, y_latent, z_latent, net=net)
    example_class = torch.randint(0, 10, (1,)).cuda()

    loss = F.cross_entropy(y_latent, example_class)
    loss.backward()
    print(net[0].weight.grad)


None


In [24]:
import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(42)

net = nn.Sequential(
    nn.Linear(30, 10),
    nn.ReLU(),
    nn.Linear(10, 30),
).cuda()

def latent_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    n_latent_reasoning_steps: int = 3,
    net: nn.Module = net
):
    x_dim = x.shape[-1]
    y_latent_dim = y_latent.shape[-1]
    z_latent_dim = z_latent.shape[-1]
    input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
    for _ in range(n_latent_reasoning_steps):
        output_tensor = net(input_tensor)
        input_tensor = output_tensor + input_tensor
    y = output_tensor[:, x_dim:x_dim+y_latent_dim]
    z = output_tensor[:, x_dim+y_latent_dim:x_dim+y_latent_dim+z_latent_dim]
    return y, z

def deep_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    t_recursion_steps: int = 2,
    net: nn.Module = net
):
    # Don't modify y_latent and z_latent in place within no_grad
    for _ in range(t_recursion_steps - 1):
        net.requires_grad_(False)
        y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach(), net=net)
        net.requires_grad_(True)
        y_latent = y_latent_new
        z_latent = z_latent_new
    y_latent, z_latent = latent_recursion(x, y_latent, z_latent)
    return y_latent, z_latent

x = torch.randn(1,10).cuda()
y_latent = torch.randn(1,10).cuda()
z_latent = torch.randn(1,10).cuda()
with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
    y_latent, z_latent = deep_recursion(x, y_latent, z_latent, net=net)
    example_class = torch.randint(0, 10, (1,)).cuda()

    loss = F.cross_entropy(y_latent, example_class)
    loss.backward()
    print(net[0].weight.grad)


tensor([[ 0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,
          0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00],
        [-8.8379e-02,  1.3867e-01,  3.2959e-03, -7.3730e-02, -6.1646e-03,
          3.8086e-02,  4.6875e-02, -5.4443e-02,  7.2754e-02, -1.1914e-01,
          7.1289e-02,  9.0332e-03,  6.7383e-02,  5.2979e-02,  3.3691e-02,
         -2.3651e-03, -9.3079e-04, -8.5938e-02,  9.9609e-02, -6.9580e-03,
          6.3477e-02,  2.6978e-02, -4.5410e-02, -7.3730e-02, -3.1494e-02,
          8.6426e-02,  1.2402e-01,  4.3701e-02,  5.0293e-02, -3.9307e-02],
        [ 0.0000e+00, -0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00,
         -0.0000e+00, -0.0000e+00,  