In [6]:
import os 
import re

def extract_top_n_accuracies_with_firstline(log_folder, top_n=10):
    entries = []

    for root, dirs, files in os.walk(log_folder):
        for file in files:
            if file == "info_log.txt":
                full_path = os.path.join(root, file)
                try:
                    with open(full_path, 'r') as f:
                        lines = f.readlines()
                        first_line = lines[0].strip() if lines else ""
                        for line in lines:
                            match = re.search(r"Accuracy:\s*(\d+\.\d+)", line)
                            if match:
                                accuracy = float(match.group(1))
                                entries.append({
                                    'Accuracy': accuracy,
                                    'LogPath': full_path,
                                    'FirstLine': first_line,
                                    'Line': line.strip()
                                })
                except Exception as e:
                    print(f"⚠️ Error reading {full_path}: {e}")

    if not entries:
        print("❌ No accuracy entries found.")
        return []

    # Sort by accuracy
    sorted_entries = sorted(entries, key=lambda x: x['Accuracy'], reverse=True)[:top_n]
    return sorted_entries



In [9]:
# Extract entries
top_entries = extract_top_n_accuracies_with_firstline("/scratch/namrata/dog_heart/models", top_n=50)

# Write to TXT file
output_txt_path = "top_accuracies_with_context.txt"
with open(output_txt_path, 'w') as f:
    for i, entry in enumerate(top_entries, start=1):
        f.write(f"#{i}\n")
        f.write(f"Accuracy: {entry['Accuracy']}\n")
        f.write(f"LogPath: {entry['LogPath']}\n")
        f.write(f"FirstLine: {entry['FirstLine']}\n")
        f.write(f"Line: {entry['Line']}\n")
        f.write("\n")

output_txt_path


'top_accuracies_with_context.txt'

In [15]:
import os
import sys

def remove_empty_second_level_folders(source_folder):
    """
    Removes only second-level subfolders if they (or their subdirectories) do not contain any .pth files.
    """
    # Check if the source folder exists
    if not os.path.exists(source_folder):
        print(f"❌ Source folder '{source_folder}' does not exist.")
        return

    # Iterate over each second-level folder
    for second_level_folder in os.listdir(source_folder):
        second_level_path = os.path.join(source_folder, second_level_folder)
        print(second_level_path)

        # Skip if not a directory
        if not os.path.isdir(second_level_path):
            continue

        # Check for .pth files in the entire second-level directory tree
        has_pth_files = False
        for root, dirs, files in os.walk(second_level_path):
            if any(file.endswith(".pth") for file in files):
                has_pth_files = True
                break
        
        # If no .pth files found, remove the second-level folder
        if not has_pth_files:
            try:
                # Use shutil.rmtree to remove the entire directory tree
                import shutil
                shutil.rmtree(second_level_path)
                print(f"🗑️ Removed folder without models: {second_level_path}")
            except Exception as e:
                print(f"⚠️ Error removing {second_level_path}: {e}")

if __name__ == "__main__":
    if len(sys.argv) != 2:
        print("Usage: python remove_empty_second_level.py <source_folder>")
        sys.exit(1)
    
    source_folder = "/scratch/namrata/dog_heart/models/custom"
    remove_empty_second_level_folders(source_folder)




/scratch/namrata/dog_heart/models/custom/splits
🗑️ Removed folder without models: /scratch/namrata/dog_heart/models/custom/splits
/scratch/namrata/dog_heart/models/custom/CM_Mamba_Trial2.py
/scratch/namrata/dog_heart/models/custom/20250504_041051
/scratch/namrata/dog_heart/models/custom/20250503_050118
/scratch/namrata/dog_heart/models/custom/20250505_053604
/scratch/namrata/dog_heart/models/custom/20250507_062810
🗑️ Removed folder without models: /scratch/namrata/dog_heart/models/custom/20250507_062810
/scratch/namrata/dog_heart/models/custom/20250507_063322
/scratch/namrata/dog_heart/models/custom/20250503_035821
/scratch/namrata/dog_heart/models/custom/20250503_050421
/scratch/namrata/dog_heart/models/custom/20250504_012517
/scratch/namrata/dog_heart/models/custom/mamba
🗑️ Removed folder without models: /scratch/namrata/dog_heart/models/custom/mamba
/scratch/namrata/dog_heart/models/custom/CM_Mamba_Attn.py
/scratch/namrata/dog_heart/models/custom/evaluate.py
/scratch/namrata/dog_hea

