In [64]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [65]:
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import pyrallis
import torch
from torch import Tensor

from src.consts import (
    FILTERATIONS,
    MODEL_SIZES_PER_ARCH_TO_MODEL_ID,
)
from src.types import DATASETS, MODEL_ARCH, DatasetArgs, TModelID

In [66]:
@dataclass
class Args:
    # model_arch: MODEL_ARCH = MODEL_ARCH.MINIMAL_MAMBA2_new
    model_arch: MODEL_ARCH = MODEL_ARCH.MAMBA1
    model_size: str = "2.8B"
    dataset_args: DatasetArgs = pyrallis.field(
        default=DatasetArgs(name=DATASETS.COUNTER_FACT, splits="all"), is_mutable=True
    )
    filteration: str = FILTERATIONS.all_correct
    _batch_size: int = 16  # Adjust based on GPU memory
    output_file: Optional[Path] = None
    with_slurm: bool = False
    temperature = 1
    top_k = 0
    top_p = 1
    window_size = 5
    prompt_indices = [1,2,3,4,5]
    knockout_map = {'last': ['last', 'first', "subject", "relation"], 
                    'subject': ['context', 'subject']}

    output_dir: Optional[Path] = None

    @property
    def batch_size(self) -> int:
        return (
            1
            if (
                self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2
                or self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2_new
            )
            else self._batch_size
        )

    @property
    def model_id(self) -> TModelID:
        return MODEL_SIZES_PER_ARCH_TO_MODEL_ID[self.model_arch][self.model_size]


In [None]:
from typing import Iterable

from torch import matmul, zeros_like

from .. import KnockoutMode


def knockout_scan(seq_len: int, ssm_state: Tensor, discrete_A: Tensor, discrete_B: Tensor, u: Tensor, C: Tensor, knocked_out_inputs: Iterable[int], affected_outputs: Iterable[int], knockout_mode: KnockoutMode, dtype) -> List[Tensor]:
    deltaB_u = discrete_B * u[:, :, :, None].float()
    knockout_state = zeros_like(ssm_state)
    scan_outputs = []
    for i in range(seq_len):
        ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]      # [batch, intermediade_size, ssm_state]
        if i not in knocked_out_inputs:
            knockout_state = discrete_A[:, :, i, :] * knockout_state + deltaB_u[:, :, i, :]      # [batch, intermediade_size, ssm_state]
        elif i in knocked_out_inputs:
            if knockout_mode == KnockoutMode.ZERO_ATTENTION:
                knockout_state = discrete_A[:, :, i, :] * knockout_state
            elif knockout_mode == KnockoutMode.ZERO_DELTA:
                knockout_state = knockout_state
        if i in affected_outputs:
            scan_output = matmul(knockout_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        else:
            scan_output = matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        scan_outputs.append(scan_output[:, :, 0])

    return scan_outputs


def materialize_ssm_transition(A: torch.Tensor) -> torch.Tensor:
    batch = A.shape[0]
    D = A.shape[1]
    T = A.shape[2]
    N = A.shape[3]
    A = A.transpose(-1,-2).repeat(1,1,1,T).reshape(batch,D,N,T,T).transpose(-1,-2)
    A = torch.tril(A) + torch.triu(torch.ones_like(A),1)
    A_cumprod = torch.cumprod(A, dim=-2)

    transition_mat = A_cumprod.transpose(-2,-3)

    return transition_mat


def materialize_ssm_attention(A: Tensor, B: Tensor, C: Tensor, return_transition: bool) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: 
    transition_mat = materialize_ssm_transition(A)

    AB = (transition_mat * B.unsqueeze(-1))

    out = torch.einsum('btn, bdtnq -> bdtq', C, AB)
    
    if return_transition:
        return out, transition_mat
    
    return out


def knockout_matrix(seq_len: int, discrete_A: Tensor, discrete_B: Tensor, u: Tensor, C: Tensor, knocked_out_inputs: Iterable[int], affected_outputs: Iterable[int], dtype) -> List[Tensor]:
    attn = materialize_ssm_attention(discrete_A, discrete_B, C, False)
    for i, j in zip(affected_outputs, knocked_out_inputs):
        attn[:, :, i, j] = 0
    outputs = attn * u[:, :, :, None].float()
    return outputs


In [83]:
from typing import Iterable, List

from torch import Tensor


def knockout_scan(seq_len: int, ssm_state: Tensor, discrete_A: Tensor, deltaB_u: Tensor, C: Tensor, knocked_out_inputs: Iterable[int], affected_outputs: Iterable[int], dtype) -> List[Tensor]:
    knockout_state = zeros_like(ssm_state)
    scan_outputs = []
    for i in range(seq_len):
        ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]      # [batch, intermediade_size, ssm_state]
        print(ssm_state)
        if i not in knocked_out_inputs:
            knockout_state = discrete_A[:, :, i, :] * knockout_state + deltaB_u[:, :, i, :]      # [batch, intermediade_size, ssm_state]
        elif i in knocked_out_inputs:
            knockout_state = discrete_A[:, :, i, :] * knockout_state
        if i in affected_outputs:
            scan_output = matmul(knockout_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        else:
            scan_output = matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        scan_outputs.append(scan_output[:, :, 0])

    return scan_outputs

In [84]:
from typing import Iterable, List

from torch import Tensor


def scan(seq_len: int, ssm_state: Tensor, discrete_A: Tensor, deltaB_u: Tensor, C: Tensor, knocked_out_inputs: Iterable[int], affected_outputs: Iterable[int], dtype) -> List[Tensor]:
    scan_outputs = []
    for i in range(seq_len):
        ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]      # [batch, intermediade_size, ssm_state]
        # print(ssm_state)
        scan_output = matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        scan_outputs.append(scan_output[:, :, 0])

    return scan_outputs

