# Segmentation avec U-Net

Le résultat de la segmentation précédente par l'algorithme `watershed` est bon mais pas entièrement satisfaisant. Bien qu'une très grande majorité de leucocytes soit bien extrait des images, il y a quelques cas où la segmentation échoue notamment pour les cellules dont la surface est étendu et non convexe.

Dans la technique précédente, les maximums locaux déterminent le nombre de cellules dans l'image et nous pouvons nous retrouver avec 2 minimums locaux pour la même cellule : il en résulte donc l'extraction de 2 cellules au lieu d'une seule.

Nous allons utilisé le réseau U-Net pour réaliser cette segmentation donc dans un mode supervisé. Pour cela, nous devons disposer des images des cellules ainsi que d'un masque binaire représentant la zone d'intérêt (ROI) que nous voulons extraire. Le réseau va donc apprendre à repérer les ROI dans les images pour pouvoir les extraire plus tard par inférence.

Cependant nous ne disposons pas de ces masques dans nos données et il serait fastidieux de les créer manuellement. Nous allons donc créer ces masques automatiquement. Dans la tentative précédente, nous avons généré ces masques mais comme ils ne sont pas satisfaisants pour toutes les cellules, nous ne pouvons pas leur faire confiance.

Nous commençons par vérifier que nous avons bien accès au GPU de la machine :

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

import tensorflow as tf
tf.config.experimental.list_physical_devices()

[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

Nous créons le modèle U-Net :

![U-Net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

In [2]:
from keras.models import Model
from keras.layers import (Input, Conv2D, BatchNormalization, Activation,
                          MaxPool2D, Conv2DTranspose, Concatenate)

INPUT_SHAPE = (128, 128, 3)
BATCH_SIZE = 16


def conv_block(inp, num_filters):
    # Block 1
    x = Conv2D(num_filters, (3, 3), strides=1, padding='same')(inp)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    # Block 2
    x = Conv2D(num_filters, (3, 3), strides=1, padding='same')(x)
    x = BatchNormalization()(x)
    x = Activation('relu')(x)

    return x


def encoder_block(inp, num_filters):
    x = conv_block(inp, num_filters)
    p = MaxPool2D((2, 2))(x)
    return x, p


def decoder_block(inp, enc, num_filters):
    x = Conv2DTranspose(num_filters, (2, 2), strides=2, padding='same')(inp)
    x = Concatenate()([x, enc])
    x = conv_block(x, num_filters)
    return x


def create_model():
    # Input
    inputs = Input(INPUT_SHAPE)

    # Contraction part (top-down)
    e1, p1 = encoder_block(inputs, 64)
    e2, p2 = encoder_block(p1, 128)
    e3, p3 = encoder_block(p2, 256)
    e4, p4 = encoder_block(p3, 512)

    # Bottleneck
    b1 = conv_block(p4, 1024)

    # Expansion part (botom-up)
    d1 = decoder_block(b1, e4, 512)
    d2 = decoder_block(d1, e3, 256)
    d3 = decoder_block(d2, e2, 128)
    d4 = decoder_block(d3, e1, 64)

    # Output
    outputs = Conv2D(1, 1, padding="same", activation='sigmoid')(d4)

    model = Model(inputs, outputs, name="U-Net")
    return model


model = create_model()
model.summary()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

Model: "U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 128, 128, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 128, 128, 64  1792        ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 128, 128, 64  256        ['conv2d[0][0]']                 
 alization)                     )                                                             

 batch_normalization_9 (BatchNo  (None, 8, 8, 1024)  4096        ['conv2d_9[0][0]']               
 rmalization)                                                                                     
                                                                                                  
 activation_9 (Activation)      (None, 8, 8, 1024)   0           ['batch_normalization_9[0][0]']  
                                                                                                  
 conv2d_transpose (Conv2DTransp  (None, 16, 16, 512)  2097664    ['activation_9[0][0]']           
 ose)                                                                                             
                                                                                                  
 concatenate (Concatenate)      (None, 16, 16, 1024  0           ['conv2d_transpose[0][0]',       
                                )                                 'activation_7[0][0]']           
          

 batch_normalization_17 (BatchN  (None, 128, 128, 64  256        ['conv2d_17[0][0]']              
 ormalization)                  )                                                                 
                                                                                                  
 activation_17 (Activation)     (None, 128, 128, 64  0           ['batch_normalization_17[0][0]'] 
                                )                                                                 
                                                                                                  
 conv2d_18 (Conv2D)             (None, 128, 128, 1)  65          ['activation_17[0][0]']          
                                                                                                  
