In [1]:
# Notebook for processing data into 3D cine + segmentations

# Imports
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import glob
import numpy as np
from tqdm import tqdm
import tensorflow_mri as tfmri
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import pydicom as dicom
from unet3plusnew import *
from custom_unet_code import *
from process_utils import *


2025-06-13 14:39:23.533909: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-06-13 14:39:23.678059: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-13 14:39:23.734772: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned e

In [None]:
#Makes directories for processed data

if not os.path.exists('./processed_data'):
    os.makedirs('./processed_data')
    
if not os.path.exists('./processed_data/3D_cine'):
    os.makedirs('./processed_data/3D_cine')
    
if not os.path.exists('./processed_data/seg_ML'):
    os.makedirs('./processed_data/seg_ML')

In [3]:
#Applies debanding model to any number of slices
def apply_debanding_model(input_im,frames =32):

    debanding_model = "./models_final/Deband_model"
    debanding = tf.keras.models.load_model(debanding_model, compile=False)
    weights = debanding.get_weights()

    inputs = tf.keras.Input(shape = [None,None,None,1])
    unet = tfmri.models.UNet3D([32,64,128], kernel_size=3, out_channels=1,use_global_residual=False) 
    DB = unet(inputs)
    de_banding_model = tf.keras.Model(inputs = inputs, outputs = DB)
    de_banding_model.set_weights(weights)

    de_banded = []
    for i in range(frames):
        temp = de_banding_model.predict(tf.expand_dims(tf.expand_dims(input_im[i],0),-1),verbose = 0)
        de_banded.append(temp)

    return de_banded

#Function that applies deformations to 28 slice data
def deformation_28(x):

    sagittal_deformed = []

    for i in range(28):
        
        input_img = tf.expand_dims(x[0][0,i,:,:], -1) 
        dy = tf.expand_dims(tf.expand_dims(x[1][0,i,:,:], -1),0)
        dx = tf.expand_dims(tf.expand_dims(x[2][0,i,:,:], -1),0)
        
        displacement = tf.concat((dy[0,...],dx[0,...]), axis=-1)

        img = tf.image.convert_image_dtype(tf.expand_dims(input_img, 0), tf.dtypes.float32)
        displacement = tf.image.convert_image_dtype(displacement, tf.dtypes.float32)
        dense_img_warp = tfa.image.dense_image_warp(img, displacement)
        im_deformed = tf.squeeze(dense_img_warp, 0)
        sagittal_deformed.append(im_deformed)

    sagittal_deformed = tf.image.convert_image_dtype(sagittal_deformed, tf.dtypes.float32)
    sagittal_deformed = tf.expand_dims(sagittal_deformed,axis= 0)

    return sagittal_deformed

#Applies respiratory correction model
def apply_resp_model_28(input_im,frames = 32):

    inputs = tf.keras.Input(shape = [None,256,128,1])
    unet = build_3d_unet_resp([None,256,128,1],2) # Acts as aa deformation field generator
    deformation_fields = unet(inputs) # Outputs the deformation fields
    lambda_deformation = tf.keras.layers.Lambda(deformation_28)
    out_2 = lambda_deformation([inputs[:,:,:,:,0],deformation_fields[:,:,:,:,0],deformation_fields[:,:,:,:,1]]) # Outputs the deformed volume
    outputs  = [deformation_fields,out_2] 
    complete_model = tf.keras.Model(inputs = inputs, outputs = outputs)
    complete_model.load_weights('./models_final/Resp_Correction_model/variables/variables')

    resp_corrected = []
    deformations = []
    for i in range(frames):

        def_fields, resp_cor = complete_model.predict(input_im[i][:,:,:,:,:],verbose=0)
        resp_corrected.append(resp_cor)
        deformations.append(def_fields)

    return deformations, resp_corrected

#Applies super resolution model
def apply_SR_model(input_im,frames = 32):
    E2E_model = "./models_final/E2E_SR_model"
    E2E = tf.keras.models.load_model(E2E_model, compile=False)
    weights = E2E.get_weights()
    sr_weights = weights[22:]

    inputs = tf.keras.Input(shape = [None,None,None,1])
    SR_model = build_3d_unet(input_shape=(None, None,None,1), num_classes=1)
    SR = SR_model(inputs)
    SR_model_done = tf.keras.Model(inputs = inputs, outputs = SR)
    SR_model_done.set_weights(sr_weights)

    super_resed = []
    for i in range(frames):
        super_resed.append(SR_model_done.predict(input_im[i],verbose=0))

    return super_resed

