In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
import os, sys
sys.path.append(os.path.abspath(os.path.join('..')))

from sam2.build_sam import build_sam2


class Adapter(nn.Module):
    def __init__(self, blk) -> None:
        super(Adapter, self).__init__()
        self.block = blk
        dim = blk.attn.qkv.in_features
        self.prompt_learn = nn.Sequential(
            nn.Linear(dim, 32),
            nn.GELU(),
            nn.Linear(32, dim),
            nn.GELU()
        )

    def forward(self, x):
        prompt = self.prompt_learn(x)
        promped = x + prompt
        net = self.block(promped)
        return net

  from torch.distributed.optim import \


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MemoryAttentionModule(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        
        # Self-attention for memory features
        self.self_attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Cross-attention between current frame and memory
        self.cross_attention = nn.MultiheadAttention(
            embed_dim=embed_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        # Layer normalization and feed-forward network
        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.norm3 = nn.LayerNorm(embed_dim)
        
        self.ffn = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(4 * embed_dim, embed_dim)
        )

    def forward(self, current_features, memory_features):
        """
        Args:
            current_features: Features of current frame [B, H*W, C]
            memory_features: Features of memory frames [B, T*H*W, C]
        Returns:
            Enhanced current features with memory information
        """
        # Self-attention on memory features
        memory_out = self.self_attention(
            memory_features,
            memory_features,
            memory_features
        )[0]
        memory_out = self.norm1(memory_features + memory_out)
        
        # Cross-attention between current features and memory
        current_out = self.cross_attention(
            current_features,
            memory_out,
            memory_out
        )[0]
        current_out = self.norm2(current_features + current_out)
        
        # Feed-forward network
        ffn_out = self.ffn(current_out)
        output = self.norm3(current_out + ffn_out)
        
        return output

class VideoSegmentationModel(nn.Module):
    def __init__(self, backbone, embed_dim=256, num_heads=8):
        super().__init__()
        self.backbone = backbone
        self.memory_module = MemoryAttentionModule(
            embed_dim=embed_dim,
            num_heads=num_heads
        )
        
        self.memory_buffer = None
        self.buffer_size = 5  # Number of frames to keep in memory
        
    def update_memory(self, features):
        """Update memory buffer with new features"""
        if self.memory_buffer is None:
            self.memory_buffer = features.unsqueeze(1)
        else:
            self.memory_buffer = torch.cat([
                self.memory_buffer,
                features.unsqueeze(1)
            ], dim=1)
            
            # Keep only recent frames
            if self.memory_buffer.size(1) > self.buffer_size:
                self.memory_buffer = self.memory_buffer[:, -self.buffer_size:]

    def forward(self, x):
        """
        Args:
            x: Input frame [B, C, H, W]
        """
        # Extract features using backbone
        features = self.backbone(x)
        B, C, H, W = features.shape
        
        # Reshape features to sequence format
        features = features.flatten(2).transpose(1, 2)  # [B, H*W, C]
        
        # Update memory buffer
        self.update_memory(features)
        
        # Flatten memory buffer for attention
        memory_features = self.memory_buffer.view(
            B, -1, C
        )  # [B, T*H*W, C]
        
        # Apply memory attention
        enhanced_features = self.memory_module(features, memory_features)
        
        # Reshape back to spatial format
        output = enhanced_features.transpose(1, 2).view(B, C, H, W)
        
        return output

In [7]:

class SAM2Adapt(nn.Module):
    def __init__(self, checkpoint_path=None) -> None:
        super(SAM2Adapt, self).__init__()    
        checkpoint_path = "../sam2/checkpoints/sam2.1_hiera_large.pt"
        model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
        if checkpoint_path:
            model = build_sam2(model_cfg, checkpoint_path)
        else:
            model = build_sam2(model_cfg)
        del model.sam_mask_decoder
        del model.sam_prompt_encoder
        del model.memory_encoder
        del model.memory_attention
        del model.mask_downsample
        del model.obj_ptr_tpos_proj
        del model.obj_ptr_proj
        del model.image_encoder.neck

        self.encoder = model.image_encoder.trunk
        for param in self.encoder.parameters():
            param.requires_grad = False
        blocks = []
        for block in self.encoder.blocks:
            blocks.append(
                Adapter(block)
            )
        self.encoder.blocks = nn.Sequential(
            *blocks
        )

        # self.memory_attention = MemoryAttentionModule(embed_dim=)
        


    def get_memory(self,x):
        T, B, C, H, W = x.shape
        with torch.no_grad():
            x = self.encoder(x.contiguous().view(-1,C,H,W))
        
        _, C, H, W = x[-1].shape 
        return x[-1].view(T, B, C, H, W)

    def forward(self, x):
        x_curr, x_memo = x[0], x[1:]
        f_curr = self.encoder(x_curr)

        return f_curr

In [5]:
model = SAM2Adapt()

  OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()


In [8]:
last_layer = list(model.encoder.children())[-1]
last_layer

Sequential(
  (0): Adapter(
    (block): MultiScaleBlock(
      (norm1): LayerNorm((144,), eps=1e-06, elementwise_affine=True)
      (attn): MultiScaleAttention(
        (qkv): Linear(in_features=144, out_features=432, bias=True)
        (proj): Linear(in_features=144, out_features=144, bias=True)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((144,), eps=1e-06, elementwise_affine=True)
      (mlp): MLP(
        (layers): ModuleList(
          (0): Linear(in_features=144, out_features=576, bias=True)
          (1): Linear(in_features=576, out_features=144, bias=True)
        )
        (act): GELU(approximate='none')
      )
    )
    (prompt_learn): Sequential(
      (0): Linear(in_features=144, out_features=32, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=32, out_features=144, bias=True)
      (3): GELU(approximate='none')
    )
  )
  (1): Adapter(
    (block): MultiScaleBlock(
      (norm1): LayerNorm((144,), eps=1e-06, elementwise_affi

In [9]:
1152/8

144.0

In [16]:
with torch.no_grad():
    model = SAM2Adapt().cuda()
    x = torch.randn(8, 32, 3, 256, 256).cuda()
    f, pos = model(x)
    for x1, x2 in zip(f,pos):
        print(x1.shape,x2.shape)

torch.Size([32, 256, 64, 64]) torch.Size([32, 256, 64, 64])
torch.Size([32, 256, 32, 32]) torch.Size([32, 256, 32, 32])
torch.Size([32, 256, 16, 16]) torch.Size([32, 256, 16, 16])
torch.Size([32, 256, 8, 8]) torch.Size([32, 256, 8, 8])
