In [11]:
# ============================================================
# BLOCK 1 — IMPORTS & GLOBAL CONFIG
# ============================================================

import os
import numpy as np
import torch
import torch.nn as nn
import nibabel as nib
import cv2
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torchvision import transforms

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

Using device: cuda


In [12]:
# ============================================================
# BLOCK 2 — UTILITY FUNCTIONS
# ============================================================

def load_nifti(path):
    return nib.load(path).get_fdata()

def normalize(img):
    img = img.astype(np.float32)
    return (img - img.mean()) / (img.std() + 1e-6)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [13]:
# ============================================================
# BLOCK 3 — SEGMENTATION MODEL
# ============================================================

class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU())
        self.decoder = nn.Sequential(
            nn.Conv2d(32, 16, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, out_channels, 1))

    def forward(self, x):
        x = self.encoder(x)
        return self.decoder(x)

In [17]:
# ============================================================
# VERIFY SEGMENTATION MODEL ARCHITECTURE
# ============================================================

print(seg_model)
print("\nState dict keys:")
print(list(seg_model.state_dict().keys()))


UNet(
  (encoder): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
  )
  (decoder): Sequential(
    (0): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)

State dict keys:
['encoder.0.weight', 'encoder.0.bias', 'encoder.2.weight', 'encoder.2.bias', 'decoder.0.weight', 'decoder.0.bias', 'decoder.2.weight', 'decoder.2.bias']


In [16]:
# ============================================================
# BLOCK 4 — LOAD SEGMENTATION MODEL (LOCKED)
# ============================================================

import torch
import os

SEG_MODEL_PATH = "checkpoints/segmentation_model.pt"

assert os.path.exists(SEG_MODEL_PATH), \
    f"Segmentation model missing at {SEG_MODEL_PATH}"

seg_model = UNet().to(DEVICE)
seg_model.load_state_dict(
    torch.load(
        SEG_MODEL_PATH,
        map_location=DEVICE,
        weights_only=True
    )
)
seg_model.eval()

print("Segmentation model loaded ✅")


RuntimeError: Error(s) in loading state_dict for UNet:
	Missing key(s) in state_dict: "encoder.0.weight", "encoder.0.bias", "encoder.2.weight", "encoder.2.bias", "decoder.0.weight", "decoder.0.bias", "decoder.2.weight", "decoder.2.bias". 
	Unexpected key(s) in state_dict: "conv.weight", "conv.bias". 

In [9]:
import os
os.listdir("checkpoints")


['segmentation_model.pt']