## Generate Predictions

In [17]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "5" 

import random
import logging
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from scipy.io import loadmat, savemat

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.nn import L1Loss, CrossEntropyLoss
from torch.cuda.amp import GradScaler, autocast

import torchvision
import torchvision.transforms as T
import torchvision.models as models
from collections import Counter
from torchvision import models
import os
import sys
device = 'cuda' if torch.cuda.is_available() else 'cpu'

import torch.nn.functional as F
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from einops import rearrange



In [38]:
class DogHeartTestDataset(Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        self.imgs = sorted([f for f in os.listdir(root) if f.endswith(('png', 'jpg'))])

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.imgs[idx])
        img = Image.open(img_path).convert("RGB")
        if self.transforms:
            img = self.transforms(img)
        return img, self.imgs[idx]

    def __len__(self):
        return len(self.imgs)

In [41]:
# ====================
# Transform
# ====================
def get_transform(size):
    return T.Compose([
        T.ToTensor(),
        T.Resize((size, size)),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

# =======================
# Dataset
# ======================
resized_image_size = 512
true_batch_size = 256
accumulation_steps = 8
root_db_folder = "models/data"
transform = get_transform(resized_image_size)

dataset_test = DogHeartTestDataset(f'{root_db_folder}/Test_Images/Images', transform)
test_loader = DataLoader(dataset_test, batch_size=16, shuffle=False, num_workers=4)

In [42]:
print(f"Number of samples in the test dataset: {len(dataset_test)}")

# ====================
# VHS Utilities
# ====================
def calc_vhs(x):
    A, B, C, D, E, F = x[..., :2], x[..., 2:4], x[..., 4:6], x[..., 6:8], x[..., 8:10], x[..., 10:12]
    AB = torch.norm(A - B, p=2, dim=-1)
    CD = torch.norm(C - D, p=2, dim=-1)
    EF = torch.norm(E - F, p=2, dim=-1)
    return 6 * (AB + CD) / EF

def get_labels(vhs):
    return ((vhs >= 10).long() - (vhs < 8.2).long() + 1).squeeze()


Number of samples in the test dataset: 4274


In [23]:
def test_accuracy(model, test_loader, device, calc_vhs, get_labels, logger=None):
    model.eval()
    total_correct = 0
    total_samples = 0

    all_true_labels = []
    all_pred_labels = []

    progress = tqdm(test_loader, desc="Testing", leave=False, unit="batch")

    with torch.no_grad():
        for batch_idx, (idx, images, _, vhs) in enumerate(progress):
            images = images.to(device)
            vhs = vhs.to(device)

            outputs = model(images).squeeze()
            pred_vhs = calc_vhs(outputs).squeeze()

            pred_classes = get_labels(pred_vhs).cpu().numpy()
            true_classes = get_labels(vhs).cpu().numpy()

            total_correct += (pred_classes == true_classes).sum()
            total_samples += len(true_classes)

            all_true_labels.extend(true_classes)
            all_pred_labels.extend(pred_classes)

            progress.set_postfix({
                "Accuracy": f"{(total_correct / total_samples):.4f}"
            })

    progress.close()
    accuracy = total_correct / total_samples

    class_names = ['< 8.2', '8.2 - 10', '>= 10']
    #report = classification_report(all_true_labels, all_pred_labels, target_names=class_names, digits=4)
    
    print(f"Test Accuracy: {accuracy:.4f}")
    return accuracy


In [33]:
# ====================
# Model
# ====================
class SELayer(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.se(x)

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim),
            nn.SiLU(),
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim)
        )
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(x + self.block(x))

