# Libs

In [None]:
%%time

import os
import imageio

import pandas as pd
import numpy as np

import tensorflow as tf
from tensorflow.keras import layers

import matplotlib.pyplot as plt
from IPython import display

# Base variables

Paths to files

In [None]:
TRAIN_DIR = '../input/sartorius-cell-instance-segmentation/train/'
TRAIN_CSV = '../input/sartorius-cell-instance-segmentation/train.csv'

TEST_DIR = '../input/sartorius-cell-instance-segmentation/test/'

variables from _train_df_

In [None]:
train_df = pd.read_csv(TRAIN_CSV)

TRAIN_IDS = train_df['id'].unique()
WIDTH, HEIGHT = train_df.loc[0, ['width', 'height']]

# Get images

In [None]:
def get_mask(df: str, idx: str) -> np.ndarray:
    """get mask for image from dataframe
    
    params: 
        df:  dataframe from which the mask will be taken
        idx: mask index from dataframe"""
    
    parts = df[df['id'] == idx]
    
    mask = np.zeros(WIDTH * HEIGHT)
    
    for part in parts['annotation']:
        part = part.split()
        part = np.array(part, dtype=np.uint)
        part = part.reshape(-1, 2)
        part[:, 0] -=1
        
        for i in range(len(part)):
            
            part_row = part[i]
            part_row = np.arange(part_row[0], part_row[0]+part_row[1])
            part_row = part_row.astype(np.uint)
            mask[part_row] = 1
            
    mask = mask.reshape(HEIGHT, WIDTH)
            
    return mask

def get_image(path: str) -> np.ndarray:
    """get image from the paved path
    
    params:
        path: path to image"""
    
    img = imageio.imread(path)
    img = np.array(img)
    
    return img

In [None]:
%%time

train_imgs = np.array([get_image(TRAIN_DIR + idx + '.png') for idx in TRAIN_IDS])
train_masks = np.array([get_mask(train_df, idx) for idx in TRAIN_IDS])

print('train shape:', train_imgs.shape)
print('mask shape:', train_masks.shape)

plt.figure(figsize=(10, 10))
plt.imshow(train_imgs[0], cmap="binary")
plt.imshow(train_masks[0], cmap="gnuplot", alpha=0.3)

Total 606 images in 704х520 size.<br>
Let's break these images into parts, this will increase the number of examples, and will also allow using less video memory when training the model.

In [None]:
def split_img(img: np.ndarray, 
              size: int = 128, 
              excess: bool = True) -> np.ndarray:
    """get image and split it on size
    
    params:
        img:  original image
        size: size of results"""
    
    h_offsets = HEIGHT // size
    w_offsets = WIDTH // size
    
    if excess:
        h_excess = HEIGHT % size
        w_excess = WIDTH % size
    
        h_offsets += 1 if h_excess else 0
        w_offsets += 1 if h_excess else 0
    
    arr_imgs = []
    for i in range(h_offsets):
        for j in range(w_offsets):
            
            w_start, w_end = j*size, (j+1)*size
            
            if j == w_offsets-1 and excess:
                w_start += w_excess - size
                w_end += w_excess - size
            
            h_start, h_end = i*size, (i+1)*size
            
            if i == h_offsets-1 and excess:
                h_start += h_excess - size
                h_end += h_excess - size
            
            piese = img[h_start: h_end,
                        w_start: w_end]
        
            arr_imgs.append(piese)
    
    arr_imgs = np.array(arr_imgs)[..., None]
    
    return arr_imgs

def view_splitimg_imgs(pieces: np.ndarray, 
                       masks: np.ndarray = np.array([]), 
                       excess: bool = True) -> None:
    """Show original image from its parts in the form of a grid
    
    params:
        pieces: array of images that are part of a original image
        masks:  array of masks that are part of a original mask
        excess: allocate space for excesses"""
    
    size = pieces.shape[1] # 128
    
    cols = WIDTH // size
    rows = HEIGHT // size
    
    if excess:
        w_excess = WIDTH % size
        h_excess = HEIGHT % size
    
        cols += 1 if w_excess else 0
        rows += 1 if h_excess else 0

    fig, ax = plt.subplots(rows, cols, figsize=(12, 10))

    for i in range(rows):
        for j in range(cols):

            idx = j + i*cols

            ax[i, j].imshow(pieces[idx], cmap="binary")
            if len(masks) > 0:
                ax[i, j].imshow(masks[idx], cmap="gnuplot", alpha=0.3)
            ax[i, j].axis("off")

    plt.subplots_adjust(wspace=0.01, hspace=0.01)

