In [3]:
import numpy as np
import os
import random
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
from architecture.model import DLUNet
from utils.custom_loss import Weighted_BCEnDice_loss
from architecture.model import mean_iou, class_dice
from utils.custom_metric import dice_coef
from utils.load_data import BrainDataset
from utils.training import resume_training


# Define the directory containing the data
train_data = r"../data"
batch_size = 8
num_epochs = 5


# Initialize model, optimizer and loss
device = torch.device("mps")
print(f"Using device: {device}")
model = DLUNet(in_channels=4).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', patience=5, factor=0.1)
criterion = Weighted_BCEnDice_loss


train_dataset = BrainDataset(train_data)
val_dataset = BrainDataset(train_data, "val")

train_loader = DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)


best_val_loss = float('inf')

checkpoint_interval = 1  # Save a checkpoint every 5 epochs
os.makedirs('model', exist_ok=True)

Using device: mps


In [2]:

# for epoch in range(num_epochs):
#     print(f"Epoch {epoch+1}/{num_epochs}")

#     # Training phase
#     model.train()
#     train_loss = 0.0

#     for step, (images, masks) in enumerate(train_loader, start=1):
#         images = images.to(device)
#         masks = masks.to(device)

#         optimizer.zero_grad()

#         outputs = model(images)
#         loss = Weighted_BCEnDice_loss(outputs, masks)
#         dice = dice_coef(outputs, masks)

#         loss.backward()
#         optimizer.step()

#         train_loss += loss.item()

#         # Log progress for training
#         if step % 10 == 0 or step == len(train_loader):
#             print(
#                 f"\r{step}/{len(train_loader)} [==============================] - loss: {loss.item():.4f}", end="")
#             print(
#                 f"\r{step}/{len(train_loader)} [==============================] - metrics: {mean_iou(outputs, masks):.4f} c_2: {class_dice(outputs, masks, 2)} c_3: {class_dice(outputs, masks, 3)} c_4: {class_dice(outputs, masks, 4)}", end="")

#     train_loss /= len(train_loader)
#     print(f"\nTraining Loss: {train_loss:.4f}")

#     # Validation phase
#     model.eval()
#     val_loss = 0.0
#     val_dice = 0.0

#     with torch.no_grad():
#         for step, (images, masks) in enumerate(val_loader, start=1):
#             images = images.to(device)
#             masks = masks.to(device)

#             outputs = model(images)
#             loss = Weighted_BCEnDice_loss(outputs, masks)
#             dice = dice_coef(outputs, masks)

#             val_loss += loss.item()
#             val_dice += dice.item()

#             # Log progress for validation
#             if step % 10 == 0 or step == len(val_loader):
#                 print(
#                     f"\rValidation {step}/{len(val_loader)} [==============================]"
#                     f" - val_loss: {loss.item():.4f}"
#                     f" - val_dice: {dice.item():.4f}",
#                     end=""
#                 )

#     val_loss /= len(val_loader)
#     val_dice /= len(val_loader)
#     print(
#         f"\nValidation Loss: {val_loss:.4f}, Validation Dice: {val_dice:.4f}")

#     # Step the scheduler with validation loss
#     scheduler.step(val_loss)

#     # Save the best model
#     if val_loss < best_val_loss:
#         best_val_loss = val_loss
#         torch.save(model.state_dict(), 'model/dlu_net_model_best.pth')
#         print(f"Saved new best model with val_loss: {val_loss:.4f}")

#     # Save an intermediate checkpoint every 'checkpoint_interval' epochs
#     if (epoch + 1) % checkpoint_interval == 0:
#         checkpoint_path = f"model/dlu_net_model_epoch_{epoch+1}.pth"
#         torch.save(model.state_dict(), checkpoint_path)
#         print(
#             f"Intermediate checkpoint saved at epoch {epoch+1} to '{checkpoint_path}'")

#     print('-' * 60)

In [4]:
from utils.training import resume_training


model, best_val_loss = resume_training(
    model=model,
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-4),
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_function=Weighted_BCEnDice_loss,
    device=device,
    resume_checkpoint_path='../model/slimmed_model_best.pth',
    starting_epoch=0,
    num_epochs=10,
    checkpoint_interval=1,
    model_save_dir='../model/pruned_pretrained/slimmed',
    validate_every=10,
)

2025/04/28 06:14:52 INFO mlflow.tracking.fluent: Experiment with name 'brain_segmentation' does not exist. Creating a new experiment.


MLflow experiment 'brain_segmentation' initialized with run ID: d9c9e8b4dd034ea599cb67cabd326fde
Tracking URI: file:///Users/joe_codes/dev/school/projects/rewrite_brain_segmentation_pytourch/src/mlruns
Loading checkpoint from ../model/slimmed_model_best.pth
Resuming training from epoch 0
Epoch 1/10
Training Loss: 0.1097
Running validation on 164 batches...
Validation Loss: 0.1453, Validation Dice: 0.9862
Saved new best model with val_loss: 0.1453


