In [1]:
## Our model -- UMamba w/ pyrmid. pooling -- done

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


from mamba_model import UMambaBot
import os

from dice_loss import DiceLoss
import torch.optim as optim
from brain_mri_dataset import BrainMRIDatasetBuilder,BrainMRIDataset

from transforms import BrainMRITransforms

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# batch
batch_size = 64

learning_rate = 0.0003

In [5]:
data_dir = "../datasets/lgg-mri-segmentation/kaggle_3m"

builder = BrainMRIDatasetBuilder(data_dir)
df = builder.create_df()
train_df, val_df, test_df = builder.split_df(df)

transform_ = BrainMRITransforms()

train_data = BrainMRIDataset(train_df, transform = transform_ ,  mask_transform= transform_)
val_data = BrainMRIDataset(val_df, transform = transform_ ,  mask_transform= transform_)
test_data = BrainMRIDataset(test_df, transform = transform_ ,  mask_transform= transform_)

train_dataloader = DataLoader(train_data, batch_size = batch_size , shuffle = True)
val_dataloader = DataLoader(val_data, batch_size = batch_size , shuffle = False)
test_dataloader = DataLoader(test_data, batch_size = batch_size , shuffle = False)


In [6]:
model = nn.DataParallel(UMambaBot(
    input_channels=3,  # Assuming RGB images with 3 channels
    n_stages=5,
    features_per_stage=(32, 64, 128, 256,512),
    conv_op=nn.Conv2d,  # Assuming 2D convolution
    kernel_sizes=(3, 3, 3, 3, 3),  # Adjusted kernel sizes for 2D convolution
    strides=(1, 2, 2, 2, 2),
    num_classes=1,
    n_conv_per_stage=(1, 1, 1, 1, 1),
    n_conv_per_stage_decoder=(1, 1, 1, 1),
    conv_bias=True,
    norm_op=nn.InstanceNorm2d,  # Assuming 2D instance normalization
    norm_op_kwargs={},
    dropout_op=None,
    nonlin=nn.LeakyReLU,
    nonlin_kwargs={'inplace': True},
    # Pyramidal Pooling
    ppm_pool_sizes=(1,2,3,6)
)).to(device)

In [7]:
# Check output size
input_tensor = torch.randn(1, 3, 256, 256).to(device)  # Example input tensor with size 256x256
output = model(input_tensor)
print(output.shape)  # Check the shape of the output

torch.Size([1, 1, 256, 256])


In [8]:
criterion = DiceLoss()
optimizer = optim.Adam(
    model.parameters(),
    lr=learning_rate)

In [9]:
checkpoint_dir = "../checkpoints/pp_umamba_checkpoints/"

In [10]:
epochs = 100

train_loss = []
val_loss = []
trainIOU = []
valIOU = []

for epoch in range(epochs):
    total_train_loss = 0
    total_val_loss = 0

    # Training mode
    model.train()
    total_train_iou = 0

    for imgs, labels in train_dataloader:
        imgs, labels = imgs.to(device).float(), labels.to(device).float()
        optimizer.zero_grad()

        pred = model(imgs)

        loss = criterion(pred, labels)
        total_train_loss += loss.item()

        loss.backward()
        optimizer.step()

    train_loss.append(total_train_loss / len(train_dataloader))

    # Validation mode 
    model.eval()
    total_val_iou = 0
    with torch.no_grad():
        for imgs, labels in val_dataloader:
            imgs, labels = imgs.to(device).float(), labels.to(device).float()
            
            pred = model(imgs)

            loss = criterion(pred, labels)
            total_val_loss += loss.item()

    total_val_loss = total_val_loss / len(val_dataloader)
    val_loss.append(total_val_loss)
        
    # Print
    print('Epoch: {}/{}, Train Loss: {:.4f}, Val Loss: {:.4f}'.format(epoch + 1, epochs, train_loss[-1], total_val_loss))

    # Save checkpoint
    checkpoint_path = os.path.join(checkpoint_dir, f'pp_umamba_checkpoint_epoch_{epoch+1}.pt')
    torch.save(model.state_dict(), checkpoint_path)

# Assuming your model is named 'model' and you want to save its state_dict
model_state_dict = model.state_dict()

# Specify the file path where you want to save the weights
file_path = 'pp_umamba_weights.pth'

# Save the model state_dict to the specified file
torch.save(model_state_dict, file_path)



Epoch: 1/100, Train Loss: 0.9312, Val Loss: 0.9115
Epoch: 2/100, Train Loss: 0.8871, Val Loss: 0.8663
Epoch: 3/100, Train Loss: 0.8361, Val Loss: 0.8086
Epoch: 4/100, Train Loss: 0.7618, Val Loss: 0.7387
Epoch: 5/100, Train Loss: 0.6778, Val Loss: 0.6273
Epoch: 6/100, Train Loss: 0.5927, Val Loss: 0.5514
Epoch: 7/100, Train Loss: 0.4802, Val Loss: 0.4345
Epoch: 8/100, Train Loss: 0.3984, Val Loss: 0.3592
Epoch: 9/100, Train Loss: 0.3328, Val Loss: 0.3082
Epoch: 10/100, Train Loss: 0.2916, Val Loss: 0.3022
Epoch: 11/100, Train Loss: 0.2935, Val Loss: 0.2425
Epoch: 12/100, Train Loss: 0.2267, Val Loss: 0.2105
Epoch: 13/100, Train Loss: 0.1981, Val Loss: 0.1887
Epoch: 14/100, Train Loss: 0.1829, Val Loss: 0.1749
Epoch: 15/100, Train Loss: 0.1689, Val Loss: 0.1691
Epoch: 16/100, Train Loss: 0.1553, Val Loss: 0.1472
Epoch: 17/100, Train Loss: 0.1552, Val Loss: 0.1574
Epoch: 18/100, Train Loss: 0.1458, Val Loss: 0.1415
Epoch: 19/100, Train Loss: 0.1344, Val Loss: 0.1325
Epoch: 20/100, Train 

In [11]:
import gc
gc.collect()
del model
del train_data
del train_dataloader
del val_data
del val_dataloader
del test_data
del test_dataloader


torch.cuda.empty_cache()

In [12]:
import torch
with torch.no_grad():
    torch.cuda.empty_cache()

: 