In [None]:
import torch
import torch.nn as nn

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    if freqs_cis.shape == (x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim-2 else 1 for i, d in enumerate(x.shape)]
    elif freqs_cis.shape == (x.shape[-3], x.shape[-2], x.shape[-1]):
        shape = [d if i >= ndim-3 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

def apply_rotary_emb(xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor):
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq).to(xq.device), xk_out.type_as(xk).to(xk.device)

def apply_rotary_emb_mamba(x: torch.Tensor, freqs_cis: torch.Tensor):
    x_ = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, x_)
    x_out = torch.view_as_real(x_ * freqs_cis).flatten(3)
    return x_out.type_as(x).to(x.device)

class RoPEAttention(nn.Module):
    """Multi-head Attention block with RoPE."""
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x, freqs_cis):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q[:, :, 1:], k[:, :, 1:] = apply_rotary_emb(q[:, :, 1:], k[:, :, 1:], freqs_cis=freqs_cis)
        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

import math
from einops import rearrange
from mamba_ssm.ops.triton.ssd_combined import ssd_selective_scan

def init_conv(conv):
    """Initialize convolution layers."""
    nn.init.kaiming_normal_(conv.weight, mode="fan_out", nonlinearity="relu")
    if conv.bias is not None:
        nn.init.constant_(conv.bias, 0)