Total params: 31,055,297
Trainable params: 31,043,521
Non-trainable params: 11,776
__________________________________________________________________________________________________


Nous avons isolé 1600 images et masques de la segmentation à base de couleur. Sur ces 1600 images, nous avons rejeté 259 et conservé 1341 images. Nous créons nos jeux d'entraînement, de test et de validation.

In [3]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pathlib
import cv2
from sklearn.model_selection import train_test_split

DATA_DIR = pathlib.Path('/home/damien/yawbcc_data/barcelona_256_unet')

image_directory = DATA_DIR / 'images'
mask_directory = DATA_DIR / 'masks'

# Load datasets
image_dataset = pd.Series(sorted([filename for filename in image_directory.glob('*.jpg')]))
mask_dataset = pd.Series(sorted([filename for filename in mask_directory.glob('*.npy')]))

# Split datasets
X_train, X_test, y_train, y_test = train_test_split(image_dataset, mask_dataset, test_size=0.2, random_state=2022)
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.2, random_state=2022)

Nous créons une classe pour générer les batchs d'images (X) et de masques (y)

In [4]:
class WBCDataSequence(tf.keras.utils.Sequence):

    def __init__(self, image_set, mask_set, batch_size=32, image_size=(128, 128)):
        self.image_set = np.array(image_set)
        self.mask_set = np.array(mask_set)
        self.batch_size = batch_size
        self.image_size = image_size

    def __get_input(self, path):
        image = tf.keras.preprocessing.image.load_img(path, color_mode='rgb')
        image_arr = tf.keras.preprocessing.image.img_to_array(image)
        image_arr = cv2.cvtColor(image_arr, cv2.COLOR_RGB2XYZ)
        image_arr = tf.image.resize(image_arr, self.image_size).numpy()
        return image_arr / 255.

    def __get_output(self, path):
        mask_arr = (cv2.resize(np.load(path), self.image_size) > 0).astype(np.uint8)
        return mask_arr

    def __get_data(self, images, masks):
        image_batch = images
        mask_batch = masks
        image_batch = np.asarray([self.__get_input(path) for path in image_batch])
        mask_batch = np.asarray([self.__get_output(path) for path in mask_batch])
        return image_batch, mask_batch

    def __getitem__(self, index):
        images = self.image_set[index * self.batch_size:(index + 1) * self.batch_size]
        masks = self.mask_set[index * self.batch_size:(index + 1) * self.batch_size]
        images, masks = self.__get_data(images, masks)        
        return images, masks

    def __len__(self):
        return len(self.image_set) // self.batch_size + (len(self.image_set) % self.batch_size > 0)


train_ds = WBCDataSequence(X_train, y_train, image_size=INPUT_SHAPE[:2], batch_size=BATCH_SIZE)
valid_ds = WBCDataSequence(X_valid, y_valid, image_size=INPUT_SHAPE[:2], batch_size=BATCH_SIZE)
test_ds = WBCDataSequence(X_test, y_test, image_size=INPUT_SHAPE[:2], batch_size=BATCH_SIZE)

### Entraînement du modèle

In [None]:
history = model.fit(train_ds,
                    batch_size=BATCH_SIZE,
                    epochs=25,
                    validation_data=valid_ds,
                    shuffle=False)

model.evaluate(test_ds)
model.save(DATA_DIR / 'unet_128_5.hdf5')

In [5]:
from keras.models import load_model
model = load_model(DATA_DIR / 'unet_128_5.hdf5', compile=True)
model.evaluate(test_ds)



[0.008620369248092175, 0.9972246289253235]

### Interprétation des résultats

In [6]:
import itertools