class MambaBlock(nn.Module):
    def __init__(self, dim, d_state=16, expand=2):
        super().__init__()
        self.d_inner = dim * expand
        self.in_proj = nn.Linear(dim, self.d_inner)
        self.out_proj = nn.Linear(self.d_inner, dim)
        self.A = nn.Parameter(torch.randn(self.d_inner, d_state))
        self.B = nn.Parameter(torch.randn(self.d_inner, d_state))
        self.C = nn.Parameter(torch.randn(self.d_inner, d_state))
        self.D = nn.Parameter(torch.ones(self.d_inner))

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.view(B, C, -1).transpose(1, 2)
        x = self.in_proj(x)
        x = x.transpose(1, 2).view(B, self.d_inner, H, W)
        x = self.out_proj(x.view(B, self.d_inner, -1).transpose(1, 2))
        x = x.transpose(1, 2).view(B, -1, H, W)
        return x

class Downsample(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.down = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )

    def forward(self, x):
        return self.down(x)

class MambaStem(nn.Module):
    def __init__(self, in_ch=3, out_ch=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )

    def forward(self, x):
        return self.conv(x)

class MambaStage(nn.Module):
    def __init__(self, in_ch, out_ch, depth=2):
        super().__init__()
        self.down = Downsample(in_ch, out_ch)
        self.blocks = nn.Sequential(*[nn.Sequential(ResidualBlock(out_ch), MambaBlock(out_ch)) for _ in range(depth)])
        self.se = SELayer(out_ch)

    def forward(self, x):
        x = self.down(x)
        x = self.blocks(x)
        x = self.se(x)
        return x

class MambaKeypointRegressor(nn.Module):
    def __init__(self, in_ch=3, num_points=12):
        super().__init__()
        self.stem = MambaStem(in_ch, 64)
        self.stage1 = MambaStage(64, 128, 2)
        self.stage2 = MambaStage(128, 256, 3)
        self.stage3 = MambaStage(256, 512, 3)
        self.stage4 = MambaStage(512, 640, 3)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(640, 384),
            nn.ReLU(),
            nn.Linear(384, num_points)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.pool(x)
        return self.regressor(x).view(x.size(0), -1)

# Load pretrained model weights 
checkpoint_path = '/scratch/namrata/dog_heart/models/custom2/20250507_044926/iter_10/models/bm_15.pth'

# Load teacher and student
model = MambaKeypointRegressor()
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu',weights_only=True))
model = model.to(device)


In [27]:
test_accuracy(model, test_loader, device, calc_vhs, get_labels)


                                                                              

Test Accuracy: 0.8802




0.8802058961160505

In [36]:

import matplotlib.pyplot as plt
import scipy.io as sio
import os
import torch
import numpy as np

PREDICT_FOLDER = "predictions"
os.makedirs(PREDICT_FOLDER, exist_ok=True)

In [46]:
IMAGE_FOLDER = "models/data/Test_Images/Images"
img_size = 512
total_images = len(test_loader.dataset)
processed_images = 0

model.eval()

def save_predictions_to_mat(image_name, predicted_points, vhs_value):
    """Save predicted points and VHS in MATLAB-compatible .mat format."""
    mat_filename = os.path.splitext(image_name)[0] + ".mat"
    mat_filepath = os.path.join(PREDICT_FOLDER, mat_filename)

    # Convert points to correct MATLAB format
    mat_data = {
        "six_points": np.array(predicted_points),  # Store 6 points in an array
        "VHS": np.array([[vhs_value]])  # Store VHS as a 2D array
    }

    sio.savemat(mat_filepath, mat_data)
    print(f"✅ Predictions saved in MATLAB format: {mat_filepath}")


def calculateVHS(A,B,C,D,E,F):
    # Calculate distances using Euclidean formula
    AB = np.linalg.norm(B - A)  
    CD = np.linalg.norm(D - C)  
    EF = np.linalg.norm(F - E) 

    # Calculate VHS
    VHS = 6 * (AB + CD) / EF
    return VHS
    