In [None]:
%%time

train_imgs = np.concatenate([split_img(train_imgs[i], excess=False) for i in range(len(train_imgs))], axis=0)
train_masks = np.concatenate([split_img(train_masks[i], excess=False) for i in range(len(train_masks))], axis=0)

print(train_imgs.shape, train_masks.shape)

view_splitimg_imgs(train_imgs, train_masks, excess=False)

# Prepocess data

split on train & test

In [None]:
split_by = len(train_imgs) // 10

train_imgs, valid_imgs = train_imgs[split_by:], train_imgs[:split_by]
train_masks, valid_masks = train_masks[split_by:], train_masks[:split_by]

print('train_imgs shape:', train_imgs.shape)
print('train_masks shape:', train_masks.shape)

print('valid_imgs shape:', valid_imgs.shape)
print('valid_masks shape:', valid_masks.shape)

arrays in datasets

In [None]:
%%time

train_ds = tf.data.Dataset.from_tensor_slices((train_imgs, train_masks))
valid_ds = tf.data.Dataset.from_tensor_slices((valid_imgs, valid_masks))

preprocess datasets

In [None]:
BATCH_SIZE = 1

def prep_data(img: np.ndarray, 
              mask: np.ndarray) -> tuple:
    """normalize pixel array -> retype pixel array
    
    params:
        img:  image array
        mask: mask array"""
    
    img /= 255
    
    img = tf.cast(img, tf.float32)
    mask = tf.cast(mask, tf.float32)
    
    return img, mask

def pipline(ds):
    """cache -> suffle -> preprocess -> split on batchs -> prefetch
    
    params:
        ds: dataset to pipline"""
    
    ds = ds.cache()
    ds = ds.shuffle(1000)
    ds = ds.map(prep_data)
    ds = ds.batch(BATCH_SIZE)
    ds = ds.prefetch(tf.data.AUTOTUNE)
    
    return ds

train_ds = pipline(train_ds)
valid_ds = valid_ds.map(prep_data).batch(BATCH_SIZE)

# Model