preds = model.predict(test_ds)
batches = list(zip(*test_ds))
images = list(itertools.chain(*batches[0]))
masks = list(itertools.chain(*batches[1]))



In [7]:
pd.DataFrame({'X': [p.name for p in X_test],
              'y': [p.name for p in y_test]})

Unnamed: 0,X,y
0,BNE_727973.jpg,BNE_727973.npy
1,MO_993578.jpg,MO_993578.npy
2,MMY_873744.jpg,MMY_873744.npy
3,MO_902264.jpg,MO_902264.npy
4,PLATELET_530681.jpg,PLATELET_530681.npy
...,...,...
264,BNE_355910.jpg,BNE_355910.npy
265,BNE_344758.jpg,BNE_344758.npy
266,PLATELET_885144.jpg,PLATELET_885144.npy
267,EO_766383.jpg,EO_766383.npy


In [8]:
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score

THRESH = 0.5
metrics = []

for idx, path in enumerate(tqdm(X_test)):

    # Convert to integer images
    X = np.round(255 * images[idx]).astype(np.uint8)
    y = 255 * (masks[idx] >= THRESH).astype(np.uint8)
    p = 255 * (preds[idx].squeeze() >= THRESH).astype(np.uint8)

    # Postprocess mask
    contours, _ = cv2.findContours(p, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    p = cv2.drawContours(np.zeros(p.shape, np.uint8), contours, 0, 255, cv2.FILLED)

    # Metrics
    accuracy = accuracy_score(y.ravel(), p.ravel())  # real results / total
    precision = precision_score(y.ravel(), p.ravel(), pos_label=255)  # true positive / actual results (red)
    recall = recall_score(y.ravel(), p.ravel(), pos_label=255)  # true positive / predicted results (green)
    metrics.append([path.stem, str(path), accuracy, precision, recall])

    # Display figure
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))

    ax1.set_title(f'WBC: {path.stem}')
    ax1.imshow(X[..., 1], cmap='gray')
    ax1.axis('off')

    ax2.set_title('Actual mask')
    ax2.imshow(y, cmap='gray')
    ax2.axis('off')

    ax3.set_title('Predicted mask')
    ax3.imshow(p, cmap='gray')
    ax3.axis('off')

    ax3.scatter(*np.flip(np.where(y < p)), c='g', s=1, marker='s')  # False positive (Type I)
    ax3.scatter(*np.flip(np.where(y > p)), c='r', s=1, marker='s')  # False negative (Type II)

    plt.tight_layout()
    plt.savefig(DATA_DIR / 'output' / f'{idx:03}_{path.stem}.png')
    plt.close()

metrics = pd.DataFrame(metrics, columns=['name', 'path', 'accuracy', 'precision', 'recall'])
metrics.head()

  0%|          | 0/269 [00:00<?, ?it/s]

Unnamed: 0,name,path,accuracy,precision,recall
0,BNE_727973,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.997925,1.0,0.99126
1,MO_993578,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.998718,0.999761,0.995238
2,MMY_873744,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.998779,1.0,0.993213
3,MO_902264,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.99884,0.999508,0.995829
4,PLATELET_530681,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.99939,1.0,0.986413


In [10]:
metrics['f1-score'] = 2 * (metrics['precision'] * metrics['recall']) / (metrics['precision'] + metrics['recall'])
metrics.sort_values('f1-score', na_position='first').head(30)

Unnamed: 0,name,path,accuracy,precision,recall,f1-score
27,BA_882416,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.798157,0.0,0.0,
40,LY_153467,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.88501,0.0,0.0,
47,EO_284748,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.649658,0.0,0.0,
69,BA_700831,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.829956,0.0,0.0,
106,MY_91439,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.817505,0.0,0.0,
162,BA_338836,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.797913,0.0,0.0,
182,BNE_218516,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.798035,0.0,0.0,
205,BA_846077,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.810913,0.0,0.0,
261,ERB_878703,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.870605,0.0,0.0,
150,EO_321528,/home/damien/yawbcc_data/barcelona_256_unet/im...,0.66272,1.0,0.000181,0.000362


In [11]:
from yawbcc.datasets import load_barcelona_wbc

