In [3]:
from keras.models import Model
from keras.layers import Conv2D, MaxPooling2D, Input, ZeroPadding2D, \
    Dropout, Conv2DTranspose, Cropping2D, Add, UpSampling2D
from keras.layers.merge import concatenate
from image_segmentation_keras.keras_segmentation.models.model_utils import get_segmentation_model
from glob import glob

import tensorflow as tf

strategy = tf.distribute.MirroredStrategy()
print("Number of devices: {}".format(strategy.num_replicas_in_sync))

# Open a strategy scope.
with strategy.scope():
    input_height = 256
    input_width = 256
    n_classes = 3
    channels = 3

    img_input = Input(shape=(input_height,input_width, channels))

    conv0 = Conv2D(64, (3, 3), activation='relu', padding='same')(img_input)
    conv0 = Dropout(0.2)(conv0)
    conv0 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv0)
    pool0 = MaxPooling2D((2, 2))(conv0)
    
    conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool0)
    conv1 = Dropout(0.2)(conv1)
    conv1 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D((2, 2))(conv1)

    conv2 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Dropout(0.2)(conv2)
    conv2 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D((2, 2))(conv2)

    conv3 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Dropout(0.2)(conv3)
    conv3 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D((2, 2))(conv3)
    
    conv4 = Conv2D(1024, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Dropout(0.2)(conv4)
    conv4 = Conv2D(1024, (3, 3), activation='relu', padding='same')(conv4)
    print("conv4",conv4.shape)
    print('conv3',conv3.shape)

    up_= Conv2DTranspose(512,(2,2),strides=2,padding='same')(conv4)
    print('up_',up_.shape)
    up0 = concatenate([up_, conv3], axis=3)
    print(up0.shape)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(up0)
    conv5 = Dropout(0.2)(conv5)
    conv5 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv5)
    
    up_2= Conv2DTranspose(256,(2,2),strides=2,padding='same')(conv5)
    up1 = concatenate([up_2, conv2], axis=-1)
    print(up1.shape)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(up1)
    conv6 = Dropout(0.2)(conv6)
    conv6 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv6)
    
    up_3= Conv2DTranspose(128,(2,2),strides=2,padding='same')(conv6)
    up2 = concatenate([up_3, conv1], axis=3)
    print(up2.shape)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(up2)
    conv7 = Dropout(0.2)(conv7)
    conv7 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv7)
    
    up_4= Conv2DTranspose(64,(2,2),strides=2,padding='same')(conv7)
    up3 = concatenate([up_4, conv0], axis=3)
    print(up3.shape)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(up3)
    conv8 = Dropout(0.2)(conv8)
    conv8 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv8)
    print(conv8.shape)
    out = Conv2D( n_classes, (1, 1) , padding='same')(conv8)
    print('out',out.shape)
    model = get_segmentation_model(img_input ,  out ) # this would build the segmentation model

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
Number of devices: 1
conv4 (None, 16, 16, 1024)
conv3 (None, 32, 32, 512)
up_ (None, None, None, 512)
(None, 32, 32, 1024)
(None, 64, 64, 512)
(None, 128, 128, 256)
(None, 256, 256, 128)
(None, 256, 256, 64)
out (None, 256, 256, 3)


In [2]:
model.train(
    train_images =  "/Users/mavaylon/Research/Research_Gambier/Data_P/BP_lrc_filtered/",
    train_annotations = "/Users/mavaylon/Research/Research_Gambier/Data_P/BP_lrc_norm/",
    epochs=20,
    steps_per_epoch=len(glob("/Users/mavaylon/Research/Research_Gambier/Data_P/BP_lrc_norm/*")),
    batch_size=1
)

  3%|▎         | 26/1024 [00:00<00:03, 256.84it/s]

Verifying training dataset


100%|██████████| 1024/1024 [00:04<00:00, 207.96it/s]


Dataset verified! 
correct
Epoch 1/20

Epoch 00001: saving model to pet_class_crf.h5
Epoch 2/20
 112/1024 [==>...........................] - ETA: 30:53 - loss: 0.0404 - accuracy: 0.9834

KeyboardInterrupt: 