In [None]:
!pip install -U segmentation-models
!pip install -U git+https://github.com/albu/albumentations --no-cache-dir

In [None]:
import cv2
import os
import numpy as np
import glob
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from itertools import chain
from skimage.io import imread, imshow, imread_collection, concatenate_images
from skimage.transform import resize
from skimage.morphology import label
from keras.models import Model, load_model
from keras.layers import Input,Dense
from keras.layers.core import Dropout, Lambda
from keras.layers.convolutional import Conv2D, Conv2DTranspose
from keras.layers.pooling import MaxPooling2D
from keras.layers.merge import concatenate
from keras.callbacks import EarlyStopping, ModelCheckpoint
from keras import backend as K
import tensorflow as tf
import albumentations as A
from PIL import Image
import keras
from sklearn.model_selection import train_test_split
import segmentation_models as sm

Data Load

In [None]:
flist=tf.random.shuffle(glob.glob('/content/drive/MyDrive/data/echocardiography/train/A*C/*.npy'))

In [None]:
test_flist=tf.random.shuffle(glob.glob('/content/drive/MyDrive/data/echocardiography/validation/A*C/*.npy'))

In [None]:
len(flist)
len(test_flist)

In [None]:
val_flist=flist[:int(len(flist)*0.2)]
train_flist=flist[int(len(flist)*0.2)+1:]

Pre-Processing

In [None]:
def make_dataset(test_flist):
  test_flist=tf.data.Dataset.from_tensor_slices(test_flist)
  test_dataset=data_(test_flist)
  test_dataset=test_dataset.map(data_preprocess)
  test_dataset=test_dataset.batch(3)
  test_dataset=test_dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return test_dataset


In [None]:
train_dataset = make_dataset(train_flist)
val_dataset = make_dataset(val_flist)
test_dataset = make_dataset(test_flist)

In [None]:
def data_(dataset):

  dataset=dataset.map(file_replace)
  dataset=dataset.map(data_resize)
  
  return  dataset

def file_replace(path):
  
  return tf.strings.regex_replace(path, '.npy', '.png'), path

def data_load(data):
  return np.load(data.numpy())


def data_preprocess(img,label):

  img,label=tf.py_function(data_data,inp=[img,label],Tout=(tf.float32,tf.float32,tf.float32))
  img=tf.cast(img,tf.float32)
  label=tf.cast(label,tf.float32)
  
  return img,label

def data_data(img,label):
  #img_hsv=cv2.cvtColor(img.numpy(),cv2.COLOR_RGB2HSV)
  #img=np.concatenate((img,img_hsv),axis=-1)
  img=cv2.resize(img.numpy(),(608,416),cv2.INTER_AREA)
  
  
  label=cv2.resize(label.numpy(),(608,416),cv2.INTER_AREA)
  
  return img,label


def data_resize(x,y):

  raw = tf.io.read_file(x)
  img = tf.image.decode_png(raw, channels=4)
  img_np=tf.py_function(data_load,inp=[y],Tout=tf.float32)

  return img,img_np




In [None]:
for i,j in train_dataset.take(1):
  print(i.shape)

Training

In [None]:
BACKBONE = 'efficientnetb7'
CLASSES = ['seg']
sm.set_framework('tf.keras')
sm.framework()

In [None]:
n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1) 
activation = 'sigmoid' if n_classes == 1 else 'softmax'


model = sm.Unet(BACKBONE, classes=n_classes, activation=activation,input_shape=(416, 608, 4), encoder_weights=None)

In [None]:
def jacard_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection)


def dice_coef(y_true, y_pred):
    y_true_f = K.flatten(y_true)
    y_pred_f = K.flatten(y_pred)
    intersection = K.sum(y_true_f * y_pred_f)
    return 2.0*(intersection) / (K.sum(y_true_f) + K.sum(y_pred_f))

In [None]:
model.compile(
    tf.keras.optimizers.Adam(learning_rate=0.001),
    loss=sm.losses.bce_jaccard_loss+sm.losses.dice_loss,
    metrics=[sm.metrics.iou_score,jacard_coef,dice_coef]
)

In [None]:
callbacks = [
    tf.keras.callbacks.ModelCheckpoint('./new_{epoch:02d}.h5', verbose=True, save_weights_only=True, mode='auto'),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_iou_score', factor=0.66, patience=4, verbose=0, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0),
]

In [None]:
history=model.fit(train_dataset,epochs=50,callbacks=callbacks,validation_data=val_dataset)

In [None]:
his=model.evaluate(test_dataset)

Image Reconstruction

In [None]:
a=model.predict(test_dataset)

In [None]:
size=[]
for i in sorted(os.listdir('/content/drive/MyDrive/A4C')):
  img=plt.imread(f'/content/drive/MyDrive/A4C/{i}')
  size.append(img.shape[:2])

In [None]:
def visualize(images,n,size):
  images= cv2.resize(images,(size[1],size[0]),cv2.INTER_CUBIC)
  
  plt.imsave('/content/drive/MyDrive/1214/9{:02d}.png'.format(n),images)

In [None]:
for i in range(100):
  visualize(a[i].squeeze(),i,size)

In [None]:
)