Successfully registered model 'brain_segmentation_best'.
Created version '1' of model 'brain_segmentation_best'.


Best model logged to MLflow with metrics: {'best_train_loss': 0.10965258961189295, 'best_train_dice': 0.9898649122787664, 'best_train_iou': 0.9820288708663298, 'best_train_class_2_dice': 0.8896188203154838, 'best_train_class_3_dice': 0.7802447084180676, 'best_train_class_4_dice': 0.9222212411322683, 'best_learning_rate': 0.0001, 'best_val_loss': 0.14534988663182025, 'best_val_dice': 0.9862056352743288, 'best_val_iou': 0.9739553782998062, 'best_val_class_2_dice': 0.8112278960463477, 'best_val_class_3_dice': 0.7513610792414445, 'best_val_class_4_dice': 0.8661600420387779, 'best_best_val_loss': 0.14534988663182025}
Intermediate checkpoint saved at epoch 1 to '../model/pruned_pretrained/slimmed/dlu_net_model_epoch_1.pth'
------------------------------------------------------------
Epoch 2/10
Training Loss: 0.0905
Skipping validation for epoch 2 (will validate every 10 epochs)
Intermediate checkpoint saved at epoch 2 to '../model/pruned_pretrained/slimmed/dlu_net_model_epoch_2.pth'
--------

## here we are pretraining models for at least 10 epochs


In [None]:
# New cell to properly load the pruned model
from architecture.model import DLUNet, ReASPP3
import torch

# 1. Create a fresh model instance
model = DLUNet(in_channels=4).to(device)

# 2. Register ALL needed model classes with torch serialization
torch.serialization.add_safe_globals([DLUNet, ReASPP3])

# 3. Load the pruned model
pruned_path = 'model/pruned_dlu_net.pth'
print(f"Loading pruned model from {pruned_path}")

try:
    # Load with weights_only=False since we trust this file
    checkpoint = torch.load(pruned_path, weights_only=False)

    # If it's a full model (which seems to be the case)
    if not isinstance(checkpoint, dict):
        # Get state_dict from the loaded model
        model.load_state_dict(checkpoint.state_dict())
        print("Successfully loaded state dict from full model object")
    else:
        # If it happens to be a state_dict
        model.load_state_dict(checkpoint)
        print("Loaded state dictionary successfully")

except Exception as e:
    print(f"Error loading model: {e}")

In [None]:
# New cell for pretraining the pruned model
# Create a fresh optimizer since we're starting a new training phase
# Higher learning rate for fresh training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', patience=3, factor=0.1)

# Don't try to load the checkpoint again in resume_training
# Start from epoch 0 since we're pretraining
model, best_val_loss = resume_training(
    model=model,  # Use model we already loaded
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_function=Weighted_BCEnDice_loss,
    device=device,
    resume_checkpoint_path='',  # Important: leave empty to skip loading
    starting_epoch=0,  # Start from epoch 0 for pretraining
    num_epochs=10,  # Train for 10 epochs
    checkpoint_interval=1,
    # Save to a different directory
    model_save_dir='model/pruned_pretrained/depgraph',
)

### Pre Training DepGraph


In [None]:
from architecture.model import load_trained_model

depgraph_model = torch.load('model/pruned_dlu_net.pth', weights_only=False)
depgraph_model.to(device)  # Move model to GPU/MPS device

In [None]:
# train the model

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")

    # Training phase
    depgraph_model.train()
    train_loss = 0.0

    for step, (images, masks) in enumerate(train_loader, start=1):
        images = images.to(device)
        masks = masks.to(device)

        optimizer.zero_grad()

        outputs = depgraph_model(images)
        loss = Weighted_BCEnDice_loss(outputs, masks)
        dice = dice_coef(outputs, masks)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Log progress for training
        if step % 10 == 0 or step == len(train_loader):
            print(
                f"\r{step}/{len(train_loader)} [==============================] - loss: {loss.item():.4f}", end="")
            print(
                f"\r{step}/{len(train_loader)} [==============================] - metrics: {mean_iou(outputs, masks):.4f} c_2: {class_dice(outputs, masks, 2)} c_3: {class_dice(outputs, masks, 3)} c_4: {class_dice(outputs, masks, 4)}", end="")

    train_loss /= len(train_loader)
    print(f"\nTraining Loss: {train_loss:.4f}")

    # Validation phase
    model.eval()
    val_loss = 0.0
    val_dice = 0.0

    with torch.no_grad():
        for step, (images, masks) in enumerate(val_loader, start=1):
            images = images.to(device)
            masks = masks.to(device)

            outputs = model(images)
            loss = Weighted_BCEnDice_loss(outputs, masks)
            dice = dice_coef(outputs, masks)

            val_loss += loss.item()
            val_dice += dice.item()

            # Log progress for validation
            if step % 10 == 0 or step == len(val_loader):
                print(
                    f"\rValidation {step}/{len(val_loader)} [==============================]"
                    f" - val_loss: {loss.item():.4f}"
                    f" - val_dice: {dice.item():.4f}",
                    end=""
                )

    val_loss /= len(val_loader)
    val_dice /= len(val_loader)
    print(
        f"\nValidation Loss: {val_loss:.4f}, Validation Dice: {val_dice:.4f}")

    # Step the scheduler with validation loss
    scheduler.step(val_loss)

    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(depgraph_model.state_dict(),
                   'model/pruned_pretrained/depgraph_best.pth')
        print(f"Saved new best model with val_loss: {val_loss:.4f}")

    # Save an intermediate checkpoint every 'checkpoint_interval' epochs
    if (epoch + 1) % checkpoint_interval == 0:
        checkpoint_path = f"model/pruned_pretrained/depgraph{epoch+1}.pth"
        torch.save(depgraph_model.state_dict(), checkpoint_path)
        print(
            f"Intermediate checkpoint saved at epoch {epoch+1} to '{checkpoint_path}'")

    print('-' * 60)

