In [1]:
import torch
import torch.nn as nn
import mamba_ssm
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from tqdm.notebook import trange, tqdm

ModuleNotFoundError: No module named 'mamba_ssm'

In [None]:
in_features_min=64
in_features_max=8192
head_dim=64
seq_len=512
bsz=2
device="cuda"

In [None]:
in_feat = in_features_min
in_features_list = []
while in_feat <= in_features_max:
    in_features_list.append(in_feat)
    in_feat *= 2


In [None]:
def init_lins(mod: nn.Module)->None:
    for lin_name, lin in mod.named_modules():
        if isinstance(lin, nn.Linear):
            print(f"Init {lin_name=}")
            nn.init.normal_(lin.weight, std=1/(lin.in_features**0.5))
            if lin.bias is not None:
                nn.init.zeros_(lin.bias)


In [None]:
mlp = GatedMLP(512, device=device)
with torch.no_grad():
    print(mlp.fc1.weight.mean())
    print(mlp.fc1.weight.pow(2).mean())

mlp = GatedMLP(512, device=device)
init_lins(mlp)
with torch.no_grad():
    print(mlp.fc1.weight.mean())
    print(mlp.fc1.weight.pow(2).mean())

In [None]:
inputs = torch.randn(16, 512, device=device)
lin = nn.Linear(512, 512, bias=False, device=device)
with torch.no_grad():
    out = lin(inputs)
    print(out.mean())
    print(out.pow(2).mean())
    
lin = nn.Linear(512, 512, bias=False, device=device)
init_lins(lin)
with torch.no_grad():
    out = lin(inputs)
    print(out.mean())
    print(out.pow(2).mean())

In [None]:
from mamba_ssm.modules.mlp import GatedMLP
    
mlp_amp = False
mlp_results = []
for custom_init in (True, False):
    for in_features in tqdm(in_features_list): 
        mlp = GatedMLP(in_features, device=device)
        if custom_init:
            init_lins(mlp)
        
        inputs = torch.randn(bsz, seq_len, in_features, device=device)
        with torch.no_grad():
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=mlp_amp):
                outputs = mlp(inputs)
        mlp_results_dict = {"in_features": in_features,
                        "l2_mean": outputs.pow(2).mean().item(),
                        "l1_mean": outputs.abs().mean().item(),
                        "l2_sum": outputs.pow(2).sum().item(),
                        "l1_sum": outputs.abs().sum().item(),
                        "std": outputs.std().item(),
                        "var": outputs.var().item(),
                        "mean": outputs.mean().item(),
                        "custom_init": custom_init,
                       }
        mlp_results.append(mlp_results_dict)
    
mlp_df = pd.DataFrame(mlp_results)
mlp_df

In [None]:
mlp_plot = sns.lineplot(data=mlp_df, x="in_features", y="l2_mean", hue="custom_init")
mlp_plot.set(xscale="log")
mlp_plot.set(yscale="log")
plt.suptitle("MLP scaling")

In [None]:
from mamba_ssm.modules.mha import MHA

attn_results = []
attn_amp = False
for custom_init in (True, False):
    for in_features in tqdm(in_features_list):
        attn_cfg = {
        "causal": True,
        "head_dim": head_dim,
        "num_heads": in_features // head_dim,
        "out_proj_bias": False,
        "qkv_proj_bias": False,
        "rotary_emb_dim": head_dim // 2,  # Apparently correct for mamba-ssm
    }
    
        mha = MHA(in_features, **attn_cfg, device=device)
        if custom_init:
            init_lins(mha)
        inputs = torch.randn(bsz, seq_len, in_features, device=device)
        with torch.no_grad():
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=attn_amp):
                outputs = mha(inputs)
        attn_results_dict = {"in_features": in_features,
                        "l2_mean": outputs.pow(2).mean().item(),
                        "l1_mean": outputs.abs().mean().item(),
                        "l2_sum": outputs.pow(2).sum().item(),
                        "l1_sum": outputs.abs().sum().item(),
                        "std": outputs.std().item(),
                        "var": outputs.var().item(),
                        "mean": outputs.mean().item(),
                        "custom_init": custom_init,
                       }
        attn_results.append(attn_results_dict)

attn_df = pd.DataFrame(attn_results)
attn_df


