## ENCODER

In [1]:
import torch
from torch import nn
from mamba_ssm import Mamba


class FeedForward(nn.Module):
    def __init__(self, n_embed, dropout=0.2) -> None:
        super().__init__()
        self.ffn = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.ffn(x)


class MamabaBlock(nn.Module):
    def __init__(
        self,
        n_embed,
        d_state=16,
        d_conv=4,
        expand=2,
    ) -> None:
        super().__init__()
        self.sa_head = Mamba(
            # This module uses roughly 3 * expand * d_model^2 parameters
            d_model=n_embed,  # Model dimension d_model
            d_state=d_state,  # SSM state expansion factor
            d_conv=d_conv,  # Local convolution width
            expand=expand,  # Block expansion factor
        )
        self.ffn = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa_head(self.ln1(x))
        x = x + self.ffn(self.ln2(x))

        return x


class MambaEncoder(nn.Module):
    def __init__(self, n_embed, num_layers=6) -> None:
        super().__init__()
        self.layers = nn.ModuleList([MamabaBlock(n_embed)
                                    for _ in range(num_layers)])

    def forward(self, x, pos_enc=None):
        if pos_enc is not None:
            x = x + pos_enc
        for layer in self.layers:
            x = layer(x)
        return x


src = torch.randn(2, 30, 512)
encoder = MambaEncoder(512, num_layers=6).cuda()
out = encoder(src.cuda())
print(out.shape)

torch.Size([2, 30, 512])


## DECODER

In [2]:
from linformer import LinformerSelfAttention


class DecoderBlock(nn.Module):
    def __init__(self, n_embed, heads, max_seq_len):
        super().__init__()
        self.self_attn = LinformerSelfAttention(
            dim=n_embed,
            seq_len=max_seq_len,
            heads=heads,
            k=256,
            one_kv_head=True,
            share_kv=True,
        )

        self.cross_attn = LinformerSelfAttention(
            dim=n_embed,
            seq_len=max_seq_len,
            heads=heads,
            k=256,
            one_kv_head=True,
            share_kv=True,
        )
        self.ffn = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x, context, x_pos_enc=None, context_pos_enc=None):
        # Self-attention
        if x_pos_enc is not None:
            x = x + x_pos_enc
        z = x + self.self_attn(self.ln1(x))

        # Context attention
        if context_pos_enc is not None:
            context = context + context_pos_enc
        z = z + self.cross_attn(z, context)

        # Feedforward
        z = z + self.ffn(self.ln2(z))

        return z

class Decoder(nn.Module):
    def __init__(self, n_embed, heads, max_seq_len, num_decoder_layers=6):
        super().__init__()
        self.layers = nn.ModuleList([DecoderBlock(n_embed, heads,  max_seq_len)
                                    for _ in range(num_decoder_layers)])

    def forward(self, x, context, x_pos_enc=None, context_pos_enc=None):
        for layer in self.layers:
            x = layer(x, context, x_pos_enc, context_pos_enc)
        return x


# Dummy
mem = torch.randn(2, 625, 512)
obj_query = torch.randn(2, 100, 512)
dec = Decoder(512, 8, 625, num_decoder_layers=6)
z = dec(obj_query, mem)  # (1, 2048, 512)
z.shape

torch.Size([2, 100, 512])

## WHOLE

In [3]:
import torch
from torch import nn
from torchvision.models import resnet50
import torchvision.transforms as T


class DETRdemo(nn.Module):

    def __init__(
        self,
        num_classes,
        hidden_dim=256,
        nheads=8,
        num_encoder_layers=6,
        num_decoder_layers=6,
        ckpt=None,
    ):
        super().__init__()

        # create ResNet-50 backbone
        if ckpt is None:
            self.backbone = resnet50(weights='ResNet50_Weights.IMAGENET1K_V2')
        else:
            self.backbone = resnet50()
        del self.backbone.fc

        # create conversion layer
        self.conv = nn.Conv2d(2048, hidden_dim, 1)

        # create a default PyTorch transformer
        # self.transformer = nn.Transformer(
        #     hidden_dim, nheads, num_encoder_layers, num_decoder_layers
        # )
        self.enc = MambaEncoder(hidden_dim, num_encoder_layers)
        self.dec = Decoder(hidden_dim, nheads, 625, num_decoder_layers)

        # prediction heads, one extra class for predicting non-empty slots
        # note that in baseline DETR linear_bbox layer is 3-layer MLP
        self.linear_class = nn.Linear(hidden_dim, num_classes + 1)
        self.linear_bbox = nn.Linear(hidden_dim, 4)

        # output positional encodings (object queries)
        self.query_pos = nn.Parameter(torch.rand(100, hidden_dim))

        # spatial positional encodings
        # note that in baseline DETR we use sine positional encodings
        self.row_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))
        self.col_embed = nn.Parameter(torch.rand(50, hidden_dim // 2))

    def forward(self, inputs):
        # propagate inputs through ResNet-50 up to avg-pool layer
        x = self.backbone.conv1(inputs)
        x = self.backbone.bn1(x)
        x = self.backbone.relu(x)
        x = self.backbone.maxpool(x)

        x = self.backbone.layer1(x)
        x = self.backbone.layer2(x)
        x = self.backbone.layer3(x)
        x = self.backbone.layer4(x)

        # convert from 2048 to 256 feature planes for the transformer
        h = self.conv(x)

        # construct positional encodings
        H, W = h.shape[-2:]
        pos = (
            torch.cat(
                [
                    self.col_embed[:W].unsqueeze(0).repeat(H, 1, 1),
                    self.row_embed[:H].unsqueeze(1).repeat(1, W, 1),
                ],
                dim=-1,
            )
            .flatten(0, 1)
            .unsqueeze(1)
        ).permute(1, 0, 2)

        h = h.flatten(2).permute(0, 2, 1)
        q = self.query_pos.unsqueeze(1).repeat(
            1, h.shape[0], 1).permute(1, 0, 2)

        # propagate through the transformer
        mem = self.enc(h, pos)
        h = self.dec(q, mem)

        # finally project transformer outputs to class labels and bounding boxes
        return {
            "pred_logits": self.linear_class(h),
            "pred_boxes": self.linear_bbox(h).sigmoid(),
        }


model = DETRdemo(num_classes=91).cuda()

# dummy
model.eval()
dummy = torch.rand(1, 3, 800, 800).cuda()
out = model(dummy)
print(out['pred_logits'].shape, out['pred_boxes'].shape)

torch.Size([1, 100, 92]) torch.Size([1, 100, 4])


In [4]:
# count total paremeters
total_params = sum(p.numel() for p in model.parameters())
print(f'{total_params:,} total parameters.')

36,635,552 total parameters.
