In [None]:
import os
import nibabel as nib
import numpy as np
import cv2
import matplotlib.pyplot as plt
import random
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Conv2DTranspose, concatenate, Dropout

In [None]:
def load_dataset_by_patient(img_dir, mask_dir, patient_ids):
    images, masks = [], []

    for pid in patient_ids:
        for fname in sorted(os.listdir(img_dir)):
            if fname.startswith(pid):
                img = cv2.imread(os.path.join(img_dir, fname), cv2.IMREAD_GRAYSCALE)
                mask = cv2.imread(os.path.join(mask_dir, fname), cv2.IMREAD_GRAYSCALE)

                img = img / 255.0
                mask = mask / 255.0

                images.append(img)
                masks.append(mask)

    images = np.array(images)[..., np.newaxis]
    masks = np.array(masks)[..., np.newaxis]

    return images, masks

def process_patient(ct_path, mask_path, patient_id,
                    out_img_dir="images",
                    out_mask_dir="masks"):

    os.makedirs(out_img_dir, exist_ok=True)
    os.makedirs(out_mask_dir, exist_ok=True)

    ct = nib.load(ct_path).get_fdata()
    mask = nib.load(mask_path).get_fdata()

    for i in range(ct.shape[2]):
        ct_slice = ct[:, :, i]
        mask_slice = mask[:, :, i]

        if mask_slice.max() == 0:
            continue

        ct_slice = cv2.resize(ct_slice, (256, 256))
        mask_slice = cv2.resize(mask_slice, (256, 256))

        mask_slice = (mask_slice > 0).astype(np.uint8) * 255

        cv2.imwrite(
            os.path.join(out_img_dir, f"{patient_id}_slice_{i:03d}.png"),
            ct_slice
        )
        cv2.imwrite(
            os.path.join(out_mask_dir, f"{patient_id}_slice_{i:03d}.png"),
            mask_slice
        )


Define the patient dataset (its path and IDs.)

In [None]:
patients = {
    "p243": (
        r"C:\Users\u163680\Documents\SlicerDICOMDatabase\Burdeus_243\CT_243.nii.gz",
        r"C:\Users\u163680\Documents\SlicerDICOMDatabase\Burdeus_243\Mask_LA_243.nii.gz"
    ),
    "p157": (
        r"C:\Users\u163680\Documents\SlicerDICOMDatabase\Burdeus_157\CT_157.nii.gz",
        r"C:\Users\u163680\Documents\SlicerDICOMDatabase\Burdeus_157\Mask_LA_157.nii.gz"
    ),
    "p166": (
        r"C:\Users\u163680\Documents\SlicerDICOMDatabase\Burdeus_166\CT_166.nii.gz",
        r"C:\Users\u163680\Documents\SlicerDICOMDatabase\Burdeus_166\Mask_LA_166.nii.gz"
    ),
    "p274": (
        r"C:\Users\u163680\Documents\SlicerDICOMDatabase\Burdeus_274\CT_274.nii.gz",
        r"C:\Users\u163680\Documents\SlicerDICOMDatabase\Burdeus_274\Mask_LA_274.nii.gz"
    ),
}


Build the datset.

In [None]:
for pid, (ct_path, mask_path) in patients.items():
    process_patient(ct_path, mask_path, pid)


Define the training and test patients

In [None]:
train_patients = ["p243", "p157", "p166"]
val_patients   = ["p274"]

X_train, y_train = load_dataset_by_patient("images", "masks", train_patients)
X_val, y_val     = load_dataset_by_patient("images", "masks", val_patients)

print(X_train.shape, X_val.shape)


