In [1]:
import os
from itertools import chain
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import cifar10
from tensorflow.keras import utils
from tensorflow.python.keras.utils.np_utils import to_categorical
from tensorflow.keras import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv2D, Activation, MaxPooling2D, GlobalAveragePooling2D, LSTM, TimeDistributed, Dropout, Dense, BatchNormalization
from tensorflow.keras.preprocessing.image import load_img, img_to_array, ImageDataGenerator, array_to_img
import numpy as np
from PIL import Image
from tqdm import tqdm
import pandas as pd
import math
import matplotlib.pyplot as plt
import shutil

from data import Data

In [2]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

In [3]:
LABEL_INDEX = {
    'ap': 0,
    'bs': 1,
    'mid': 2,
    'oap': 3,
    'obs': 4,
}

In [4]:
keras_app = tf.keras.applications.mobilenet
keras_model = tf.keras.applications.mobilenet.MobileNet
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    rescale=1,
    fill_mode="nearest",
    preprocessing_function=keras_app.preprocess_input)
datagen = ImageDataGenerator(preprocessing_function=keras_app.preprocess_input)

In [5]:
class PhaseDataGenerator(keras.utils.Sequence):
    def __init__(self, data: Data, datasets=None, batch_size=32, target_size=(224, 224), 
                 slices_per_sample=25, shuffle=True, image_data_generator=None):
        self.data = data
        self.datasets = datasets
        self.batch_size = batch_size
        self.target_size=  target_size
        self.slices_per_sample = slices_per_sample
        self.shuffle = shuffle
        self.datagen = image_data_generator
        
        self.n_classes = 5
        self.label_indices = LABEL_INDEX

        self.samples = dict()
        self.max_slices = 1
        
        if datasets is None:
            datasets = list(data.data.keys())
        if isinstance(datasets, str):
            datasets = [datasets]
        
        # All plural variables are dicts
        for dataset in datasets:
            for patient, phases in data.data[dataset].items():
                for phase, slices in phases.items():
                    key = "{dataset}_{patient:06d}_{slice:02d}".format(dataset=dataset, patient=patient, slice=phase)
                    self.samples[key] = slices
                    if len(slices) > self.max_slices:
                        self.max_slices = len(slices)
        
        if slices_per_sample < self.max_slices:
            raise ValueError("There are some samples that contain more than {} slices ({})".format(
                slices_per_sample, self.max_slices))
        
        unlabeled = []
        self.images_by_label = [0] * len(self.label_indices)
        for phases in self.samples.values():
            for phase in phases.values():
                label = self.data.labels.get(phase, None)
                if label is None:
                    unlabeled.append(phase)
                else:
                    index = self.label_indices[label]
                    self.images_by_label[index] += 1
        if unlabeled:
            raise ValueError("{} unlabeled slice(s): {}...".format(len(unlabeled), str(unlabeled)[:200]))
        
        self.n_batches = math.ceil(len(self.samples) / batch_size)
        
        self._refresh_sample_keys()
    
    def _refresh_sample_keys(self):
        self.sample_keys = sorted(list(self.samples.keys()))
        if self.shuffle:
            np.random.shuffle(self.sample_keys)
        
    def _get_sample_key_batch(self, index):
        return self.sample_keys[index * self.batch_size:(index + 1) * self.batch_size]

    def _load_and_preprocess_image(self, path, standardize=False):
        img = Image.open(path)
        img = img.resize(self.target_size, Image.NEAREST)
        #img = load_img(path, color_mode="rgb", target_size=self.target_size)
        x = img_to_array(img, data_format="channels_last")
        params = datagen.get_random_transform(x.shape)
        x = x / 65536 * 255
        x = datagen.apply_transform(x, params)
        if standardize:
            x = datagen.standardize(x)
        return x
    
    def get_class_weight(self):
        counts = np.array(self.images_by_label)
        weights = counts.sum() / counts / len(counts)
        weights = { i: weight for i, weight in enumerate(weights.tolist())}
        return weights
        
    def __getitem__(self, index):
        """Get `index`th batch
        """
        keys = self._get_sample_key_batch(index)
        batch_size = len(keys)
        x = np.zeros((batch_size, self.slices_per_sample) + self.target_size + (3,))
        y = np.zeros((batch_size, self.slices_per_sample) + (self.n_classes,))
        for i, key in enumerate(keys):
            items = sorted(list(self.samples[key].items()))
            for j, (slice_index, sid) in enumerate(items):
                path = self.data.paths[sid]
                image = self._load_and_preprocess_image(path, standardize=True)
                x[i][j] = image
                label = self.data.labels[sid]
                label_index = self.label_indices[label]
                y[i][j][label_index] = 1
        
        return x, y
    
    def __len__(self):
        return self.n_batches
    
    def on_epoch_end(self):
        self._refresh_sample_keys()

