In [None]:
# pip install segmentation_models as sm
import tensorflow as tf 
import segmentation_models as sm
import glob
import cv2 
import os 
import numpy as np 
from matplotlib import pyplot as plt 


BACKBONE = 'resnet34'
preprocess_input = sm.get_preprocessing(BACKBONE)

# Set the image height and width
SIZE_X = 224
SIZE_Y = 224    


In [None]:

# Capture training image info as a list
train_images = []

for directory_path in glob.glob("insert train image here"):
    for img_path in glob.glob(os.path.join(directory_path, ".png")):
        # print img path
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)

        train_images.append(img)

train_images = np.array(train_images)

In [None]:
# Capture mask/label info as a list
train_masks = []

for directory_path in glob.glob("insert train mask here"):
    for mask_path in glob.glob(os.path.join(directory_path, ".png")):

        train_mask.append(mask)

train_masks = np.array(train_masks)

In [None]:
# use the customary x_train y train variable

X = train_images
Y = train_masks
Y = np.expand_dims(Y, axis=3)

In [None]:
from sklearn.model_selection import train_test_split
x_train, x_val, y_train, y_val = train_test_split(X,Y, test_size=0.2, random_state=42)



In [None]:
# preprocess input 
x_train = preprocess_input(x_train)
x_val = preprocess_input(x_val)



In [None]:
# defining a model
model = sm.Unet(BACKBONE, encoder_weights='imagenet')
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=[sm.metrics.iou_score],)

print(model.summary())



In [None]:
history = model.fit(
    x_train,
    y_train,
    batch_size=8,
    epochs=10,
    verbose=1,
    validation_data=(x_val, y_val)
)

In [None]:
#accuracy = model.evaluate(x_val, y_val)
#plot the training and validation accuracy and loss at each epoch
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(loss) + 1)
plt.plot(epochs, loss, 'y', label='Training loss')
plt.plot(epochs, val_loss, 'r', label='Validation loss')
plt.title('Training and validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

#model.save('membrane.h5')


from tensorflow import keras
model = keras.models.load_model('membrane.h5', compile=False)
#Test on a different image
#READ EXTERNAL IMAGE...
test_img = cv2.imread('membrane/test/0.png', cv2.IMREAD_COLOR)       
test_img = cv2.resize(test_img, (SIZE_Y, SIZE_X))
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
test_img = np.expand_dims(test_img, axis=0)

prediction = model.predict(test_img)

#View and Save segmented image
prediction_image = prediction.reshape(mask.shape)
plt.imshow(prediction_image, cmap='gray')
plt.imsave('membrane/test0_segmented.jpg', prediction_image, cmap='gray')