class RoPEMamba(nn.Module):
    """RoPE Mamba model."""
    def __init__(self,
                 # --------------------------------
                 use_conv=True,
                 d_model=256,
                 d_state=1,
                 headdim=32,
                 A_init_range=(1, 16),
                 dt_min=0.001,
                 dt_max=0.1,
                 dt_init_floor=1e-4,
                 dt_limit=(0.0, float("inf")),
                 bias=False,
                 chunk_size=256,
                 device=None,
                 dtype=None,
                 H=180,
                 W=180,
                 # --------------------------------
                 attn_drop=0.,
                 proj_drop=0.):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = self.d_inner = self.d_ssm = d_model
        self.d_state = d_state
        self.headdim = headdim
        assert self.d_ssm % self.headdim == 0
        self.nheads = self.d_ssm // self.headdim
        self.dt_limit = dt_limit
        self.activation = "silu"
        self.chunk_size = chunk_size
        d_in_proj = (2 * self.d_inner + 2 * self.d_state + self.nheads) * 4
        self.d_in_proj = d_in_proj
        # self.in_proj_H = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
        # self.in_proj_V = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
        self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=bias, **factory_kwargs)
        self.act = nn.SiLU()
        # --- HF Path ---
        dt_HF = torch.clamp(
            torch.exp(
                torch.rand(self.nheads, **factory_kwargs)
                * (math.log(dt_max) - math.log(dt_min))
                + math.log(dt_min)
            ),
            min=dt_init_floor,
        )
        inv_dt_HF = dt_HF + torch.log(-torch.expm1(-dt_HF))
        self.dt_bias_HF = nn.Parameter(inv_dt_HF)
        self.dt_bias_HF._no_weight_decay = True

        # --- HB Path ---
        dt_HB = torch.clamp(
            torch.exp(
                torch.rand(self.nheads, **factory_kwargs)
                * (math.log(dt_max) - math.log(dt_min))
                + math.log(dt_min)
            ),
            min=dt_init_floor,
        )
        inv_dt_HB = dt_HB + torch.log(-torch.expm1(-dt_HB))
        self.dt_bias_HB = nn.Parameter(inv_dt_HB)
        self.dt_bias_HB._no_weight_decay = True

        # --- VH Path ---
        dt_VH = torch.clamp(
            torch.exp(
                torch.rand(self.nheads, **factory_kwargs)
                * (math.log(dt_max) - math.log(dt_min))
                + math.log(dt_min)
            ),
            min=dt_init_floor,
        )
        inv_dt_VH = dt_VH + torch.log(-torch.expm1(-dt_VH))
        self.dt_bias_VH = nn.Parameter(inv_dt_VH)
        self.dt_bias_VH._no_weight_decay = True

        # --- VB Path ---
        dt_VB = torch.clamp(
            torch.exp(
                torch.rand(self.nheads, **factory_kwargs)
                * (math.log(dt_max) - math.log(dt_min))
                + math.log(dt_min)
            ),
            min=dt_init_floor,
        )
        inv_dt_VB = dt_VB + torch.log(-torch.expm1(-dt_VB))
        self.dt_bias_VB = nn.Parameter(inv_dt_VB)
        self.dt_bias_VB._no_weight_decay = True

        assert A_init_range[0] > 0 and A_init_range[1] >= A_init_range[0]

        # --- A_log for HF Path ---
        A_HF = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
            *A_init_range
        )
        A_log_HF = torch.log(A_HF).to(dtype=dtype)
        self.A_log_HF = nn.Parameter(A_log_HF)
        self.A_log_HF._no_weight_decay = True

        # --- A_log for HB Path ---
        A_HB = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
            *A_init_range
        )
        A_log_HB = torch.log(A_HB).to(dtype=dtype)
        self.A_log_HB = nn.Parameter(A_log_HB)
        self.A_log_HB._no_weight_decay = True

        # --- A_log for VH Path ---
        A_VH = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
            *A_init_range
        )
        A_log_VH = torch.log(A_VH).to(dtype=dtype)
        self.A_log_VH = nn.Parameter(A_log_VH)
        self.A_log_VH._no_weight_decay = True

        # --- A_log for VB Path ---
        A_VB = torch.empty(self.nheads, dtype=torch.float32, device=device).uniform_(
            *A_init_range
        )
        A_log_VB = torch.log(A_VB).to(dtype=dtype)
        self.A_log_VB = nn.Parameter(A_log_VB)
        self.A_log_VB._no_weight_decay = True

        # --- D for HF Path ---
        self.D_HF = nn.Parameter(torch.ones(self.nheads, device=device)) # Consider adding dtype=dtype if needed
        self.D_HF._no_weight_decay = True

        # --- D for HB Path ---
        self.D_HB = nn.Parameter(torch.ones(self.nheads, device=device)) # Consider adding dtype=dtype if needed
        self.D_HB._no_weight_decay = True

        # --- D for VH Path ---
        self.D_VH = nn.Parameter(torch.ones(self.nheads, device=device)) # Consider adding dtype=dtype if needed
        self.D_VH._no_weight_decay = True

        # --- D for VB Path ---
        self.D_VB = nn.Parameter(torch.ones(self.nheads, device=device)) # Consider adding dtype=dtype if needed
        self.D_VB._no_weight_decay = True

        self.out_proj = nn.Linear(self.d_model * 4, self.d_model, bias=bias, **factory_kwargs)
        self.out_norm = nn.LayerNorm(self.d_model, eps=1e-6, **factory_kwargs)
        self.out_drop = nn.Dropout(proj_drop)
        self.use_conv = use_conv
        if self.use_conv:
            self.conv1d_HF = nn.Conv1d(
                in_channels=d_model,
                out_channels=d_model,
                kernel_size=4,
                stride=4,
                padding=0,
                bias=False,
            )
            self.conv1d_HB = nn.Conv1d(
                in_channels=d_model,
                out_channels=d_model,
                kernel_size=4,
                stride=4,
                padding=0,
                bias=False,
            )
            self.conv1d_VF = nn.Conv1d(
                in_channels=d_model,
                out_channels=d_model,
                kernel_size=4,
                stride=4,
                padding=0,
                bias=False,
            )
            self.conv1d_VB = nn.Conv1d(
                in_channels=d_model,
                out_channels=d_model,
                kernel_size=4,
                stride=4,
                padding=0,
                bias=False,
            )
            init_conv(self.conv1d_HF)
            init_conv(self.conv1d_HB)
            init_conv(self.conv1d_VF)
            init_conv(self.conv1d_VB)
            mask_low_resolution = torch.ones(
                1, 1, H // 2, W // 2, device=device if device is not None else "cpu"
            )
            morton_H_indices_low, morton_V_indices_low = self.morton_code_extraction(mask_low_resolution)
            inverse_H_indices_low = torch.empty_like(morton_H_indices_low)
            inverse_H_indices_low[morton_H_indices_low] = torch.arange(morton_H_indices_low.size(0), device=morton_H_indices_low.device)
            inverse_V_indices_low = torch.empty_like(morton_V_indices_low)
            inverse_V_indices_low[morton_V_indices_low] = torch.arange(morton_V_indices_low.size(0), device=morton_V_indices_low.device)
            self.register_buffer('morton_H_indices_low', morton_H_indices_low)
            self.register_buffer('morton_V_indices_low', morton_V_indices_low)
            self.register_buffer('inverse_H_indices_low', inverse_H_indices_low)
            self.register_buffer('inverse_V_indices_low', inverse_V_indices_low)
        self.H = H
        self.W = W
        mask = torch.ones(1, 1, H, W, device=device if device is not None else "cpu")
        morton_H_indices, morton_V_indices = self.morton_code_extraction(mask)
        inverse_H_indices = torch.empty_like(morton_H_indices)
        inverse_H_indices[morton_H_indices] = torch.arange(morton_H_indices.size(0), device=morton_H_indices.device)
        inverse_V_indices = torch.empty_like(morton_V_indices)
        inverse_V_indices[morton_V_indices] = torch.arange(morton_V_indices.size(0), device=morton_V_indices.device)
        self.register_buffer('morton_H_indices', morton_H_indices)
        self.register_buffer('morton_V_indices', morton_V_indices)
        self.register_buffer('inverse_H_indices', inverse_H_indices)
        self.register_buffer('inverse_V_indices', inverse_V_indices)
        conv_dim = self.d_inner + 2 * self.d_state
        self.conv_dim = conv_dim
        # --- Conv1d for x component ---
        # Horizontal Forward for x
        # self.conv1d_hf_x = nn.Conv1d(
        #     in_channels=conv_dim,
        #     out_channels=conv_dim,
        #     bias=False,
        #     kernel_size=4,
        #     groups=conv_dim,
        #     **factory_kwargs,
        # )
        # # Horizontal Backward for x
        # self.conv1d_hb_x = nn.Conv1d(
        #     in_channels=conv_dim,
        #     out_channels=conv_dim,
        #     bias=False,
        #     kernel_size=4,
        #     groups=conv_dim,
        #     **factory_kwargs,
        # )
        # # Vertical Forward (Height-wise) for x
        # self.conv1d_vh_x = nn.Conv1d(
        #     in_channels=conv_dim,
        #     out_channels=conv_dim,
        #     bias=False,
        #     kernel_size=4,
        #     groups=conv_dim,
        #     **factory_kwargs,
        # )
        # # Vertical Backward (Width-wise) for x
        # self.conv1d_vb_x = nn.Conv1d(
        #     in_channels=conv_dim,
        #     out_channels=conv_dim,
        #     bias=False,
        #     kernel_size=4,
        #     groups=conv_dim,
        #     **factory_kwargs,
        # )

        # # --- Conv1d for z component ---
        # # Horizontal Forward for z
        # self.conv1d_hf_z = nn.Conv1d(
        #     in_channels=self.d_inner,
        #     out_channels=self.d_inner,
        #     bias=False,
        #     kernel_size=4,
        #     groups=self.d_inner,
        #     **factory_kwargs,
        # )
        # # Horizontal Backward for z
        # self.conv1d_hb_z = nn.Conv1d(
        #     in_channels=self.d_inner,
        #     out_channels=self.d_inner,
        #     bias=False,
        #     kernel_size=4,
        #     groups=self.d_inner,
        #     **factory_kwargs,
        # )
        # # Vertical Forward (Height-wise) for z
        # self.conv1d_vh_z = nn.Conv1d(
        #     in_channels=self.d_inner,
        #     out_channels=self.d_inner,
        #     bias=False,
        #     kernel_size=4,
        #     groups=self.d_inner,
        #     **factory_kwargs,
        # )
        # # Vertical Backward (Width-wise) for z
        # self.conv1d_vb_z = nn.Conv1d(
        #     in_channels=self.d_inner,
        #     out_channels=self.d_inner,
        #     bias=False,
        #     kernel_size=4,
        #     groups=self.d_inner,
        #     **factory_kwargs,
        # )
        # ------------------------
        d_conv2d = d_model * 2
        self.conv2d_hf = nn.Sequential(
            nn.Conv2d(d_conv2d,d_conv2d,kernel_size=3,stride=1,padding=1,bias=False),
            nn.GELU(),
            nn.BatchNorm2d(d_conv2d, eps=1e-3, momentum=0.01))
        self.conv2d_hb = nn.Sequential(
            nn.Conv2d(d_conv2d,d_conv2d,kernel_size=3,stride=1,padding=1,bias=False),
            nn.GELU(),
            nn.BatchNorm2d(d_conv2d, eps=1e-3, momentum=0.01))
        self.conv2d_vf = nn.Sequential(
            nn.Conv2d(d_conv2d,d_conv2d,kernel_size=3,stride=1,padding=1,bias=False),
            nn.GELU(),
            nn.BatchNorm2d(d_conv2d, eps=1e-3, momentum=0.01))
        self.conv2d_vb = nn.Sequential(
            nn.Conv2d(d_conv2d,d_conv2d,kernel_size=3,stride=1,padding=1,bias=False),
            nn.GELU(),
            nn.BatchNorm2d(d_conv2d, eps=1e-3, momentum=0.01))
        
    def forward(self, x, freqs_cis):
        B, C, H, W = x.shape
        morton_H_indices = self.morton_H_indices.to(x.device)
        morton_V_indices = self.morton_V_indices.to(x.device)
        inverse_H_indices = self.inverse_H_indices.to(x.device)
        inverse_V_indices = self.inverse_V_indices.to(x.device)
        if self.use_conv:
            inverse_H_indices_low = self.inverse_H_indices_low.to(x.device)
            inverse_V_indices_low = self.inverse_V_indices_low.to(x.device)
        x_flat = x.view(B, C, -1).permute(0, 2, 1)
        # ---- 重排在in_proj后实现
        # x_morton_H = x_flat[:, :, morton_H_indices].permute(0, 2, 1)
        # x_morton_V = x_flat[:, :, morton_V_indices].permute(0, 2, 1)
        zxbcdt = self.in_proj(x_flat)

        A_HF = -torch.exp(self.A_log_HF.float())
        A_HB = -torch.exp(self.A_log_HB.float())
        A_VF = -torch.exp(self.A_log_VF.float())
        A_VB = -torch.exp(self.A_log_VB.float())

        dim = self.d_ssm

        z_HF, xBC_HF, dt_HF, z_HB, xBC_HB, dt_HB, z_VF, xBC_VF, dt_VF, z_VB, xBC_VB, dt_VB = torch.split(
            zxbcdt, [dim, dim + 2 * self.d_state, self.nheads, dim, dim + 2 * self.d_state, self.nheads, dim, dim + 2 * self.d_state, self.nheads, dim, dim + 2 * self.d_state, self.nheads], dim=-1
        )
        # x_shape: (B, N, C)
        x_HF, B_HF, C_HF = torch.split(xBC_HF, [dim, self.d_state, self.d_state], dim=-1)
        x_HB, B_HB, C_HB = torch.split(xBC_HB, [dim, self.d_state, self.d_state], dim=-1)
        x_VF, B_VF, C_VF = torch.split(xBC_VF, [dim, self.d_state, self.d_state], dim=-1)
        x_VB, B_VB, C_VB = torch.split(xBC_VB, [dim, self.d_state, self.d_state], dim=-1)

        xz_HF = torch.cat([x_HF, z_HF], dim=-1).permute(0, 2, 1).view(B,C*2,H,W)
        xz_HB = torch.cat([x_HB, z_HB], dim=-1).permute(0, 2, 1).view(B,C*2,H,W)
        xz_VF = torch.cat([x_VF, z_VF], dim=-1).permute(0, 2, 1).view(B,C*2,H,W)
        xz_VB = torch.cat([x_VB, z_VB], dim=-1).permute(0, 2, 1).view(B,C*2,H,W)

        xz_HF = self.conv2d_hf(xz_HF)
        xz_HB = self.conv2d_hb(xz_HB)
        xz_VF = self.conv2d_vf(xz_VF)
        xz_VB = self.conv2d_vb(xz_VB)

        x_HF, z_HF = torch.split(xz_HF.view(B, C*2, -1).permute(0, 2, 1),[C,C], dim=-1)
        x_HB, z_HB = torch.split(xz_HB.view(B, C*2, -1).permute(0, 2, 1),[C,C], dim=-1)
        x_VF, z_VF = torch.split(xz_VF.view(B, C*2, -1).permute(0, 2, 1),[C,C], dim=-1)
        x_VB, z_VB = torch.split(xz_VB.view(B, C*2, -1).permute(0, 2, 1),[C,C], dim=-1)

        x_HF = x_HF[:, morton_H_indices, :]
        x_HB = x_HB[:, morton_H_indices, :]
        x_VF = x_VF[:, morton_V_indices, :]
        x_VB = x_VB[:, morton_V_indices, :]
        z_HF = z_HF[:, morton_H_indices, :]
        z_HB = z_HB[:, morton_H_indices, :]
        z_VF = z_VF[:, morton_V_indices, :]
        z_VB = z_VB[:, morton_V_indices, :]
        B_HF = B_HF[:, morton_H_indices, :]
        B_HB = B_HB[:, morton_H_indices, :]
        B_VF = B_VF[:, morton_V_indices, :]
        B_VB = B_VB[:, morton_V_indices, :]
        C_HF = C_HF[:, morton_H_indices, :]
        C_HB = C_HB[:, morton_H_indices, :]
        C_VF = C_VF[:, morton_V_indices, :]
        C_VB = C_VB[:, morton_V_indices, :]
        dt_HF = dt_HF[:, morton_H_indices]
        dt_HB = dt_HB[:, morton_H_indices]
        dt_VF = dt_VF[:, morton_V_indices]
        dt_VB = dt_VB[:, morton_V_indices]

        x_HF = apply_rotary_emb_mamba(x_HF, freqs_cis=freqs_cis)
        x_HB = apply_rotary_emb_mamba(x_HB, freqs_cis=freqs_cis)
        x_VF = apply_rotary_emb_mamba(x_VF, freqs_cis=freqs_cis)
        x_VB = apply_rotary_emb_mamba(x_VB, freqs_cis=freqs_cis)

        x_HF = rearrange(x_HF, "b l (h p) -> b l h p", h=self.nheads).contiguous()
        B_HF = rearrange(B_HF, "b l (g n) -> b l g n", g=1).contiguous()
        C_HF = rearrange(C_HF, "b l (g n) -> b l g n", g=1).contiguous()
        z_HF = rearrange(z_HF, "b l (h p) -> b l h p", h=self.nheads).contiguous()

        x_HB = rearrange(x_HB, "b l (h p) -> b l h p", h=self.nheads).contiguous()
        B_HB = rearrange(B_HB, "b l (g n) -> b l g n", g=1).contiguous()
        C_HB = rearrange(C_HB, "b l (g n) -> b l g n", g=1).contiguous()
        z_HB = rearrange(z_HB, "b l (h p) -> b l h p", h=self.nheads).contiguous()

        x_VF = rearrange(x_VF, "b l (h p) -> b l h p", h=self.nheads).contiguous()
        B_VF = rearrange(B_VF, "b l (g n) -> b l g n", g=1).contiguous()
        C_VF = rearrange(C_VF, "b l (g n) -> b l g n", g=1).contiguous()
        z_VF = rearrange(z_VF, "b l (h p) -> b l h p", h=self.nheads).contiguous()

        x_VB = rearrange(x_VB, "b l (h p) -> b l h p", h=self.nheads).contiguous()
        B_VB = rearrange(B_VB, "b l (g n) -> b l g n", g=1).contiguous()
        C_VB = rearrange(C_VB, "b l (g n) -> b l g n", g=1).contiguous()
        z_VB = rearrange(z_VB, "b l (h p) -> b l h p", h=self.nheads).contiguous()

        out_HF = ssd_selective_scan(
            x_HF,
            dt_HF.to(x_HF.dtype),
            A_HF,
            B_HF,
            C_HF,
            D=self.D_HF.float(),
            z=z_HF,
            dt_bias=self.dt_bias_HF,
            dt_softplus=True,
            dt_limit=self.dt_limit,
        )
        out_HB = ssd_selective_scan(
            x_HB.flip(1),
            dt_HB.to(x_HB.dtype).flip(1),
            A_HB,
            B_HB.flip(1),
            C_HB.flip(1),
            D=self.D_HB.float(),
            z=z_HB.flip(1),
            dt_bias=self.dt_bias_HB,
            dt_softplus=True,
            dt_limit=self.dt_limit,
        ).flip(1)
        out_VF = ssd_selective_scan(
            x_VF,
            dt_VF.to(x_VF.dtype),
            A_VF,
            B_VF,
            C_VF,
            D=self.D_VF.float(),
            z=z_VF,
            dt_bias=self.dt_bias_VF,
            dt_softplus=True,
            dt_limit=self.dt_limit,
        )
        out_VB = ssd_selective_scan(
            x_VB.flip(1),
            dt_VB.to(x_VB.dtype).flip(1),
            A_VB,
            B_VB.flip(1),
            C_VB.flip(1),
            D=self.D_VB.float(),
            z=z_VB.flip(1),
            dt_bias=self.dt_bias_VB,
            dt_softplus=True,
            dt_limit=self.dt_limit,
        ).flip(1)

        out_HF = rearrange(out_HF, "b s h p -> b s (h p)")
        out_HB = rearrange(out_HB, "b s h p -> b s (h p)")
        out_VF = rearrange(out_VF, "b s h p -> b s (h p)")
        out_VB = rearrange(out_VB, "b s h p -> b s (h p)")

        if self.use_conv:
            out_HF = self.conv1d_HF(out_HF.permute(0, 2, 1)).permute(0, 2, 1)
            out_HB = self.conv1d_HB(out_HB.permute(0, 2, 1)).permute(0, 2, 1)
            out_VF = self.conv1d_VF(out_VF.permute(0, 2, 1)).permute(0, 2, 1)
            out_VB = self.conv1d_VB(out_VB.permute(0, 2, 1)).permute(0, 2, 1)
            
            out_HF_inverse = out_HF[:, inverse_H_indices_low, :]
            out_HB_inverse = out_HB[:, inverse_H_indices_low, :]
            out_VF_inverse = out_VF[:, inverse_V_indices_low, :]
            out_VB_inverse = out_VB[:, inverse_V_indices_low, :]
        else:
            out_HF_inverse = out_HF[:, inverse_H_indices, :]
            out_HB_inverse = out_HB[:, inverse_H_indices, :]
            out_VF_inverse = out_VF[:, inverse_V_indices, :]
            out_VB_inverse = out_VB[:, inverse_V_indices, :]

        out = self.out_proj(
            torch.cat([out_HF_inverse, out_HB_inverse, out_VF_inverse, out_VB_inverse], dim=-1).contiguous()
        )

        out = self.out_norm(out)
        out = self.out_drop(out)
        if self.use_conv:
            out = out.view(B, C, H // 2, W // 2)
        else:
            out = out.view(B, C, H, W)

        return out
    
    def morton_code_extraction(self, mask):
        device = mask.device
        h, w = mask[0][0].shape
        """
        说明：
            row_indices[2,3]的值为2，表示该位置的行索引为2
            col_indices[2,3]的值为3，表示该位置的列索引为3
        """
        row_indices, col_indices = torch.meshgrid(
            torch.arange(h, device=device),
            torch.arange(w, device=device),
            indexing="ij",
        )
        row_indices = row_indices.flatten()
        col_indices = col_indices.flatten()
        valid_indices = mask[0][0].flatten() != 0
        row_indices = row_indices[valid_indices]
        col_indices = col_indices[valid_indices]
        morton_codes_1 = self.interleave_bits(col_indices, row_indices)
        morton_codes_2 = self.interleave_bits_x_last(col_indices, row_indices)
        sorted_indices_1 = torch.argsort(morton_codes_1)
        sorted_indices_2 = torch.argsort(morton_codes_2)
        linear_indices_1 = (
            row_indices[sorted_indices_1] * w + col_indices[sorted_indices_1]
        )
        linear_indices_2 = (
            row_indices[sorted_indices_2] * w + col_indices[sorted_indices_2]
        )
        return linear_indices_1, linear_indices_2

    def interleave_bits(self, x, y):
        x = (x | (x << 8)) & 0x00FF00FF
        x = (x | (x << 4)) & 0x0F0F0F0F
        x = (x | (x << 2)) & 0x33333333
        x = (x | (x << 1)) & 0x55555555
        y = (y | (y << 8)) & 0x00FF00FF
        y = (y | (y << 4)) & 0x0F0F0F0F
        y = (y | (y << 2)) & 0x33333333
        y = (y | (y << 1)) & 0x55555555
        z = (x << 1) | y
        return z

    def interleave_bits_x_last(self, x, y):
        x = (x | (x << 8)) & 0x00FF00FF
        x = (x | (x << 4)) & 0x0F0F0F0F
        x = (x | (x << 2)) & 0x33333333
        x = (x | (x << 1)) & 0x55555555
        y = (y | (y << 8)) & 0x00FF00FF
        y = (y | (y << 4)) & 0x0F0F0F0F
        y = (y | (y << 2)) & 0x33333333
        y = (y | (y << 1)) & 0x55555555
        z = (y << 1) | x
        return z






        
        


        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q[:, :, 1:], k[:, :, 1:] = apply_rotary_emb(q[:, :, 1:], k[:, :, 1:], freqs_cis=freqs_cis)
        attn = (q * self.scale) @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from timm.models.layers import DropPath
from timm.models.vision_transformer import Mlp

class RoPE_Layer_scale_init_Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=RoPEAttention, Mlp_block=Mlp ,init_values=1e-4):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_block(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)

    def forward(self, x, freqs_cis):
        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), freqs_cis=freqs_cis))
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x
    
