In [None]:
%load_ext autoreload
%autoreload 2


import os
import datetime
from glob import glob
from tqdm import tqdm

import cv2
import numpy as np
import mlflow
import matplotlib.pyplot as plt
import albumentations as A

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers,backend as K
from tensorflow.keras.metrics import Recall,Precision
import  tensorflow_addons as tfa
import logging

import one_ring
from one_ring.config import get_config
from one_ring.data import get_data_loader,get_camvid_data_loader
from one_ring.transformers import Transformer
from one_ring.models import Unet,DeepLabV3Plus,AttUnet
from one_ring.losses import FocalTverskyLoss, DiceLoss,BASNetHybridLoss,JaccardLoss,LogCoshDiceLoss,ComboLoss,BoundaryDoULoss
from one_ring.train import Trainer
from one_ring.callbacks import get_callbacks
from one_ring.losses import FocalTverskyLoss, DiceLoss,LogCoshDiceLoss,binary_focal_loss,categorical_focal_loss,FocalLoss,sym_unified_focal_loss,SymmetricUnifiedFocalLoss
from one_ring.metrics import DiceScore,JaccardScore
from one_ring.callbacks import ORLearningRateCallback
from one_ring.scheduler import ORLearningRateScheduler
from one_ring.utils import generate_overlay_image,calculate_confusion_matrix_and_report,plot_history_dict

print('tensorflow version :',tf.__version__)

import warnings

warnings.filterwarnings("always")
logging.getLogger('tensorflow').setLevel(logging.ERROR)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  

In [2]:
# # create a dateste
    # X = np.random.normal(1,2,(100,224,224,3)).astype(np.float32)
    # y = np.random.normal(1,2,(100,224,224,12)).astype(np.float32)



    # # create a dataset
    # dataset = tf.data.Dataset.from_tensor_slices((X,y))

    # dataset = dataset.batch(10)

    # # split the dataset
    # train_dataset = dataset.take(70)
    # val_dataset = dataset.skip(70)


    # # create most simple  model
    # model = keras.Sequential()
    # model.add(layers.Input(shape=(224,224,3)))
    # model.add(layers.Conv2D(1,3,activation='relu',padding='same'))

    # model.summary()


    # model.compile(optimizer='adam',loss=BoundaryDoULoss(1),metrics=[DiceScore()])
    # model.fit(train_dataset,epochs=20)


In [3]:
# def log_albumentation(cfg,prefix):
   
    #     # assume composee
    #     tr_cfg = cfg["transform"]["transforms"]

    #     for t in tr_cfg:
    #         # Use the transformation class name as part of the parameter name
    #         class_name = t['__class_fullname__'].split('.')[-1]  # Extracts the class name without the full module path
    #         param_name = f"{prefix}_{class_name}"
            
    #         # Prepare a dictionary with the parameters, excluding '__class_fullname__'
    #         params = {k: v for k, v in t.items() if k != '__class_fullname__'}
            
    #         # Log the parameters of each transformation as a separate parameter, converting the dictionary to a JSON string
    #         mlflow.log_param(param_name,params)


    # prefix = "aug"
    # # Log the augmentation type
    # mlflow.log_param(f"{prefix}_type", cfg.aug_type)

    # # for train

    # log_albumentation(cfg["train"],prefix="aug_train")
    # log_albumentation(cfg["test"],prefix="aug_test")


    # For the complex nested structure like cfg.train, consider logging key aspects or entire structure as a JSON string
    # Given the request to adjust, we're focusing on the transformations specifically

In [43]:
config = get_config(config_filename="spinal_cord")
cfg = config["augmentation"]

In [None]:

config = get_config(config_filename="spinal_cord")
train_data_loader, val_data_loader = get_data_loader(config.data, train_data=True, val_data=True, test_data=False)


IM_SIZE = config.data["image_size"][0]
aug_config = {
    "aug_prob": 0.1,
    #"random_contrast_limit": 0.4,
    #"random_brightness_limit": 0.3,
    "rotate_limit": 10,
}
tracing_object = {"mlflow":{"augmentation":aug_config}}

