In [2]:
# --- Colab: 3D heart reconstruction (CT + MRI) with compact 3D U-Net ---
# Run in Google Colab (Runtime -> Change runtime type -> GPU)
# 1) Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# 2) Install required packages
# (Colab already has many packages; install or upgrade if needed)
!pip install nibabel trimesh scikit-image tqdm

# 3) Imports
import os
import numpy as np
import nibabel as nib
from scipy.ndimage import zoom
from skimage import measure
import trimesh
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import BinaryCrossentropy
from tqdm import tqdm

# 4) Paths (update if different)
ct_img_path = "/content/drive/MyDrive/Major_Project/MM-WHS 2017 Dataset/ct_train/ct_train_1001_image.nii.gz"
ct_lbl_path = "/content/drive/MyDrive/Major_Project/MM-WHS 2017 Dataset/ct_train/ct_train_1001_label.nii.gz"
mr_img_path = "/content/drive/MyDrive/Major_Project/MM-WHS 2017 Dataset/mr_train/mr_train_1001_image.nii.gz"
mr_lbl_path = "/content/drive/MyDrive/Major_Project/MM-WHS 2017 Dataset/mr_train/mr_train_1001_label.nii.gz"
output_dir = "/content/drive/MyDrive/Major_Project/heart_model_output"
os.makedirs(output_dir, exist_ok=True)

# 5) Helper functions
def load_nifti(path):
    img = nib.load(path)
    return img.get_fdata().astype(np.float32)

def resample_to_shape(volume, target_shape=(128,128,128), is_label=False):
    factors = [t / s for s, t in zip(volume.shape, target_shape)]
    order = 0 if is_label else 1
    return zoom(volume, factors, order=order)

def normalize_volume(vol):
    p1, p99 = np.percentile(vol, (1, 99))
    vol = np.clip(vol, p1, p99)
    if vol.max() - vol.min() < 1e-8:
        return np.zeros_like(vol)
    return (vol - vol.min()) / (vol.max() - vol.min())

# 6) Load & preprocess (adjust target_shape if you need)
target_shape = (128,128,128)
print("Loading images...")
ct_img = load_nifti(ct_img_path)
ct_lbl = load_nifti(ct_lbl_path)
mr_img = load_nifti(mr_img_path)
mr_lbl = load_nifti(mr_lbl_path)

print("Resampling to", target_shape)
ct_img_r = resample_to_shape(ct_img, target_shape, is_label=False)
ct_lbl_r = resample_to_shape(ct_lbl, target_shape, is_label=True)
mr_img_r = resample_to_shape(mr_img, target_shape, is_label=False)
mr_lbl_r = resample_to_shape(mr_lbl, target_shape, is_label=True)

ct_img_n = normalize_volume(ct_img_r)[..., np.newaxis]
mr_img_n = normalize_volume(mr_img_r)[..., np.newaxis]
ct_lbl_b = (ct_lbl_r > 0).astype(np.float32)[..., np.newaxis]
mr_lbl_b = (mr_lbl_r > 0).astype(np.float32)[..., np.newaxis]

# Create a tiny dataset (demo). For real training, collect many samples.
X = np.stack([ct_img_n, mr_img_n], axis=0)  # shape (N,128,128,128,1)
y = np.stack([ct_lbl_b, mr_lbl_b], axis=0)
print("Dataset shapes:", X.shape, y.shape)