In [None]:
class UNet(tf.keras.Model):
    
    def __init__(self):
        super().__init__()
        
        # encoder
        self.conv_enc64_1 = layers.Conv2D(64, (3, 3), padding='same', activation='relu')
        self.conv_enc64_2 = layers.Conv2D(64, (3, 3), padding='same', activation='relu')
        
        self.conv_enc128_1 = layers.Conv2D(128, (3, 3), padding='same', activation='relu')
        self.conv_enc128_2 = layers.Conv2D(128, (3, 3), padding='same', activation='relu')
        
        self.conv_enc256_1 = layers.Conv2D(256, (3, 3), padding='same', activation='relu')
        self.conv_enc256_2 = layers.Conv2D(256, (3, 3), padding='same', activation='relu')
        
        self.conv_enc512_1 = layers.Conv2D(512, (3, 3), padding='same', activation='relu')
        self.conv_enc512_2 = layers.Conv2D(512, (3, 3), padding='same', activation='relu')
        
        self.maxpool = layers.MaxPooling2D((2, 2), (2, 2), padding='same')
        
        # decoder
        self.conv_dec1024_1 = layers.Conv2D(1024, (3, 3), padding='same', activation='relu')
        self.conv_dec1024_2 = layers.Conv2D(1024, (3, 3), padding='same', activation='relu')
        self.conv_transp_512 = layers.Conv2DTranspose(512, (3, 3), strides=(2, 2), padding='same', activation='relu')
        
        self.conv_dec512_1 = layers.Conv2D(512, (3, 3), padding='same', activation='relu')
        self.conv_dec512_2 = layers.Conv2D(512, (3, 3), padding='same', activation='relu')
        self.conv_transp_256 = layers.Conv2DTranspose(256, (3, 3), strides=(2, 2), padding='same', activation='relu')
        
        self.conv_dec256_1 = layers.Conv2D(256, (3, 3), padding='same', activation='relu')
        self.conv_dec256_2 = layers.Conv2D(256, (3, 3), padding='same', activation='relu')
        self.conv_transp_128 = layers.Conv2DTranspose(128, (3, 3), strides=(2, 2), padding='same', activation='relu')
        
        self.conv_dec128_1 = layers.Conv2D(128, (3, 3), padding='same', activation='relu')
        self.conv_dec128_2 = layers.Conv2D(128, (3, 3), padding='same', activation='relu')
        self.conv_transp_64 = layers.Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu')
        
        self.conv_dec64_1 = layers.Conv2D(64, (3, 3), padding='same', activation='relu')
        self.conv_dec64_2 = layers.Conv2D(64, (3, 3), padding='same', activation='relu')
        
        self.conv_final = layers.Conv2D(1, (3, 3), padding='same', activation='sigmoid')
        
    def call(self, x):
        
        # encoder
        out = self.conv_enc64_1(x)
        out1 = self.conv_enc64_2(out)
        out = self.maxpool(out1)
        
        out = self.conv_enc128_1(out)
        out2 = self.conv_enc128_2(out)
        out = self.maxpool(out2)
        
        out = self.conv_enc256_1(out)
        out3 = self.conv_enc256_2(out)
        out = self.maxpool(out3)
        
        out = self.conv_enc512_1(out)
        out4 = self.conv_enc512_2(out)
        out = self.maxpool(out4)
        
        # decoder
        out = self.conv_dec1024_1(out)
        out = self.conv_dec1024_2(out)
        out = self.conv_transp_512(out)
        out = tf.concat([out4, out], axis=3)
        
        out = self.conv_dec512_1(out)
        out = self.conv_dec512_2(out)
        out = self.conv_transp_256(out)
        out = tf.concat([out3, out], axis=3)
        
        out = self.conv_dec256_1(out)
        out = self.conv_dec256_2(out)
        out = self.conv_transp_128(out)
        out = tf.concat([out2, out], axis=3)
        
        out = self.conv_dec128_1(out)
        out = self.conv_dec128_2(out)
        out = self.conv_transp_64(out)
        out = tf.concat([out1, out], axis=3)
        
        out = self.conv_dec64_1(out)
        out = self.conv_dec64_2(out)
        
        out = self.conv_final(out)
        
        return out

## init training objects

In [None]:
model = UNet()

optimizer = tf.keras.optimizers.Adam()
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

loss_obj = tf.keras.losses.BinaryCrossentropy()
loss = tf.keras.metrics.Mean()

## checkpoint

In [None]:
checkpoint_path = "./checkpoints/trainUnet"

ckpt = tf.train.Checkpoint(
    model=model,
    optimizer=optimizer
)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=1)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!')

## variables for training

In [None]:
EPOCHS = 10

train_loss, train_accuracy = [], []
valid_loss, valid_accuracy = [], []

valid_batch_img, valid_batch_mask = next(valid_ds.as_numpy_iterator())

## training

In [None]:
def plot_process(train_values, valid_values, figsize=(16, 4)):
    """plot loss or accuracy
    
    params:
        train_values: array for train
        valid_values: array for valid
        figsize:      size of plots"""
    
    plt.figure(figsize=figsize)
    plt.plot(train_values, label='train')
    plt.plot(valid_values, label='valid')
    plt.legend()
    plt.show()
    
def view_masks(img, mask, pred_mask, figsize=(14, 8)):
    """view result of training
    
    params:
        img:       array of original image
        mask:      array of original mask
        pred_mask: array of predicted mask
        figsize:   size of image"""
    
    fig, ax = plt.subplots(1, 2, figsize=figsize)

    ax[0].set_title('true')
    ax[0].imshow(img, cmap='binary')
    ax[0].imshow(mask, cmap='gnuplot', alpha=0.3)

    ax[1].set_title('pred')
    ax[1].imshow(img, cmap='binary')
    ax[1].imshow(pred_mask, cmap='gnuplot', alpha=0.3)
    
    plt.show()