with torch.no_grad():
    for images, names in test_loader:
        images = images.to(device)
        outputs = model(images)
        outputs = outputs.cpu().numpy()
        outputs = outputs.reshape(outputs.shape[0], 6, 2)

        for i, points in enumerate(outputs):
            img = Image.open(f'{IMAGE_FOLDER}/{names[i]}')
            
            # Get original image size and return predicted points back to original points size
            w, h = img.size
            points = points.reshape(-1, 2)
            points = points * img_size
            
            points[:, 0] = w / img_size * points[:, 0]
            points[:, 1] = h / img_size * points[:, 1]
            
            vhs = calculateVHS(points[0], points[1], points[2], points[3], points[4], points[5])

            save_predictions_to_mat(names[i], points, vhs)

            processed_images += 1
            progress = round((processed_images * 100) / total_images, 0)
            

✅ Predictions saved in MATLAB format: predictions/11037.mat
✅ Predictions saved in MATLAB format: predictions/11037_3.mat
✅ Predictions saved in MATLAB format: predictions/11038.mat
✅ Predictions saved in MATLAB format: predictions/11038_3.mat
✅ Predictions saved in MATLAB format: predictions/11039.mat
✅ Predictions saved in MATLAB format: predictions/11040.mat
✅ Predictions saved in MATLAB format: predictions/11040_3.mat
✅ Predictions saved in MATLAB format: predictions/11041.mat
✅ Predictions saved in MATLAB format: predictions/11041_3.mat
✅ Predictions saved in MATLAB format: predictions/11042.mat
✅ Predictions saved in MATLAB format: predictions/11042_3.mat
✅ Predictions saved in MATLAB format: predictions/11043.mat
✅ Predictions saved in MATLAB format: predictions/11044.mat
✅ Predictions saved in MATLAB format: predictions/11045.mat
✅ Predictions saved in MATLAB format: predictions/11046.mat
✅ Predictions saved in MATLAB format: predictions/11047.mat
✅ Predictions saved in MATLAB 

In [None]:
#################### START : MODEL
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ====================
# Model
# ====================
class SELayer(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.se(x)

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim),
            nn.SiLU(),
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim)
        )
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(x + self.block(x))

class MambaBlock(nn.Module):
    def __init__(self, dim, d_state=16, expand=2):
        super().__init__()
        self.d_inner = dim * expand
        self.in_proj = nn.Linear(dim, self.d_inner)
        self.out_proj = nn.Linear(self.d_inner, dim)
        self.A = nn.Parameter(torch.randn(self.d_inner, d_state))
        self.B = nn.Parameter(torch.randn(self.d_inner, d_state))
        self.C = nn.Parameter(torch.randn(self.d_inner, d_state))
        self.D = nn.Parameter(torch.ones(self.d_inner))

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.view(B, C, -1).transpose(1, 2)
        x = self.in_proj(x)
        x = x.transpose(1, 2).view(B, self.d_inner, H, W)
        x = self.out_proj(x.view(B, self.d_inner, -1).transpose(1, 2))
        x = x.transpose(1, 2).view(B, -1, H, W)
        return x

class Downsample(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.down = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )

    def forward(self, x):
        return self.down(x)

class MambaStem(nn.Module):
    def __init__(self, in_ch=3, out_ch=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )

    def forward(self, x):
        return self.conv(x)

class MambaStage(nn.Module):
    def __init__(self, in_ch, out_ch, depth=2):
        super().__init__()
        self.down = Downsample(in_ch, out_ch)
        self.blocks = nn.Sequential(*[nn.Sequential(ResidualBlock(out_ch), MambaBlock(out_ch)) for _ in range(depth)])
        self.se = SELayer(out_ch)

    def forward(self, x):
        x = self.down(x)
        x = self.blocks(x)
        x = self.se(x)
        return x

