In [None]:
%load_ext autoreload
%autoreload 2


import os 
import numpy as np
import matplotlib.pyplot as plt
import albumentations as A
import cv2
from glob import glob
from tqdm import tqdm
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.metrics import Recall,Precision
import  tensorflow_addons as tfa
import tensorflow as tf
import datetime
import mlflow


import one_ring
from one_ring.config import get_config
from one_ring.data import get_data_loader,get_camvid_data_loader
from one_ring.transformers import Transformer
from one_ring.models import Unet,DeepLabV3Plus,AttUnet
from one_ring.losses import FocalTverskyLoss, DiceLoss
from one_ring.train import Trainer
from one_ring.callbacks import get_callbacks
from one_ring.losses import FocalTverskyLoss, DiceLoss,LogCoshDiceLoss
from one_ring.callbacks import ORLearningRateCallback 
from one_ring.scheduler import ORLearningRateScheduler
from one_ring.utils import generate_overlay_image,calculate_confusion_matrix_and_report,plot_history_dict
#from sklearn.metrics import classification_report, confusion_matrix
#from sklearn.metrics import ConfusionMatrixDisplay

print('tensorflow version :',tf.__version__)
print('one_ring version :',one_ring.__version__)
#print(tf.config.list_physical_devices())

In [None]:
config = get_config(config_filename="spinal_cord")
train_data_loader, val_data_loader = get_data_loader(config.data, train_data=True, val_data=True, test_data=False)

IM_SIZE = config.data["image_size"][0]

aug_config = {
    "aug_prob": 0.1,
    "random_contrast_limit": 0.4,
    "random_brightness_limit": 0.3,
    "rotate_limit": 10,
}
tracing_object = {"mlflow":{"augmentation":aug_config}}

train_transforms = A.Compose(
    [
        A.Resize(IM_SIZE, IM_SIZE),
        #A.GaussNoise(var_limit=(10.0, 50.0), p=aug_config["aug_prob"]),
        #A.CLAHE(p=aug_config["aug_prob"]),
       # A.RandomBrightnessContrast(p=aug_config["aug_prob"], brightness_limit=aug_config["random_brightness_limit"], contrast_limit=aug_config["random_contrast_limit"]),
       # A.RandomGamma(p=aug_config["aug_prob"]),
        A.Rotate(limit=aug_config["rotate_limit"], p=aug_config["aug_prob"]),
        A.RandomSizedCrop(min_max_height=(400, 512), height=IM_SIZE, width=IM_SIZE, p=aug_config["aug_prob"]),
    ]
)
test_transforms = A.Compose(
    [
        A.Resize(IM_SIZE, IM_SIZE),
        # A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)
tr_transforms_object = Transformer(config.augmentation, "train", train_transforms)
ts_transforms_object = Transformer(config.augmentation, "test", test_transforms)
train_dataset = train_data_loader.load_data(transform_func=tr_transforms_object).prefetch(tf.data.AUTOTUNE)
val_dataset = val_data_loader.load_data(transform_func=ts_transforms_object, shuffle=False).prefetch(tf.data.AUTOTUNE)

callbacks = get_callbacks(config.callbacks)

steps_per_epoch = len(train_dataset)
lr_schedule = ORLearningRateScheduler(
    strategy=config.trainer.lr_scheduler["name"],
    total_epochs=config.trainer.epochs,
    steps_per_epoch=steps_per_epoch,
    **config.trainer.lr_scheduler["params"]
).get()



log_dir = f"board_logs/{config.trainer.experiment_name}/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
callbacks["tensorboard"] = tf.keras.callbacks.TensorBoard(log_dir=log_dir,histogram_freq=1, write_images=True, write_graph=True, update_freq='epoch')#, profile_batch='500,520')

#callbacks["lr_sch"]= tf.keras.callbacks.LearningRateScheduler(cosine_decay_scheduler)

# callbacks["lr_sch"] = OneRingLearningRateScheduler(cosine_decay_scheduler)
callbacks["lr_sch"] = ORLearningRateCallback(lr_schedule)

In [None]:
gamma = 4/3
alpha = 0.5

losses = [FocalTverskyLoss(gamma=gamma,alpha=alpha)]
#losses = [LogCoshDiceLoss()]
#losses = [BASNetHybridLoss()]
metrics = [Recall(),Precision()]

In [None]:
model = AttUnet(**config.model).build_model()
#model = Unet(**config.model).build_model()
#model.summary()
#keras.utils.plot_model(model, show_shapes=True, to_file='model.png')
trainer = Trainer(config, model, train_dataset, val_dataset, callbacks=callbacks, metrics=metrics,losses=losses,tracing_object=tracing_object)

In [None]:
trainer.fit(continue_training=True)
model = trainer._model

In [None]:
model.evaluate(val_dataset)

In [None]:
trainer.end()

In [None]:
# scores = model.evaluate(val_dataset, verbose=1)

# save_path = f"best/"
# os.makedirs(save_path,exist_ok=True)
# model_name = save_path+f"d-{config.data.name}-dsc-{scores[1]:.4f}"
# #model.save(model_name)
# trainer.save(model_name)

In [None]:
def calculate_binary_differece(target, pred):
    # Calculate differences
    tp = (pred == 1) & (target == 1)  # True Positives
    fp = (pred == 1) & (target == 0)  # False Positives
    fn = (pred == 0) & (target == 1)  # False Negatives

    # Create an RGB image where each difference is colored differently
    # Initialize with zeros (black) for the background
    diff_image = np.zeros(target.shape + (3,), dtype=np.uint8)

    # Assign colors (R, G, B)
    # True Positives in green
    diff_image[tp] = [0, 255, 0]
    # False Positives in red
    diff_image[fp] = [255, 0, 0]
    # False Negatives in blue
    diff_image[fn] = [0, 0, 255]

    diff_image_corrected = np.squeeze(diff_image, axis=2)

    return diff_image_corrected

In [None]:
threshold = 0.6
for i in val_dataset:
    pred_logits = model.predict(i[0])



    for n in range(len(i[0])): 
        image = i[0][n].numpy().astype(np.uint8)
        pred_logit = pred_logits[n]
     #   print(pred_logit.shape)
        pred_value = np.where(pred_logit>threshold,1,0).reshape(512,512,1)
        pred_mask = (pred_value*255).astype(np.uint8)

    #    print(pred_mask.shape, image.shape)
        overlay = generate_overlay_image(pred_mask, image, alpha=0.3)
        target = i[1][n].numpy()
        np.unique(target,return_counts=True)    

        plt.figure(figsize=(20,10))
        plt.subplot(1,3,1)
        plt.title('target')
        plt.imshow(target)
        plt.grid(True)
        plt.subplot(1,3,2)
        plt.imshow(pred_value)
        plt.grid(True)
        plt.subplot(1,3,3)
        plt.imshow(overlay)
        plt.grid(True)
        plt.show()

        #print(pred_value.shape, target.shape)

        cm,cr = calculate_confusion_matrix_and_report(pred_value, target)
        print(cm)
        print(cr)

        diff_image_corrected = calculate_binary_differece(target, pred_value)
        im = diff_image_corrected[150:350, 150:350]
        
        plt.figure(figsize=(10,10))
        plt.imshow(im,cmap='gray')
        plt.grid(True)
        plt.show()


    
        

        