class RoPE_Layer_scale_init_Mamba_Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, Attention_block=RoPEAttention, Mlp_block=Mlp ,init_values=1e-4):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention_block(dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp_block(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
        self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
        self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)),requires_grad=True)
    
    def forward(self, x, freqs_cis):
        x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), freqs_cis=freqs_cis))
        x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
        return x

In [None]:
from timm.models.vision_transformer import PatchEmbed
from timm.models.layers import trunc_normal_
from functools import partial
import torch.nn.functional as F

def compute_mixed_cis(freqs: torch.Tensor, t_x: torch.Tensor, t_y: torch.Tensor, num_heads: int):
    N = t_x.shape[0]
    depth = freqs.shape[1]
    with torch.cuda.amp.autocast(enabled=False):
        freqs_x = (t_x.unsqueeze(-1) @ freqs[0].unsqueeze(-2)).view(depth, N, num_heads, -1).permute(0, 2, 1, 3)
        freqs_y = (t_y.unsqueeze(-1) @ freqs[1].unsqueeze(-2)).view(depth, N, num_heads, -1).permute(0, 2, 1, 3)
        freqs_cis = torch.polar(torch.ones_like(freqs_x), freqs_x + freqs_y)        
    return freqs_cis

def init_random_2d_freqs(dim: int, num_heads: int, theta: float = 10.0, rotate: bool = True):
    freqs_x = []
    freqs_y = []
    mag = 1 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    for i in range(num_heads):
        angles = torch.rand(1) * 2 * torch.pi if rotate else torch.zeros(1)        
        fx = torch.cat([mag * torch.cos(angles), mag * torch.cos(torch.pi/2 + angles)], dim=-1)
        fy = torch.cat([mag * torch.sin(angles), mag * torch.sin(torch.pi/2 + angles)], dim=-1)
        freqs_x.append(fx)
        freqs_y.append(fy)
    freqs_x = torch.stack(freqs_x, dim=0)
    freqs_y = torch.stack(freqs_y, dim=0)
    freqs = torch.stack([freqs_x, freqs_y], dim=0)
    return freqs