class MambaKeypointRegressor(nn.Module):
    def __init__(self, in_ch=3, num_points=12):
        super().__init__()
        self.stem = MambaStem(in_ch, 64)
        self.stage1 = MambaStage(64, 128, 2)
        self.stage2 = MambaStage(128, 256, 3)
        self.stage3 = MambaStage(256, 512, 3)
        self.stage4 = MambaStage(512, 640, 3)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(640, 384),
            nn.ReLU(),
            nn.Linear(384, num_points)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.pool(x)
        return self.regressor(x).view(x.size(0), -1)

##=================

model = MambaKeypointRegressor()

# 2. Load the checkpoint
checkpoint_path = '20250507_004756/models/bm_5.pth'

# 3. Load model state dict
model.load_state_dict(torch.load(checkpoint_path, map_location='cpu',weights_only=True))

model = model.to(device)

In [1]:
import torch
from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
import torch.nn as nn
from mamba_ssm import Mamba

class ChannelSELayer(nn.Module):
    """Squeeze-and-Excitation layer for channel attention"""
    def __init__(self, channels, reduction=8):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, t, c = x.size()
        y = self.avg_pool(x.transpose(1, 2)).view(b, c)
        y = self.fc(y).view(b, 1, c)
        return x * y.expand_as(x)

class ResidualBlock(nn.Module):
    """Residual block with layer normalization and dropout"""
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.mamba = Mamba(
            d_model=dim,
            d_state=16,
            d_conv=4,
            expand=2,
        )
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
        self.se = ChannelSELayer(dim)

    def forward(self, x):
        residual = x
        x = self.norm(x)
        x = self.mamba(x)
        x = self.se(x)  # Apply channel attention
        x = self.dropout(x)
        return residual + x

class EnhancedMambaRegressor(nn.Module):
    def __init__(self, input_dim, hidden_dim=128, output_dim=12, num_layers=4):
        super().__init__()
        self.hidden_dim = hidden_dim
        
        # Input processing
        self.input_proj = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.GELU(),
            nn.Dropout(0.1)
        )
        
        # Stack of Mamba residual blocks
        self.blocks = nn.Sequential(*[
            ResidualBlock(hidden_dim) 
            for _ in range(num_layers)
        ])
        
        # Output processing
        self.output_proj = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, output_dim)
        )
        
        # Skip connection from input to output
        self.skip = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        # x shape: [batch, seq_len, input_dim]
        skip = self.skip(x.mean(dim=1))  # Average over sequence
        
        # Process through main network
        x = self.input_proj(x)
        x = self.blocks(x)
        
        # Take last timestep and project
        x = x[:, -1, :]
        x = self.output_proj(x)
        
        # Combine with skip connection
        return x + skip

# Example usage
if __name__ == "__main__":
    # Configuration
    input_dim = 32
    batch_size = 4
    seq_len = 20
    
    # Create model
    model = EnhancedMambaRegressor(input_dim=input_dim)
    
    # Test input
    x = torch.randn(batch_size, seq_len, input_dim)
    output = model(x)
    
    print(f"Input shape: {x.shape}")
    print(f"Output shape: {output.shape}")  # [batch_size, 12]

RuntimeError: Expected x.is_cuda() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba

class SELayer(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.se(x)

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim),
            nn.SiLU(),
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim)
        )
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(x + self.block(x))

class Downsample(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.down = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )

    def forward(self, x):
        return self.down(x)

class MambaStem(nn.Module):
    def __init__(self, in_ch=3, out_ch=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )

    def forward(self, x):
        return self.conv(x)

class MambaStage(nn.Module):
    def __init__(self, in_ch, out_ch, depth=2, d_state=16, expand=2):
        super().__init__()
        self.down = Downsample(in_ch, out_ch)
        self.blocks = nn.ModuleList([
            nn.Sequential(
                ResidualBlock(out_ch),
                nn.LayerNorm(out_ch),
                Mamba(d_model=out_ch, d_state=d_state, expand=expand)
            )
            for _ in range(depth)
        ])
        self.se = SELayer(out_ch)

    def forward(self, x):
        # Downsample to increase channels
        x = self.down(x)  # (B, C, H, W)
        B, C, H, W = x.shape
        
        # Flatten for Mamba (B, C, H, W) -> (B, L, C)
        x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)  # (B, L, C)
        
        # Apply Mamba blocks
        for block in self.blocks:
            x = block(x)
        
        # Reshape back to 4D (B, C, H, W)
        x = x.view(B, H, W, C).permute(0, 3, 1, 2)  # (B, C, H, W)
        
        # Apply SE Layer
        x = self.se(x)
        return x


