In [38]:
from functools import partial

import torch
from matplotlib import pyplot as plt

In [None]:
%load_ext autoreload
%autoreload 2

import torch

torch.set_printoptions(linewidth=400, threshold=100000)

import sys

sys.path.append("../..")

from mlstm_kernels.components.ln import MultiHeadLayerNorm
from mlstm_kernels.mlstm.chunkwise.max_triton_fwbw_v3 import (
    mlstm_chunkwise_max_triton_v3,
)

import pandas as pd
import torch.nn.functional as F

# Interactive plot of the mLSTM Signal Propagation

In [40]:
DTYPE = torch.bfloat16  # torch.bfloat16
DEVICE = torch.device("cuda:0")

## Code: mLSTM implementations

In [41]:
def mlstm_paper_unstable_fgate(
    matQ: torch.Tensor,
    matK: torch.Tensor,
    matV: torch.Tensor,
    vecI: torch.Tensor,
    vecF: torch.Tensor,
    eps: float = 1e-6,
    mstate_mode: str = "paper",
) -> torch.Tensor:
    import math

    B, NH, S, DHQK = matQ.shape
    assert matK.shape == (B, NH, S, DHQK)
    assert vecI.shape == (B, NH, S)
    assert vecF.shape == (B, NH, S)

    _dtype, _device = matQ.dtype, matQ.device

    vecLogSigF = F.logsigmoid(vecF)  # (B, NH, S)
    vecLogSigF_cumsum = vecLogSigF.cumsum(-1)

    matLogSigF = vecLogSigF_cumsum[:, :, :, None] - vecLogSigF_cumsum[:, :, None, :]

    ltr = torch.tril(
        torch.ones(
            (S, S),
            dtype=torch.bool,
            device=_device,
        )
    )

    matLogSigF_mask = torch.where(ltr, matLogSigF, -float("inf"))

    matLogD = matLogSigF_mask + vecI[:, :, None, :]

    vecM, _ = torch.max(matLogD, dim=-1, keepdim=True)  # (B, NH, S, 1)
    matLogD_stabilized = matLogD - vecM

    matD = torch.exp(matLogD_stabilized)  # (B, NH, S, S)

    matS = (matQ @ matK.transpose(-2, -1)) / math.sqrt(DHQK)  # (B, NH, S, S)

    matCtilde = matS * matD  # (B, NH, S, S)
    if mstate_mode == "paper":
        vecN = torch.maximum(matCtilde.sum(dim=-1, keepdim=True).abs(), torch.exp(-vecM))  # (B, NH, S, 1)
    elif mstate_mode == "exp_minus_m_to_one":
        vecN = torch.maximum(
            matCtilde.sum(dim=-1, keepdim=True).abs(),
            torch.tensor([1.0], device=_device, dtype=_dtype),
        )  # (B, NH, S, 1)
    elif mstate_mode == "sum_only":
        vecN = matCtilde.sum(dim=-1, keepdim=True).abs()

    elif mstate_mode == "denom_one":
        vecN = torch.tensor([1.0], device=_device, dtype=_dtype)

    else:
        raise ValueError(f"mstate_mode {mstate_mode} not recognized")

    # (B, NH, S, S)
    matC = matCtilde / (vecN + eps)

    matH = matC @ matV  # (B, NH, S, DH)

    return (
        matH,
        vecM.squeeze(-1),
        vecN.squeeze(-1),
        matLogD,
        matLogD_stabilized,
        matD,
        matCtilde,
        matC,
        vecLogSigF,
        vecLogSigF_cumsum,
    )

