In [64]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [65]:
import json
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional

import numpy as np
import pandas as pd
import pyrallis
import torch
from torch import Tensor
from transformers import (
    AutoTokenizer,
)

from src.consts import (
    FILTERATIONS,
    MODEL_SIZES_PER_ARCH_TO_MODEL_ID,
    PATHS,
)
from src.datasets.download_dataset import load_dataset
from src.types import DATASETS, MODEL_ARCH, DatasetArgs, TModelID

In [66]:
@dataclass
class Args:
    # model_arch: MODEL_ARCH = MODEL_ARCH.MINIMAL_MAMBA2_new
    model_arch: MODEL_ARCH = MODEL_ARCH.MAMBA1
    model_size: str = "2.8B"
    dataset_args: DatasetArgs = pyrallis.field(
        default=DatasetArgs(name=DATASETS.COUNTER_FACT, splits="all"), is_mutable=True
    )
    filteration: str = FILTERATIONS.all_correct
    _batch_size: int = 16  # Adjust based on GPU memory
    output_file: Optional[Path] = None
    with_slurm: bool = False
    temperature = 1
    top_k = 0
    top_p = 1
    window_size = 5
    prompt_indices = [1, 2, 3, 4, 5]
    knockout_map = {"last": ["last", "first", "subject", "relation"], "subject": ["context", "subject"]}

    output_dir: Optional[Path] = None

    @property
    def batch_size(self) -> int:
        return (
            1
            if (self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2 or self.model_arch == MODEL_ARCH.MINIMAL_MAMBA2_new)
            else self._batch_size
        )

    @property
    def model_id(self) -> TModelID:
        return MODEL_SIZES_PER_ARCH_TO_MODEL_ID[self.model_arch][self.model_size]

In [67]:
args = Args()

In [68]:
original_res, attn_res = [
    pd.read_parquet(
        PATHS.OUTPUT_DIR
        / args.model_id
        / "data_construction"
        / f"ds={args.dataset_args.dataset_name}"
        / f"entire_results_{'attention' if attention else 'original'}.parquet"
    )
    for attention in [True, False]
]

mask = (original_res["hit"] == attn_res["hit"]) & (attn_res["hit"] == True)
data = attn_res[mask]

In [14]:
import pandas as pd
import plotly.express as px

df = pd.read_csv(
    "/a/home/cc/students/cs/nirendy/repos/ssm_analysis/output/state-spaces/mamba-2.8B-hf/evaluate_context_interference/ds=counter_fact/2.8B_norm_1_output.csv"
)
px.line(df, x="layer", y="acc", color="category").show()

In [15]:
import pandas as pd
import plotly.express as px

df = pd.read_csv(
    "/a/home/cc/students/cs/nirendy/repos/ssm_analysis/output/state-spaces/mamba-1.4B-hf/evaluate_context_interference/ds=counter_fact/1.4B_norm_1_output.csv"
)
px.line(df, x="layer", y="acc", color="category").show()

In [16]:
import pandas as pd
import plotly.express as px

df = pd.read_csv(
    "/a/home/cc/students/cs/nirendy/repos/ssm_analysis/output/state-spaces/mamba-130M-hf/evaluate_context_interference/ds=counter_fact/130M_norm_1_output.csv"
)
px.line(df, x="layer", y="acc", color="category").show()

In [None]:
from typing import Iterable

from torch import matmul, zeros_like

from .. import KnockoutMode