#Applies segmentation model
def apply_seg_model(input_im,frames=32):

    seg_model = "./models_final/Segmentation_model"
    seg= tf.keras.models.load_model(seg_model, compile=False)
    weights = seg.get_weights()

    number_of_seg = 7

    inputs = tf.keras.Input(shape = [None,None,None,1])

    unet3 = unet3plus(inputs,
                 filters=[32,64,128],
                 rank = 3,  # dimension
                 out_channels = number_of_seg,
                 add_dropout = 0, # 1 or 0 to add dropout
                 dropout_rate = 0.3,
                 #base_filters = 32,
                 kernel_size = 3,
                 encoder_block_depth= 2,
                 decoder_block_depth = 1,
                 pool_size = 2, # This can be either a tuple or int for same pooling across dims
                 skip_type = 'encoder',
                 batch_norm = 1,
                 skip_batch_norm = 1,
                 activation = tf.keras.layers.LeakyReLU(alpha =0.01),#tf.keras.layers.LeakyReLU(alpha =0.01),#'relu',
                 out_activation = 'softmax',
                 CGM = 0,
                 deep_supervision = 0) # 1 or 0 to add deep_supervision

    seg_model_done = tf.keras.Model(inputs = inputs, outputs = unet3.outputs())

    seg_model_done.set_weights(weights)

    seg = []
    for i in range(frames):
        seg.append(get_one_hot(seg_model_done.predict(input_im[i],verbose=0),number_of_seg))

    return seg

#Reads in example RT sagittal stack
def load_data_samples(number_of_scans = 32):
    sag_volumes = []
    filename=f"./raw_data/RT_Stack/*"
    if not os.path.exists(f"./raw_data/RT_Stack/"):
        raise Exception("Error with file path.")
    else:
        for i in range(number_of_scans):
            
            clean_ims_1 = []
            locations_1 = []
            test = sorted(glob.glob(filename))
            for folder in test:
                for file in tqdm(glob.glob(folder+'/*'), disable=True):
                    ds = dicom.dcmread(file)

                    if ds.InstanceNumber == i+1:
                        locations_1.append(ds.SliceLocation)
                        clean_ims_1.append(ds.pixel_array)
            
            clean_ims_1 = [x for _,x in sorted(zip(locations_1,clean_ims_1))]
            sag_volumes.append(clean_ims_1)

    return np.array(sag_volumes)

In [None]:
time_steps = 32 #Time steps of cine data
sag_vols = load_data_samples(number_of_scans=time_steps) #Loads in data
sag_vols_cropped = []

print('Cropping...')
sag_vols = np.array(sag_vols)
for j in range(time_steps):
    sag_cropped = []
    for i in range(sag_vols.shape[1]):    
        sag_cropped.append(resize(sag_vols[j,i,:,:],256,128))
    sag_cropped = np.dstack(sag_cropped)
    sag_cropped = np.swapaxes(sag_cropped,0,1)
    sag_cropped = np.swapaxes(sag_cropped,0,2)
    sag_vols_cropped.append(sag_cropped)

sag_vols_cropped = norm(sag_vols_cropped)

print('De-banding...')
debanded = apply_debanding_model(sag_vols_cropped,frames=time_steps)
debanded = norm(debanded)

print('Resp-correction...')
def_fields , resp_cor = apply_resp_model_28(debanded,frames=time_steps)
resp_cor = norm(resp_cor)

print('SR...')
super_resed_E2E = apply_SR_model(resp_cor,frames=time_steps)
super_resed_E2E = norm(super_resed_E2E)

for i in range(time_steps):
    np.save(f'./processed_data/3D_cine/3D_cine_{i}.npy',super_resed_E2E[i])

In [None]:

print('Segmenting...')

clahe = []
for i in range(time_steps):
    sr = np.load(f'./processed_data/3D_cine/3D_cine_{i}.npy')
    
    clahe.append(tf.expand_dims(tf.expand_dims(apply_clahe(sr[0,:,:,:,0]),-1),0))

E2E_seg = apply_seg_model(clahe)

for i in range(time_steps):
    seg_rm_add = keep_largest_component_add(E2E_seg[i][0,...])

    final_seg = np.zeros_like(E2E_seg[0][...,0])
    for j in range(6):
        final_seg = final_seg + seg_rm_add[...,j+1]*(j+1)
    np.save(f'./processed_data/seg_ML/seg_{i}.npy',final_seg)