def init_t_xy(end_x: int, end_y: int):
    t = torch.arange(end_x * end_y, dtype=torch.float32)
    t_x = (t % end_x).float()
    t_y = torch.div(t, end_x, rounding_mode='floor').float()
    return t_x, t_y

def compute_axial_cis(dim: int, end_x: int, end_y: int, theta: float = 100.0):
    freqs_x = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    freqs_y = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float() / dim))
    t_x, t_y = init_t_xy(end_x, end_y)
    freqs_x = torch.outer(t_x, freqs_x)
    freqs_y = torch.outer(t_y, freqs_y)
    freqs_cis_x = torch.polar(torch.ones_like(freqs_x), freqs_x)
    freqs_cis_y = torch.polar(torch.ones_like(freqs_y), freqs_y)
    return torch.cat([freqs_cis_x, freqs_cis_y], dim=-1)

class rope_vit_models(nn.Module):
    def __init__(self,
                 rope_theta=100.0,
                 rope_mixed=False,
                 use_ape=False,
                 img_size=224,
                 patch_size=16,
                 in_chans=3,
                 num_classes=1000,
                 embed_dim=768,
                 depth=12,
                 num_heads=8,
                 mlp_ratio=4.,
                 qkv_bias=False,
                 qk_scale=None,
                 drop_rate=0.,
                 attn_drop_rate=0.,
                 drop_path_rate=0.,
                 norm_layer=nn.LayerNorm,
                 global_pool=None,
                 block_layers=None,
                 Patch_layer=PatchEmbed,
                 act_layer=nn.GELU,
                 Attention_block=None,
                 Mlp_block=Mlp,
                 dpr_constant=True,
                 init_scale=1e-4,
                 mlp_ratio_clstk=4.0,
                 **kwargs):
        super().__init__()
        self.dropout_rate = drop_rate
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim
        self.patch_embed = Patch_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
        num_patches = self.patch_embed.num_patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        dpr = [drop_path_rate for i in range(depth)]
        self.blocks = nn.ModuleList([
            block_layers(
                dim=embed_dim,
                num_heads=num_heads,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias,
                qk_scale=qk_scale,
                drop=0.0,
                attn_drop=attn_drop_rate,
                drop_path=dpr[i],
                norm_layer=norm_layer,
                act_layer=act_layer,
                Attention_block=Attention_block,
                Mlp_block=Mlp_block,
                init_values=init_scale
            ) for i in range(depth)])
        self.norm = norm_layer(embed_dim)
        self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        trunc_normal_(self.pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)
        self.use_ape = use_ape
        if not self.use_ape:
            self.pos_embed = None
        self.rope_mixed = rope_mixed
        self.num_heads = num_heads
        self.patch_size = patch_size
        if self.rope_mixed:
            self.compute_cis = partial(compute_mixed_cis, num_heads=self.num_heads)
            freqs = []
            for i, _ in enumerate(self.blocks):
                freqs.append(init_random_2d_freqs(dim=embed_dim // num_heads, num_heads=num_heads, theta=rope_theta))
            freqs = torch.stack(freqs, dim=1).view(2, len(self.blocks), -1)
            self.freqs = nn.Parameter(freqs.clone(), requires_grad=True)
            t_x, t_y = init_t_xy(end_x = img_size // patch_size, end_y = img_size // patch_size)
            self.register_buffer('freqs_t_x', t_x)
            self.register_buffer('freqs_t_y', t_y)
        else:
            self.compute_cis = partial(compute_axial_cis, dim=embed_dim//num_heads, theta=rope_theta)
            freqs_cis = self.compute_cis(end_x = img_size // patch_size, end_y = img_size // patch_size)
            self.freqs_cis = freqs_cis

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):
        return {'pos_embed', 'cls_token', 'freqs'}
    
    def get_num_layers(self):
        return len(self.blocks)

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        
    def forward_features(self, x):
        B, C, H, W = x.shape
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        if self.use_ape:
            pos_embed = self.pos_embed
            if pos_embed.shape[-2] != x.shape[-2]:
                img_size = self.patch_embed.img_size
                patch_size = self.patch_embed.patch_size
                pos_embed = pos_embed.view(1, (img_size[1] // patch_size[1]), (img_size[0] // patch_size[0]), self.embed_dim).permute(0, 3, 1, 2)
                pos_embed = F.interpolate(pos_embed, size=(H // patch_size[1], W // patch_size[0]), mode='bicubic', align_corners=False)
                pos_embed = pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
            x = x + pos_embed
        x = torch.cat((cls_tokens, x), dim=1)
        if self.rope_mixed:
            if self.freqs_t_x.shape[0] != x.shape[1] - 1:
                t_x, t_y = init_t_xy(end_x = W // self.patch_size, end_y = H // self.patch_size)
                t_x, t_y = t_x.to(x.device), t_y.to(x.device)
            else:
                t_x, t_y = self.freqs_t_x, self.freqs_t_y
            freqs_cis = self.compute_cis(self.freqs, t_x, t_y)
            for i, blk in enumerate(self.blocks):
                x = blk(x, freqs_cis=freqs_cis[i])
        else:
            if self.freqs_cis.shape[0] != x.shape[1] - 1:
                freqs_cis = self.compute_cis(end_x = W // self.patch_size, end_y = H // self.patch_size)
            else:
                freqs_cis = self.freqs_cis
            freqs_cis = freqs_cis.to(x.device)
            for i , blk in enumerate(self.blocks):
                x = blk(x, freqs_cis=freqs_cis)
        x = self.norm(x)
        x = x[:, 0]
        return x

    def forward(self, x):
        x = self.forward_features(x)
        if self.dropout_rate:
            x = F.dropout(x, p=float(self.dropout_rate), training=self.training)
        x = self.head(x)
        return x

In [None]:
# model = rope_vit_models(
#         img_size=224,
#         patch_size=16,
#         embed_dim=384,
#         depth=12,
#         num_heads=6,
#         mlp_ratio=4,
#         qkv_bias=True,
#         norm_layer=partial(nn.LayerNorm, eps=1e-6),
#         block_layers=RoPE_Layer_scale_init_Block,
#         Attention_block=RoPEAttention,
#         rope_theta=100.0,
#         rope_mixed=False)
model = rope_vit_models(
        img_size=224,
        patch_size=16,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4,
        qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6),
        block_layers=RoPE_Layer_scale_init_Block,
        Attention_block=RoPEAttention,
        rope_theta=10.0,
        rope_mixed=True,
        use_ape=True)

input = torch.randn(1, 3, 224, 224)
output = model(input)
print(output.shape)