def knockout_scan(
    seq_len: int,
    ssm_state: Tensor,
    discrete_A: Tensor,
    discrete_B: Tensor,
    u: Tensor,
    C: Tensor,
    knocked_out_inputs: Iterable[int],
    affected_outputs: Iterable[int],
    knockout_mode: KnockoutMode,
    dtype,
) -> List[Tensor]:
    deltaB_u = discrete_B * u[:, :, :, None].float()
    knockout_state = zeros_like(ssm_state)
    scan_outputs = []
    for i in range(seq_len):
        ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]  # [batch, intermediade_size, ssm_state]
        if i not in knocked_out_inputs:
            knockout_state = (
                discrete_A[:, :, i, :] * knockout_state + deltaB_u[:, :, i, :]
            )  # [batch, intermediade_size, ssm_state]
        elif i in knocked_out_inputs:
            if knockout_mode == KnockoutMode.ZERO_ATTENTION:
                knockout_state = discrete_A[:, :, i, :] * knockout_state
            elif knockout_mode == KnockoutMode.ZERO_DELTA:
                knockout_state = knockout_state
        if i in affected_outputs:
            scan_output = matmul(knockout_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        else:
            scan_output = matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        scan_outputs.append(scan_output[:, :, 0])

    return scan_outputs


def materialize_ssm_transition(A: torch.Tensor) -> torch.Tensor:
    batch = A.shape[0]
    D = A.shape[1]
    T = A.shape[2]
    N = A.shape[3]
    A = A.transpose(-1, -2).repeat(1, 1, 1, T).reshape(batch, D, N, T, T).transpose(-1, -2)
    A = torch.tril(A) + torch.triu(torch.ones_like(A), 1)
    A_cumprod = torch.cumprod(A, dim=-2)

    transition_mat = A_cumprod.transpose(-2, -3)

    return transition_mat


def materialize_ssm_attention(
    A: Tensor, B: Tensor, C: Tensor, return_transition: bool
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
    transition_mat = materialize_ssm_transition(A)

    AB = transition_mat * B.unsqueeze(-1)

    out = torch.einsum("btn, bdtnq -> bdtq", C, AB)

    if return_transition:
        return out, transition_mat

    return out


def knockout_matrix(
    seq_len: int,
    discrete_A: Tensor,
    discrete_B: Tensor,
    u: Tensor,
    C: Tensor,
    knocked_out_inputs: Iterable[int],
    affected_outputs: Iterable[int],
    dtype,
) -> List[Tensor]:
    attn = materialize_ssm_attention(discrete_A, discrete_B, C, False)
    for i, j in zip(affected_outputs, knocked_out_inputs):
        attn[:, :, i, j] = 0
    outputs = attn * u[:, :, :, None].float()
    return outputs

In [8]:
import pandas as pd

df = pd.read_parquet(
    "/a/home/cc/students/cs/nirendy/repos/ssm_analysis/output/state-spaces/mamba-2.8B-hf/data_construction/ds=counter_fact/entire_results_original.parquet"
)
df["target_true"].head()

0     Microsoft
1       English
2        French
3        French
4        Google
Name: target_true, dtype: object

In [133]:
def get_top5_knockout():
    # Open and read the JSON file
    with open(
        "/a/home/cc/students/cs/nirendy/repos/ssm_analysis/output/state-spaces/mamba-2.8B-hf/info_flow_test_top_outputs/ds=counter_fact/ws=9/block_last_target_last/outputs.json",
        "r",
    ) as file:
        data = json.load(file)

    # Print the data
    print(data)
    new_df = {
        # 'index': [],
        "row": [],
        "token": [],
        "prob": [],
        "rank": [],
        "token_id": [],
    }

    t = 0

    for i in range(len(data)):
        for j, opt in enumerate(data[i][0]):
            # new_df['index'].append(t)
            new_df["row"].append(i)
            new_df["rank"].append(j)
            new_df["token"].append(opt[1])
            new_df["prob"].append(opt[2])
            new_df["token_id"].append(opt[0])
            t += 1

    new_df = pd.DataFrame(new_df)
    return new_df

In [181]:
def get_correct_token_knockout_prob():
    # Open and read the JSON file
    with open(
        "/a/home/cc/students/cs/nirendy/repos/ssm_analysis/output/state-spaces/mamba-2.8B-hf/info_flow_test_top_outputs_5_last_windows/ds=counter_fact/ws=9/block_last_target_last/outputs.json",
        "r",
    ) as file:
        data = json.load(file)
    print(data.keys())
    knockout_probs = data["55"]
    new_df = {
        # 'index': [],
        "row": [],
        "prob": [],
    }

    t = 0

    for i, p in enumerate(knockout_probs):
        new_df["row"].append(i)
        new_df["prob"].append(p)

    new_df = pd.DataFrame(new_df)
    return new_df

In [182]:
get_correct_token_knockout_prob()

dict_keys(['51', '52', '53', '54', '55'])


Unnamed: 0,row,prob
0,0,0.960056
1,1,0.999680
2,2,0.987940
3,3,0.784016
4,4,0.998711
...,...,...
805,805,0.994226
806,806,0.220406
807,807,0.983616
808,808,0.994120


In [183]:
def load_dataset():
    from src.consts import (
        PATHS,
    )

    # "/a/home/cc/students/cs/nirendy/repos/ssm_analysis/output/state-spaces/mamba-2.8B-hf/data_construction/ds=counter_fact/entire_results_original.parquet"
    original_res, attn_res = [
        pd.read_parquet(
            PATHS.OUTPUT_DIR
            / "state-spaces/mamba-2.8B-hf"
            / "data_construction"
            / "ds=counter_fact"
            / f"entire_results_{'attention' if attention else 'original'}.parquet"
        )
        for attention in [True, False]
    ]

    mask = (original_res["hit"] == attn_res["hit"]) & (attn_res["hit"] == True)
    data = attn_res[mask]
    data = data.reset_index(drop=True)
    return data


def calc_correct(merged):
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    merged["correct"] = merged.apply(
        lambda row: row["token_id"] == tokenizer(row["target_true"])["input_ids"][0], axis=1
    )
    print(merged["correct"].sum())
    # filtered = merged[merged['correct']]
    return merged

In [184]:
ds = load_dataset()
top5 = get_top5_knockout()
top5 = top5[top5["rank"] == 0]
merged = ds.merge(top5, left_index=True, right_on=["row"])
merged = calc_correct(merged)
(~merged["correct"]).sum()

[[[[9664, ' Microsoft', 0.9600555300712585], [7464, ' Windows', 0.0322045236825943], [253, ' the', 0.006631913594901562], [16880, 'Microsoft', 0.0005786402616649866], [14877, 'Windows', 0.00013698582188226283]]], [[[4383, ' English', 0.9996795654296875], [4782, ' British', 0.00014863394608255476], [6571, ' mostly', 6.665463297395036e-05], [7194, ' mainly', 4.3976102460874245e-05], [1925, ' called', 3.1884770578471944e-05]]], [[[5112, ' French', 0.9879401922225952], [4383, ' English', 0.01107520516961813], [1097, ' both', 0.0003751036711037159], [9883, ' Spanish', 0.00026353111024945974], [253, ' the', 0.00014860212104395032]]], [[[5112, ' French', 0.7840160727500916], [4383, ' English', 0.10410340130329132], [417, ' not', 0.0661514550447464], [7202, ' unknown', 0.035282522439956665], [1335, ' still', 0.0022929527331143618]]], [[[5559, ' Google', 0.9987107515335083], [17664, 'Google', 0.0008913984056562185], [17899, ' google', 0.0002078302059089765], [3186, ' search', 9.959286398952827e

48

In [185]:
del merged["prob"]
merged.reset_index(drop=True, inplace=True)
new_df = get_correct_token_knockout_prob()
print(new_df.shape)
merged = merged.merge(new_df, left_index=True, right_on=["row"])

dict_keys(['51', '52', '53', '54', '55'])
(810, 2)


In [186]:
merged.shape

(810, 24)

In [187]:
def calculate_prob_diff(filtered):
    filtered["prob diff"] = filtered["prob"] - filtered["true_prob"]
    filtered["normalized prob diff"] = filtered["prob diff"] / filtered["true_prob"]
    filtered.rename(columns={"true_prob": "baseline prob", "prob": "knockout prob"}, inplace=True)
    return filtered

In [188]:
merged.head(5)

Unnamed: 0,row,relation,relation_prefix,relation_suffix,prompt,relation_id,target_false_id,target_true_id,target_true,target_false,...,max_prob,hit,pred,row_x,token,rank,token_id,correct,row_y,prob
0,0,{} is a product of,,{} is a product of,Windows XP Media Center Edition is a product of,P178,Q312,Q2283,Microsoft,Apple,...,0.726621,True,Microsoft,0,Microsoft,0,9664,True,0,0.960056
1,1,"In {}, the language spoken is",In,", the language spoken is","In United Kingdom, the language spoken is",P37,Q1412,Q1860,English,Finnish,...,0.529818,True,English,1,English,0,4383,True,1,0.99968
2,2,{} is a native speaker of,,{} is a native speaker of,Henry de Montherlant is a native speaker of,P103,Q1860,Q150,French,English,...,0.748545,True,French,2,French,0,5112,True,2,0.98794
3,3,The native language of {} is,The native language of,is,The native language of Olga Georges-Picot is,P103,Q7737,Q150,French,Russian,...,0.184752,True,French,3,French,0,5112,True,3,0.784016
4,4,{} is owned by,,{} is owned by,Google Marketing Platform is owned by,P127,Q183,Q95,Google,Germany,...,0.757935,True,Google,4,Google,0,5559,True,4,0.998711


In [None]:
# merged = calculate_prob_diff(merged)

In [194]:
# from tkinter import font
import plotly.express as px

fig = px.scatter(merged, x="baseline prob", y="knockout prob", color="correct")
fig.update_layout(
    title="Baseline vs Knockout Probability",
    xaxis_title="Baseline Probability",
    yaxis_title="Knockout Probability",
    width=400,
    height=400,
    font=dict(size=12),
)
fig.show()
# merged.loc[0, 'token']

In [169]:
px.scatter(merged, x="baseline prob", y="prob diff").show()

In [170]:
px.scatter(merged, x="baseline prob", y="normalized prob diff").show()

In [213]:
px.histogram(filtered, x="normalized prob diff").show()

In [None]:
px.histogram(filtered, x="baseline prob").show()

In [207]:
filtered.sort_values("normalized prob diff", ascending=False)[
    ["normalized prob diff", "prob diff", "baseline prob", "knockout prob"]
].corr()

Unnamed: 0,normalized prob diff,prob diff,baseline prob,knockout prob
normalized prob diff,1.0,0.788688,-0.776852,-0.065352
prob diff,0.788688,1.0,-0.820757,0.198883
baseline prob,-0.776852,-0.820757,1.0,0.396631
knockout prob,-0.065352,0.198883,0.396631,1.0


In [208]:
filtered.sort_values("normalized prob diff")
filtered[
    [
        "normalized prob diff",
        "baseline prob",
        "prob diff",
    ]
].corr()

Unnamed: 0,normalized prob diff,baseline prob,prob diff
normalized prob diff,1.0,-0.776852,0.788688
baseline prob,-0.776852,1.0,-0.820757
prob diff,0.788688,-0.820757,1.0


In [214]:
high_base_prob_filtered = filtered[filtered["baseline prob"] >= 0.5]
px.histogram(high_base_prob_filtered, x="normalized prob diff").show()

In [None]:
high_base_prob_filtered = filtered[filtered["baseline prob"] >= 0.5]

norm_prob = high_base_prob_filtered["normalized prob diff"].to_numpy()
norm_prob.sort()

mean = norm_prob.cumsum() / (np.arange(norm_prob.shape[0]) + 1)

px.scatter(x=np.arange(norm_prob.shape[0]), y=mean)

In [168]:
merged["prefix_len"] = merged["relation_prefix"].str.len()

In [171]:
(merged["prefix_len"] == 0).sum()

670

In [217]:
filtered

Unnamed: 0,relation,relation_prefix,relation_suffix,prompt,relation_id,target_false_id,target_true_id,target_true,target_false,subject,...,hit,pred,row,token,knockout prob,rank,token_id,correct,prob diff,normalized prob diff
0,{} is a product of,,{} is a product of,Windows XP Media Center Edition is a product of,P178,Q312,Q2283,Microsoft,Apple,Windows XP Media Center Edition,...,True,Microsoft,0,Microsoft,0.960056,0,9664,True,0.233435,0.321261
5,"In {}, the language spoken is",In,", the language spoken is","In United Kingdom, the language spoken is",P37,Q1412,Q1860,English,Finnish,United Kingdom,...,True,English,1,English,0.999680,0,4383,True,0.469861,0.886835
10,{} is a native speaker of,,{} is a native speaker of,Henry de Montherlant is a native speaker of,P103,Q1860,Q150,French,English,Henry de Montherlant,...,True,French,2,French,0.987940,0,5112,True,0.239395,0.319814
15,The native language of {} is,The native language of,is,The native language of Olga Georges-Picot is,P103,Q7737,Q150,French,Russian,Olga Georges-Picot,...,True,French,3,French,0.784016,0,5112,True,0.599264,3.243615
20,{} is owned by,,{} is owned by,Google Marketing Platform is owned by,P127,Q183,Q95,Google,Germany,Google Marketing Platform,...,True,Google,4,Google,0.998711,0,5559,True,0.240776,0.317674
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4025,{} is created by,,{} is created by,Microsoft InfoPath is created by,P178,Q312,Q2283,Microsoft,Apple,Microsoft InfoPath,...,True,Microsoft,805,Microsoft,0.994226,0,9664,True,0.569201,1.339217
4030,{} worked in the city of,,{} worked in the city of,Pavlo Skoropadskyi worked in the city of,P937,Q84,Q1899,Kiev,London,Pavlo Skoropadskyi,...,True,K,806,K,0.220406,0,611,True,0.085604,0.635038
4035,"{}, created by",,"{}, created by","IBM 3790, created by",P176,Q66,Q37156,IBM,Boeing,IBM 3790,...,True,IBM,807,IBM,0.983616,0,21314,True,0.459724,0.877517
4040,{} belongs to the continent of,,{} belongs to the continent of,Riiser-Larsen Ice Shelf belongs to the contine...,P30,Q48,Q51,Antarctica,Asia,Riiser-Larsen Ice Shelf,...,True,Antar,808,Antar,0.994120,0,31913,True,0.173516,0.211449


In [218]:
filtered["relation_len"] = filtered["relation_suffix"].str.len()



A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy



In [227]:
(filtered["relation_len"] >= 4).sum()

644

In [176]:
px.histogram(filtered[(filtered["prefix_len"] > 0)], x="normalized prob diff").show()

In [178]:
px.histogram(data, x="true_prob")

In [180]:
(data["true_prob"] < 0.5).sum()

285

In [83]:
from typing import Iterable, List

from torch import Tensor


def knockout_scan(
    seq_len: int,
    ssm_state: Tensor,
    discrete_A: Tensor,
    deltaB_u: Tensor,
    C: Tensor,
    knocked_out_inputs: Iterable[int],
    affected_outputs: Iterable[int],
    dtype,
) -> List[Tensor]:
    knockout_state = zeros_like(ssm_state)
    scan_outputs = []
    for i in range(seq_len):
        ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]  # [batch, intermediade_size, ssm_state]
        print(ssm_state)
        if i not in knocked_out_inputs:
            knockout_state = (
                discrete_A[:, :, i, :] * knockout_state + deltaB_u[:, :, i, :]
            )  # [batch, intermediade_size, ssm_state]
        elif i in knocked_out_inputs:
            knockout_state = discrete_A[:, :, i, :] * knockout_state
        if i in affected_outputs:
            scan_output = matmul(knockout_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        else:
            scan_output = matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        scan_outputs.append(scan_output[:, :, 0])

    return scan_outputs

In [84]:
from typing import Iterable, List

from torch import Tensor


def scan(
    seq_len: int,
    ssm_state: Tensor,
    discrete_A: Tensor,
    deltaB_u: Tensor,
    C: Tensor,
    knocked_out_inputs: Iterable[int],
    affected_outputs: Iterable[int],
    dtype,
) -> List[Tensor]:
    scan_outputs = []
    for i in range(seq_len):
        ssm_state = discrete_A[:, :, i, :] * ssm_state + deltaB_u[:, :, i, :]  # [batch, intermediade_size, ssm_state]
        # print(ssm_state)
        scan_output = matmul(ssm_state.to(dtype), C[:, i, :].unsqueeze(-1))  # [batch, intermediade_size, 1]
        scan_outputs.append(scan_output[:, :, 0])

    return scan_outputs

In [101]:
import torch

batch = 1
d_intermediate = 1
T = 5
ssm_d = 8

A = torch.ones((batch, d_intermediate, T, ssm_d))  # [batch, intermediate_size, seq_len, ssm_state_size]
B = torch.ones((batch, d_intermediate, T, ssm_d))  # [batch, intermediate_size, seq_len, ssm_state_size]
B = B / ssm_d
C = torch.ones((batch, T, ssm_d))  # [batch, seq_len, intermediate_size, ssm_state_size]
u = torch.Tensor([[[2**i for i in range(T)]]])  # [batch, seq_len, intermediate_size]

Bu = B * u[:, :, :, None]

state = torch.zeros(batch, d_intermediate, ssm_d)

In [102]:
out = knockout_scan(T, state, A, Bu, C, [0, 1, 2], [4], torch.float)
# out = scan(T, state, A, Bu, C, [1], [3], torch.float)
print(u)
print(u.cumsum(dim=-1))
for i in out:
    print(i, "{0:b}".format(int(i.item())))

tensor([[[0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]]])
tensor([[[0.3750, 0.3750, 0.3750, 0.3750, 0.3750, 0.3750, 0.3750, 0.3750]]])
tensor([[[0.8750, 0.8750, 0.8750, 0.8750, 0.8750, 0.8750, 0.8750, 0.8750]]])
tensor([[[1.8750, 1.8750, 1.8750, 1.8750, 1.8750, 1.8750, 1.8750, 1.8750]]])
tensor([[[3.8750, 3.8750, 3.8750, 3.8750, 3.8750, 3.8750, 3.8750, 3.8750]]])
tensor([[[ 1.,  2.,  4.,  8., 16.]]])
tensor([[[ 1.,  3.,  7., 15., 31.]]])
tensor([[1.]]) 1
tensor([[3.]]) 11
tensor([[7.]]) 111
tensor([[15.]]) 1111
tensor([[24.]]) 11000


In [130]:
def compute_attn_matrix_fn(dA, dB, C, L, x_shape, dtype=torch.float16):
    # dA = torch.exp(torch.einsum("bdl,dn->bldn", dt, A))
    # dB = torch.einsum("bdl,bnl->bldn", dt, B.squeeze(1))
    AttnMatrixOverCLS = (
        torch.zeros((x_shape[0], x_shape[1], x_shape[2], x_shape[2])).to(dtype).to(dA.device)
    )  # BHLL: L vectors per batch and channel
    for r in range(L):
        for c in range(r + 1):
            curr_C = C[:, r, :]
            currA = torch.ones((dA.shape[0], dA.shape[1], dA.shape[3]), dtype=dtype).to(dA.device)
            if c < r:
                for i in range(r - c):
                    currA = currA * dA[:, :, r - i, :]
            currB = dB[:, :, c, :]
            AttnMatrixOverCLS[:, :, r, c] = torch.sum(curr_C * currA * currB, axis=-1)
    return AttnMatrixOverCLS


def knockout_matrix(
    seq_len: int,
    discrete_A: Tensor,
    discrete_B: Tensor,
    u: Tensor,
    C: Tensor,
    knocked_out_inputs: Iterable[int],
    affected_outputs: Iterable[int],
    dtype,
) -> List[Tensor]:
    # _attn = materialize_ssm_attention(discrete_A, discrete_B, C, False)
    attn = compute_attn_matrix_fn(discrete_A, discrete_B, C, seq_len, u.shape, dtype)
    for i in affected_outputs:
        for j in knocked_out_inputs:
            attn[:, :, i, j] = 0
    # u = u.transpose(2,1)
    outputs = (attn @ u).squeeze(-1)
    return outputs

In [126]:
u.shape

torch.Size([1, 1, 5])

In [132]:
materialize_ssm_attention(A, B, C, False)
out = knockout_matrix(T, A, B, u.unsqueeze(-1), C, [0, 1, 3], [4], torch.float)
for i in out[0][0]:
    print(i, "{0:b}".format(int(i.item())))

tensor(1.) 1
tensor(3.) 11
tensor(7.) 111
tensor(15.) 1111
tensor(20.) 10100


In [16]:
from dataclasses import dataclass
from typing import Iterable, Optional

import torch
from einops import rearrange, repeat
from torch import Tensor, device


def segsum(x: Tensor, device: Optional[device] = None) -> Tensor:
    """Stable segment sum calculation.

    `exp(segsum(A))` produces a 1-semiseparable matrix, which is equivalent to a scalar SSM.

    Source: https://github.com/state-spaces/mamba/blob/219f03c840d5a44e7d42e4e728134834fddccf45/mamba_ssm/modules/ssd_minimal.py#L23-L32
    """
    T = x.size(-1)
    x = repeat(x, "... d -> ... d e", e=T)
    mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=-1)
    x = x.masked_fill(~mask, 0)
    x_segsum = torch.cumsum(x, dim=-2)
    mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=0)
    x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
    return x_segsum


def mamba2_knockout(L, B, C, x, list_of_masks):
    CBT = torch.einsum("bclhn, bcshn -> bhcls", C, B)
    print(CBT.shape)
    attention_matrix = torch.einsum("bhcls, bhcls -> bclhs", CBT, L)
    # attention_matrix shape is: batch, chunk, seq_len , heads, seq_len
    # To remove the attention *given* by idx1 to idx2 do :
    # attention_matrix[:,:,idx1,:,idx2] = 0
    # by this notion idx1 is always greater or equal to idx2
    # attention_matrix[:, :, 11, :, 3] = 0
    for idx1, idx2 in list_of_masks:
        attention_matrix[:, :, idx1, :, idx2] = 0

    out_by_atten = torch.einsum("bclhs, bcshp-> bclhp", attention_matrix, x)
    out_by_atten = rearrange(out_by_atten, "b c l h p -> b (c l) h p")
    return out_by_atten

In [30]:
import torch

b, c, l, h, n, p = 1, 1, 5, 1, 8, 1
L = torch.ones((b, h, c, l, l))  # bhcls
L = torch.tril(L)

# bclhn, bcshn
C = torch.ones((b, c, l, h, n))
B = torch.ones((b, c, l, h, n))
B = B / n
u = torch.ones((b, c, l, h, p))
for i in range(l):
    u[0, 0, i, 0, 0] = 2**i

knockout_list = [(3, 1), (4, 1)]

out = mamba2_knockout(L, B, C, u, knockout_list)
print(out)
print(out.shape)
for i in range(l):
    print("{0:b}".format(int(out[0, i, 0, 0].item())))

torch.Size([1, 1, 1, 5, 5])
tensor([[[[ 1.]],

         [[ 3.]],

         [[ 7.]],

         [[13.]],

         [[29.]]]])
torch.Size([1, 5, 1, 1])
1
11
111
1101
11101
