<a href="https://colab.research.google.com/github/nidhisoley/DL_Final_Project_2024/blob/main/medmamba.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## DL Project - Reading in Data

This is a script to read in the ACDC dataset. Download the dataset from this url:

https://humanheart-project.creatis.insa-lyon.fr/database/#collection/637218c173e9f0047faa00fb

Then edit the paths below to match where the training and testing data is (these folders should be located in the downloaded data). Currently extracting just 1 of the 7-9 images for each subject (note: the index is 2 bc index 0 didn't always have a gt with it), so this is something we can change if we need.

Should see 200 training images and 100 testing images.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Edit these variables to match your setup
dataset_path_training = '/content/drive/My Drive/python/Resources/training'

dataset_path_testing = '/content/drive/My Drive/python/Resources/testing'

# dataset_path_training = '/Users/calebhallinan/Desktop/jhu/classes/deep_learning/DL_Final_Project_2024/data/ACDC/training'
# dataset_path_testing = '/Users/calebhallinan/Desktop/jhu/classes/deep_learning/DL_Final_Project_2024/data/ACDC/testing'

In [None]:
### Import packages

import os
import nibabel as nib
import numpy as np
import re
from skimage.transform import resize


### Functions to load in the data ###

# Regular expression to extract the patient number and frame number from filenames
filename_pattern = re.compile(r'patient(\d+)_frame(\d+)(_gt)?\.nii\.gz')

# Function to get sorting key from the filename
def get_sort_key(filepath):
    match = filename_pattern.search(os.path.basename(filepath))
    if match:
        patient_num = int(match.group(1))
        frame_num = int(match.group(2))
        return (patient_num, frame_num)
    else:
        raise ValueError(f'Filename does not match expected pattern: {filepath}')


In [None]:
### Read in training data ###

# Lists to hold the file paths for images and ground truths
image_file_paths_train = []
ground_truth_file_paths_train = []

# Walk through the directory and collect all relevant file paths
for root, dirs, files in os.walk(dataset_path_training):
    for file in files:
        if 'frame' in file:
            full_path = os.path.join(root, file)
            if '_gt' in file:
                ground_truth_file_paths_train.append(full_path)
            else:
                image_file_paths_train.append(full_path)

# Sort the file paths to ensure alignment
image_file_paths_train.sort(key=get_sort_key)
ground_truth_file_paths_train.sort(key=get_sort_key)

# Check to make sure each image has a corresponding ground truth
assert len(image_file_paths_train) == len(ground_truth_file_paths_train)
for img_path, gt_path in zip(image_file_paths_train, ground_truth_file_paths_train):
    assert get_sort_key(img_path) == get_sort_key(gt_path), "Mismatch between image and ground truth files"

# Load the images and ground truths into numpy arrays
# using 2 index bc not all 0 index had a gt
images_train = [resize(nib.load(path).get_fdata()[:,:,2], (224,224)) for path in image_file_paths_train]
ground_truths_train = [resize(nib.load(path).get_fdata()[:,:,2], (224,224)) for path in ground_truth_file_paths_train]

# Stack the arrays into 4D numpy arrays
images_array_train = np.stack(images_train)
ground_truths_array_train = np.stack(ground_truths_train)

print(f'Training Images array shape: {images_array_train.shape}')
print(f'Training Ground truths array shape: {ground_truths_array_train.shape}')

Training Images array shape: (200, 224, 224)
Training Ground truths array shape: (200, 224, 224)


In [None]:
### Read in testing data ###

# Lists to hold the file paths for images and ground truths
image_file_paths_test = []
ground_truth_file_paths_test = []

# Walk through the directory and collect all relevant file paths
for root, dirs, files in os.walk(dataset_path_testing):
    for file in files:
        if 'frame' in file:
            full_path = os.path.join(root, file)
            if '_gt' in file:
                ground_truth_file_paths_test.append(full_path)
            else:
                image_file_paths_test.append(full_path)

# Sort the file paths to ensure alignment
image_file_paths_test.sort(key=get_sort_key)
ground_truth_file_paths_test.sort(key=get_sort_key)

# Check to make sure each image has a corresponding ground truth
assert len(image_file_paths_test) == len(ground_truth_file_paths_test)
for img_path, gt_path in zip(image_file_paths_test, ground_truth_file_paths_test):
    assert get_sort_key(img_path) == get_sort_key(gt_path), "Mismatch between image and ground truth files"

# Load the images and ground truths into numpy arrays
# using 2 index bc not all 0 index had a gt
images_test = [resize(nib.load(path).get_fdata()[:,:,2], (224,224)) for path in image_file_paths_test]
ground_truths_test = [resize(nib.load(path).get_fdata()[:,:,2], (224,224)) for path in ground_truth_file_paths_test]

# Stack the arrays into 4D numpy arrays
images_array_test = np.stack(images_test)
ground_truths_array_test = np.stack(ground_truths_test)

print(f'Test Images array shape: {images_array_test.shape}')
print(f'Test Ground truths array shape: {ground_truths_array_test.shape}')

Test Images array shape: (100, 224, 224)
Test Ground truths array shape: (100, 224, 224)


UNET

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

def unet_model(input_shape):
    inputs = tf.keras.Input(shape=input_shape)

    # Downscaling path
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(inputs)
    conv1 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv1)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(pool1)
    conv2 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv2)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(pool2)
    conv3 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv3)
    pool3 = layers.MaxPooling2D(pool_size=(2, 2))(conv3)

    # Bottleneck
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(pool3)
    conv4 = layers.Conv2D(512, 3, activation='relu', padding='same')(conv4)

    # Upscaling path
    up5 = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same')(conv4)
    up5 = layers.concatenate([up5, conv3], axis=3)
    conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(up5)
    conv5 = layers.Conv2D(256, 3, activation='relu', padding='same')(conv5)

    up6 = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same')(conv5)
    up6 = layers.concatenate([up6, conv2], axis=3)
    conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(up6)
    conv6 = layers.Conv2D(128, 3, activation='relu', padding='same')(conv6)

    up7 = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(conv6)
    up7 = layers.concatenate([up7, conv1], axis=3)
    conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(up7)
    conv7 = layers.Conv2D(64, 3, activation='relu', padding='same')(conv7)

    # Output layer
    outputs = layers.Conv2D(1, 1, activation='sigmoid')(conv7)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