In [101]:
import torch

batch = 1
d_intermediate = 1
T = 5
ssm_d = 8

A = torch.ones((batch, d_intermediate, T, ssm_d)) # [batch, intermediate_size, seq_len, ssm_state_size]
B = torch.ones((batch, d_intermediate, T, ssm_d)) # [batch, intermediate_size, seq_len, ssm_state_size]
B = B / ssm_d
C = torch.ones((batch, T, ssm_d)) # [batch, seq_len, intermediate_size, ssm_state_size]
u = torch.Tensor([[[2**i for i in range(T)]]])    # [batch, seq_len, intermediate_size]

Bu = B * u[:,:,:,None]

state = torch.zeros(batch,d_intermediate,ssm_d)

In [102]:
out = knockout_scan(T, state, A, Bu, C, [0,1,2], [4], torch.float)
# out = scan(T, state, A, Bu, C, [1], [3], torch.float)
print(u)
print(u.cumsum(dim=-1))
for i in out:
    print(i, "{0:b}".format(int(i.item())))


tensor([[[0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]]])
tensor([[[0.3750, 0.3750, 0.3750, 0.3750, 0.3750, 0.3750, 0.3750, 0.3750]]])
tensor([[[0.8750, 0.8750, 0.8750, 0.8750, 0.8750, 0.8750, 0.8750, 0.8750]]])
tensor([[[1.8750, 1.8750, 1.8750, 1.8750, 1.8750, 1.8750, 1.8750, 1.8750]]])
tensor([[[3.8750, 3.8750, 3.8750, 3.8750, 3.8750, 3.8750, 3.8750, 3.8750]]])
tensor([[[ 1.,  2.,  4.,  8., 16.]]])
tensor([[[ 1.,  3.,  7., 15., 31.]]])
tensor([[1.]]) 1
tensor([[3.]]) 11
tensor([[7.]]) 111
tensor([[15.]]) 1111
tensor([[24.]]) 11000


