# Numerical Check of Theorem
Reacting to the reviewers request, we numerically check our derivation of the state-to-state jacobian.

Approach: Iterate a single token till convergence in a random model with random weights and a random initial hidden state.


In [134]:
from mamba_deq.models import ImplicitModel, ExplicitModel
from omegaconf import OmegaConf as om
import torch
from torch.autograd.functional import jacobian

torch.manual_seed(0)
torch.use_deterministic_algorithms(True)
%env CUBLAS_WORKSPACE_CONFIG=:4096:8

# TODO
# - 64 bit precision

In [135]:
deq_params = om.load('config/deq/phantom.yaml')
deq_params.solver.f_tol=1e-6
deq_params.solver.f_max_iter=512
deq_params.solver.eval_f_max_iter=512

precision = torch.float64
d_model = 64

In [136]:
model = ImplicitModel(
    deq_params, 
    d_model=d_model,
    n_layer=1,
    d_inner=0,
    pre_norm=True,
    pretrain_steps=0,
    block_cfg={
        'd_state': 1,
        'expand': 1,
        'headdim': d_model,
    },
    device='cuda',
    dtype=precision
)
explicit_model = ExplicitModel(
    d_model=d_model,
    n_layer=1,
    d_inner=0,
    pre_norm=True,
    block_cfg={
        'd_state': 1,
        'expand': 1,
        'headdim': d_model,
    },
    device='cuda',
    dtype=precision
)
explicit_model.layers = model.layers

In [137]:
batch_size = 1
device = 'cuda'
initial_state = model.allocate_inference_cache(batch_size=batch_size, device=device)

In [138]:
x = torch.randn(1, d_model).to(device, precision)
u = torch.randn(1, 2 * d_model + 3).to(device, precision)
z = torch.zeros(1, d_model).to(device, precision)
conv_state = torch.rand_like(initial_state[0][0]).to(device, precision)
h = torch.rand(1, 1, d_model, 1).to(device, precision) * 2
model = model.to(device, precision)

In [139]:
def get_inference_cache(h):
    return [(conv_state, h)]

## Compare Explicit and Implicit Model Jacobian

In [140]:
def step(h):
    inference_cache = get_inference_cache(h)
    fixed_point, new_inference_cache, a, r, s = model._sequential_step(z, u, inference_cache)
    new_h = new_inference_cache[0][1]
    print(f'Convergence after {s} steps with rel diff {r}')
    return new_h, fixed_point

In [141]:
# get fixed point
with torch.no_grad():
    _, z_fp = step(h)

J = jacobian(lambda x: step(x)[0], h, strict=True)

In [142]:
def step_explicit(h):
    inference_cache = get_inference_cache(h)
    output, new_inference_cache = model._step(x, injected_inputs=None, inference_cache=inference_cache)
    new_h = new_inference_cache[0][1]
    return new_h

In [143]:
J_explicit = jacobian(step_explicit, h, strict=True)

In [144]:
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np

fig, axes = plt.subplots(1, 2, figsize=(6, 3))

data1 = np.abs(J.squeeze().cpu().numpy())
data2 = np.abs(J_explicit.squeeze().cpu().numpy())

# Use the same normalization across both plots
norm = LogNorm(vmin=1e-4, vmax=1)
print(norm.vmin, norm.vmax)
img1 = axes[0].imshow(data1, norm=norm)
img2 = axes[1].imshow(data2, norm=norm)

# One colorbar for both axes
fig.colorbar(img2, ax=axes, orientation='vertical', fraction=0.046, pad=0.04)

# Annotations
axes[0].set_title('Implicit Jacobian')
axes[1].set_title('Explicit Jacobian')
for ax in axes:
    ax.set_xlabel('State index')
    ax.set_ylabel('State index')
fig.subplots_adjust(left=0.1, right=0.8, wspace=0.4)

