In [1]:
import tensorflow as tf
import numpy as np
import itertools
import glob
import cv2
import matplotlib.pyplot as plt

In [1]:
import argparse

In [None]:
args = argparse.ArgumentParser().add_argument()

In [2]:
class UNet:
    def __init__(
        self,
        input_width,
        input_height,
        num_classes,
        train_images,
        train_instances,
        val_images,
        val_instances,
        ecophs,
        lr,
        lr_decay,
        batch_size,
        save_path
    ):
        self.input_width = input_width
        self.input_height = input_height
        self.num_classes = num_classes
        self.train_images = train_images
        self.train_instances = train_instances
        self.val_images = val_images
        self.val_instances = val_instances
        self.ecophs = ecophs
        self.lr = lr
        self.lr_decay = lr_decay
        self.batch_size = batch_size
        self.save_path = save_path
        
    def lefNetwork(self, inputs):
        x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(inputs)
        o_1 = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_1)
        
        x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)
        o_2 = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_2)
        
        x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)
        o_3 = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_3)
        
        x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)
        o_4 = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2), strides=(2, 2))(o_4)
        
        x = tf.keras.layers.Conv2D(1024, (3, 3), padding='valid', activation='relu')(x)
        o_5 = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)

        return [o_1, o_2, o_3, o_4, o_5]
    
    def rightNetwork(self, inputs):
        c_1, c_2, c_3, c_4, o_5 = inputs
        
        o_5 = tf.keras.layers.UpSampling2D((2, 2))(o_5)
        x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(4)(c_4), o_5], axis=3)
        x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.Conv2D(512, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.UpSampling2D((2, 2))(x)

        x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(16)(c_3), x], axis=3)
        x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.Conv2D(256, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.UpSampling2D((2, 2))(x)
        
        x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(40)(c_2), x], axis=3)
        x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.Conv2D(128, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.UpSampling2D((2, 2))(x)
        
        x = tf.keras.layers.concatenate([tf.keras.layers.Cropping2D(88)(c_1), x], axis=3)
        x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.Conv2D(64, (3, 3), padding='valid', activation='relu')(x)
        x = tf.keras.layers.Conv2D(self.num_classes, (1, 1), padding='valid')(x)
        x = tf.keras.layers.Activation('softmax')(x)
        
        return x
    
    def build_model(self):
        inputs = tf.keras.Input(shape=[self.input_height, self.input_width, 3])
        left_output = self.lefNetwork(inputs)
        right_output = self.rightNetwork(left_output)
        
        model = tf.keras.Model(inputs = inputs, outputs = right_output)
        
        return model
    
    def train(self):
        G_train = self.dataGenerator(mode='training')
        G_eval = self.dataGenerator(mode='validation')
        
        model = self.build_model()
        model.summary()
        model.compile(
            optimizer = tf.keras.optimizers.Adam(self.lr, self.lr_decay),
            loss = 'categorical_crossentropy',
            metrics = ['categorical_accuracy', 'Recall', 'AUC']
        )
        model.fit_generator(G_train, 4, validation_data=G_eval, validation_steps=4, epochs=self.ecophs)
        model.save(self.save_path)
        
    def dataGenerator(self, mode):
        if mode =='training':
            images = glob.glob(self.train_images + '*.png')
            images.sort()
            instances = glob.glob(self.train_instances + '*.png')
            instances.sort()
            zipped = itertools.cycle(zip(images, instances))
            while True:
                x_train = []
                y_train = []
                
                for _ in range(self.batch_size):
                    img, seg = next(zipped)
                    img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height))
                    seg = tf.keras.utils.to_categorical(cv2.imread(seg, 0), num_classes=self.num_classes)
                        
                    x_train.append(img)
                    y_train.append(seg)
                    
                yield np.array(x_train), np.array(y_train)
                
        if mode =='validation':
            images = glob.glob(self.val_images + '*.png')
            images.sort()
            instances = glob.glob(self.val_instances + '*.png')
            instances.sort()
            zipped = itertools.cycle(zip(images, instances))
            while True:
                x_val = []
                y_val = []
                
                for _ in range(self.batch_size):
                    img, seg = next(zipped)
                    img = cv2.resize(cv2.imread(img, 1), (self.input_width, self.input_height))
                    seg = tf.keras.utils.to_categorical(cv2.imread(seg, 0), num_classes=self.num_classes)
                        
                    x_val.append(img)
                    y_val.append(seg)
                    
                yield np.array(x_train), np.array(y_train)
                    

In [5]:
unet = UNet(
    input_width=572,
    input_height=572,
    num_classes=3,
    train_images='./datasets/training/images/',
    train_instances='./datasets/training/instances/',
    val_images='./datasets/validation/images/',
    val_instances='./datasets/validation/instances/',
    ecophs=10,
    lr=0.0001,
    lr_decay=0.000001,
    batch_size=2,
    save_path='./my_models/model_unet.h5'
)


In [8]:
model = unet.build_model()
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 572, 572, 3) 0                                            
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 570, 570, 64) 1792        input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_39 (Conv2D)              (None, 568, 568, 64) 36928       conv2d_38[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D)  (None, 284, 284, 64) 0           conv2d_39[0][0]                  
____________________________________________________________________________________________

In [None]:
unet.train()

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_3 (InputLayer)            [(None, 572, 572, 3) 0                                            
__________________________________________________________________________________________________
conv2d_38 (Conv2D)              (None, 570, 570, 64) 1792        input_3[0][0]                    
__________________________________________________________________________________________________
conv2d_39 (Conv2D)              (None, 568, 568, 64) 36928       conv2d_38[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D)  (None, 284, 284, 64) 0           conv2d_39[0][0]                  
____________________________________________________________________________________________

Epoch 1/10


## 数据准备

In [6]:
import glob
import cv2
import os
import numpy as np

In [5]:
for item in glob.glob('./datasets/trainging/instances/' + '*.png'):
    img_name = os.path.split(item)[1].split('.p')[0]
    
    img = cv2.imread(item, 0)
    img = cv2.resize(img, (388, 388))
    
#     np.unique(img)
    img[img==38] = 1
    img[img==75] = 2
    
    os.remove(item)
    cv2.imwrite(item, img)

In [32]:
## 查看instances

import matplotlib.pyplot as plt

img = cv2.imread('./datasets/training/instances/label4.png', 0)
# img = cv2.imread('E:/download/dataset/palnts/label/0.jpg', 0)
np.unique(img)
# cv2.imshow(img)
# print(img)
# plt.imshow(img)
# plt.show()

array([0], dtype=uint8)

## 模型调用

In [None]:
model = tf.keras.models.load('./snapshots/model_unet.h5')
for i in range(10):
    img = cv2.resize(cv2.imread('datasets/validation/images/' + str(i) + '.jpg'), (572,572))/ 255.0
    img = np.expand_dims(img, 0)
    
    pred = model.predict(img)
    pred = np.argmax(pred[0], axis=-1)
    
    plt.imshow(pred)
    plt.show()