In [130]:
def compute_attn_matrix_fn(dA, dB, C, L, x_shape, dtype=torch.float16):
    # dA = torch.exp(torch.einsum("bdl,dn->bldn", dt, A))
    # dB = torch.einsum("bdl,bnl->bldn", dt, B.squeeze(1))
    AttnMatrixOverCLS = torch.zeros((x_shape[0], x_shape[1], x_shape[2], x_shape[2])).to(dtype).to(dA.device) #BHLL: L vectors per batch and channel
    for r in range(L):
        for c in range(r+1):
            curr_C = C[:,r,:]
            currA = torch.ones((dA.shape[0],dA.shape[1],dA.shape[3]), dtype = dtype).to(dA.device)
            if c < r:
                for i in range(r-c):
                    currA = currA*dA[:,:,r-i,:]
            currB = dB[:,:,c,:]
            AttnMatrixOverCLS[:,:,r,c] = torch.sum(curr_C*currA*currB, axis=-1)
    return AttnMatrixOverCLS

def knockout_matrix(seq_len: int, discrete_A: Tensor, discrete_B: Tensor, u: Tensor, C: Tensor, knocked_out_inputs: Iterable[int], affected_outputs: Iterable[int], dtype) -> List[Tensor]:
    # _attn = materialize_ssm_attention(discrete_A, discrete_B, C, False)
    attn = compute_attn_matrix_fn(discrete_A, discrete_B, C, seq_len, u.shape, dtype)
    for i in affected_outputs:
        for j in knocked_out_inputs:
            attn[:, :, i, j] = 0
    # u = u.transpose(2,1)
    outputs = (attn @ u).squeeze(-1)
    return outputs

In [132]:
materialize_ssm_attention(A, B, C, False)
out = knockout_matrix(T, A, B, u.unsqueeze(-1), C, [0,1,3], [4], torch.float)
for i in out[0][0]:
    print(i, "{0:b}".format(int(i.item())))


tensor(1.) 1
tensor(3.) 11
tensor(7.) 111
tensor(15.) 1111
tensor(20.) 10100


In [16]:
from dataclasses import dataclass
from typing import Iterable, Optional

import torch
from einops import rearrange, repeat
from torch import Tensor, device


def segsum(x: Tensor, device: Optional[device] = None) -> Tensor:
    """Stable segment sum calculation.

    `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.

    Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
    """
    T = x.size(-1)
    x = repeat(x, "... d -> ... d e", e=T)
    mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
    x = x.masked_fill(~mask, 0)
    x_segsum = torch.cumsum(x, dim=-2)
    mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
    return x_segsum


def mamba2_knockout(L,B,C,x, list_of_masks):
    CBT = torch.einsum("bclhn, bcshn -> bhcls", C, B)
    print(CBT.shape)
    attention_matrix = torch.einsum("bhcls, bhcls -> bclhs", CBT, L)
    # attention_matrix shape is: batch, chunk, seq_len , heads, seq_len
    # To remove the attention *given* by idx1 to idx2 do :
    # attention_matrix[:,:,idx1,:,idx2] = 0
    # by this notion idx1 is always greater or equal to idx2
    # attention_matrix[:, :, 11, :, 3] = 0
    for idx1, idx2 in list_of_masks:
        attention_matrix[:, :, idx1, :, idx2] = 0

    out_by_atten = torch.einsum("bclhs, bcshp-> bclhp", attention_matrix, x)
    out_by_atten = rearrange(out_by_atten, 'b c l h p -> b (c l) h p')
    return out_by_atten

In [30]:
import torch

b, c, l, h, n, p = 1, 1, 5, 1, 8, 1
L = torch.ones((b, h, c, l, l)) #bhcls
L = torch.tril(L)

# bclhn, bcshn
C = torch.ones((b, c, l, h, n))
B = torch.ones((b, c, l, h, n))
B = B / n
u = torch.ones((b, c, l, h, p))
for i in range(l):
    u[0, 0, i, 0, 0] = 2**i

knockout_list = [(3,1), (4,1)]

out = mamba2_knockout(L,B,C,u, knockout_list)
print(out)
print(out.shape)
for i in range(l):
    print("{0:b}".format(int(out[0,i,0,0].item())))

torch.Size([1, 1, 1, 5, 5])
tensor([[[[ 1.]],

         [[ 3.]],

         [[ 7.]],

         [[13.]],

         [[29.]]]])
torch.Size([1, 5, 1, 1])
1
11
111
1101
11101