# Define input shape
input_shape = (224, 224, 1)

# Create the model
model = unet_model(input_shape)

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Display model summary
model.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 224, 224, 1)]        0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 224, 224, 64)         640       ['input_1[0][0]']             
                                                                                                  
 conv2d_1 (Conv2D)           (None, 224, 224, 64)         36928     ['conv2d[0][0]']              
                                                                                                  
 max_pooling2d (MaxPooling2  (None, 112, 112, 64)         0         ['conv2d_1[0][0]']            
 D)                                                                                           

In [None]:
# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(images_array_train, ground_truths_array_train, batch_size=32, epochs=10, validation_split=0.2)

# Evaluate the model on test data
test_loss, test_accuracy = model.evaluate(images_array_test, ground_truths_array_test)
print(f'Test Loss: {test_loss}, Test Accuracy: {test_accuracy}')

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Test Loss: 1.6920727491378784, Test Accuracy: 0.9398632645606995


unet++

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class UpConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpConvBlock, self).__init__()
        self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
        self.conv_block = ConvBlock(in_channels, out_channels)

    def forward(self, x, skip_connection):
        x = self.upconv(x)
        x = torch.cat([x, skip_connection], dim=1)
        x = self.conv_block(x)
        return x

class UNetPlusPlus(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(UNetPlusPlus, self).__init__()
        self.conv1 = ConvBlock(in_channels, 64)
        self.conv2 = ConvBlock(64, 128)
        self.conv3 = ConvBlock(128, 256)
        self.conv4 = ConvBlock(256, 512)
        self.conv5 = ConvBlock(512, 1024)

        self.upconv1 = UpConvBlock(1024, 512)
        self.upconv2 = UpConvBlock(512, 256)
        self.upconv3 = UpConvBlock(256, 128)
        self.upconv4 = UpConvBlock(128, 64)

        self.final_conv = nn.Conv2d(64, num_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.conv1(x)
        x2 = F.max_pool2d(x1, kernel_size=2, stride=2)
        x3 = self.conv2(x2)
        x4 = F.max_pool2d(x3, kernel_size=2, stride=2)
        x5 = self.conv3(x4)
        x6 = F.max_pool2d(x5, kernel_size=2, stride=2)
        x7 = self.conv4(x6)
        x8 = F.max_pool2d(x7, kernel_size=2, stride=2)
        x9 = self.conv5(x8)

        # Decoder
        x10 = self.upconv1(x9, x7)
        x11 = self.upconv2(x10, x5)
        x12 = self.upconv3(x11, x3)
        x13 = self.upconv4(x12, x1)

        # Final convolution
        output = self.final_conv(x13)
        return output

# Example usage:
input_channels = 1  # Example number of input channels (assuming grayscale images)
num_classes = 4  # Example number of classes for the ACDC dataset
model = UNetPlusPlus(input_channels, num_classes)
print(model)


UNetPlusPlus(
  (conv1): ConvBlock(
    (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (conv2): ConvBlock(
    (conv): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (conv3): ConvBlock(
    (conv): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (conv4): ConvBlock(
    (conv): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
  )
  (conv5): ConvBlock(
    (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=

In [None]:
# Define your model
model = UNetPlusPlus(input_channels=1, num_classes=num_classes)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * images.size(0)

    train_loss /= len(train_loader.dataset)

    # Validation
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * images.size(0)

    val_loss /= len(val_loader.dataset)

    print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")

# Save the trained model
torch.save(model.state_dict(), 'unetplusplus_acdc.pth')


TypeError: UNetPlusPlus.__init__() got an unexpected keyword argument 'input_channels'

MedMamba

In [None]:
import torch
import torch.nn as nn

class ConvSSMBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvSSMBlock, self).__init__()
        self.conv_branch = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )
        self.ssm_branch = nn.Sequential(
            nn.LayerNorm((out_channels, 4, 4)),
            nn.Linear(out_channels * 4 * 4, out_channels),
            nn.ReLU(inplace=True),
        )
        self.merge = nn.Sequential(
            nn.LayerNorm(out_channels),
            nn.Linear(out_channels, out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        conv_out = self.conv_branch(x)
        ssm_out = self.ssm_branch(x.view(x.size(0), -1))
        ssm_out = ssm_out.view(x.size(0), out_channels, 4, 4)  # Reshape to match conv_out
        ssm_out = self.merge(ssm_out)
        return conv_out + ssm_out  # Residual connection



class MedMamba(nn.Module):
    def __init__(self, in_channels=1, num_classes=4):
        super(MedMamba, self).__init__()
        self.patch_embed = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=4),
            nn.ReLU(inplace=True)
        )
        self.conv_ssm_blocks = nn.Sequential(
            ConvSSMBlock(64, 128),
            ConvSSMBlock(128, 256),
            ConvSSMBlock(256, 512),
            ConvSSMBlock(512, 1024)
        )
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(1024, num_classes)

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.conv_ssm_blocks(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

model = MedMamba(in_channels=1, num_classes=4)
print(model)


MedMamba(
  (patch_embed): Sequential(
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(4, 4))
    (1): ReLU(inplace=True)
  )
  (conv_ssm_blocks): Sequential(
    (0): ConvSSMBlock(
      (conv_branch): Sequential(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (ssm_branch): Sequential(
        (0): LayerNorm((128, 4, 4), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=2048, out_features=128, bias=True)
        (2): ReLU(inplace=True)
      )
      (merge): Sequential(
        (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (1): Linear(in_features=128, out_features=128, bias=True)
        (2): ReLU(inplace=True)
      )
    )
    (1): ConvSSMBlock(
      (conv_branch): Sequential(
        (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): ReLU(inplace=True)
      )
      (ssm_branch): Sequential(
        (0): LayerNorm((256, 4, 4), 

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score

# Convert numpy arrays to PyTorch tensors
train_images_tensor = torch.tensor(images_array_train).unsqueeze(1).float()  # Add channel dimension
train_ground_truths_tensor = torch.tensor(ground_truths_array_train).unsqueeze(1).float()  # Add channel dimension

test_images_tensor = torch.tensor(images_array_test).unsqueeze(1).float()  # Add channel dimension
test_ground_truths_tensor = torch.tensor(ground_truths_array_test).unsqueeze(1).float()  # Add channel dimension

# Create PyTorch datasets and dataloaders
train_dataset = TensorDataset(train_images_tensor, train_ground_truths_tensor)
test_dataset = TensorDataset(test_images_tensor, test_ground_truths_tensor)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    for images, ground_truths in train_dataloader:
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, ground_truths)
        loss.backward()
        optimizer.step()

    # Evaluate on test set
    model.eval()
    all_predictions = []
    all_ground_truths = []
    with torch.no_grad():
        for images, ground_truths in test_dataloader:
            outputs = model(images)
            all_predictions.append(outputs.numpy())
            all_ground_truths.append(ground_truths.numpy())

    all_predictions = np.concatenate(all_predictions)
    all_ground_truths = np.concatenate(all_ground_truths)
    accuracy = accuracy_score(all_ground_truths.flatten(), all_predictions.flatten())
    print(f"Epoch {epoch+1}/{num_epochs}, Test Accuracy: {accuracy}")


RuntimeError: Given normalized_shape=[128, 4, 4], expected input with shape [*, 128, 4, 4], but got input of size[32, 200704]

Another method https://github.com/YubiaoYue/MedMamba/blob/main/train.py

In [None]:
import os
import sys
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from tqdm import tqdm
import nibabel as nib
import numpy as np
import re
from skimage.transform import resize

### Define the VSSM model ###
class VSSM(nn.Module):
    def __init__(self, num_classes=2):
        super(VSSM, self).__init__()
        # Define the architecture of the model
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 28 * 28, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 28 * 28)  # Flatten the input for the fully connected layers
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("Using {} device.".format(device))

    num_classes = 2


    # Convert data to PyTorch tensors
    train_dataset = torch.utils.data.TensorDataset(torch.tensor(images_array_train), torch.tensor(ground_truths_train))
    test_dataset = torch.utils.data.TensorDataset(torch.tensor(images_array_test), torch.tensor(ground_truths_test))

    batch_size = 32
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=nw)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=nw)

    print("Using {} images for training, {} images for testing.".format(len(train_dataset), len(test_dataset)))

    net = VSSM(num_classes=num_classes)
    net.to(device)
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)

    epochs = 100
    best_acc = 0.0
    save_path = './VSSMNet.pth'

    train_steps = len(train_loader)
    for epoch in range(epochs):
        # train
        net.train()
        running_loss = 0.0
        train_bar = tqdm(train_loader, file=sys.stdout)
        for step, (images, labels) in enumerate(train_bar):
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = net(images.unsqueeze(1).float())  # Add a channel dimension to images
            loss = loss_function(outputs, labels.long())
            loss.backward()
            optimizer.step()

            # print statistics
            running_loss += loss.item()

            train_bar.desc = "Train Epoch [{}/{}] Loss: {:.3f}".format(epoch + 1, epochs, loss.item())

        # validate
        net.eval()
        acc = 0.0  # accumulate accurate number / epoch
        with torch.no_grad():
            test_bar = tqdm(test_loader, file=sys.stdout)
            for images, labels in test_bar:
                images, labels = images.to(device), labels.to(device)
                outputs = net(images.unsqueeze(1).float())
                predict_y = torch.argmax(outputs, dim=1)
                acc += torch.eq(predict_y, labels.long()).sum().item()

        val_accurate = acc / len(test_dataset)
        print('[Epoch %d] Training Loss: %.3f  Test Accuracy: %.3f' %
              (epoch + 1, running_loss / train_steps, val_accurate))

        if val_accurate > best_acc:
            best_acc = val_accurate
            torch.save(net.state_dict(), save_path)

    print('Finished Training')

if __name__ == '__main__':
    main()


Using cpu device.


  train_dataset = torch.utils.data.TensorDataset(torch.tensor(images_array_train), torch.tensor(ground_truths_train))


Using 2 dataloader workers every process
Using 200 images for training, 100 images for testing.
  0%|          | 0/7 [00:04<?, ?it/s]


RuntimeError: 0D or 1D target tensor expected, multi-target not supported

MambaUNET

In [None]:
import tensorflow as tf

class MambaBlock(tf.keras.layers.Layer):
  def __init__(self, filters, kernel_size=(3, 3, 3), activation='relu', **kwargs):
    super(MambaBlock, self).__init__(**kwargs)
    self.filters = filters
    self.kernel_size = kernel_size
    self.activation = tf.keras.activations.get(activation)  # Map the activation string to the corresponding function


    # Define layers for Mamba block
    self.conv1 = tf.keras.layers.Conv3D(filters, kernel_size, padding='same')
    self.conv2 = tf.keras.layers.Conv3D(filters, kernel_size, padding='same')
    self.bnorm = tf.keras.layers.BatchNormalization(axis=-1)

  def call(self, inputs):
    x = self.conv1(inputs)
    x = self.activation(x)
    x = self.conv2(x)
    x = self.bnorm(x)
    return x

class MambaUNet(tf.keras.Model):
    def __init__(self, input_shape, n_classes, filters=32):
        super(MambaUNet, self).__init__()
        self.inputs = tf.keras.layers.Input(shape=input_shape)

        # Encoder - Contracting Path (with Mamba Blocks)
        conv1 = MambaBlock(filters, activation='relu')(self.inputs)
        pool1 = tf.keras.layers.MaxPooling3D((2, 2, 2))(conv1)
        conv2 = MambaBlock(filters * 2, activation='relu')(pool1)
        pool2 = tf.keras.layers.MaxPooling3D((2, 2, 2))(conv2)
        conv3 = MambaBlock(filters * 4, activation='relu')(pool2)
        pool3 = tf.keras.layers.MaxPooling3D((2, 2, 2))(conv3)
        conv4 = MambaBlock(filters * 8, activation='relu')(pool3)

        # Decoder - Expanding Path (with Mamba Blocks and Skip Connections)
        upconv5 = tf.keras.layers.Conv3DTranspose(filters * 4, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv4)
        crop_concat5 = tf.keras.layers.Cropping3D(cropping=((0, 0), (0, 0), (0, 0)))(conv3)
        if conv3.shape[3] < conv4.shape[3]:
            conv3 = tf.pad(conv3, [[0, 0], [0, 0], [0, 0], [0, conv4.shape[3] - conv3.shape[3]], [0, 0]])
        elif conv4.shape[3] < conv3.shape[3]:
            conv4 = tf.pad(conv4, [[0, 0], [0, 0], [0, 0], [0, conv3.shape[3] - conv4.shape[3]], [0, 0]])

        concat5 = tf.keras.layers.Concatenate(axis=-1)([upconv5, crop_concat5])
        # concat5 = tf.keras.layers.Concatenate(axis=-1)([upconv5, crop_concat5])
        conv5 = MambaBlock(filters * 4, activation='relu')(concat5)

        upconv6 = tf.keras.layers.Conv3DTranspose(filters * 2, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv5)
        crop_concat6 = tf.keras.layers.Cropping3D(cropping=((0, 0), (0, 0), (0, 0)))(conv2)
        concat6 = tf.keras.layers.Concatenate()([upconv6, crop_concat6])
        conv6 = MambaBlock(filters * 2, activation='relu')(concat6)

        upconv7 = tf.keras.layers.Conv3DTranspose(filters, (2, 2, 2), strides=(2, 2, 2), padding='same')(conv6)
        crop_concat7 = tf.keras.layers.Cropping3D(cropping=((0, 0), (0, 0), (0, 0)))(conv1)
        concat7 = tf.keras.layers.Concatenate()([upconv7, crop_concat7])
        conv7 = MambaBlock(filters, activation='relu')(concat7)

        # Output layer
        self.outputs = tf.keras.layers.Conv3D(n_classes, 1, activation='softmax')(conv7)

    def call(self, inputs):
        return self.outputs(inputs)



In [None]:
import tensorflow as tf
from sklearn.model_selection import train_test_split

input_shape = (224, 224, 20, 1)  # Adjust dimensions and channels according to your data
n_classes = 4  # Assuming 4 classes for segmentation
model = MambaUNet(input_shape, n_classes)


# Compile the model
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(images_array_train, ground_truths_array_train, test_size=0.2, random_state=42)

# Reshape the data to match the model's input shape
X_train = X_train.reshape(X_train.shape + (1,))
X_val = X_val.reshape(X_val.shape + (1,))

# Train the model
history = model.fit(X_train, y_train, validation_data=(X_val, y_val), batch_size=32, epochs=10)

# Evaluate the model on test data
X_test = images_array_test.reshape(images_array_test.shape + (1,))
y_test = ground_truths_array_test
test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=2)
print(f"Test Loss: {test_loss}, Test Accuracy: {test_accuracy}")


ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 56, 56, 4, 128), (None, 56, 56, 5, 128)]