meta = load_barcelona_wbc()
image_list = [path.name for path in image_dataset]
df = meta[meta['image'].isin(image_list)]
df.head()

Unnamed: 0,image,group,label,width,height,path
6,MO_526259.jpg,MONOCYTE,MO,360,363,/home/damien/yawbcc_data/barcelona/monocyte/MO...
22,MO_134596.jpg,MONOCYTE,MO,360,363,/home/damien/yawbcc_data/barcelona/monocyte/MO...
42,MO_688574.jpg,MONOCYTE,MO,360,363,/home/damien/yawbcc_data/barcelona/monocyte/MO...
45,MO_61107.jpg,MONOCYTE,MO,360,363,/home/damien/yawbcc_data/barcelona/monocyte/MO...
57,MO_578103.jpg,MONOCYTE,MO,360,363,/home/damien/yawbcc_data/barcelona/monocyte/MO...


### Test

Nous pouvons également testé avec les images que nous avons rejetées manuellement pour l'entrainement du réseau U-Net suite à la segmentation basée sur la vision par ordinateur. Nous pourrons ainsi vérifier si U-Net fait mieux que la méthode colorimétrique.

In [12]:
image_directory2 = DATA_DIR / 'trash' / 'images'
mask_directory2 = DATA_DIR / 'trash' / 'masks'

X_test2 = sorted([filename for filename in image_directory2.glob('*.jpg')])
y_test2 = sorted([filename for filename in mask_directory2.glob('*.npy')])

test_ds2 = WBCDataSequence(X_test2, y_test2, image_size=INPUT_SHAPE[:2], batch_size=BATCH_SIZE)
model.evaluate(test_ds2)



[0.38747674226760864, 0.922502338886261]

In [13]:
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score

THRESH = 0.2
metrics = []

preds = model.predict(test_ds2)
batches = list(zip(*test_ds2))
images = list(itertools.chain(*batches[0]))
masks = list(itertools.chain(*batches[1]))

for idx, path in enumerate(tqdm(X_test2)):

    # Convert to integer images
    X = np.round(255 * images[idx]).astype(np.uint8)
    y = 255 * (masks[idx] >= THRESH).astype(np.uint8)
    p = 255 * (preds[idx].squeeze() >= THRESH).astype(np.uint8)

    # Postprocess mask
    #contours, _ = cv2.findContours(p, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
    #p = cv2.drawContours(np.zeros(p.shape, np.uint8), contours, 0, 255, cv2.FILLED)

    # Metrics
    accuracy = accuracy_score(y.ravel(), p.ravel())  # real results / total
    precision = precision_score(y.ravel(), p.ravel(), pos_label=255)  # true positive / actual results (red)
    recall = recall_score(y.ravel(), p.ravel(), pos_label=255)  # true positive / predicted results (green)
    metrics.append([path.stem, accuracy, precision, recall])

    # Display figure
    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(12, 4))

    ax1.set_title(f'WBC: {path.stem}')
    ax1.imshow(X[..., 1], cmap='gray')
    ax1.axis('off')

    ax2.set_title('Actual mask')
    ax2.imshow(y, cmap='gray')
    ax2.axis('off')

    ax3.set_title('Predicted mask')
    ax3.imshow(p, cmap='gray')
    ax3.axis('off')

    ax3.scatter(*np.flip(np.where(y < p)), c='g', s=1, marker='s')  # False positive (Type I)
    ax3.scatter(*np.flip(np.where(y > p)), c='r', s=1, marker='s')  # False negative (Type II)

    plt.tight_layout()
    plt.savefig(DATA_DIR / 'trash' / 'output' / f'{idx:03}_{path.stem}.png')
    plt.close()

metrics = pd.DataFrame(metrics, columns=['name', 'accuracy', 'precision', 'recall'])
metrics.head()



  0%|          | 0/259 [00:00<?, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))


Unnamed: 0,name,accuracy,precision,recall
0,BA_114542,0.99823,0.998619,0.991432
1,BA_172469,0.921997,0.906972,0.839854
2,BA_202924,0.961548,0.861731,0.997419
3,BA_243213,0.991699,0.974398,0.99063
4,BA_267693,0.94751,0.991311,0.824442
