<a href="https://colab.research.google.com/github/kleelab-bch/MARS-Net/blob/master/run_MARS_Net_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPU Setting

Verify that GPU is selected. <br>
You can select GPU by
Edit -> Notebook settings -> Set Hardware Accelerator to GPU

In [None]:
!nvcc --version
!nvidia-smi

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2020 NVIDIA Corporation
Built on Wed_Jul_22_19:09:09_PDT_2020
Cuda compilation tools, release 11.0, V11.0.221
Build cuda_11.0_bu.TC445_37.28845127_0
Sun Apr 11 04:15:03 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.67       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P8     9W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+---------

# Example models and images

Upload Models

In [None]:
url = 'https://github.com/kleelab-bch/MARS-Net/blob/master/assets/040119_PtK1_S01_01_phase_ROI2/img_all/040119_PtK1_S01_01_phase_ROI2_1_001.png?raw=true'

from PIL import Image
import requests
from io import BytesIO
from matplotlib.pyplot import imshow
%matplotlib inline

response = requests.get(url)
img = Image.open(BytesIO(response.content))
imshow(img)


<_io.BytesIO object at 0x7fce7de57b90>


UnidentifiedImageError: ignored

## Model Definition

In [None]:
import gc
import numpy as np
import cv2
import os
import os.path
from tensorflow.keras import backend as K
from tensorflow.keras.utils import plot_model
from keras.utils.data_utils import get_file

done


In [None]:

def UNet(img_rows, img_cols, crop_margin, right_crop, bottom_crop, weights_path):
    inputs = Input((3, img_rows, img_cols))
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    conv1 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1)

    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(pool1)
    conv2 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv2)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2)

    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(pool2)
    conv3 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv3)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3)

    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(pool3)
    conv4 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv4)
    pool4 = MaxPooling2D(pool_size=(2, 2))(conv4)

    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(pool4)
    conv5 = Conv2D(1024, (3, 3), activation='relu', padding='same')(conv5)

    up6 = concatenate([UpSampling2D(size=(2, 2))(conv5), conv4], axis=1)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), conv3], axis=1)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), conv2], axis=1)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), conv1], axis=1)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
    if bottom_crop == 0:
        conv10 = Cropping2D(cropping=((crop_margin, crop_margin),(crop_margin, crop_margin)))(conv10) # ((top_crop, bottom_crop), (left_crop, right_crop)) for training
    else:
        conv10 = Cropping2D(cropping=((0, bottom_crop),(0, right_crop)))(conv10)  # remove reflected portion from the image for prediction

    model = Model(inputs=inputs, outputs=conv10)

    weights_path = get_file(
        'vgg19_weights.h5',
        'https://github.com/kleelab-bch/MARS-Net/raw/master/models/results/model_round1_specialist_unet/model_frame2_D_repeat0.hdf5')
    model.load_weights(weights_path, by_name=True)

    return model

In [None]:

