In [35]:
from tensorflow.keras import Input, optimizers, datasets, Sequential,Model
from tensorflow.keras.layers import Conv2D,MaxPooling2D,Conv2DTranspose
from tensorflow.keras.losses import binary_crossentropy
from tensorflow import concat
import tensorflow as tf
from unet import get_unet
from resunet3 import ResUNet
from deeplabv3plus0 import DeeplabV3Plus
from resunetplusplus import ResUnetPlusPlus
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K
from unet_plus import unet_plus_plus
from tensorflow.keras.callbacks import ModelCheckpoint
#from generator_data1 import get_dataset
import cv2
import numpy as np
import datetime
import os
import random

gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

smooth = 1.
def dice_coef(y_true, y_pred):
   
    y_true_f = K.flatten(y_true)#no one_hot
    y_pred_f = K.flatten(y_pred)
    y_true_f=tf.cast(y_true_f,tf.float32)
    intersection = K.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth)

def miou(y_true, y_pred):
   
    y_true_f = K.flatten(y_true)#no one_hot
    y_pred_f = K.flatten(y_pred)
    y_true_f=tf.cast(y_true_f,tf.float32)
    intersection = K.sum(y_true_f * y_pred_f)
    return (intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) - intersection + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1-dice_coef(y_true, y_pred)

def binary_focal_loss(y_true,y_pred):
    gamma=2
    alpha=0.25
    alpha = tf.constant(alpha, dtype=tf.float32)
    gamma = tf.constant(gamma, dtype=tf.float32)

    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)
    y_true = tf.cast(y_true, tf.float32)
    alpha_t = y_true*alpha + (K.ones_like(y_true)-y_true)*(1-alpha)
    p_t = y_true*y_pred + (K.ones_like(y_true)-y_true)*(K.ones_like(y_true)-y_pred) + K.epsilon()
    focal_loss = - alpha_t * K.pow((K.ones_like(y_true)-p_t),gamma) * K.log(p_t)
    return K.mean(focal_loss)
    


def f1(y_true, y_pred):
    def recall(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
        recall = true_positives / (possible_positives + K.epsilon())
        return recall

    def precision(y_true, y_pred):
        true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
        predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
        precision = true_positives / (predicted_positives + K.epsilon())
        return precision
    precision = precision(y_true, y_pred)
    recall = recall(y_true, y_pred)
    return 2*((precision*recall)/(precision+recall+K.epsilon()))


def dice_binary_cross(y_true, y_pred):#后续更新
    return binary_crossentropy(y_true, y_pred)+dice_coef_loss(y_true, y_pred)


def dice_focal_loss(y_true,y_pred):
    return dice_coef_loss(y_true, y_pred)+binary_focal_loss(y_true,y_pred)


In [51]:
def get_train(name,img):
    if img:
        x=cv2.imread('/home/mist/train/img/'+name)
        x=cv2.resize(x,(512,512),interpolation=cv2.INTER_CUBIC)
        x=np.array(x,dtype=np.float32)
        #x-=x.mean()
        #x/=x.std()
        x=x/255.
    else:
        x=cv2.imread('/home/mist/train/mask/'+name)[:,:,2]
        x[x==128]=255
        x=cv2.resize(x,(512,512),interpolation=cv2.INTER_CUBIC).reshape([512,512,1])
        x=x/255.0
    return x

count=1
def generate_arrays_from_file(X):
    X=np.array(X)
    index=np.arange(len(X))
    random.shuffle(index)
    X=X[index]
    batch_size = 4
    global count
    while 1:
        if count>=len(X)/batch_size:
            count=1
        index =X[(count - 1) * batch_size:count * batch_size]
        batch_x = np.array([get_train(i,True) for i in index]).astype(np.float32)
        batch_y = np.array([get_train(i,False) for i in index]).astype(np.int32)
        # print(batch_x.shape,type(batch_x),batch_y.shape,type(batch_y))
        batch_x=tf.convert_to_tensor(batch_x)
        batch_y=tf.convert_to_tensor(batch_y)
        count = count + 1
        yield batch_x, batch_y

In [52]:
def get_test(name,img):
    if img:
        x=cv2.imread('/home/mist/test/img/'+name)
        x=cv2.resize(x,(512,512),interpolation=cv2.INTER_CUBIC)
        x=np.array(x,dtype=np.float32)
        #x-=x.mean()
        #x/=x.std()
        x=x/255.
    else:
        x=cv2.imread('/home/mist/test/mask/'+name)[:,:,2]
        x[x==128]=255
        x=cv2.resize(x,(512,512),interpolation=cv2.INTER_CUBIC).reshape([512,512,1])
        x=x/255.0
    return x

ct=1
def generate_from_file(X):
    X=np.array(X)
    batch_size = 4
    global ct
    while 1:
        if ct>=len(X)/batch_size:
            ct=1
        index =X[(ct - 1) * batch_size:ct * batch_size]
        batch_x = np.array([get_test(i,True) for i in index]).astype(np.float32)
        batch_y = np.array([get_test(i,False) for i in index]).astype(np.int32)
        # print(batch_x.shape,type(batch_x),batch_y.shape,type(batch_y))
        batch_x=tf.convert_to_tensor(batch_x)
        batch_y=tf.convert_to_tensor(batch_y)
        ct = ct + 1
        yield batch_x, batch_y

In [53]:
train_x=os.listdir('/home/mist/train/img/')
#train_x.remove('.ipynb_checkpoints')
test_x=os.listdir('/home/mist/test/img/')
test_x.remove('.ipynb_checkpoints')

In [54]:
#import os
#if not os.path.exists('/data/img'): os.makedirs('/data/img')
#if not os.path.exists('/data/mask'): os.makedirs('/data/mask')

In [55]:
#model=ResUNet()
#model=get_unet()
#model=unet_plus_plus()
model=DeeplabV3Plus()
#model=ResUnetPlusPlus()

In [None]:
batch_size=4
log_dir=os.path.join('logg',datetime.datetime.now().strftime("%Y%m%D"))
# 创建tensorboard回调
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir,embeddings_freq=1)
model_checkpoint = ModelCheckpoint('logg/unet_weights.h5', monitor='val_loss', save_best_only=True,mode='min')#val_dice_coef
# reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 
#                                             patience=3, 
#                                             verbose=1, 
#                                             factor=0.1,
#                                             epsilon=0.0001,
#                                             mode='max',
#                                             min_lr=0)#0.00001
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', 
                                            patience=3, 
                                            verbose=1, 
                                            factor=0.5,
                                            min_lr=0.00001)#0.00001