In [145]:
fig, ax = plt.subplots(1, 1, figsize=(3, 3))
bins = np.logspace(-8, 0, 20)
ax.hist(J.squeeze().cpu().numpy().flatten(), bins=bins)
ax.set_xscale('log')

## Check Theorem Equation

### Get Mamba Internal Variables
We need to get derivatives of $\Lambda$ and $u$ from the Mamba code

In [146]:
# Copyright (c) 2024, Tri Dao, Albert Gu.
import math
from typing import Tuple
from einops import rearrange
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F

def mamba(self, hidden_states, injected_inputs, conv_state, ssm_state) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
        """
        Sequentially step through a sequence and carry over internal states for convolutions and state-space models.
        Args:
            hidden_states: input to the Mamba2 layer (B, D_model)
            injected_inputs: injected input to the Mamba2 layer (B, D_in_proj)
            conv_state: carry for convolution (B, D_conv, W)
            ssm_state: carry for state-space model (B, nheads, headdim, D_state)

        Returns:
            out: output for this step (B, D_model)
            new_conv_state: updated convolution state (B, D_conv, W)
            new_ssm_state: updated state-space model state (B, nheads, headdim, D_state)1

        """
        dtype = hidden_states.dtype

        if hidden_states.dim() > 2:
            assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
            hidden_states = hidden_states.squeeze(1)
        zxbcdt = self.in_proj(hidden_states)  # (B 2D)

        # inject inputs
        if injected_inputs is not None:
            if injected_inputs.dim() > 2:
                assert injected_inputs.shape[1] == 1, "Only support decoding with 1 token at a time for now"
                injected_inputs = injected_inputs.squeeze(1)
            zxbcdt += injected_inputs

        d_mlp = (zxbcdt.shape[-1] - 2 * self.d_ssm - 2 * self.ngroups * self.d_state - self.nheads) // 2
        z0, x0, z, xBC, dt = torch.split(
            zxbcdt, [d_mlp, d_mlp, self.d_ssm, self.d_ssm + 2 * self.ngroups * self.d_state, self.nheads], dim=-1
        )

        # Conv step
        new_conv_state = torch.roll(conv_state, shifts=-1, dims=-1)  # Update state (B D W)
        new_conv_state[:, :, -1] = xBC
        xBC = torch.sum(new_conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)  # (B D)
        if self.conv1d.bias is not None:
            xBC = xBC + self.conv1d.bias
        xBC = self.act(xBC).to(dtype=dtype)

        x, B, C = torch.split(xBC, [self.d_ssm, self.ngroups * self.d_state, self.ngroups * self.d_state], dim=-1)
        A = -torch.exp(self.A_log.float())  # (nheads,)

        # SSM step
        assert self.ngroups == 1, "Only support ngroups=1 for this inference code path"
        # Discretize A and B
        dt = F.softplus(dt + self.dt_bias.to(dtype=dt.dtype))  # (batch, nheads)
        dA = torch.exp(dt * A)  # (batch, nheads)
        x = rearrange(x, "b (h p) -> b h p", p=self.headdim)
        dBx = torch.einsum("bh,bn,bhp->bhpn", dt, B, x)
        new_ssm_state = ssm_state * rearrange(dA, "b h -> b h 1 1") + dBx
        y = torch.einsum("bhpn,bn->bhp", new_ssm_state.to(dtype), C)
        y = y + rearrange(self.D.to(dtype), "h -> h 1") * x
        y = rearrange(y, "b h p -> b (h p)")
        if not self.rmsnorm:
            y = y * self.act(z)  # (B D)
        if self.rmsnorm:
            y = self.norm(y, z)
        if d_mlp > 0:
            y = torch.cat([F.silu(z0) * x0, y], dim=-1)
        out = self.out_proj(y)

        return out, dA, dBx

In [147]:
mamba(model.layers[0].mixer, z_fp, u, *initial_state[0])