def VGG19_dropout(img_rows, img_cols, crop_margin, right_crop, bottom_crop):
    inputs = Input(shape=[3, img_rows, img_cols])
    # Block 1
    x = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1')(inputs)
    block1_conv2 = Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')(block1_conv2)
    x = Dropout(0.25)(x)

    # Block 2
    x = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')(x)
    block2_conv2 = Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')(block2_conv2)
    x = Dropout(0.5)(x)

    # Block 3
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')(x)
    x = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')(x)
    block3_conv4 = Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv4')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')(block3_conv4)
    x = Dropout(0.5)(x)

    # Block 4
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')(x)
    block4_conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv4')(x)
    x = MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')(block4_conv4)
    x = Dropout(0.5)(x)

    # Block 5
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')(x)
    x = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv3')(x)
    block5_conv4 = Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv4')(x)


    # upsampling model
    up6 = concatenate([UpSampling2D(size=(2, 2))(block5_conv4), block4_conv4], axis=1)
    up6 = Dropout(0.5)(up6)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(up6)
    conv6 = Conv2D(512, (3, 3), activation='relu', padding='same')(conv6)

    up7 = concatenate([UpSampling2D(size=(2, 2))(conv6), block3_conv4], axis=1)
    up7 = Dropout(0.5)(up7)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(up7)
    conv7 = Conv2D(256, (3, 3), activation='relu', padding='same')(conv7)

    up8 = concatenate([UpSampling2D(size=(2, 2))(conv7), block2_conv2], axis=1)
    up8 = Dropout(0.5)(up8)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(up8)
    conv8 = Conv2D(128, (3, 3), activation='relu', padding='same')(conv8)

    up9 = concatenate([UpSampling2D(size=(2, 2))(conv8), block1_conv2], axis=1)
    up9 = Dropout(0.5)(up9)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(up9)
    conv9 = Conv2D(64, (3, 3), activation='relu', padding='same')(conv9)

    conv10 = Conv2D(1, (1, 1), activation='sigmoid')(conv9)
    if bottom_crop == 0:
        conv10 = Cropping2D(cropping=((crop_margin, crop_margin), (crop_margin, crop_margin)))(conv10)  # ((top_crop, bottom_crop), (left_crop, right_crop)) for training
    else:
        conv10 = Cropping2D(cropping=((0, bottom_crop), (0, right_crop)))(conv10)  # remove reflected portion from the image for prediction

    model = Model(inputs=inputs, outputs=conv10)

    # Load weights.
    weights_path = get_file(
        'vgg19_weights.h5',
        'https://github.com/kleelab-bch/MARS-Net/raw/master/models/results/model_round1_generalist_VGG19_dropout/model_frame2_D_repeat0.hdf5')
    model.load_weights(weights_path, by_name=True)

    return model

# Run trained model to segment phase contrast live cell movie

## Code for data loading and preprocessing

In [None]:
def to3channel(imgs):
    imgs_p = np.repeat(imgs, 3, axis=1)
    imgs_p = imgs_p.astype('float32')

    return imgs_p


def preprocess_output(imgs):
    imgs_p = imgs.astype('float32')
    imgs_p /= 255.  # scale masks to [0, 1]

    return imgs_p


def preprocess_input(imgs, std=None, mean=None):
    imgs_p = to3channel(imgs)
    if std is None:
        std = np.std(imgs_p)
    if mean is None:
        mean = np.mean(imgs_p)

    imgs_p -= mean
    imgs_p /= std

    return imgs_p


def normalize_input(imgs):
    imgs_p = to3channel(imgs)
    imgs_p /= 255.  # scale image to [0, 1]
    return imgs_p


def square(list):
    return [i ** 2 for i in list]


def get_rest_indices_from_all(all_indices, chosen_index):
    '''
    given a list of indices, and one chosen dataset index,
    get indices other than the chosen dataset index
    '''
    rest_indices = set(all_indices) - set([chosen_index])
    return list(rest_indices)


def loop_aggregate_std_mean(constants):
    for dataset_index in range(0, len(constants.dataset), 1):
        for frame in constants.frame_list:
            aggregate_std_mean(constants, dataset_index, frame)


def aggregate_std_mean_except(constants, dataset_index, frame, crop_path):
    print(constants.model_names[dataset_index], end=' ')
    print(frame)
    frame_mean_list = []
    frame_std_list = []
    rest_indices = get_rest_indices_from_all(range(len(constants.dataset)), dataset_index)
    for rest_index in rest_indices:
        std_mean = np.load(crop_path + constants.dataset[rest_index] + '_' + str(frame) + '_std_mean.npz')
        mean_value = std_mean['arr_0'].tolist()
        std_value = std_mean['arr_1'].tolist()

        frame_mean_list.append(mean_value)
        frame_std_list.append(std_value)
        print(constants.dataset[rest_index], mean_value, std_value)
    frame_mean_value = statistics.mean(frame_mean_list)
    frame_std_value = math.sqrt(statistics.mean(square(frame_std_list)))
    return frame_std_value, frame_mean_value