In [None]:
attn_plot = sns.lineplot(data=attn_df, x="in_features", y="l2_mean", hue="custom_init")
attn_plot.set(xscale="log")
attn_plot.set(yscale="log")

plt.suptitle("MHA scaling")

In [None]:
import torch.nn as nn
from typing import Any, Optional

class InputStatsHook:
    def __init__(
        self,
        module: nn.Module,
        name: str,
        results_list: list[dict],
        width: int,
        other_data: Optional[dict] = None,
    ) -> None:
        self.module = module
        self.name = name
        self.width = width
        self.results_list = results_list
        self._hook = module.register_forward_pre_hook(self)
        self._step = 0
        self.other_data = other_data or {}

    def __call__(self, module: nn.Module, args: Any) -> None:
        inputs = args[0]
        results = {"name": self.name, "width": self.width, "step": self._step}
        results = {**results, **self.other_data}
        with torch.no_grad():
            results["mean"] = inputs.mean().item()
            results["l1_mean"] = inputs.abs().mean().item()
            results["l2_mean"] = inputs.pow(2).mean().item()
            results["std"] = inputs.std().item()
            results["var"] = inputs.var().item()
        self.results_list.append(results)
        self._step += 1



    def remove(self) -> None:
        self._hook.remove()


In [None]:
from mamba_ssm.models.config_mamba import MambaConfig
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from fms_fsdp.mup.mup_mamba import apply_mup_init


model_results = []
lm_head_input_results = []
model_amp = True
mup = True
n_layer=4
vocab_size=128256
head_dim=128
for in_features in trange(in_features_min, in_features_max+1, 4*head_dim):
    attn_cfg = {
    "causal": True,
    "head_dim": head_dim,
    "num_heads": in_features // head_dim,
    "out_proj_bias": False,
    "qkv_proj_bias": False,
    "rotary_emb_dim": head_dim // 2,  # Apparently correct for mamba-ssm
}
    if mup:
        attn_cfg["softmax_scale"] = head_dim
    
    config = MambaConfig(
    d_model=in_features,
    d_intermediate=4 * in_features,
    n_layer=n_layer,
    attn_layer_idx=list(range(n_layer)),  # Transformer-only blocks
    vocab_size=vocab_size,
    attn_cfg=attn_cfg,
    tie_embeddings=False,
)
    model = MambaLMHeadModel(config=config, device=device)

    hook = InputStatsHook(model.lm_head, "lm_head", lm_head_input_results, width=in_features)
    if mup:
        apply_mup_init(model)
    inputs = torch.randint(vocab_size, size=(bsz, seq_len), device=device)
    with torch.no_grad():
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=model_amp):
            outputs = model(inputs).logits
        print(f"{outputs.shape=}, {in_features=}")
    model_results_dict = {"in_features": in_features,
                    "l2_mean": outputs.pow(2).mean().item(),
                    "l1_mean": outputs.abs().mean().item(),
                    "l2_sum": outputs.pow(2).sum().item(),
                    "l1_sum": outputs.abs().sum().item(),
                    "std": outputs.std().item(),
                    "var": outputs.var().item(),
                    "mean": outputs.mean().item(),
                   }
    model_results.append(model_results_dict)
model_df = pd.DataFrame(model_results)
model_df


In [None]:
model_plot = sns.lineplot(data=model_df, x="in_features", y="l2_mean")
model_plot.set(xscale="log")
model_plot.set(yscale="log")

plt.suptitle("Model scaling")


In [None]:
lm_head_df = pd.DataFrame(lm_head_input_results)
lm_head_df


In [None]:
lm_head_plot = sns.lineplot(data=lm_head_df, x="width", y="l2_mean")
lm_head_plot.set(xscale="log")
lm_head_plot.set(yscale="log")

plt.suptitle("LM Head scaling")

In [None]:
lm_head_df

In [None]:
test_data = [{"x": y, "y":y + 1 if group == "red" else 0, "group": group} for group in ("red", "blue") for y in range(10)]

In [None]:
test_df = pd.DataFrame(test_data)
test_df

In [None]:
sns.lineplot(data=test_df, x="x", y="y", hue="group")

In [None]:
dd = {g: (1, 0) for g in test_df.group.unique()}
dd["blue"] = (4, 2)

In [None]:
sns.lineplot(data=test_df, x="x", y="y", hue="group", dashes=dd, style="group")