# Подготовка модели

In [None]:
def unet(num_classes = 3, input_shape= (160, 160, 1)):
    img_input = keras.Input(input_shape)

    # Block 1
    x = layers.Conv2D(64, (3, 3), padding='same', name='block1_conv1')(img_input)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(64, (3, 3), padding='same', name='block1_conv2')(x)
    x = layers.BatchNormalization()(x)
    block_1_out = layers.Activation('relu')(x)

    x = layers.MaxPooling2D()(block_1_out)

    # Block 2
    x = layers.Conv2D(128, (3, 3), padding='same', name='block2_conv1')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(128, (3, 3), padding='same', name='block2_conv2')(x)
    x = layers.BatchNormalization()(x)
    block_2_out = layers.Activation('relu')(x)

    x = layers.MaxPooling2D()(block_2_out)

    # Block 3
    x = layers.Conv2D(256, (3, 3), padding='same', name='block3_conv1')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(256, (3, 3), padding='same', name='block3_conv2')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(256, (3, 3), padding='same', name='block3_conv3')(x)
    x = layers.BatchNormalization()(x)
    block_3_out = layers.Activation('relu')(x)

    x = layers.MaxPooling2D()(block_3_out)

    # Block 4
    x = layers.Conv2D(512, (3, 3), padding='same', name='block4_conv1')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(512, (3, 3), padding='same', name='block4_conv2')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(512, (3, 3), padding='same', name='block4_conv3')(x)
    x = layers.BatchNormalization()(x)
    block_4_out = layers.Activation('relu')(x)
###########################################################
    x = layers.MaxPooling2D()(block_4_out)

    # Block 5
    x = layers.Conv2D(512, (3, 3), padding='same', name='block5_conv1')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(512, (3, 3), padding='same', name='block5_conv2')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(512, (3, 3), padding='same', name='block5_conv3')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # UP 1
    x = layers.Conv2DTranspose(512, (2, 2), strides=(2, 2), padding='same', name = 'Conv2DTranspose_UP1')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.concatenate([x, block_4_out])
    x = layers.Conv2D(512, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(512, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
###########################################################
    # UP 2
    x = layers.Conv2DTranspose(256, (2, 2), strides=(2, 2), padding='same', name = 'Conv2DTranspose_UP2')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.concatenate([x, block_3_out])
    x = layers.Conv2D(256, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(256, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # UP 3
    x = layers.Conv2DTranspose(128, (2, 2), strides=(2, 2), padding='same', name = 'Conv2DTranspose_UP3')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.concatenate([x, block_2_out])
    x = layers.Conv2D(128, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(128, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    # UP 4
    x = layers.Conv2DTranspose(64, (2, 2), strides=(2, 2), padding='same', name = 'Conv2DTranspose_UP4')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.concatenate([x, block_1_out])
    x = layers.Conv2D(64, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(64, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.Conv2D(num_classes, (3, 3), activation='softmax', padding='same')(x)

    model = keras.Model(img_input, x)

    return model

# Free up RAM in case the model definition cells were run multiple times
# keras.backend.clear_session()

model = unet()
model.summary()

# Тренировка модели

In [None]:
model.compile(optimizer=keras.optimizers.Adam(learning_rate=0.005),
              loss="categorical_crossentropy",
              metrics=[tf.keras.metrics.MeanIoU(num_classes=3)])

callbacks = [
    keras.callbacks.ModelCheckpoint("anode_defects_unet.h5", save_best_only=True)
]

# Train the model, doing validation at the end of each epoch.
epochs = 5
model.fit(X, y, epochs=epochs, batch_size=None, validation_split=0.2, verbose=1, callbacks=callbacks)