def aggregate_std_mean(dataset_names, excluded_dataset_name, frame, repeat_index, crop_path):
    # for self training five fold validation,
    # get average of std and mean from four movies to preprocess the test set images.
    print('aggregate_std_mean:' + str(frame))
    frame_mean_list = []
    frame_std_list = []

    for dataset_index in range(len(dataset_names)):
        if dataset_names[dataset_index] != excluded_dataset_name:
            save_suffix = '{}_frame{}_repeat{}'.format(dataset_names[dataset_index], str(frame), str(repeat_index))
            std_mean = np.load(crop_path + save_suffix + '_std_mean.npz')
            mean_value = std_mean['arr_0'].tolist()
            std_value = std_mean['arr_1'].tolist()

            frame_mean_list.append(mean_value)
            frame_std_list.append(std_value)
            print(dataset_names[dataset_index], mean_value, std_value)
    frame_mean_value = statistics.mean(frame_mean_list)
    frame_std_value = math.sqrt(statistics.mean(square(frame_std_list)))
    return frame_std_value, frame_mean_value


def get_std_mean_from_images(all_img_path, img_format):
    img_list = glob.glob(all_img_path + '*' + img_format)

    if len(img_list) == 0:  # skip this dataset
        print('img list is empty')
        exit()

    img = cv2.imread(img_list[0], cv2.IMREAD_GRAYSCALE)
    img_r, img_c = img.shape
    total_number = len(img_list)
    imgs = np.ndarray((total_number, img_r, img_c), dtype=np.uint8)
    for i in range(len(img_list)):
        img_path = img_list[i]
        img_name = img_path[len(all_img_path):]
        imgs[i] = cv2.imread(all_img_path + img_name, cv2.IMREAD_GRAYSCALE)

    avg = np.mean(imgs)
    std = np.std(imgs)
    return std, avg