class MambaKeypointRegressor(nn.Module):
    def __init__(self, in_ch=3, num_points=12):
        super().__init__()
        self.stem = MambaStem(in_ch, 64)
        self.stage1 = MambaStage(64, 128, 2)
        self.stage2 = MambaStage(128, 256, 3)
        self.stage3 = MambaStage(256, 512, 3)
        self.stage4 = MambaStage(512, 640, 3)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(640, 384),
            nn.ReLU(),
            nn.Linear(384, num_points)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.pool(x)
        return self.regressor(x).view(x.size(0), -1)

model = MambaKeypointRegressor()
x = torch.randn(1, 3, 512, 512)
out = model(x)
print(out.shape)

RuntimeError: Given groups=1, weight of size [128, 128, 3, 3], expected input[1, 1, 16384, 128] to have 128 channels, but got 1 channels instead

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from mamba_ssm import Mamba

class SELayer(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(channels, channels // reduction, 1),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return x * self.se(x)

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim),
            nn.SiLU(),
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.BatchNorm2d(dim)
        )
        self.act = nn.SiLU()

    def forward(self, x):
        return self.act(x + self.block(x))

class Downsample(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.down = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )

    def forward(self, x):
        return self.down(x)

class MambaStem(nn.Module):
    def __init__(self, in_ch=3, out_ch=64):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, stride=2, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU(),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.SiLU()
        )

    def forward(self, x):
        return self.conv(x)

class MambaStage(nn.Module):
    def __init__(self, in_ch, out_ch, depth=2, d_state=16, expand=2):
        super().__init__()
        self.down = Downsample(in_ch, out_ch)
        self.blocks = nn.ModuleList([
            nn.Sequential(
                ResidualBlock(out_ch),
                nn.Sequential(
                    nn.LayerNorm(out_ch),
                    Mamba(
                        d_model=out_ch, 
                        d_state=d_state, 
                        expand=expand
                    )
                )
            )
            for _ in range(depth)
        ])
        self.se = SELayer(out_ch)

    def forward(self, x):
        x = self.down(x)
        B, C, H, W = x.shape
        
        for block in self.blocks:
            # Residual block operates on 4D tensors
            x = block[0](x)
            
            # Prepare for Mamba (B, C, H, W) -> (B, L, C)
            x_mamba = x.permute(0, 2, 3, 1).reshape(B, H*W, C)
            
            # Apply LayerNorm and Mamba
            x_mamba = block[1](x_mamba)
            
            # Reshape back to 4D (B, L, C) -> (B, C, H, W)
            x = x_mamba.reshape(B, H, W, C).permute(0, 3, 1, 2)
        
        x = self.se(x)
        return x

class MambaKeypointRegressor(nn.Module):
    def __init__(self, in_ch=3, num_points=12):
        super().__init__()
        self.stem = MambaStem(in_ch, 64)
        self.stage1 = MambaStage(64, 128, 2)
        self.stage2 = MambaStage(128, 256, 3)
        self.stage3 = MambaStage(256, 512, 3)
        self.stage4 = MambaStage(512, 640, 3)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(640, 384),
            nn.ReLU(),
            nn.Linear(384, num_points)  # Directly output 12 values
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.pool(x)
        return self.regressor(x)  # Output shape [1, 12]

# Test the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MambaKeypointRegressor().to(device)
x = torch.randn(1, 3, 512, 512).to(device)
out = model(x)
print(out.shape)  # Now outputs [1, 12]

torch.Size([1, 12])
