# Brain Identification 

In [10]:
import tensorflow as tf
from tensorflow.keras.models import Model
from neurite.tf import models  # Assuming the module's location
import voxelmorph.tf.losses as vtml
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from keras.callbacks import ReduceLROnPlateau

from tensorflow.keras.layers import Input
from tensorflow.keras.layers import Lambda
import neurite as ne
import sys
import nibabel as nib

import os
import tensorflow as tf

import numpy as np
from keras.callbacks import LearningRateScheduler

def step_decay(epoch):
    initial_lr = 0.0001
    drop = 0.5
    epochs_drop = 10.0
    lr = initial_lr * (drop ** (epoch // epochs_drop))
    return lr
    
tf.get_logger().setLevel('ERROR')

import os
import random
num_gen = 20
nb_labels=2
nb_features=64
batch_size=4
warp_max=2   
warp_max=2.5
warp_min=.5
warp_blur_min=np.array([2, 4, 8])
warp_blur_max=warp_blur_min*2
bias_blur_min=np.array([2, 4, 8])
bias_blur_max=bias_blur_min*2
initial_lr=1e-4
lr = 1e-4
lr_lin = 1e-4
nb_levels=5
conv_size=3

num_epochs=40000
initial_epoch = 8000
models_dir="models.bi.zb.0.2"
zero_background=0.2

# checkpoint_callback = PeriodicModelSaver(filepath=models_dir)
checkpoint_path = models_dir+'/weights_epoch_2000.h5'

import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    tf.config.experimental.set_visible_devices(gpus[0], 'GPU')

log_dir = "logs.bi.zb.0.2"  # Specify the directory where you want to save logs
summary_writer = tf.summary.create_file_writer(log_dir)


from tensorflow.keras.callbacks import ModelCheckpoint

class CustomTensorBoard(tf.keras.callbacks.TensorBoard):
    def __init__(self, base_log_dir, **kwargs):
        super(CustomTensorBoard, self).__init__(**kwargs)
        self.base_log_dir = base_log_dir

    def on_epoch_begin(self, epoch, logs=None):
        if epoch % 200 == 0:  # Check if it's the start of a new set of 50 epochs
            self.log_dir = f"{self.base_log_dir}/epoch_{epoch}"
            super().set_model(self.model)




class PeriodicWeightsSaver(tf.keras.callbacks.Callback):
    def __init__(self, filepath, save_freq=200, **kwargs):
        super().__init__(**kwargs)
        self.filepath = filepath
        self.save_freq = save_freq

    def on_epoch_end(self, epoch, logs=None):
        # Save the weights every `save_freq` epochs
        if (epoch + 1) % self.save_freq == 0:
            weights_path = os.path.join(self.filepath, f"weights_epoch_{epoch + 1}.h5")
            self.model.save_weights(weights_path)
            print(f"Saved weights to {weights_path}")


weights_saver = PeriodicWeightsSaver(filepath=models_dir, save_freq=50)  # Save weights every 5 epochs

TB_callback = CustomTensorBoard(
    base_log_dir=log_dir,
    histogram_freq=0,
    write_graph=True,
    write_images=False,
    write_steps_per_second=False,
    update_freq='epoch',
    profile_batch=0,
    embeddings_freq=0,
    embeddings_metadata=None
)


class PeriodicModelSaver(tf.keras.callbacks.Callback):
    def __init__(self, filepath, save_freq=200, **kwargs):
        super().__init__(**kwargs)
        self.filepath = filepath
        self.save_freq = save_freq

    def on_epoch_end(self, epoch, logs=None):
        # Save the model every `save_freq` epochs
        if (epoch + 1) % self.save_freq == 0:
            save_path = os.path.join(self.filepath, f"model_epoch_{epoch + 1}.h5")
            self.model.save(save_path)
            print(f"Saved model to {save_path}")




def dice_loss(y_true, y_pred):
    ndims = len(y_pred.get_shape().as_list()) - 2
    vol_axes = list(range(1, ndims + 1))

    top = 2 * tf.reduce_sum(y_true * y_pred, vol_axes)
    bottom = tf.reduce_sum(y_true + y_pred, vol_axes)

    div_no_nan = tf.math.divide_no_nan if hasattr(
        tf.math, 'divide_no_nan') else tf.div_no_nan  # pylint: disable=no-member
    dice = tf.reduce_mean(div_no_nan(top, bottom))
    return -dice
    
def dice_coefficient(y_true, y_pred):
    ndims = len(y_pred.get_shape().as_list()) - 2
    vol_axes = list(range(1, ndims + 1))

    top = 2 * tf.reduce_sum(y_true * y_pred, vol_axes)
    bottom = tf.reduce_sum(y_true + y_pred, vol_axes)

    div_no_nan = tf.math.divide_no_nan if hasattr(
    tf.math, 'divide_no_nan') else tf.div_no_nan  # pylint: disable=no-member
    dice = tf.reduce_mean(div_no_nan(top, bottom))
    return dice

def one_hot_encode_image(img, depth=2):
    img = tf.convert_to_tensor(img)
    
    one_hot_img = tf.one_hot(tf.cast(img[..., 0], tf.int32), depth=depth)
    
    return one_hot_img
    
slice_seg24_list = []
slice_norm_list = []

for i in range(401):
    file_path = "neurite-oasis.2d.v1.0/OASIS_OAS1_"+str(i).zfill(4)+"_MR1/slice_seg24.nii.gz"
    file_path_norm = "neurite-oasis.2d.v1.0/OASIS_OAS1_"+str(i).zfill(4)+"_MR1/slice_norm.nii.gz"
    if os.path.exists(file_path):
        img = nib.load(file_path)
        img = img.get_fdata()

        norm = nib.load(file_path_norm)
        norm = norm.get_fdata()

        num_segments = len(np.unique(img).astype(int))
        slice_seg24_list.append(img)
        slice_norm_list.append(norm)
    else:
        continue

slice_seg24_array = np.array(slice_seg24_list)
slice_norm_array = np.array(slice_norm_list)

train_indices, test_indices = train_test_split(range(len(slice_seg24_array)), test_size=0.2, random_state=42)
slice_seg24_array_train = [slice_seg24_array[i] for i in train_indices]
slice_seg24_array_test = [slice_seg24_array[i] for i in test_indices]

slice_norm_array_train = [slice_norm_array[i] for i in train_indices]
slice_norm_array_test = [slice_norm_array[i] for i in test_indices]


def my_generator(label_maps, batch_size=1, same_subj=False, flip=False):
    print(len(label_maps))
    in_shape = label_maps[0].shape
    num_dim = len(in_shape)
    void = np.zeros((batch_size, *in_shape), dtype='float32')
    rand = np.random.default_rng()
    prop = dict(replace=False, shuffle=False)
    num_batches = len(label_maps) // batch_size
    while True:
        ind = rand.integers(len(label_maps), size=2 * batch_size)
        x = [label_maps[i] for i in ind]
        if same_subj:
            x = x[:batch_size] * 2
        x = np.stack(x)[..., None]

        if flip:
            axes = rand.choice(num_dim, size=rand.integers(num_dim + 1), **prop)
            x = np.flip(x, axis=axes + 1)

        src = x[:batch_size, ...,0]
        y = np.array(void)
        yield src, y
        

in_shape = (160,192)
input_image = slice_seg24_array_train


gen_arg = {
    'in_shape': in_shape,
    'labels_in': [i for i in range(25)],
    'labels_out': {i: 1 if i > 0 else 0 for i in range(25)},  # This line creates the desired mapping
    'warp_min': 0.01,
    'warp_max': 2.5,
    'zero_background': zero_background
}




unet_model = models.unet(input_shape=(160, 192, 1), nb_features=nb_features, nb_labels=nb_labels, nb_levels=nb_levels, conv_size=conv_size)

slices= []
maps=[]




gen_model_1 = ne.models.labels_to_image_new(**gen_arg, id=1)
input_img = Input(shape=(160, 192,1))


generated_img, y = gen_model_1(input_img)
print(generated_img.shape,y.shape)
segmentation = unet_model(generated_img)

combined_model = Model(inputs=input_img, outputs=segmentation)
combined_model.add_loss(dice_loss(y, segmentation))
combined_model.compile(optimizer=Adam(learning_rate=initial_lr))

print(len(slice_seg24_array_test))
gen  = my_generator(input_image,batch_size=batch_size, same_subj=False, flip=False)

steps_per_epoch = len(slice_seg24_array_train) // 1  
validation_steps = len(slice_seg24_array_test) // 1  

from tensorflow.keras.models import load_model

if os.path.exists(checkpoint_path):
    combined_model.load_weights(checkpoint_path)
    print("Loaded weights from the checkpoint and continued training.")
else:
    print("Checkpoint file not found.")

reduce_lr = ReduceLROnPlateau(
    monitor='val_loss',  # Monitor validation loss for learning rate reduction
    factor=0.2,  # Reduce learning rate by a factor of 0.2 when triggered
    patience=5,  # Number of epochs with no improvement before reducing learning rate
    min_lr=1e-6,  # Minimum learning rate
    verbose=0  # Display messages about learning rate reduction
)



hist = combined_model.fit(
    gen,
    epochs=num_epochs,  # Set the total number of epochs including previous ones
    initial_epoch=initial_epoch,  # Specify the initial epoch
    verbose=0,
    steps_per_epoch=10,
    callbacks=[weights_saver, TB_callback]
)



using final_pred_activation softmax for unet
(None, 160, 192, 1) (None, 160, 192, 2)
73
Checkpoint file not found.
288
Saved weights to models.bi.zb.0.2/weights_epoch_8050.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8100.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8150.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8200.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8250.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8300.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8350.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8400.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8450.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8500.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8550.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8600.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8650.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8700.h5
Saved weights to models.bi.zb.0.2/weights_epoch_8750.h5
Saved weights to models.bi.zb.0.2/weights