In [6]:
class SliceDataGenerator(keras.utils.Sequence):
    def __init__(self, data: Data, datasets=None, batch_size=32,
                 target_size=(224, 224), shuffle=True, image_data_generator=None):
        self.data = data
        self.datasets = datasets
        self.batch_size = batch_size
        self.target_size=  target_size
        self.shuffle = shuffle
        self.datagen = image_data_generator
        
        self.n_classes = 5
        self.label_indices = LABEL_INDEX

        self.slices = list()

        if datasets is None:
            datasets = list(data.data.keys())
        if isinstance(datasets, str):
            datasets = [datasets]
        
        # All plural variables are dicts
        for dataset in datasets:
            for patient, phases in data.data[dataset].items():
                for phase, slices in phases.items():
                    ordered = sorted(list(slices.items()))
                    ordered_slices = [item[1] for item in ordered]
                    self.slices.extend(ordered_slices)
        self.slices.sort()

        unlabeled = []
        self.images_by_label = [0] * len(self.label_indices)
        for slice in self.slices:
            label = self.data.labels.get(slice, None)
            if label is None:
                unlabeled.append(slice)
            else:
                index = self.label_indices[label]
                self.images_by_label[index] += 1
        if unlabeled:
            raise ValueError("{} unlabeled slice(s): {}...".format(len(unlabeled), str(unlabeled)[:200]))
        
        self.n_batches = math.ceil(len(self.slices) / batch_size)
        self._refresh_slice_order()
    
    def _refresh_slice_order(self):
        self.slices = sorted(list(self.slices))
        if self.shuffle:
            np.random.shuffle(self.slices)
        
    def _get_slice_batch(self, index):
        return self.slices[index * self.batch_size:(index + 1) * self.batch_size]

    def _load_and_preprocess_image(self, path, standardize=False):
        img = Image.open(path)
        img = img.resize(self.target_size, Image.NEAREST)
        #img = load_img(path, color_mode="rgb", target_size=self.target_size)
        x = img_to_array(img, data_format="channels_last")
        params = datagen.get_random_transform(x.shape)
        x = x / 65536 * 255
        x = datagen.apply_transform(x, params)
        if standardize:
            x = datagen.standardize(x)
        return x
    
    def get_class_weight(self):
        counts = np.array(self.images_by_label)
        weights = counts.sum() / counts / len(counts)
        weights = { i: weight for i, weight in enumerate(weights.tolist())}
        return weights
        
    def __getitem__(self, index):
        """Get `index`th batch
        """
        slices = self._get_slice_batch(index)
        batch_size = len(slices)
        x = np.zeros((batch_size, ) + self.target_size + (3,))
        y = np.zeros((batch_size, ) + (self.n_classes,))
        for i, sid in enumerate(slices):
            path = self.data.paths[sid]
            image = self._load_and_preprocess_image(path, standardize=True)
            x[i] = image
            label = self.data.labels[sid]
            label_index = self.label_indices[label]
            y[i][label_index] = 1

        assert(y.sum() == len(slices))
        
        return x, y
    
    def __len__(self):
        return self.n_batches
    
    def on_epoch_end(self):
        self._refresh_slice_order()

In [23]:
class PhasewiseAccuracy(tf.keras.metrics.Metric):
    def __init__(self, name='phasewise_accuracy', **kwargs):
        super().__init__(name=name, **kwargs)
        self.reset_states()

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.cast(y_true, tf.bool)
        y_pred = tf.cast(y_pred, tf.bool)

        for true, pred in zip(y_true, y_pred):  # batch
            correct_phase = True
            for t, p in zip(true, pred):  # phase
                if t.sum() == 0:
                    continue
                self.total_slices += 1
                if t.argmax() == p.argmax():
                    self.correct_slices += 1
                else:
                    correct_phase = False
            if correct_phase:
                correct_phases += 1

    def result(self):
        slice_accuracy = self.correct_slices / self.total_slices
        phase_accuracy = self.correct_phases / self.total_phases
        return slice_accuracy, phase_accuracy

    def reset_states(self):
        self.total_phases = 0
        self.total_slices = 0
        self.correct_phases = 0
        self.correct_slices = 0

