In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import skimage.io
import os
from data import utils as CTRUtil

In [2]:
from tqdm import tqdm
import matplotlib.pyplot as plt
from keras.layers import BatchNormalization
from keras.layers import LeakyReLU, Input
from keras.layers import Reshape, UpSampling2D, MaxPooling2D
from keras.layers import Conv2D, Conv2DTranspose, Flatten
from keras.layers.core import Dense, Activation, Dropout
from keras.regularizers import l2
from keras.models import Sequential, Model
from keras.datasets import mnist
import keras
import matplotlib
# matplotlib.use('Agg')

In [3]:
DATA_DIR = "data/"
JSRT_DIR = DATA_DIR + "jsrt/"
WING_DIR = DATA_DIR + "wingspan/"
JSRT_FNAMES = os.listdir(JSRT_DIR + "png")
WING_FNAMES = os.listdir(WING_DIR + "png")

In [4]:
def get_imgs(base_dir, fname):
    """
    Given base_dir = JSRT_DIR or WING_DIR,
    Returns tuple of numpy arrays representing 
        - Original scan
        - Left lung mask
        - Right lung mask
        - Heart mask
    """
    img = skimage.io.imread(base_dir + "png/" + fname)
    left = skimage.io.imread(base_dir + "mask/left_lung/" + fname)
    right = skimage.io.imread(base_dir + "mask/right_lung/" + fname)
    heart = skimage.io.imread(base_dir + "mask/heart/" + fname)
    for idx, mask in enumerate([left, right, heart]):
        mask[mask > 0] = idx + 1
    return img, left, right, heart

def print_img_demo(n_imgs = 3):
    """
    Prints some JSRT images with overlays (just for looking at dataset.)
    Can pass in n_imgs = (int) x to print x images. 
    """
    if n_imgs > len(JSRT_FNAMES):
        n_imgs = len(JSRT_FNAMES)
    for frame in JSRT_FNAMES[:n_imgs]:
        img, left, right, heart = get_imgs(JSRT_DIR, frame)
        show_annotation(img, left, right, heart)
        plt.show()


def get_input_image_example(base_dir, fname):
    """
    Parameter
        base_dir = JSRT_DIR or WING_DIR,
        fname = one of JSRT_FNAMES or WING_FNAMES
    Returns numpy array representing image with 1 BW channel.
    Range of each value is [0.0, 1.0]
    Shape: (height, width, 1 channel)
    """
    img = skimage.io.imread(base_dir + "png/" + fname)
    norm = img / 255.0
    return np.expand_dims(norm, -1)


def get_input_image_examples(base_dir, fnames):
    """
    Accumulates get_input_image_example. 
    Shape: (#examples, height, width, 1 channel)
    """
    return np.array([
        get_input_image_example(base_dir, fname)
        for fname in fnames
    ])


def get_ground_truth_example(base_dir, fname):
    """
    Parameter
        base_dir = JSRT_DIR or WING_DIR,
        fname = one of JSRT_FNAMES or WING_FNAMES
    Returns numpy array representing image with 4 channels (one-hot):
        0: none
        1: Left lung mask
        2: Right lung mask
        3: Heart mask
    Shape: (height, width, 4 channels)
    """
    left = skimage.io.imread(base_dir + "mask/left_lung/" + fname)
    right = skimage.io.imread(base_dir + "mask/right_lung/" + fname)
    heart = skimage.io.imread(base_dir + "mask/heart/" + fname)

    for mask in (left, right, heart):
        mask[mask > 0] = 1.0
        
    non = np.ones(left.shape) - left - right - heart
    non[non < 0] = 0.0
    non[non > 0] = 1.0
    
    concat = np.stack((non, left, right, heart), axis=2)
    return concat

def get_ground_truth_set(base_dir, fnames):
    """
    Accumulates get_ground_truth_example. 
    Shape: (#examples, height, width, 4 channels)
    """
    return np.array([
        get_ground_truth_example(base_dir, fname)
        for fname in fnames
    ])

def get_data(base_dir, fnames):
    """
    Calls get_input_image_examples on the given 
        base_dir : ex. JSRT_DIR
        fnames : ex. JSRT_FNAMES
    """
    X = get_input_image_examples(base_dir, fnames)
    Y = get_ground_truth_set(base_dir, fnames)
    return X, Y

def show_annotation(img, left, right, heart):
    annotated = CTRUtil.add_seg(img, left + right + heart)
    skimage.io.imshow(annotated)


In [5]:
X, Y = get_data(JSRT_DIR, JSRT_FNAMES)
X.shape, Y.shape

((247, 512, 512, 1), (247, 512, 512, 4))

Each x: (examples, H, W, 1 channel)
Each y: (examples, H, W, 4 channels (left, right, heart, none))

Discriminator uses Avg pooling, 1x1 conv. 