In [42]:
def mlstm_paper_stable_fgate(
    matQ: torch.Tensor,
    matK: torch.Tensor,
    matV: torch.Tensor,
    vecI: torch.Tensor,
    vecF: torch.Tensor,
    eps: float = 1e-6,
    mstate_mode: str = "paper",
) -> torch.Tensor:
    import math

    B, NH, S, DHQK = matQ.shape
    assert matK.shape == (B, NH, S, DHQK)
    assert vecI.shape == (B, NH, S)
    assert vecF.shape == (B, NH, S)

    _dtype, _device = matQ.dtype, matQ.device

    vecLogSigF = F.logsigmoid(vecF)  # (B, NH, S)

    matLogSigF_tril = vecLogSigF[:, :, :, None].repeat(1, 1, 1, S).tril(-1)
    matLogSigF_cum = matLogSigF_tril.cumsum(-2)

    ltr = torch.tril(
        torch.ones(
            (S, S),
            dtype=torch.bool,
            device=_device,
        )
    )

    matLogSigF_mask = torch.where(ltr, matLogSigF_cum, -float("inf"))

    matLogD = matLogSigF_mask + vecI[:, :, None, :]

    vecM, _ = torch.max(matLogD, dim=-1, keepdim=True)  # (B, NH, S, 1)
    matLogD_stabilized = matLogD - vecM

    matD = torch.exp(matLogD_stabilized)  # (B, NH, S, S)

    matS = (matQ @ matK.transpose(-2, -1)) / math.sqrt(DHQK)  # (B, NH, S, S)

    matCtilde = matS * matD  # (B, NH, S, S)
    if mstate_mode == "paper":
        vecN = torch.maximum(matCtilde.sum(dim=-1, keepdim=True).abs(), torch.exp(-vecM))  # (B, NH, S, 1)
    elif mstate_mode == "exp_minus_m_to_one":
        vecN = torch.maximum(
            matCtilde.sum(dim=-1, keepdim=True).abs(),
            torch.tensor([1.0], device=_device, dtype=_dtype),
        )  # (B, NH, S, 1)
    elif mstate_mode == "sum_only":
        vecN = matCtilde.sum(dim=-1, keepdim=True).abs()

    elif mstate_mode == "denom_one":
        vecN = torch.tensor([1.0], device=_device, dtype=_dtype)

    else:
        raise ValueError(f"mstate_mode {mstate_mode} not recognized")
    # (B, NH, S, S)
    matC = matCtilde / (vecN + eps)

    matH = matC @ matV  # (B, NH, S, DH)

    return (
        matH,
        vecM.squeeze(-1),
        vecN.squeeze(-1),
        matLogD,
        matLogD_stabilized,
        matD,
        matCtilde,
        matC,
        vecLogSigF,
        None,
    )

## Code: Plotting code

### matplotlib plots

In [43]:
def make_h_output_plot_mlstm_with_internals(
    mlstm_func,
    B,
    NH,
    S,
    DHQK,
    DHV,
    vecI_offset,
    vecF_offset,
    seed=0,
    plot_max_min=True,
    vecI_init_fn=torch.randn,
    vecF_init_fn=torch.randn,
):
    torch.manual_seed(seed)
    matQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
    matK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
    matV = torch.randn((B, NH, S, DHV), dtype=DTYPE, device=DEVICE)
    # vecI = 0.00001 * torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)
    # vecF = -30. + torch.randn((B, NH, S), dtype=DTYPE, device=DEVICE)
    vecI = vecI_offset + vecI_init_fn((B, NH, S), dtype=DTYPE, device=DEVICE)
    vecF = vecF_offset + vecF_init_fn((B, NH, S), dtype=DTYPE, device=DEVICE)

    out = mlstm_func(matQ, matK, matV, vecI, vecF)

    if isinstance(out, tuple):
        (
            h_out,
            m_out,
            n_out,
            matLogD,
            matLogD_stabilized,
            matD,
            matCtilde,
            matC,
            vecLogSigF,
            vecLogSigF_cumsum,
        ) = out
    else:
        h_out = out
        m_out = None
        n_out = None
        matLogD = None
        matLogD_stabilized = None
        matD = None
        matCtilde = None
        matC = None
        vecLogSigF = None
        vecLogSigF_cumsum = None

    # plot hout + mstate
    h_out_pl_mean = h_out.mean(-1).flatten().cpu().float().numpy()
    h_out_pl_std = h_out.std(-1).flatten().cpu().float().numpy()
    h_out_max = h_out.max(-1)[0].flatten().cpu().float().numpy()
    h_out_min = h_out.min(-1)[0].flatten().cpu().float().numpy()
    if m_out is not None:
        m_pl = m_out.flatten().cpu().float().numpy()
        plt.plot(m_pl, label="m_state")
    # plt.plot(f_pl, label="f_preact")
    # plt.plot(flogsig_pl)
    # plt.plot(n_pl, label="n_state")
    plt.plot(h_out_pl_mean, label="h_out_mean")
    plt.fill_between(
        range(len(h_out_pl_mean)),
        h_out_pl_mean - h_out_pl_std,
        h_out_pl_mean + h_out_pl_std,
        alpha=0.5,
    )
    if plot_max_min:
        plt.plot(h_out_max, label="h_out_max")
        plt.plot(h_out_min, label="h_out_min")

    plt.legend()
    print(f"vecI_offs: {vecI_offset}, vecF_offs: {vecF_offset}")
    print(f"S: {S}, B: {B}, NH: {NH}, DHQK: {DHQK}, DHV: {DHV}")
    # plt.yscale("log")
    plt.show()

    return (
        h_out,
        m_out,
        n_out,
        matLogD,
        matLogD_stabilized,
        matD,
        matCtilde,
        matC,
        vecLogSigF,
        vecLogSigF_cumsum,
    )

In [44]:
# Set vecF_offset to -6 (almost any negative value) produces an m_state spike
B = 1
NH = 1
S = 8192
D = 1024
DHQK = D
DHV = D
vecI_offset = -3.0  # -3.0
vecF_offset = 5.0
vecI_init_fn = torch.randn
vecF_init_fn = torch.randn

