## 예제 16.1 맘바 블록 코드
코드 출처: https://github.com/johnma2006/mamba-minimal/blob/master/model.py

In [None]:
class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.args = args
        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)
        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )
        # ssm 내부에서 사용
        # 입력 x를 확장해 Δ, B, C를 위한 벡터를 생성하는 층
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        # dt_rank차원을 d_inner차원으로 확장해 Δ 생성하는 층
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)
        A = repeat(torch.arange(1, args.d_state + 1), 'd_state -> d_model d_state',
        d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
    def forward(self, x):
        (b, l, d_model) = x.shape
        x_and_res = self.in_proj(x) # shape (b, l, 2 * d_inner)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner],
        dim=-1)
        x = rearrange(x, 'b l d_inner -> b d_inner l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_inner l -> b l d_inner')
        x = F.silu(x)
        y = self.ssm(x)
        y = y * F.silu(res)
        output = self.out_proj(y)
    return output

## 예제 16.2 ssm 메서드
코드 출처: https://github.com/johnma2006/mamba-minimal/blob/master/model.py

In [None]:
def ssm(self, x):
    (d_inner, d_state) = self.A_log.shape
    A = -torch.exp(self.A_log.float()) # shape (d_inner, d_state)
    D = self.D.float()
    x_dbl = self.x_proj(x) # (b, l, dt_rank + 2*d_state)
    
    (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, d_state, d_state], dim=-1)
    delta = F.softplus(self.dt_proj(delta)) # (b, l, d_inner)
    
    y = self.selective_scan(x, delta, A, B, C, D)
    return y

## 예제 16.3 selective_scan 코드
코드 출처: https://github.com/johnma2006/mamba-minimal/blob/master/model.py

In [None]:
def selective_scan(self, x, delta, A, B, C, D):
    (b, l, d_inner) = x.shape
    d_state = A.shape[1]
    
    deltaA = torch.exp(einsum(delta, A, 'b l d_inner, d_inner d_state -> b l d_inner
    d_state'))
    deltaB_x = einsum(delta, B, x, 'b l d_inner, b l d_state, b l d_inner -> b l d_inner d_state')
    
    h = torch.zeros((b, d_in, d_state), device=deltaA.device)
    ys = []
    for i in range(l):
        h = deltaA[:, i] * h + deltaB_x[:, i]
        y = einsum(h, C[:, i, :], 'b d_inner d_state, b d_state -> b d_inner')
        ys.append(y)
    y = torch.stack(ys, dim=1) # shape (b, l, d_in)
    y = y + x * D
    return y