earlystop=tf.keras.callbacks.EarlyStopping(monitor='val_dice_coef', min_delta=0, patience=5, verbose=0, mode='max', restore_best_weights=True)
callbacks=[model_checkpoint,tensorboard_callback]

rms=optimizers.Adam(lr=1e-4)
model.compile(optimizer=rms,
              loss=[dice_focal_loss],#binary_focal_loss(alpha=.25, gamma=2),dice_binary_cross
              metrics=[dice_coef,miou,tf.keras.metrics.BinaryAccuracy()])#,tf.keras.metrics.MeanIoU(num_classes=2)
history=model.fit_generator(generate_arrays_from_file(train_x), steps_per_epoch=len(train_x)/batch_size, epochs=150,callbacks=callbacks,verbose=1,validation_data=generate_from_file(test_x),validation_steps=len(test_x)/batch_size)

Epoch 1/150
Epoch 2/150
Epoch 3/150
Epoch 4/150
Epoch 5/150
Epoch 6/150
Epoch 7/150
Epoch 8/150
Epoch 9/150
Epoch 10/150
Epoch 11/150
Epoch 12/150
Epoch 13/150
Epoch 14/150
Epoch 15/150
Epoch 16/150
Epoch 17/150
Epoch 18/150
Epoch 19/150
Epoch 20/150
Epoch 21/150
Epoch 22/150
Epoch 23/150
Epoch 24/150
Epoch 25/150
Epoch 26/150
Epoch 27/150
Epoch 28/150
Epoch 29/150
Epoch 30/150
Epoch 31/150
Epoch 32/150
Epoch 33/150
Epoch 34/150
Epoch 35/150
Epoch 36/150
Epoch 37/150
Epoch 38/150
Epoch 39/150
Epoch 40/150
Epoch 41/150
Epoch 42/150
Epoch 43/150
Epoch 44/150
Epoch 48/150
Epoch 49/150
Epoch 50/150
Epoch 51/150
Epoch 52/150
Epoch 53/150
Epoch 54/150
Epoch 55/150
Epoch 56/150
Epoch 57/150
Epoch 58/150
Epoch 59/150
Epoch 60/150
Epoch 61/150
Epoch 62/150
Epoch 63/150
Epoch 64/150
Epoch 65/150
Epoch 66/150
Epoch 67/150
Epoch 68/150
Epoch 69/150
Epoch 70/150
Epoch 71/150
Epoch 72/150
Epoch 73/150
Epoch 74/150
Epoch 75/150
Epoch 76/150
Epoch 77/150
Epoch 78/150
Epoch 79/150
Epoch 80/150
Epoch 81

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