In [None]:
print("paper version")
(
    h_out,
    m_out,
    n_out,
    matLogD,
    matLogD_stabilized,
    matD,
    matCtilde,
    matC,
    vecLogSigF,
    vecLogSigF_cumsum,
) = make_h_output_plot_mlstm_with_internals(
    mlstm_func=partial(mlstm_paper_unstable_fgate, mstate_mode="paper"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset=vecI_offset,
    vecF_offset=vecF_offset,
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)
print("exp_minus_m_to_one")
_ = make_h_output_plot_mlstm_with_internals(
    mlstm_func=partial(mlstm_paper_unstable_fgate, mstate_mode="exp_minus_m_to_one"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset=vecI_offset,
    vecF_offset=vecF_offset,
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)
print("denom_one")
(
    h_out_do,
    m_out_do,
    n_out_do,
    matLogD_do,
    matLogD_stabilized_do,
    matD_do,
    matCtilde_do,
    matC_do,
    vecLogSigF_do,
    vecLogSigF_cumsum_do,
) = make_h_output_plot_mlstm_with_internals(
    mlstm_func=partial(mlstm_paper_unstable_fgate, mstate_mode="denom_one"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset=vecI_offset,
    vecF_offset=vecF_offset,
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)
print("max_triton_v3 kernel")
_ = make_h_output_plot_mlstm_with_internals(
    mlstm_func=mlstm_chunkwise_max_triton_v3,
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset=vecI_offset,
    vecF_offset=vecF_offset,
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)
print("==== stable torch version ====")
print("paper version")
(
    h_out,
    m_out,
    n_out,
    matLogD,
    matLogD_stabilized,
    matD,
    matCtilde,
    matC,
    vecLogSigF,
    vecLogSigF_cumsum,
) = make_h_output_plot_mlstm_with_internals(
    mlstm_func=partial(mlstm_paper_stable_fgate, mstate_mode="paper"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset=vecI_offset,
    vecF_offset=vecF_offset,
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)
print("exp_minus_m_to_one")
_ = make_h_output_plot_mlstm_with_internals(
    mlstm_func=partial(mlstm_paper_stable_fgate, mstate_mode="exp_minus_m_to_one"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset=vecI_offset,
    vecF_offset=vecF_offset,
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)
print("denom_one")
(
    h_out_do,
    m_out_do,
    n_out_do,
    matLogD_do,
    matLogD_stabilized_do,
    matD_do,
    matCtilde_do,
    matC_do,
    vecLogSigF_do,
    vecLogSigF_cumsum_do,
) = make_h_output_plot_mlstm_with_internals(
    mlstm_func=partial(mlstm_paper_stable_fgate, mstate_mode="denom_one"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset=vecI_offset,
    vecF_offset=vecF_offset,
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)

### static plotly plot

In [46]:
import plotly.graph_objects as go
import torch


def make_h_output_plot_mlstm_with_internals_plotly(
    mlstm_func,
    B,
    NH,
    S,
    DHQK,
    DHV,
    vecI_offset,
    vecF_offset,
    seed=0,
    plot_max_min=True,
    vecI_init_fn=torch.randn,
    vecF_init_fn=torch.randn,
):
    torch.manual_seed(seed)
    matQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
    matK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
    matV = torch.randn((B, NH, S, DHV), dtype=DTYPE, device=DEVICE)

    vecI = vecI_offset + vecI_init_fn((B, NH, S), dtype=DTYPE, device=DEVICE)
    vecF = vecF_offset + vecF_init_fn((B, NH, S), dtype=DTYPE, device=DEVICE)

    out = mlstm_func(matQ, matK, matV, vecI, vecF)

    if isinstance(out, tuple):
        (
            h_out,
            m_out,
            n_out,
            matLogD,
            matLogD_stabilized,
            matD,
            matCtilde,
            matC,
            vecLogSigF,
            vecLogSigF_cumsum,
        ) = out
    else:
        h_out = out
        m_out = None
        n_out = None
        matLogD = None
        matLogD_stabilized = None
        matD = None
        matCtilde = None
        matC = None
        vecLogSigF = None
        vecLogSigF_cumsum = None

    # Data preparation
    h_out_pl_mean = h_out.mean(-1).flatten().cpu().float().numpy()
    h_out_pl_std = h_out.std(-1).flatten().cpu().float().numpy()
    h_out_max = h_out.max(-1)[0].flatten().cpu().float().numpy()
    h_out_min = h_out.min(-1)[0].flatten().cpu().float().numpy()

    # Create plotly figure
    fig = go.Figure()

    # Plot m_out if it exists
    if m_out is not None:
        m_pl = m_out.flatten().cpu().float().numpy()
        fig.add_trace(go.Scatter(y=m_pl, mode="lines", name="m_state"))

    # Plot h_out_mean
    fig.add_trace(go.Scatter(y=h_out_pl_mean, mode="lines", name="h_out_mean"))

    # Add shaded region for standard deviation
    fig.add_trace(
        go.Scatter(
            x=list(range(len(h_out_pl_mean))) + list(range(len(h_out_pl_mean))[::-1]),
            y=list(h_out_pl_mean + h_out_pl_std) + list((h_out_pl_mean - h_out_pl_std)[::-1]),
            fill="toself",
            fillcolor="rgba(0,100,80,0.2)",
            line=dict(color="rgba(255,255,255,0)"),
            showlegend=False,
            name="h_out_std",
        )
    )

    # Plot max and min if required
    if plot_max_min:
        fig.add_trace(go.Scatter(y=h_out_max, mode="lines", name="h_out_max"))
        fig.add_trace(go.Scatter(y=h_out_min, mode="lines", name="h_out_min"))

    # Update figure layout
    fig.update_layout(
        title="h_out and Internal States",
        xaxis_title="Sequence Position",
        yaxis_title="Value",
        legend_title="Legend",
    )

    print(f"vecI_offs: {vecI_offset}, vecF_offs: {vecF_offset}")
    print(f"S: {S}, B: {B}, NH: {NH}, DHQK: {DHQK}, DHV: {DHV}")
    # Display the plot
    fig.show()

    return (
        h_out,
        m_out,
        n_out,
        matLogD,
        matLogD_stabilized,
        matD,
        matCtilde,
        matC,
        vecLogSigF,
        vecLogSigF_cumsum,
    )

In [None]:
print("paper version")
(
    h_out,
    m_out,
    n_out,
    matLogD,
    matLogD_stabilized,
    matD,
    matCtilde,
    matC,
    vecLogSigF,
    vecLogSigF_cumsum,
) = make_h_output_plot_mlstm_with_internals_plotly(
    mlstm_func=partial(mlstm_paper_unstable_fgate, mstate_mode="paper"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset=vecI_offset,
    vecF_offset=vecF_offset,
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)

### variable plotly plot

In [48]:
import plotly.graph_objects as go
import torch


def plot_input_mean_std_max_min(matQKV):
    q_pl_mean = matQKV.mean(-1).flatten().cpu().float().numpy()
    q_pl_std = matQKV.std(-1).flatten().cpu().float().numpy()
    q_pl_max = matQKV.max(-1)[0].flatten().cpu().float().numpy()
    q_pl_min = matQKV.min(-1)[0].flatten().cpu().float().numpy()
    plt.plot(q_pl_mean, label="qkv_mean")
    plt.fill_between(
        range(len(q_pl_mean)),
        q_pl_mean - q_pl_std,
        q_pl_mean + q_pl_std,
        alpha=0.5,
    )
    plt.plot(q_pl_max, label="qkv_max")
    plt.plot(q_pl_min, label="qkv_min")
    plt.legend()
    plt.show()


def make_h_output_plot_mlstm_with_internals_with_separate_sliders(
    mlstm_func,
    B,
    NH,
    S,
    DHQK,
    DHV,
    vecI_offset_range,  # Tuple of (min, max, step) for vecI_offset
    vecF_offset_range,  # Tuple of (min, max, step) for vecF_offset
    seed=0,
    plot_max_min=True,
    plot_m_state=True,
    vecI_init_fn=torch.randn,
    vecF_init_fn=torch.randn,
):
    torch.manual_seed(seed)
    matQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
    matK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
    matV = torch.randn((B, NH, S, DHV), dtype=DTYPE, device=DEVICE)

    plot_input_mean_std_max_min(matQ)

    # Function to generate the plot data based on vecI_offset and vecF_offset
    def get_plot_data(vecI_offset, vecF_offset):
        vecI = vecI_offset + vecI_init_fn((B, NH, S), dtype=DTYPE, device=DEVICE)
        vecF = vecF_offset + vecF_init_fn((B, NH, S), dtype=DTYPE, device=DEVICE)
        out = mlstm_func(matQ, matK, matV, vecI, vecF)

        if isinstance(out, tuple):
            h_out = out[0]
            m_out = out[1]
        else:
            h_out = out
            m_out = None

        h_out_pl_mean = h_out.mean(-1).flatten().cpu().float().numpy()
        h_out_pl_std = h_out.std(-1).flatten().cpu().float().numpy()
        h_out_max = h_out.max(-1)[0].flatten().cpu().float().numpy()
        h_out_min = h_out.min(-1)[0].flatten().cpu().float().numpy()
        if m_out is not None:
            m_pl = m_out.flatten().cpu().float().numpy()
        else:
            m_pl = None

        if not plot_m_state:
            m_pl = None

        return h_out_pl_mean, h_out_pl_std, h_out_max, h_out_min, m_pl

    # Initial plot data
    if isinstance(vecI_offset_range, tuple):
        vecI_min, vecI_max, vecI_step = vecI_offset_range
        vecI_values = torch.arange(vecI_min, vecI_max + vecI_step, vecI_step)
    elif isinstance(vecI_offset_range, list):
        vecI_values = torch.tensor(vecI_offset_range)
    else:
        raise ValueError("vecI_offset_range must be a tuple or list")

    if isinstance(vecF_offset_range, tuple):
        vecF_min, vecF_max, vecF_step = vecF_offset_range
        vecF_values = torch.arange(vecF_min, vecF_max + vecF_step, vecF_step)
    elif isinstance(vecF_offset_range, list):
        vecF_values = torch.tensor(vecF_offset_range)
    else:
        raise ValueError("vecF_offset_range must be a tuple or list")

    # Store data for combinations of vecI_offset and vecF_offset
    data_cache = {}
    for vecI_offset in vecI_values:
        for vecF_offset in vecF_values:
            data_cache[(vecI_offset.item(), vecF_offset.item())] = get_plot_data(vecI_offset, vecF_offset)

    # Initial values
    initial_vecI = vecI_values[0].item()
    initial_vecF = vecF_values[-1].item()
    h_out_pl_mean, h_out_pl_std, h_out_max, h_out_min, m_pl = data_cache[(initial_vecI, initial_vecF)]

    # Create figure and add initial traces
    fig = go.Figure()
    fig.add_trace(go.Scatter(y=h_out_pl_mean, mode="lines", name="h_out_mean"))
    fig.add_trace(
        go.Scatter(
            x=list(range(len(h_out_pl_mean))) + list(range(len(h_out_pl_mean))[::-1]),
            y=list(h_out_pl_mean + h_out_pl_std) + list((h_out_pl_mean - h_out_pl_std)[::-1]),
            fill="toself",
            fillcolor="rgba(0,100,80,0.2)",
            line=dict(color="rgba(255,255,255,0)"),
            showlegend=False,
            name="h_out_std",
        )
    )
    if plot_max_min:
        fig.add_trace(go.Scatter(y=h_out_max, mode="lines", name="h_out_max"))
        fig.add_trace(go.Scatter(y=h_out_min, mode="lines", name="h_out_min"))

    fig.add_trace(go.Scatter(y=m_pl, mode="lines", name="m_state"))

    # Create frames for each combination of vecI_offset and vecF_offset
    frames = []
    for vecI_offset in vecI_values:
        for vecF_offset in vecF_values:
            h_out_pl_mean, h_out_pl_std, h_out_max, h_out_min, m_pl = data_cache[
                (vecI_offset.item(), vecF_offset.item())
            ]
            frames.append(
                go.Frame(
                    data=[
                        go.Scatter(y=m_pl, mode="lines", name="m_state"),
                        go.Scatter(y=h_out_pl_mean, mode="lines", name=f"h_out_mean"),
                        go.Scatter(
                            x=list(range(len(h_out_pl_mean))) + list(range(len(h_out_pl_mean))[::-1]),
                            y=list(h_out_pl_mean + h_out_pl_std) + list((h_out_pl_mean - h_out_pl_std)[::-1]),
                            fill="toself",
                            fillcolor="rgba(0,100,80,0.2)",
                            line=dict(color="rgba(255,255,255,0)"),
                            showlegend=True,
                            name="h_out_std",
                        ),
                        go.Scatter(y=h_out_max, mode="lines", name="h_out_max"),
                        go.Scatter(y=h_out_min, mode="lines", name="h_out_min"),
                    ],
                    name=f"{vecI_offset.item()}_{vecF_offset.item()}",
                )
            )

    # Add frames to the figure
    fig.frames = frames

    # Create sliders for vecI_offset and vecF_offset
    #! TODO max this does not work as they depend on each other
    slider_vecI = {
        "currentvalue": {"prefix": "vecI_offset: "},
        "pad": {"t": 50},
        "steps": [
            {
                "args": [
                    [f"{vecI_offset.item()}_{initial_vecF}"],
                    {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"},
                ],
                "label": f"{vecI_offset.item()}",
                "method": "animate",
            }
            for vecI_offset in vecI_values
        ],
    }

    slider_vecF = {
        "currentvalue": {"prefix": "vecF_offset: "},
        "pad": {"t": 140},
        "steps": [
            {
                "args": [
                    [f"{initial_vecI}_{vecF_offset.item()}"],
                    {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"},
                ],
                "label": f"{vecF_offset.item()}",
                "method": "animate",
            }
            for vecF_offset in vecF_values
        ],
    }

    #! workaround with just one slider
    slider_vecIF = {
        "currentvalue": {"prefix": "vecIF_offset: "},
        "pad": {"t": 50},
        "steps": [
            {
                "args": [
                    [f"{vecI_offset.item()}_{vecF_offset.item()}"],
                    {"frame": {"duration": 0, "redraw": True}, "mode": "immediate"},
                ],
                "label": f"I_off={vecI_offset.item()}, F_off={vecF_offset.item()}",
                "method": "animate",
            }
            for vecI_offset in vecI_values
            for vecF_offset in vecF_values
        ],
    }

    # Update layout with two sliders
    fig.update_layout(
        sliders=[slider_vecIF],  # [slider_vecI, slider_vecF],
        title="Signal propagation in mLSTM",
        xaxis_title="Sequence Position",
        yaxis_title="Value",
        legend_title="Legend",
        width=1000,
        height=1200,
    )

    # Show the figure
    fig.show()

    return fig

## Analysis: Status quo of input + forget gate with different offsets

Goal: we want to have roughly the same max/min and mean + std values of the random input. then we consider the mlstm as stable. (of course it still should do some information routing)

Conclusion: 
- the input gate preact needs a (larger) negative bias in order to avoid the feature spikes (large max/min values)
    - in between (-3, -6)
- forget gate preact needs larger positive bias (sigmoid(large x) = 1) in order to keep the max/min values on the qkv max/min level
    - around 6 ca.

- these values indicate that the input and forget gates are not independent of each other:
    - high fgate offset means much of the input is kept, i.e. the "sum window" is larger hence larger max min values
    - in order to avoid that the max / min values explode one can now also decrease the input gate by making the offset more negative, then many smaller values will be summed
    - how can we solve that?

- bfloat16 artefacts:
    - we see some numerical artefacts with the unstabilized fgate version that happen when the fgate offset is <-3 (m_state spikes)
    - we do not see this in the torch stabilized version and also NOT in the max_triton_v3 kernel!

Issue Summary: 
1. Coupling of input and forget gate. 
2. fgate offset sensitivity
3. igate offset sensitivity

Fixes for 1.
- Just make both learnable and hope that gradient decent finds an equilibrium
    - maybe by addressing 2. and 3., learning this gets easier
- look at trained models and adapt initialization
    - could also help the first "hope"

Fixes for 2.
- make the fgate less dependent on the bias -> use the GLA / Mamba trick
- add an additional non learnable bias (offset in code)
- adapt initialization

Fixes for 3.
- additional non learnable bias (offset in code)
- adapt initialization
    - first experiment on this is very promising at 7B scale


### Plots

In [None]:
print("paper version unstable fgate")

fig = make_h_output_plot_mlstm_with_internals_with_separate_sliders(
    mlstm_func=partial(mlstm_paper_unstable_fgate, mstate_mode="exp_minus_m_to_one"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset_range=[-6, -3, 0],
    vecF_offset_range=(-6, 10, 1),
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)

In [None]:
print("paper version stable fgate")

fig = make_h_output_plot_mlstm_with_internals_with_separate_sliders(
    mlstm_func=partial(mlstm_paper_stable_fgate, mstate_mode="paper"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset_range=[-6, -3, 0],
    vecF_offset_range=(-6, 10, 1),
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)

In [None]:
print("paper version max_triton_v3 kernels")

fig = make_h_output_plot_mlstm_with_internals_with_separate_sliders(
    mlstm_func=mlstm_chunkwise_max_triton_v3,
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset_range=[-6, -3, 0],
    vecF_offset_range=(-6, 10, 1),
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)

## Fix attempt "glafgate" for (2) fgate sensitivity

In [52]:
def mlstm_stable_fgate_gla(
    matQ: torch.Tensor,
    matK: torch.Tensor,
    matV: torch.Tensor,
    vecI: torch.Tensor,
    vecF: torch.Tensor,
    tau: float = 1.0,
    eps: float = 1e-6,
    mstate_mode: str = "paper",
) -> torch.Tensor:
    import math

    B, NH, S, DHQK = matQ.shape
    assert matK.shape == (B, NH, S, DHQK)
    assert vecI.shape == (B, NH, S)
    assert vecF.shape == (B, NH, S)

    _dtype, _device = matQ.dtype, matQ.device

    vecLogSigF = torch.log(torch.sigmoid(vecF) ** (1 / tau))  # (B, NH, S)

    matLogSigF_tril = vecLogSigF[:, :, :, None].repeat(1, 1, 1, S).tril(-1)
    matLogSigF_cum = matLogSigF_tril.cumsum(-2)

    ltr = torch.tril(
        torch.ones(
            (S, S),
            dtype=torch.bool,
            device=_device,
        )
    )

    matLogSigF_mask = torch.where(ltr, matLogSigF_cum, -float("inf"))

    matLogD = matLogSigF_mask + vecI[:, :, None, :]

    vecM, _ = torch.max(matLogD, dim=-1, keepdim=True)  # (B, NH, S, 1)
    matLogD_stabilized = matLogD - vecM

    matD = torch.exp(matLogD_stabilized)  # (B, NH, S, S)

    matS = (matQ @ matK.transpose(-2, -1)) / math.sqrt(DHQK)  # (B, NH, S, S)

    matCtilde = matS * matD  # (B, NH, S, S)
    if mstate_mode == "paper":
        vecN = torch.maximum(matCtilde.sum(dim=-1, keepdim=True).abs(), torch.exp(-vecM))  # (B, NH, S, 1)
    elif mstate_mode == "exp_minus_m_to_one":
        vecN = torch.maximum(
            matCtilde.sum(dim=-1, keepdim=True).abs(),
            torch.tensor([1.0], device=_device, dtype=_dtype),
        )  # (B, NH, S, 1)
    elif mstate_mode == "sum_only":
        vecN = matCtilde.sum(dim=-1, keepdim=True).abs()

    elif mstate_mode == "denom_one":
        vecN = torch.tensor([1.0], device=_device, dtype=_dtype)

    else:
        raise ValueError(f"mstate_mode {mstate_mode} not recognized")
    # (B, NH, S, S)
    matC = matCtilde / (vecN + eps)

    matH = matC @ matV  # (B, NH, S, DH)

    return (
        matH,
        vecM.squeeze(-1),
        vecN.squeeze(-1),
        matLogD,
        matLogD_stabilized,
        matD,
        matCtilde,
        matC,
        vecLogSigF,
        None,
    )

In [None]:
print("fix attempt GLA fgate version stable fgate")

fig = make_h_output_plot_mlstm_with_internals_with_separate_sliders(
    mlstm_func=partial(mlstm_stable_fgate_gla, mstate_mode="paper", tau=1.0),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset_range=[-6, -3, 0],
    vecF_offset_range=(-6, 10, 1),
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
)

In [None]:
print("fix attempt GLA fgate version stable fgate")

fig = make_h_output_plot_mlstm_with_internals_with_separate_sliders(
    mlstm_func=partial(mlstm_stable_fgate_gla, mstate_mode="paper", tau=16.0),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset_range=[-6, -3, 0],
    vecF_offset_range=(-6, 10, 1),
    seed=0,
    vecI_init_fn=vecI_init_fn,
    vecF_init_fn=vecF_init_fn,
    plot_m_state=False,
)

## Make a Sweep over vecI, vecF offset

In [55]:
def compute_mlstm_outputs(
    mlstm_func,
    B,
    NH,
    S,
    DHQK,
    DHV,
    vecI_offset,
    vecF_offset,
    vecI_init_fn=torch.randn,
    vecF_init_fn=torch.randn,
):
    torch.manual_seed(0)
    matQ = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
    matK = torch.randn((B, NH, S, DHQK), dtype=DTYPE, device=DEVICE)
    matV = torch.randn((B, NH, S, DHV), dtype=DTYPE, device=DEVICE)

    vecI = vecI_offset + vecI_init_fn((B, NH, S), dtype=DTYPE, device=DEVICE)
    vecF = vecF_offset + vecF_init_fn((B, NH, S), dtype=DTYPE, device=DEVICE)

    out = mlstm_func(matQ, matK, matV, vecI, vecF)

    if isinstance(out, tuple):
        (
            h_out,
            m_out,
            n_out,
            matLogD,
            matLogD_stabilized,
            matD,
            matCtilde,
            matC,
            vecLogSigF,
            vecLogSigF_cumsum,
        ) = out
    else:
        h_out = out
        m_out = None
        n_out = None
        matLogD = None
        matLogD_stabilized = None
        matD = None
        matCtilde = None
        matC = None
        vecLogSigF = None
        vecLogSigF_cumsum = None

    return (
        h_out,
        m_out,
        n_out,
        matLogD,
        matLogD_stabilized,
        matD,
        matCtilde,
        matC,
        vecLogSigF,
        vecLogSigF_cumsum,
    )

In [56]:
def make_offset_sweep(
    mlstm_func,
    B,
    NH,
    S,
    DHQK,
    DHV,
    vecI_offset_range,
    vecF_offset_range,
    vecI_init_fn=torch.randn,
    vecF_init_fn=torch.randn,
    metric: str = "h_out_max_mean",
):
    data = []
    data_tensor = torch.zeros(len(vecI_offset_range), len(vecF_offset_range))
    for i, vecI_offset in enumerate(vecI_offset_range):
        for j, vecF_offset in enumerate(vecF_offset_range):
            out = compute_mlstm_outputs(
                mlstm_func,
                B,
                NH,
                S,
                DHQK,
                DHV,
                vecI_offset,
                vecF_offset,
                vecI_init_fn,
                vecF_init_fn,
            )
            if metric == "h_out_max_mean":
                h_out = out[0]
                h_out_max = h_out.max(-1)[0].mean()
                metric_val = h_out_max
            else:
                raise ValueError(f"metric {metric} not recognized")
            data_val = {
                "vecI_offset": vecI_offset.item(),
                "vecF_offset": vecF_offset.item(),
                "metric": metric_val.item(),
            }
            data.append(data_val)
            data_tensor[i, j] = metric_val.cpu()

    return data, data_tensor

In [57]:
# Set vecF_offset to -6 (almost any negative value) produces an m_state spike
B = 1
NH = 1
S = 2048
D = 1024
DHQK = D
DHV = D
vecI_offset = -3.0  # -3.0
vecF_offset = 5.0
vecI_init_fn = torch.randn
vecF_init_fn = torch.randn

In [58]:
vecI_offset_range = torch.linspace(-8, 6, 50)
vecF_offset_range = torch.linspace(-5, 12, 50)
data, data_tensor = make_offset_sweep(
    mlstm_func=partial(mlstm_paper_unstable_fgate, mstate_mode="paper"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset_range=vecI_offset_range,
    vecF_offset_range=vecF_offset_range,
    metric="h_out_max_mean",
)

In [59]:
import matplotlib as mpl
import numpy as np

In [None]:
fig, ax = plt.subplots()
grid_x, grid_y = torch.meshgrid(vecF_offset_range, vecI_offset_range, indexing="ij")
grid_x = grid_x.cpu().numpy()
grid_y = grid_y.cpu().numpy()
data_z = data_tensor.transpose(0, 1).cpu().numpy()

# levels = mpl.ticker.MaxNLocator(nbins=20).tick_values(data_z.min(), data_z.max())
levels = np.linspace(0, 10, 10)
cmap = plt.colormaps["PiYG"]
norm = mpl.colors.BoundaryNorm(levels, ncolors=cmap.N, clip=True)

im = ax.pcolormesh(grid_x, grid_y, data_z, cmap=cmap, norm=norm)
fig.colorbar(im, ax=ax)
ax.set_title(label="h_out max (over feature dim) (mean over time)")
ax.set_ylabel("vecI_offset")
ax.set_xlabel("vecF_offset")

In [None]:
vecI_offset_range = torch.linspace(-8, 6, 50)
vecF_offset_range = torch.linspace(-5, 12, 50)
data, data_tensor = make_offset_sweep(
    mlstm_func=partial(mlstm_paper_unstable_fgate, mstate_mode="paper"),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset_range=vecI_offset_range,
    vecF_offset_range=vecF_offset_range,
    metric="h_out_max_mean",
)
fig, ax = plt.subplots()
grid_x, grid_y = torch.meshgrid(vecF_offset_range, vecI_offset_range, indexing="ij")
grid_x = grid_x.cpu().numpy()
grid_y = grid_y.cpu().numpy()
data_z = data_tensor.transpose(0, 1).cpu().numpy()

# levels = mpl.ticker.MaxNLocator(nbins=20).tick_values(data_z.min(), data_z.max())
levels = np.linspace(0, 10, 10)
cmap = plt.colormaps["PiYG"]
norm = mpl.colors.BoundaryNorm(levels, ncolors=cmap.N, clip=True)

im = ax.pcolormesh(grid_x, grid_y, data_z, cmap=cmap, norm=norm)
fig.colorbar(im, ax=ax)
ax.set_title(label="h_out max (over feature dim) (mean over time)")
ax.set_ylabel("vecI_offset")
ax.set_xlabel("vecF_offset")

### GLA

In [62]:
vecI_offset_range = torch.linspace(-8, 6, 50)
vecF_offset_range = torch.linspace(-5, 12, 50)
data, data_tensor = make_offset_sweep(
    mlstm_func=partial(mlstm_stable_fgate_gla, mstate_mode="paper", tau=16.0),
    B=B,
    NH=NH,
    S=S,
    DHQK=DHQK,
    DHV=DHV,
    vecI_offset_range=vecI_offset_range,
    vecF_offset_range=vecF_offset_range,
    metric="h_out_max_mean",
)

In [None]:
import matplotlib as mpl
import numpy as np

fig, ax = plt.subplots()
grid_x, grid_y = torch.meshgrid(vecF_offset_range, vecI_offset_range, indexing="ij")
grid_x = grid_x.cpu().numpy()
grid_y = grid_y.cpu().numpy()
data_z = data_tensor.transpose(0, 1).cpu().numpy()

# levels = mpl.ticker.MaxNLocator(nbins=20).tick_values(data_z.min(), data_z.max())
levels = np.linspace(0, 10, 10)
cmap = plt.colormaps["PiYG"]
norm = mpl.colors.BoundaryNorm(levels, ncolors=cmap.N, clip=True)

im = ax.pcolormesh(grid_x, grid_y, data_z, cmap=cmap, norm=norm)
fig.colorbar(im, ax=ax)
ax.set_title(label="h_out max (over feature dim) (mean over time)")
ax.set_ylabel("vecI_offset")
ax.set_xlabel("vecF_offset")

In [None]:
for i_idx in range(len(vecI_offset_range)):
    plt.plot(
        vecF_offset_range.cpu().numpy(),
        data_tensor[i_idx, :].cpu().numpy(),
        label=f"vecI_offset={vecI_offset_range[i_idx].item()}",
    )
plt.legend(loc="upper left", bbox_to_anchor=(1, 1))