# 7) Build compact 3D U-Net
def conv_block(x, filters):
    x = layers.Conv3D(filters, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Conv3D(filters, 3, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)
    return x

def build_unet(input_shape=(128,128,128,1), base_filters=8):
    inputs = layers.Input(shape=input_shape)
    c1 = conv_block(inputs, base_filters)
    p1 = layers.MaxPool3D(2)(c1)
    c2 = conv_block(p1, base_filters*2)
    p2 = layers.MaxPool3D(2)(c2)
    c3 = conv_block(p2, base_filters*4)
    p3 = layers.MaxPool3D(2)(c3)
    c4 = conv_block(p3, base_filters*8)
    p4 = layers.MaxPool3D(2)(c4)
    bn = conv_block(p4, base_filters*16)
    u1 = layers.Conv3DTranspose(base_filters*8, 2, strides=2, padding='same')(bn)
    u1 = layers.concatenate([u1, c4])
    c5 = conv_block(u1, base_filters*8)
    u2 = layers.Conv3DTranspose(base_filters*4, 2, strides=2, padding='same')(c5)
    u2 = layers.concatenate([u2, c3])
    c6 = conv_block(u2, base_filters*4)
    u3 = layers.Conv3DTranspose(base_filters*2, 2, strides=2, padding='same')(c6)
    u3 = layers.concatenate([u3, c2])
    c7 = conv_block(u3, base_filters*2)
    u4 = layers.Conv3DTranspose(base_filters, 2, strides=2, padding='same')(c7)
    u4 = layers.concatenate([u4, c1])
    c8 = conv_block(u4, base_filters)
    outputs = layers.Conv3D(1, 1, activation='sigmoid')(c8)
    model = Model(inputs, outputs)
    return model

model = build_unet(input_shape=(128,128,128,1), base_filters=8)
model.summary()

# 8) Loss and compile
def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = tf.reshape(y_true, [-1])
    y_pred_f = tf.reshape(y_pred, [-1])
    intersect = tf.reduce_sum(y_true_f * y_pred_f)
    return (2. * intersect + smooth) / (tf.reduce_sum(y_true_f) + tf.reduce_sum(y_pred_f) + smooth)

def dice_loss(y_true, y_pred):
    return 1.0 - dice_coef(y_true, y_pred)

bce = BinaryCrossentropy()
def combined_loss(y_true, y_pred):
    return bce(y_true, y_pred) + dice_loss(y_true, y_pred)

model.compile(optimizer=Adam(1e-4), loss=combined_loss, metrics=[dice_coef])

# 9) Train (demo). For real training: increase epochs, batch size, and dataset size.
batch_size = 1
epochs = 10  # change to 100+ for real training
train_ds = tf.data.Dataset.from_tensor_slices((X, y)).shuffle(2).batch(batch_size)
history = model.fit(train_ds, epochs=epochs)

# 10) Predict on CT sample (index 0)
pred = model.predict(X[0:1])[0,...,0]
pred_bin = (pred > 0.5).astype(np.uint8)

# 11) Mesh extraction & export as OBJ
verts, faces, normals, values = measure.marching_cubes(pred_bin, level=0.5)
mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals)
out_obj = os.path.join(output_dir, "heart_prediction.obj")
mesh.export(out_obj)
print("Saved predicted OBJ to:", out_obj)

# 12) Export ground-truth OBJ as well (nice to compare)
verts_gt, faces_gt, normals_gt, vals_gt = measure.marching_cubes((ct_lbl_b[...,0]).astype(np.uint8), level=0.5)
mesh_gt = trimesh.Trimesh(vertices=verts_gt, faces=faces_gt, vertex_normals=normals_gt)
out_obj_gt = os.path.join(output_dir, "heart_groundtruth_ct.obj")
mesh_gt.export(out_obj_gt)
print("Saved ground-truth OBJ to:", out_obj_gt)


Mounted at /content/drive
Collecting trimesh
  Downloading trimesh-4.8.3-py3-none-any.whl.metadata (18 kB)
Downloading trimesh-4.8.3-py3-none-any.whl (735 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m735.5/735.5 kB[0m [31m17.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trimesh
Successfully installed trimesh-4.8.3
Loading images...
Resampling to (128, 128, 128)
Dataset shapes: (2, 128, 128, 128, 1) (2, 128, 128, 128, 1)


Epoch 1/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 347ms/step - dice_coef: 0.0964 - loss: 1.6886
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 347ms/step - dice_coef: 0.1673 - loss: 1.6068
Epoch 3/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 346ms/step - dice_coef: 0.1810 - loss: 1.5757
Epoch 4/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 354ms/step - dice_coef: 0.1925 - loss: 1.5508
Epoch 5/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 344ms/step - dice_coef: 0.1213 - loss: 1.6350
Epoch 6/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 340ms/step - dice_coef: 0.1251 - loss: 1.6279
Epoch 7/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 338ms/step - dice_coef: 0.2155 - loss: 1.5025
Epoch 8/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 340ms/step - dice_coef: 0.2216 - loss: 1.4900
Epoch 9/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━