Epoch 94/150
Epoch 95/150
Epoch 96/150
Epoch 97/150
Epoch 98/150
Epoch 99/150
Epoch 100/150
Epoch 101/150
Epoch 102/150
Epoch 103/150
Epoch 104/150
Epoch 105/150
Epoch 106/150
Epoch 107/150
Epoch 108/150
Epoch 109/150
Epoch 110/150
Epoch 111/150
Epoch 112/150
Epoch 113/150
Epoch 114/150
Epoch 115/150
Epoch 116/150
Epoch 117/150
Epoch 118/150
Epoch 119/150
Epoch 120/150
Epoch 121/150

In [58]:
current_dir=os.listdir('/home/mist/test/img')
current_dir.remove('.ipynb_checkpoints')
test_x=np.zeros([len(current_dir),512,512,3])
test_y=np.zeros([len(current_dir),512,512,1])
test_shape=np.zeros([len(current_dir),2])
i=0
for name in current_dir:
    x=cv2.imread('/home/mist/test/img/'+name)
    shapes=x.shape
    x=cv2.resize(x,(512,512),interpolation=cv2.INTER_CUBIC)
    x=np.array(x,dtype=np.float32)
    x=x/255.
    test_x[i]=x
    y=cv2.imread('/home/mist/test/mask/'+name)[:,:,2]
    y[y==128]=255
    y=cv2.resize(y,(512,512),interpolation=cv2.INTER_CUBIC).reshape([512,512,1])
    y=y/255.0
    test_y[i]=y
    test_shape[i]=np.array([shapes[0],shapes[1]])
    i+=1

In [59]:
test_x,test_y=np.array(test_x).astype(np.float32),np.array(test_y).astype(np.int32)

In [60]:
#model=ResUNet()
#model=get_unet()
#model=unet_plus_plus()
model=DeeplabV3Plus()
#model=ResUnetPlusPlus()
model.load_weights('/home/mist/logg/unet_weights.h5')

In [61]:
current_dir=os.listdir('/home/mist/test/img')
current_dir.remove('.ipynb_checkpoints')
for dir in current_dir:
    x=cv2.imread('/home/mist/test/img/'+dir)
    cv2.imwrite('/home/mist/pred/img/'+dir,x)
    y=cv2.imread('/home/mist/test/img/'+dir)
    cv2.imwrite('/home/mist/pred/mask/'+dir,y)
    shapes=x.shape
    x=cv2.resize(x,(512,512),interpolation=cv2.INTER_CUBIC)
    x=np.array(x,dtype=np.float32)
    x=x/255.
    x=x.reshape([1,512,512,3])
    pre=model.predict(x)
    pre=pre*255.
    x=x.reshape([512,512,3])
    x=x*255.
    x=cv2.resize(x,(shapes[1],shapes[0]),interpolation=cv2.INTER_CUBIC)
    pre=cv2.resize(pre.reshape([512,512,1]),(shapes[1],shapes[0]),interpolation=cv2.INTER_CUBIC)
    cv2.imwrite('/home/mist/pred/pre/'+dir,pre)
    
    #提取轮廓
    pre=cv2.imread('/home/mist/pred/pre/'+dir)
    imgray = cv2.cvtColor(pre, cv2.COLOR_BGR2GRAY)
    ret, thresh = cv2.threshold(imgray, 127, 255, 0)
    contours, hierarchy = cv2.findContours(thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    x=cv2.drawContours(x, contours, -1, (0,255,0), 3)
    cv2.imwrite('/home/mist/pred/pre/'+dir,x)
    

In [62]:
current_dir=os.listdir('/home/mist/test/img')
current_dir.remove('.ipynb_checkpoints')
for dir in current_dir:
    x=cv2.imread('/home/mist/test/img/'+dir)
    cv2.imwrite('/home/mist/pred/img/'+dir,x)
    y=cv2.imread('/home/mist/test/img/'+dir)
    cv2.imwrite('/home/mist/pred/mask/'+dir,y)
    shapes=x.shape
    x=cv2.resize(x,(512,512),interpolation=cv2.INTER_CUBIC)
    x=np.array(x,dtype=np.float32)
    x=x/255.
    x=x.reshape([1,512,512,3])
    pre=model.predict(x)
    pre=pre*255.
    x=x.reshape([512,512,3])
    x=x*255.
    x=cv2.resize(x,(shapes[1],shapes[0]),interpolation=cv2.INTER_CUBIC)
    pre=cv2.resize(pre.reshape([512,512,1]),(shapes[1],shapes[0]),interpolation=cv2.INTER_CUBIC)
    cv2.imwrite('/home/mist/pred/premask/'+dir,pre)

In [7]:
dirs=os.listdir('~/1/test/img/')
i=0
for dir in dirs:
    x=cv2.imread('~/1/test/img/'+dir)
    y=cv2.imread('~/1/test/mask/'+dir)
    cv2.imwrite('/home/mist/test/img/'+dir,x)
    cv2.imwrite('/home/mist/test/mask/'+dir,y)
    i+=1
    print(i)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