train_transforms = A.Compose(
    [
        A.Resize(IM_SIZE, IM_SIZE),
        #A.GaussNoise(var_limit=(10.0, 50.0), p=aug_config["aug_prob"]),
        #A.CLAHE(p=aug_config["aug_prob"]),
       # A.RandomBrightnessContrast(p=aug_config["aug_prob"], brightness_limit=aug_config["random_brightness_limit"], contrast_limit=aug_config["random_contrast_limit"]),
       # A.RandomGamma(p=aug_config["aug_prob"]),
        A.HorizontalFlip(p=aug_config["aug_prob"]),
        A.VerticalFlip(p=aug_config["aug_prob"]),
        A.Rotate(limit=aug_config["rotate_limit"], p=aug_config["aug_prob"]),
        A.RandomSizedCrop(min_max_height=(180, 224), height=IM_SIZE, width=IM_SIZE, p=aug_config["aug_prob"]),
    ]
)
test_transforms = A.Compose(
    [
        A.Resize(IM_SIZE, IM_SIZE),
        # A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

tr_transforms_object = Transformer(config.augmentation, "train", train_transforms)#.from_dict()
ts_transforms_object = Transformer(config.augmentation, "test", test_transforms)#.from_dict()
train_dataset = train_data_loader.load_data(transform_func=tr_transforms_object)
val_dataset = val_data_loader.load_data(transform_func=ts_transforms_object, shuffle=False)

#callbacks = get_callbacks(config.callbacks)
callbacks = {}

steps_per_epoch = len(train_dataset)
lr_schedule = ORLearningRateScheduler(
    strategy=config.train.lr_scheduler["name"],
    total_epochs=config.train.epochs,
    steps_per_epoch=steps_per_epoch,
    **config.train.lr_scheduler["params"]
).get()
callbacks["lr_sch"] = ORLearningRateCallback(lr_schedule)


log_dir = f"board_logs/{config.train.experiment_name}/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
callbacks["tensorboard"] = tf.keras.callbacks.TensorBoard(log_dir=log_dir,histogram_freq=1, write_images=True, write_graph=True, update_freq='epoch')#, profile_batch='500,520')


In [6]:
loss_name = "boundary_dou_loss"
params = {
    "gamma": 4/3,
    "alpha": 0.4,
    "loss_weight": 0.4
}

tf.keras.backend.clear_session()

loss_dict = {
    "focal_loss": FocalLoss,
    "dice_loss": DiceLoss,
    "jaccard_loss": JaccardLoss,
    "log_cosh_dice_loss": LogCoshDiceLoss,
    "focal_tversky_loss": FocalTverskyLoss,
    "symmetric_unified_focal_loss": SymmetricUnifiedFocalLoss,
    "boundary_dou_loss": BoundaryDoULoss
}

metrics_dict= {
    "recal":Recall,
    "precision":Precision,
    "jaccard_score":JaccardScore
}

loss = loss_dict[loss_name](**params)
metrics = [m() for m in metrics_dict.values()]

In [None]:
#model = AttUnet(**config.model).build_model()
model = DeepLabV3Plus(**config.model).build_model()
#model = Unet(**config.model).build_model()

#model.summary()
#keras.utils.plot_model(model, show_shapes=True, to_file=f'{model.name}_model.png')
trainer = Trainer(config, model, train_dataset, val_dataset, callbacks=list(callbacks.values()), metrics=metrics,loss=loss,tracing_object=tracing_object)

In [None]:
trainer.load("models/20240908202630",setup_components = True)

In [None]:
# #trainer.compile()
# print(trainer.loss)
# print(trainer.optimizer)
# print(trainer.callbacks)
# print(trainer.metrics)
history = trainer.fit()

In [None]:
trainer.finalize_training()

In [None]:
trainer.save_path

In [None]:
trainer.load(trainer.save_path)

In [None]:
trainer.loaded_metadata

In [None]:
trainer.metadata

In [None]:
tf.keras.models.load_model("models/20240908190040/saved_model")

In [None]:
trainer.load("models/20240908190040")

In [None]:
trainer.saver.save(
    "test"
)

In [None]:
# Saving a model
saver = ModelSaver(model, config, processors)
saver.save(
    path="model_directory",
    train_ds=train_dataset,
    val_ds=val_dataset,
    custom_objects=custom_objects,
    additional_metadata={"training_iteration": 1, "best_val_loss": 0.1}
)

# Loading a model
loaded_data = saver.load(
    path="model_directory",
    compile=True,
    custom_objects=custom_objects
)
loaded_model = loaded_data['model']
loaded_metadata = loaded_data['metadata']
loaded_processors = loaded_data['processors']

# Updating a model after retraining
saver.update_model(
    new_model=retrained_model,
    path="model_directory",
    additional_metadata={"training_iteration": 2, "best_val_loss": 0.05}
)

In [16]:
model = trainer.model

In [None]:
model.save("test")

In [None]:
m = tf.keras.models.load_model("test")

In [None]:
#model.evaluate(val_dataset)

In [None]:
mlflow.log_param("loss_name",loss_name)
[mlflow.log_param(f"loss_{k}",v) for k,v in params.items()]


In [None]:
trainer.end()

In [13]:
def calculate_binary_differece(target, pred):
    # Calculate differences
    tp = (pred == 1) & (target == 1)  # True Positives
    fp = (pred == 1) & (target == 0)  # False Positives
    fn = (pred == 0) & (target == 1)  # False Negatives

    # Create an RGB image where each difference is colored differently
    # Initialize with zeros (black) for the background
    diff_image = np.zeros(target.shape + (3,), dtype=np.uint8)

    # Assign colors (R, G, B)
    # True Positives in green
    diff_image[tp] = [0, 255, 0]
    # False Positives in red
    diff_image[fp] = [255, 0, 0]
    # False Negatives in blue
    diff_image[fn] = [0, 0, 255]

    diff_image_corrected = np.squeeze(diff_image, axis=2)

    return diff_image_corrected

In [None]:
threshold = 0.6
for i in val_dataset:
    pred_logits = model.predict(i[0])

    for n in range(len(i[0])): 
        image = i[0][n].numpy().astype(np.uint8)
        pred_logit = pred_logits[n]
     #   print(pred_logit.shape)
        pred_value = np.where(pred_logit>threshold,1,0).reshape(224,224,1)
        pred_mask = (pred_value*255).astype(np.uint8)

    #    print(pred_mask.shape, image.shape)
        overlay = generate_overlay_image(pred_mask, image, alpha=0.3)
        target = i[1][n].numpy()
        np.unique(target,return_counts=True)    

        plt.figure(figsize=(20,10))
        plt.subplot(1,3,1)
        plt.title('target')
        plt.imshow(target)
        plt.grid(True)
        plt.subplot(1,3,2)
        plt.title('pred')
        plt.imshow(pred_value)
        plt.grid(True)
        plt.subplot(1,3,3)
        plt.imshow(overlay)
        plt.grid(True)
        plt.show()

        #print(pred_value.shape, target.shape)

        cm,cr = calculate_confusion_matrix_and_report(pred_value, target)
        print(cm)
        print(cr)


In [None]:
accuracy = tf.keras.metrics

In [None]:
target = i[1][n]
pred_logit.shape,target.shape

a = accuracy(tf.squeeze(target),tf.squeeze(pred_logit))

In [None]:
loss = DiceLoss()
from one_ring.losses import dice_coef

In [None]:
pred_value = (pred_logit>0.5).astype(np.float32)


In [None]:
dice_coef(target,pred_logit)

In [None]:
dice_coef(target,pred_value)

In [None]:
# remove dimension
tf.squeeze(target).shape

#### Sequence Exp

In [None]:
config = get_config(config_filename="spinal_cord")
train_data_loader, val_data_loader = get_data_loader(config.data, train_data=True, val_data=True, test_data=False)

IM_SIZE = config.data["image_size"][0]

aug_config = {
    "aug_prob": 0.1,
    #"random_contrast_limit": 0.4,
    #"random_brightness_limit": 0.3,
    "rotate_limit": 10,
}
tracing_object = {"mlflow":{"augmentation":aug_config}}

train_transforms = A.Compose(
    [
        A.Resize(IM_SIZE, IM_SIZE),
        #A.GaussNoise(var_limit=(10.0, 50.0), p=aug_config["aug_prob"]),
        #A.CLAHE(p=aug_config["aug_prob"]),
       # A.RandomBrightnessContrast(p=aug_config["aug_prob"], brightness_limit=aug_config["random_brightness_limit"], contrast_limit=aug_config["random_contrast_limit"]),
       # A.RandomGamma(p=aug_config["aug_prob"]),
        A.HorizontalFlip(p=aug_config["aug_prob"]),
        A.VerticalFlip(p=aug_config["aug_prob"]),
        A.Rotate(limit=aug_config["rotate_limit"], p=aug_config["aug_prob"]),
        A.RandomSizedCrop(min_max_height=(180, 224), height=IM_SIZE, width=IM_SIZE, p=aug_config["aug_prob"]),
    ]
)
test_transforms = A.Compose(
    [
        A.Resize(IM_SIZE, IM_SIZE),
        # A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

tr_transforms_object = Transformer(config.augmentation, "train", train_transforms)#.from_dict()
ts_transforms_object = Transformer(config.augmentation, "test", test_transforms)#.from_dict()
train_dataset = train_data_loader.load_data(transform_func=tr_transforms_object)
val_dataset = val_data_loader.load_data(transform_func=ts_transforms_object, shuffle=False)

In [None]:
for gamma in [5/4,4/3,3/2,2]:
    for i in range(1,6):


        callbacks = get_callbacks(config.callbacks)

        steps_per_epoch = len(train_dataset)
        lr_schedule = ORLearningRateScheduler(
            strategy=config.trainer.lr_scheduler["name"],
            total_epochs=config.trainer.epochs,
            steps_per_epoch=steps_per_epoch,
            **config.trainer.lr_scheduler["params"]
        ).get()



        log_dir = f"board_logs/{config.trainer.experiment_name}/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        callbacks["tensorboard"] = tf.keras.callbacks.TensorBoard(log_dir=log_dir,histogram_freq=1, write_images=True, write_graph=True, update_freq='epoch')#, profile_batch='500,520')

        #callbacks["lr_sch"]= tf.keras.callbacks.LearningRateScheduler(cosine_decay_scheduler)

        # callbacks["lr_sch"] = OneRingLearningRateScheduler(cosine_decay_scheduler)
        callbacks["lr_sch"] = ORLearningRateCallback(lr_schedule)


        #gamma = 4/3
        alpha = 0.2*i

        losses = [FocalTverskyLoss(gamma=gamma,alpha=alpha)]
        #losses = [LogCoshDiceLoss()]
        # losses = [BASNetHybridLoss()]
        # losses = [JaccardLoss()]
        #losses = [DiceLoss()]

        #losses  = [ComboLoss(alpha=alpha)]
        metrics = [Recall(),Precision()]

        model = AttUnet(**config.model).build_model()
        #model = Unet(**config.model).build_model()
        #model.summary()
        #keras.utils.plot_model(model, show_shapes=True, to_file='model.png')
        trainer = Trainer(config, model, train_dataset, val_dataset, callbacks=callbacks, metrics=metrics,losses=losses,tracing_object=tracing_object)
        trainer.fit(continue_training=True)
        model = trainer._model
        print(alpha,model.evaluate(val_dataset))

        mlflow.log_param("loss_alpha",alpha)
        mlflow.log_param("loss_gamma",gamma)

        trainer.end()
        
        [m.reset_states() for m in metrics]
        
        del model, trainer,losses
        tf.keras.backend.clear_session()

In [None]:
# scores = model.evaluate(val_dataset, verbose=1)

# save_path = f"best/"
# os.makedirs(save_path,exist_ok=True)
# model_name = save_path+f"d-{config.data.name}-dsc-{scores[1]:.4f}"
# #model.save(model_name)
# trainer.save(model_name)

In [None]:
def calculate_binary_differece(target, pred):
    # Calculate differences
    tp = (pred == 1) & (target == 1)  # True Positives
    fp = (pred == 1) & (target == 0)  # False Positives
    fn = (pred == 0) & (target == 1)  # False Negatives

    # Create an RGB image where each difference is colored differently
    # Initialize with zeros (black) for the background
    diff_image = np.zeros(target.shape + (3,), dtype=np.uint8)

    # Assign colors (R, G, B)
    # True Positives in green
    diff_image[tp] = [0, 255, 0]
    # False Positives in red
    diff_image[fp] = [255, 0, 0]
    # False Negatives in blue
    diff_image[fn] = [0, 0, 255]

    diff_image_corrected = np.squeeze(diff_image, axis=2)

    return diff_image_corrected

In [None]:
threshold = 0.6
for i in val_dataset:
    pred_logits = model.predict(i[0])



    for n in range(len(i[0])): 
        image = i[0][n].numpy().astype(np.uint8)
        pred_logit = pred_logits[n]
     #   print(pred_logit.shape)
        pred_value = np.where(pred_logit>threshold,1,0).reshape(224,224,1)
        pred_mask = (pred_value*255).astype(np.uint8)

    #    print(pred_mask.shape, image.shape)
        overlay = generate_overlay_image(pred_mask, image, alpha=0.3)
        target = i[1][n].numpy()
        np.unique(target,return_counts=True)    

        plt.figure(figsize=(20,10))
        plt.subplot(1,3,1)
        plt.title('target')
        plt.imshow(target)
        plt.grid(True)
        plt.subplot(1,3,2)
        plt.title('pred')
        plt.imshow(pred_value)
        plt.grid(True)
        plt.subplot(1,3,3)
        plt.imshow(overlay)
        plt.grid(True)
        plt.show()

        #print(pred_value.shape, target.shape)

        cm,cr = calculate_confusion_matrix_and_report(pred_value, target)
        print(cm)
        print(cr)

        # diff_image_corrected = calculate_binary_differece(target, pred_value)
        # im = diff_image_corrected[150:350, 150:350]
        
        # plt.figure(figsize=(10,10))
        # plt.imshow(im,cmap='gray')
        # plt.grid(True)
        # plt.show()


    
        

        