In [8]:
data = Data()

In [9]:
phase_gen = PhaseDataGenerator(data, "KAG", target_size=(224, 224), batch_size=2, shuffle=True)

In [10]:
slice_gen = SliceDataGenerator(data, "KAG", target_size=(224, 224), batch_size=32, shuffle=True)

In [11]:
assert(np.isclose(slice_gen[0][0][0].min(), -1, rtol=1.0e-4))
assert(np.isclose(slice_gen[0][0][0].max(), 1, rtol=1.0e-4))

In [12]:
assert(np.isclose(phase_gen[0][0][0].min(), -1, rtol=1.0e-4))
assert(np.isclose(phase_gen[0][0][0].max(), 1, rtol=1.0e-4))

In [15]:
keras_app = tf.keras.applications.mobilenet
keras_model = tf.keras.applications.mobilenet.MobileNet
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    rescale=1,
    fill_mode="nearest",
    preprocessing_function=keras_app.preprocess_input)

In [16]:
backbone = keras_model(include_top=False, pooling='avg', weights='imagenet', input_shape=(224, 224, 3))
backbone.trainable = False

cnn_model = Sequential()
cnn_model.add(backbone)
cnn_model.add(Dense(256, activation="relu"))
cnn_model.add(Dropout(0.5))
cnn_model.add(Dense(5, activation="softmax"))

cnn_model.layers[0].trainable = False
cnn_model.compile(loss="categorical_crossentropy",
                  optimizer="adam",
                  metrics=['accuracy'])