In [None]:
%%time

for epoch in range(EPOCHS):
    print(epoch)
    
    for i, (train_batch_img, train_batch_mask) in enumerate(train_ds):

        with tf.GradientTape() as tape:
            train_batch_mask_pred = model(train_batch_img)
            train_loss_value = loss_obj(train_batch_mask, train_batch_mask_pred)

        gradients = tape.gradient(train_loss_value, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        loss(train_loss_value)
        accuracy(train_batch_mask, train_batch_mask_pred)
        
        if i % 100 == 0:
            valid_batch_mask_pred = model(valid_batch_img)

            valid_loss_value = loss_obj(valid_batch_mask, valid_batch_mask_pred)

            valid_accuracy_value = accuracy(valid_batch_mask, valid_batch_mask_pred)
            train_accuracy_value = accuracy(train_batch_mask, train_batch_mask_pred)

            print(f'{i}\ttrain loss: {train_loss_value:.4f} | train accuracy: {train_accuracy_value:.4f} |',
                  f'valid loss: {valid_loss_value:.4f} | valid accuracy: {valid_accuracy_value:.4f}')

            train_loss.append(train_loss_value)
            valid_loss.append(valid_loss_value)

            train_accuracy.append(train_accuracy_value)
            valid_accuracy.append(valid_accuracy_value)
            
        if i % 1000 == 0:
            display.clear_output(wait=False)
            
            plot_process(train_accuracy, valid_accuracy, figsize=(16, 3))
            plot_process(train_loss, valid_loss, figsize=(16, 3))
            
            view_masks(
                img=valid_batch_img[0, ..., 0],
                mask=valid_batch_mask[0, ...],
                pred_mask=(valid_batch_mask_pred[0, ... ,0] > 0.5).numpy().astype(np.float32), 
                figsize=(12, 8)
            )
            
            ckpt_manager.save()
            print("checkpoint was saved")

# Test images

## Get test images

In [None]:
def prediction_test(img: np.ndarray, 
                    batch: int = 4) -> np.ndarray:
    
    """predicts masks for image pieces
    
    params:
        img:   array of image pieces
        batch: number of images in batch"""
    
    batch_excess = len(img) % batch
    
    batch_count = len(img) // batch
    batch_count += 1 if batch_excess else 0
    
    test_mask_preds = []
    
    for i in range(batch_count):
        
        test_batch_mask_pred = model(img[i*batch: (i+1)*batch] / 255)
        test_batch_mask_pred = test_batch_mask_pred.numpy() > 0.5
        test_batch_mask_pred = test_batch_mask_pred.astype(np.uint)
        
        for mask in test_batch_mask_pred:
            test_mask_preds.append(mask)
            
    test_mask_preds = np.array(test_mask_preds)
    
    return test_mask_preds

View test images

In [None]:
test_imgs = os.listdir(TEST_DIR)

test_img0 = np.array([get_image(TEST_DIR + test_imgs[0])])
test_img1 = np.array([get_image(TEST_DIR + test_imgs[1])])
test_img2 = np.array([get_image(TEST_DIR + test_imgs[2])])

print('test_imgs:', test_img0.shape)

fig, ax = plt.subplots(1, 3, figsize=(18, 16))

def add_imshow(img: np.ndarray, idx: int, name: str) -> None:
    ax[idx].set_title(name)
    ax[idx].imshow(img, cmap="binary")
    ax[idx].axis("off")
    
add_imshow(test_img0[0], 0, test_imgs[0])
add_imshow(test_img1[0], 1, test_imgs[1])
add_imshow(test_img2[0], 2, test_imgs[2])

plt.subplots_adjust(wspace=0.05, hspace=0.01)

## Preprocess test

Split original image on pieces

In [None]:
test_img0 = split_img(test_img0[0], excess=True)
test_img1 = split_img(test_img1[0], excess=True)
test_img2 = split_img(test_img2[0], excess=True)

print(test_img0.shape)

view_splitimg_imgs(test_img0, excess=True)

## Prediction test mask

In [None]:
test_mask_preds0 = prediction_test(test_img0)

print(test_mask_preds0.shape)

view_splitimg_imgs(test_img0, test_mask_preds0)

In [None]:
test_mask_preds1 = prediction_test(test_img1)

print(test_mask_preds1.shape)

view_splitimg_imgs(test_img1, test_mask_preds1)

In [None]:
test_mask_preds2 = prediction_test(test_img2)

print(test_mask_preds2.shape)

view_splitimg_imgs(test_img2, test_mask_preds2)

## Concatenate results

In [None]:
def concat_img(pieces: np.ndarray):
    """Concatinate parts of image into original image
    
    params:
        pieces: array of images that are part of a original image"""
    
    size = pieces.shape[1] # 128
    
    w_excess = WIDTH % size
    h_excess = HEIGHT % size
    
    cols = WIDTH // size
    rows = HEIGHT // size
    
    cols += 1 if w_excess else 0
    rows += 1 if h_excess else 0
    
    result = []
    for i in range(rows):
    
        row = pieces[i*cols:(i+1)*cols]

        if w_excess:
            half1 = np.concatenate(row[:cols-1], axis=1)
            half2 = row[cols-1, :, size-w_excess:]
            row = np.concatenate([half1, half2], axis=1)
        
        if h_excess and i == rows-1:
            row = row[size-h_excess:]
        
        result.append(row)
        
    result = np.concatenate(result)
    
    return result

In [None]:
test_img0 = concat_img(test_img0[..., 0])
test_mask_preds0 = concat_img(test_mask_preds0[..., 0])

print(test_img0.shape, test_mask_preds0.shape)

plt.figure(figsize=(13, 11))
plt.imshow(test_img0, cmap="gray")
plt.imshow(test_mask_preds0, cmap="gnuplot", alpha=0.3)

In [None]:
test_img1 = concat_img(test_img1[..., 0])
test_mask_preds1 = concat_img(test_mask_preds1[..., 0])

print(test_img1.shape, test_mask_preds1.shape)

plt.figure(figsize=(13, 11))
plt.imshow(test_img1, cmap="gray")
plt.imshow(test_mask_preds1, cmap="gnuplot", alpha=0.3)

In [None]:
test_img2 = concat_img(test_img2[..., 0])
test_mask_preds2 = concat_img(test_mask_preds2[..., 0])

print(test_img2.shape, test_mask_preds2.shape)

plt.figure(figsize=(13, 11))
plt.imshow(test_img2, cmap="gray")
plt.imshow(test_mask_preds2, cmap="gnuplot", alpha=0.3)

## Mask encoding

In [None]:
def recoding_mask(mask: np.ndarray) -> np.ndarray:
    """params:
        mask: original mask"""
    
    mask_idx = mask.reshape(1, -1)[0]
    mask_idx = np.nonzero(mask_idx)[0]
    
    result = []
    
    idx = 0
    while idx < len(mask_idx):

        num = mask_idx[idx]
        lenght = 1

        next_num = num + 1
        idx += 1

        while idx < len(mask_idx) and next_num == mask_idx[idx]:
            lenght += 1

            next_num += 1
            idx += 1

        result.append(num)
        result.append(lenght)
        
    result = np.array(result)
    
    return result

In [None]:
test_mask_preds0 = recoding_mask(test_mask_preds0)
test_mask_preds1 = recoding_mask(test_mask_preds1)
test_mask_preds2 = recoding_mask(test_mask_preds2)

# Submition

In [None]:
submition = pd.DataFrame(columns=["id", "predicted"])
submition.loc[0] = [test_imgs[0][:-4], " ".join(test_mask_preds0.astype(str))]
submition.loc[1] = [test_imgs[1][:-4], " ".join(test_mask_preds1.astype(str))]
submition.loc[2] = [test_imgs[2][:-4], " ".join(test_mask_preds2.astype(str))]
submition

In [None]:
submition.to_csv("submission.csv", index=False)