Strategy: First approach, only train on JSRT. Then, see how well it does on Wingspan (in terms of segmentation and especially in terms of CTR estimation). Idea: these __are__ two different datasets. Want to see if that makes an impact.


In [7]:

# Some constants
MNIST_SIZE = 512
LATENT_DIM = 100


def make_generator(num_filters=64, num_hidden_conv_layers=2, init_dim=7):
    gen = Sequential()
    # Model input is a feature vector of size 100
    gen.add(Dense(init_dim**2 * num_filters, input_dim=LATENT_DIM))
    gen.add(Activation('relu'))
    gen.add(Reshape((init_dim, init_dim, num_filters)))

    for _ in range(num_hidden_conv_layers):
        # Input: d x d x k
        # Output 2d x 2d x k/2
        if (init_dim < MNIST_SIZE):
            gen.add(UpSampling2D())
            init_dim *= 2
        num_filters //= 2
        gen.add(Conv2DTranspose(num_filters, 5, padding='same'))
        gen.add(BatchNormalization(momentum=0.4))
        gen.add(Activation('relu'))

    gen.add(Conv2DTranspose(1, 5, padding='same'))
    gen.add(Activation('sigmoid'))
    # Output should be 28 x 28 x 1
    # gen.summary()
    return gen


def make_discriminator(num_filters=32, num_hidden_layers=3, dropout=0.3):
    d = Sequential()

    d.add(Conv2D(num_filters*1, 5, strides=2,
                 input_shape=(MNIST_SIZE, MNIST_SIZE, 1), padding='same'))
    d.add(LeakyReLU())  # leakyrelu so generator has derivative
    d.add(Dropout(dropout))

    for i in range(1, num_hidden_layers):
        # Powers of 2 are generally better suited for GPU
        d.add(Conv2D(num_filters*(2**i), 5, strides=2, padding='same'))
        d.add(LeakyReLU())
        d.add(Dropout(dropout))

    # NOTE: Difference between this and build_conv_net
    #       is that there is only a SINGLE output class,
    #       which corresponds to FAKE/REAL.
    d.add(Flatten())
    d.add(Dense(1))
    d.add(Activation('sigmoid'))
    d.compile(loss='binary_crossentropy', optimizer='adam')
    return d


def make_adversial_network(generator, discriminator):
    # This will only be used for training the generator.
    # Note, the weights in the discriminator and generator are shared.
    discriminator.trainable = False
    gan = Sequential([generator, discriminator])
    gan.compile(loss='binary_crossentropy', optimizer='adam')
    return gan  # , generator, discriminator


def generate_latent_noise(n):
    return np.random.uniform(-1, 1, size=(n, LATENT_DIM))


def visualize_generator(epoch, generator,
                        num_samples=100, dim=(10, 10),
                        figsize=(10, 10), path=''):
    plt.figure(figsize=figsize)
    for i in range(num_samples):
        plt.subplot(dim[0], dim[1], i+1)
        img = generator.predict(generate_latent_noise(1))[0, :, :, 0]
        plt.imshow(img, cmap='gray')
        plt.axis('off')
    plt.tight_layout()
    plt.savefig(f'generator_samples/gan_epoch_{epoch}.png')
    plt.close()


def train(epochs=1, batch_size=128, path=''):
    # Grab all training examples. 
    X_images = X
    X_segments = Y


    # Creating GAN
    generator = make_generator()
    discriminator = make_discriminator()
    adversial_net = make_adversial_network(generator, discriminator)

    visualize_generator(0, generator, path=path)
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}')

        discr_loss = 0
        gen_loss = 0
        for _ in tqdm(range(batch_size)):
            noise = generate_latent_noise(batch_size)
            generated_segments = generator.predict(noise)

            rand_choice = np.random.choice(X_segments.shape[0], batch_size,
                                                   replace=False)
            real_images = X_images[rand_choice]
            real_segments = X_segments[rand_choice]
            discrimination_data = np.concatenate([real_segments, generated_segments])

            # Labels for generated and real data, uses soft label trick
            discrimination_labels = 0.1 * np.ones(2 * batch_size)
            discrimination_labels[:batch_size] = 0.9

            # To train, we alternate between training just the discriminator
            # and just the generator.
            discriminator.trainable = True
            discr_loss += discriminator.train_on_batch(discrimination_data,
                                                       discrimination_labels)

            # Trick to 'freeze' discriminator weights in adversial_net. Only
            # the generator weights will be changed, which are shared with
            # the generator.
            discriminator.trainable = False
            # N.B, changing the labels because now we want to 'fool' the
            # discriminator.
            gen_loss += adversial_net.train_on_batch(
                noise, np.ones(batch_size))

        print(f'Discriminator Loss: {discr_loss/batch_size}')
        print(f'Generator Loss:     {gen_loss/batch_size}')
        visualize_generator(epoch+1, generator, path=path)


train(epochs=4)


TypeError: get_data() missing 1 required positional argument: 'fnames'