In [17]:
cnn_model.fit(slice_gen, epochs=20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7f0d0c4609d0>

In [18]:
rnn_model = Sequential()
rnn_model.add(TimeDistributed(backbone))
rnn_model.add(LSTM(256, input_shape=(25, 2048), return_sequences=True))
rnn_model.add(Dropout(0.5))
#rnn_model.add(TimeDistributed(Dense(256, activation="relu")))
#rnn_model.add(TimeDistributed(Dropout(0.5)))
rnn_model.add(TimeDistributed(Dense(5, activation="softmax")))

In [24]:
rnn_model.layers[0].trainable = False
rnn_model.compile(loss="categorical_crossentropy",
                  optimizer="adam",
                  metrics=[PhasewiseAccuracy])

TypeError: 'property' object is not iterable

In [14]:
rnn_model.fit(phase_gen, epochs=20)

NameError: name 'rnn_model' is not defined

In [55]:
test_gen = PhaseDataGenerator(data, "KAG", target_size=(224, 224), batch_size=2, shuffle=False)

In [56]:
rnn_model.evaluate(test_gen)



[0.15982870757579803, 0.45308008790016174]

In [57]:
preds = rnn_model.predict(test_gen)

In [59]:
labels = None
for _, batch in test_gen:
    if labels is None:
        labels = batch
    else:
        labels = np.concatenate([labels, batch])
    print(labels.shape)

(2, 25, 5)
(4, 25, 5)
(6, 25, 5)
(8, 25, 5)
(10, 25, 5)
(12, 25, 5)
(14, 25, 5)
(16, 25, 5)
(18, 25, 5)
(20, 25, 5)
(22, 25, 5)
(24, 25, 5)
(26, 25, 5)
(28, 25, 5)
(30, 25, 5)
(32, 25, 5)
(34, 25, 5)
(36, 25, 5)
(38, 25, 5)
(40, 25, 5)
(42, 25, 5)
(44, 25, 5)
(46, 25, 5)
(48, 25, 5)
(50, 25, 5)
(52, 25, 5)
(54, 25, 5)
(56, 25, 5)
(58, 25, 5)
(60, 25, 5)
(62, 25, 5)
(64, 25, 5)
(66, 25, 5)
(68, 25, 5)
(70, 25, 5)
(72, 25, 5)
(74, 25, 5)
(76, 25, 5)
(78, 25, 5)
(80, 25, 5)
(82, 25, 5)
(84, 25, 5)
(86, 25, 5)
(88, 25, 5)
(90, 25, 5)
(92, 25, 5)
(94, 25, 5)
(96, 25, 5)
(98, 25, 5)
(100, 25, 5)
(102, 25, 5)
(104, 25, 5)
(106, 25, 5)
(108, 25, 5)
(110, 25, 5)
(112, 25, 5)
(114, 25, 5)
(116, 25, 5)
(118, 25, 5)
(120, 25, 5)
(122, 25, 5)
(124, 25, 5)
(126, 25, 5)
(128, 25, 5)
(130, 25, 5)
(132, 25, 5)
(134, 25, 5)
(136, 25, 5)
(138, 25, 5)
(140, 25, 5)
(142, 25, 5)
(144, 25, 5)
(146, 25, 5)
(148, 25, 5)
(150, 25, 5)
(152, 25, 5)
(154, 25, 5)
(156, 25, 5)
(158, 25, 5)
(160, 25, 5)
(162, 25, 5)


(1256, 25, 5)
(1258, 25, 5)
(1260, 25, 5)
(1262, 25, 5)
(1264, 25, 5)
(1266, 25, 5)
(1268, 25, 5)
(1270, 25, 5)
(1272, 25, 5)
(1274, 25, 5)
(1276, 25, 5)
(1278, 25, 5)
(1280, 25, 5)
(1282, 25, 5)
(1284, 25, 5)
(1286, 25, 5)
(1288, 25, 5)
(1290, 25, 5)
(1292, 25, 5)
(1294, 25, 5)
(1296, 25, 5)
(1298, 25, 5)
(1300, 25, 5)
(1302, 25, 5)
(1304, 25, 5)
(1306, 25, 5)
(1308, 25, 5)
(1310, 25, 5)
(1312, 25, 5)
(1314, 25, 5)
(1316, 25, 5)
(1318, 25, 5)
(1320, 25, 5)
(1322, 25, 5)
(1324, 25, 5)
(1326, 25, 5)
(1328, 25, 5)
(1330, 25, 5)
(1332, 25, 5)
(1334, 25, 5)
(1336, 25, 5)
(1338, 25, 5)
(1340, 25, 5)
(1342, 25, 5)
(1344, 25, 5)
(1346, 25, 5)
(1348, 25, 5)
(1350, 25, 5)
(1352, 25, 5)
(1354, 25, 5)
(1356, 25, 5)
(1358, 25, 5)
(1360, 25, 5)
(1362, 25, 5)
(1364, 25, 5)
(1366, 25, 5)
(1368, 25, 5)
(1370, 25, 5)
(1372, 25, 5)
(1374, 25, 5)
(1376, 25, 5)
(1378, 25, 5)
(1380, 25, 5)
(1382, 25, 5)
(1384, 25, 5)
(1386, 25, 5)
(1388, 25, 5)
(1390, 25, 5)
(1392, 25, 5)
(1394, 25, 5)
(1396, 25, 5)
(1398,

In [60]:
preds.shape

(1948, 25, 5)

In [64]:
def get_accuracy(preds, labels):
    total_slices = 0
    total_phases = 0
    correct_slices = 0
    correct_phases = 0

    for phase, slices in enumerate(preds):
        total_phases += 1
        correct_phase = True
        for slice, pred in enumerate(slices):
            label = labels[phase][slice]
            if label.sum() == 0:
                pass
            else:
                total_slices += 1
                if label.argmax() == pred.argmax():
                    correct_slices += 1
                else:
                    correct_phase = False
        if correct_phase:
            correct_phases += 1

    slice_accuracy = correct_slices / total_slices
    phase_accuracy = correct_phases / total_phases

    return slice_accuracy, phase_accuracy, total_slices

In [65]:
slice_accuracy, phase_accuracy, ts = get_accuracy(preds, labels)
print("Slice accuracy:", slice_accuracy)
print("Phase accuracy:", phase_accuracy)

Slice accuracy: 0.8401032534580167
Phase accuracy: 0.1570841889117043


In [66]:
print(ts)

20532


In [71]:
ts / (preds.shape[0] * 25)

0.4216016427104723