### finetune Network slimming


In [None]:
# 3. Load the pruned model
pruned_path = 'model/slimmed_model.pth'
print(f"Loading pruned model from {pruned_path}")

try:
    # Load with weights_only=False since we trust this file
    checkpoint = torch.load(pruned_path, weights_only=False)

    # If it's a full model (which seems to be the case)
    if not isinstance(checkpoint, dict):
        # Get state_dict from the loaded model
        model.load_state_dict(checkpoint.state_dict())
        print("Successfully loaded state dict from full model object")
    else:
        # If it happens to be a state_dict
        model.load_state_dict(checkpoint)
        print("Loaded state dictionary successfully")

except Exception as e:
    print(f"Error loading model: {e}")


# New cell for pretraining the pruned model
# Create a fresh optimizer since we're starting a new training phase
# Higher learning rate for fresh training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', patience=3, factor=0.1)

# Don't try to load the checkpoint again in resume_training
# Start from epoch 0 since we're pretraining
model, best_val_loss = resume_training(
    model=model,  # Use model we already loaded
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_function=Weighted_BCEnDice_loss,
    device=device,
    resume_checkpoint_path='',  # Important: leave empty to skip loading
    starting_epoch=0,  # Start from epoch 0 for pretraining
    num_epochs=10,  # Train for 10 epochs
    checkpoint_interval=1,
    # Save to a different directory
    model_save_dir='model/pruned_pretrained/slimmed',
)

### SNIP Prunning


In [None]:
# 3. Load the pruned model
pruned_path = 'model/snip_pruned_model.pth'
print(f"Loading pruned model from {pruned_path}")

try:
    # Load with weights_only=False since we trust this file
    checkpoint = torch.load(pruned_path, weights_only=False)

    # If it's a full model (which seems to be the case)
    if not isinstance(checkpoint, dict):
        # Get state_dict from the loaded model
        model.load_state_dict(checkpoint.state_dict())
        print("Successfully loaded state dict from full model object")
    else:
        # If it happens to be a state_dict
        model.load_state_dict(checkpoint)
        print("Loaded state dictionary successfully")

except Exception as e:
    print(f"Error loading model: {e}")


# New cell for pretraining the pruned model
# Create a fresh optimizer since we're starting a new training phase
# Higher learning rate for fresh training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', patience=3, factor=0.1)

# Don't try to load the checkpoint again in resume_training
# Start from epoch 0 since we're pretraining
model, best_val_loss = resume_training(
    model=model,  # Use model we already loaded
    optimizer=optimizer,
    scheduler=scheduler,
    train_loader=train_loader,
    val_loader=val_loader,
    loss_function=Weighted_BCEnDice_loss,
    device=device,
    resume_checkpoint_path='',  # Important: leave empty to skip loading
    starting_epoch=0,  # Start from epoch 0 for pretraining
    num_epochs=10,  # Train for 10 epochs
    checkpoint_interval=1,
    # Save to a different directory
    model_save_dir='model/pruned_pretrained/snip_pruned',
)

In [None]:
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel()
                       for p in model.parameters() if p.requires_grad)
non_trainable_params = total_params - trainable_params
print(total_params, trainable_params, non_trainable_params)

8911301 8911301 0


In [1]:
from architecture.model import load_trained_model

x = load_trained_model(
    "../model/slimmed_model_best.pth",
)

In [2]:
total_params = sum(p.numel() for p in x.parameters())
trainable_params = sum(p.numel() for p in x.parameters() if p.requires_grad)
non_trainable_params = total_params - trainable_params
print(total_params, trainable_params, non_trainable_params)

8911301 8911301 0