class DataGenerator:
    def __init__(self, img_path, n_frames_train, input_size, output_size, strategy_type, img_format = '.png'):
        self.n_frames_train = n_frames_train
        self.img_path = img_path
        self.input_size = input_size
        self.output_size = output_size
        self.strategy_type = strategy_type
        self.img_format = img_format
        self.row, self.col = self.get_img_size()

    def get_expanded_whole_frames(self):
        img_list = self.find_namespace()
        imgs, image_rows, image_cols = self.get_expanded_images(self.img_path, img_list)

        # ------------------- pre-processing images -------------------
        # std and mean from test set images
        std_value, mean_value = get_std_mean_from_images(self.img_path, img_format=self.img_format)
        print(mean_value, std_value)

        # std and mean from training set images, Don't use it because it yields worse prediction results
        # crop_path, _ = constants.get_crop_path(model_name, dataset_name, str(frame), str(0), str(repeat_index))
        # std_value, mean_value = aggregate_std_mean(constants.dataset_names, dataset_name, frame, repeat_index, crop_path)

        imgs = imgs[:, np.newaxis, :, :]
        if 'no_preprocessing' in str(self.strategy_type):
            imgs = to3channel(imgs)
        elif 'normalize_clip' in str(self.strategy_type):
            imgs = normalize_clip_input(imgs)
        elif 'normalize' in str(self.strategy_type):
            imgs = normalize_input(imgs)
        elif 'heq' in str(self.strategy_type):
            imgs = heq_norm_input(imgs)
        else:
            imgs = preprocess_input(imgs, std_value, mean_value)
        
        return imgs, img_list, image_cols, image_rows, self.col, self.row

    def find_namespace(self):
        img_list = []
        img_path = self.img_path
        
        img_filename_list = os.listdir(img_path)
        for img_filename in img_filename_list:
            if os.path.isfile(img_path + img_filename) and img_filename.endswith(self.img_format):
                img_list.append(img_filename)
        return img_list

    def get_img_size(self):
        img_path = self.img_path
        img_list = self.find_namespace()
        for file in img_list:
            if os.path.isfile(img_path + file) and file.endswith(self.img_format):
                return cv2.imread(img_path + file , cv2.IMREAD_GRAYSCALE).shape
        print("ERROR: get_img_size")
        return -1, -1

    def get_expanded_images(self,img_path, namelist, ratio = 64.0):
        # expand test set images because our model only takes the image of size in ratio of 64
        total_number = len(namelist)
        imgs_row_exp = int(np.ceil(np.divide(self.row, ratio) ) * ratio)
        imgs_col_exp = int(np.ceil(np.divide(self.col, ratio) ) * ratio)

        # crop images that are not expanded enough
        # this is necessary to prevent boundary effect
        if (imgs_row_exp - self.row) < ratio:
            imgs_row_exp = imgs_row_exp + int(ratio)
            print('imgs_row_exp', imgs_row_exp)

        if (imgs_col_exp - self.col) < ratio:
            imgs_col_exp = imgs_col_exp + int(ratio)
            print('imgs_col_exp', imgs_col_exp)

        imgs = np.ndarray((total_number, int(imgs_row_exp), int(imgs_col_exp)), dtype=np.uint8) 
        i = 0
        for name in namelist:
            img = cv2.resize( cv2.imread(img_path + name, cv2.IMREAD_GRAYSCALE),(int(self.col), int(self.row)), interpolation = cv2.INTER_CUBIC)
            imgs[i] = cv2.copyMakeBorder(img, 0, imgs_row_exp - self.row, 0, imgs_col_exp - self.col, cv2.BORDER_REFLECT)
            i += 1
        return imgs, imgs_row_exp, imgs_col_exp


## Set the parameters

In [None]:
K.set_image_data_format('channels_first')
root_prediciton_path = "results/predict_wholeframe_round1_demo/"

frame = 2
model_index = 1
model_name = 'D'
dataset_folder = './assets/'
dataset_name = '040119_PtK1_S01_01_phase_ROI2'
img_folder = '/img_all/'
img_format = '.png'

In [None]:
VGG19_dropout(img_rows, img_cols, crop_margin, right_crop, bottom_crop):

## Segment the movie

In [None]:
img_path = 'https://github.com/kleelab-bch/MARS-Net/tree/master/assets/040119_PtK1_S01_01_phase_ROI2/img_all'
    
save_path = './{}/frame{}_{}/'.format(dataset_name, str(frame), model_name)

# ------------------- Data loading -------------------
strategy_type = 'VGG19_dropout'
prediction_data_generator = DataGenerator(img_path, frame, 128, 68, strategy_type, img_format=img_format)
imgs_val, namelist, image_cols, image_rows, orig_cols, orig_rows = prediction_data_generator.get_expanded_whole_frames()

print('img size:', image_rows, image_cols)
print('orig img size:', orig_rows, orig_cols)
print('imgs_val: ', imgs_val.dtype, imgs_val.shape)

# ------------------- Load trained Model -------------------

VGG19_dropout_model = VGG19_dropout(image_rows, image_cols, 0, image_cols-orig_cols, image_rows-orig_rows)
UNet_model = UNet(image_rows, image_cols, 0, image_cols-orig_cols, image_rows-orig_rows)

print('model layers: ', len(model.layers))


FileNotFoundError: ignored

In [None]:
# ------------------- predict segmented images and save them -------------------

segmented_output = model.predict(imgs_val, batch_size = 1, verbose = 1)
segmented_output = 255 * segmented_output  # 0=black color and 255=white color

for f in range(len(namelist)):
    out = segmented_output[f, 0, :, :]
    cv2.imwrite(save_path + namelist[f], out)
K.clear_session()