In [148]:
def block(self, z, u, inference_cache=None):
    # split off residual
    residual = z

    # pre-norm formulation
    if self.pre_norm:
        z = self.norm(z.to(dtype=self.norm.weight.dtype))

    # apply the time mixer (SSM / Transformer) and skip connection
    inference_cache_t = inference_cache if isinstance(inference_cache, tuple) else (inference_cache,)
    z, dA, dBx = mamba(self.mixer, z, u, *inference_cache_t)
    z = residual + z
    if self.residual_in_fp32:
        z = z.to(torch.float32)
    return z, dA, dBx

In [149]:
def model_func(z, h):
    inference_cache = get_inference_cache(h)[0]
    out, A, Bx = block(model.layers[0], z, u, inference_cache)
    return model.norm_f(out), A, Bx

In [150]:
# Verify model outputs match and fixed point is indeed found
# out, _, _ = model_func(z_fp, h)
# iter_out, _, _, _, s = model._sequential_step(z_fp, u, get_inference_cache(h))
# print(torch.cat([z_fp, iter_out, out], dim=0)[:,:8])
# s

In [151]:
with torch.no_grad():
    _, A, Bx = model_func(z_fp, h)
jac = jacobian(model_func, (z_fp, h))

In [152]:
dzdz = jac[0][0].squeeze()
dzdh = jac[0][1].squeeze()
dAdz = jac[1][0].squeeze()
dBxdz = jac[2][0].squeeze()
print(dzdz.shape, dzdh.shape, dAdz.shape, dBxdz.shape)


In [153]:
G = torch.eye(d_model).to(device) - dzdz
print('dzdz has rank', torch.linalg.matrix_rank(G).item())
print('dzdz condition number:', torch.linalg.cond(G).item())
dfdh = torch.einsum('ij,jk->ik', torch.inverse(G), dzdh)

In [154]:
theorem = torch.eye(d_model).to(device) * A + torch.einsum('k,kj,i->ij', dAdz, dfdh, h.squeeze()) + torch.einsum('ij,jk->ik', dBxdz, dfdh)

In [157]:
from matplotlib.colors import Normalize

fig, axes = plt.subplots(1, 3, figsize=(9, 3))

data1 = J.squeeze().cpu().numpy()
data2 = theorem.squeeze().cpu().numpy()
diff  = data1 - data2

# Use the same normalization across both plots
print(data1.min(), data2.min(), data1.max(), data2.max())
print(diff.min(), diff.max())
norm = Normalize(vmin=-0.1, vmax=0.1)
lognorm = LogNorm(vmin=1e-8, vmax=1)
cmap = 'coolwarm'
img1 = axes[0].imshow(data1, cmap=cmap, norm=norm)
img2 = axes[1].imshow(data2, cmap=cmap, norm=norm)
img3 = axes[2].imshow(np.abs(diff), norm=lognorm)

# One colorbar for both axes
fig.colorbar(img1, ax=axes[0], orientation='horizontal', fraction=0.046, pad=0.2)
fig.colorbar(img2, ax=axes[1], orientation='horizontal', fraction=0.046, pad=0.2)
fig.colorbar(img3, ax=axes[2], orientation='horizontal', fraction=0.046, pad=0.2)

# Annotations
axes[0].set_title('Autograd Jacobian')
axes[1].set_title('Formula Jacobian')
axes[2].set_title('Difference')
for ax in axes:
    ax.set_xlabel('State index')
    ax.set_ylabel('State index')
fig.subplots_adjust(left=0.1, right=0.8, wspace=0.4)
fig.savefig('implicit_jacobian_check.png', dpi=300)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(6, 3))
bins = np.logspace(-8, 0, 20)
axes[0].hist(data1.flatten(), bins=bins)
axes[1].hist(diff.flatten(), bins=bins)

for ax in axes:
    ax.set_xscale('log')

In [160]:
for arr in [data1, data2, diff]:
    print(np.median(arr))