In [None]:
def simple_unet_model(input_size=(256, 256, 1)):
    inputs = Input(input_size)

    # --- ENCODER (Camí de baixada) ---
    # Bloc 1
    c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(inputs)
    c1 = Conv2D(16, (3, 3), activation='relu', padding='same')(c1)
    p1 = MaxPooling2D((2, 2))(c1)

    # Bloc 2
    c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(p1)
    c2 = Conv2D(32, (3, 3), activation='relu', padding='same')(c2)
    p2 = MaxPooling2D((2, 2))(c2)

    # Bloc 3
    c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(p2)
    c3 = Conv2D(64, (3, 3), activation='relu', padding='same')(c3)
    p3 = MaxPooling2D((2, 2))(c3)

    # --- BOTTLENECK (La part més profunda) ---
    c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(p3)
    c4 = Conv2D(128, (3, 3), activation='relu', padding='same')(c4)
    c4 = Dropout(0.2)(c4) # Per evitar overfitting si tens poques dades

    # --- DECODER (Camí de pujada amb Skip Connections) ---
    # Pujada 1 (connecta amb Bloc 3)
    u5 = Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same')(c4)
    u5 = concatenate([u5, c3])
    c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(u5)
    c5 = Conv2D(64, (3, 3), activation='relu', padding='same')(c5)

    # Pujada 2 (connecta amb Bloc 2)
    u6 = Conv2DTranspose(32, (2, 2), strides=(2, 2), padding='same')(c5)
    u6 = concatenate([u6, c2])
    c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(u6)
    c6 = Conv2D(32, (3, 3), activation='relu', padding='same')(c6)

    # Pujada 3 (connecta amb Bloc 1)
    u7 = Conv2DTranspose(16, (2, 2), strides=(2, 2), padding='same')(c6)
    u7 = concatenate([u7, c1])
    c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(u7)
    c7 = Conv2D(16, (3, 3), activation='relu', padding='same')(c7)

    # --- OUTPUT ---
    # Sigmoid perquè volem probabilitat 0 o 1 per cada píxel (blanc o negre)
    outputs = Conv2D(1, (1, 1), activation='sigmoid')(c7)

    model = Model(inputs=[inputs], outputs=[outputs])
    return model

# Definim el dice
def dice_coef(y_true, y_pred, smooth=1):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersection = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (
        tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth
    )

# Crear el model
model = simple_unet_model()
model.summary()

# Compilar el model
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[' ',dice_coef])

# Entrenem el model i guardem la seva performance per despres plotejar-la.
history = model.fit(X_train, y_train, validation_data=(X_val, y_val), epochs=10, batch_size=8)

In [None]:
def plot_history(history):
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    dice = history.history['dice_coef'] # Assegura't que coincideix amb el nom de la mètrica
    val_dice = history.history['val_dice_coef']
    epochs = range(len(loss))

    plt.figure(figsize=(12, 5))

    # Gràfica de Loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, loss, 'r', label='Training loss')
    plt.plot(epochs, val_loss, 'b', label='Validation loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # Gràfica de Dice Coefficient
    plt.subplot(1, 2, 2)
    plt.plot(epochs, dice, 'r', label='Training Dice')
    plt.plot(epochs, val_dice, 'b', label='Validation Dice')
    plt.title('Training and Validation Dice Score')
    plt.legend()

    plt.show()

plot_history(history)

In [None]:
preds = model.predict(X_val)

# agafem una imatge aleatòria de la validació
i = random.randint(0, len(X_val)-1)

plt.figure(figsize=(12,4))

# Imatge original
plt.subplot(1,3,1)
plt.title("CT")
plt.imshow(X_val[i].squeeze(), cmap="gray")
plt.axis("off")

# Màscara real (ground truth)
plt.subplot(1,3,2)
plt.title("Mask real")
plt.imshow(y_val[i].squeeze(), cmap="gray")
plt.axis("off")

# Predicció del model
plt.subplot(1,3,3)
plt.title("Predicció del model")
plt.imshow(preds[i].squeeze() > 0.5, cmap="gray")
plt.axis("off")

plt.show()


In [None]:
plt.figure(figsize=(6,6))
plt.imshow(X_val[i].squeeze(), cmap="gray")
plt.imshow((preds[i].squeeze() > 0.5), cmap="jet", alpha=0.4)
plt.title("Predicció superposada sobre TAC")
plt.axis("off")
plt.show()

In [None]:
for i in range(5):
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.imshow(X_val[i].squeeze(), cmap="gray")
    plt.title("CT")
    plt.axis("off")

    plt.subplot(1,3,2)
    plt.imshow(y_val[i].squeeze(), cmap="gray")
    plt.title("Mask real")
    plt.axis("off")

    plt.subplot(1,3,3)
    plt.imshow(preds[i].squeeze() > 0.5, cmap="gray")
    plt.title("Predicció")
    plt.axis("off")

    plt.show()
