In [2]:
import os

import numpy as np

import tensorflow as tf

from keras.models import Model, Input
from keras.optimizers import Adam

from keras.layers import Dense
from keras.layers import Conv3D, Conv3DTranspose
from keras.layers import Add, Concatenate
from keras.layers import Activation, LeakyReLU
from keras.layers import BatchNormalization, Lambda

from keras import backend as K

In [3]:
def accm(y_true, y_pred):
    '''
    accuracy metric
    '''
    y_pred = K.clip(y_pred, -1, 1)
    return K.mean(K.equal(y_true, K.round(y_pred)))

def mssim(y_true, y_pred):
    '''
    mean structural similarity index
    '''
    costs = 1.0 - tf.reduce_mean(tf.image.ssim(y_true, y_pred, 2.0))
    return costs

def wloss(y_true, y_predict):
    '''
    Wasserstein loss
    '''
    return -K.mean(y_true * y_predict)

In [10]:
def discriminator(
    kernel_initz,
    inp_shape = (32, 256, 256, 1), # 3d, with channels last
    trainable = True,
):
    
    gamma_init = tf.random_normal_initializer(1., 0.02)
    
    inp = Input(shape = (32, 256, 256, 1)) # 3d
    
    l0 = Conv3D(
        filters = 16, kernel_size = 4, strides = (1, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(inp)
    l0 = LeakyReLU(alpha = 0.2)(l0)
    
    
    l1 = Conv3D(
        filters = 16 * 2, kernel_size = 4, strides = (1, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(l0)
    l1 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l1)
    l1 = LeakyReLU(alpha = 0.2)(l1)
    
    
    l2 = Conv3D(
        filters = 16 * 4, kernel_size = 4, strides = (1, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(l1)
    l2 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l2)
    l2 = LeakyReLU(alpha = 0.2)(l2)
    
    
    l3 = Conv3D(
        filters = 16 * 8, kernel_size = 4, strides = (2, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz, 
        )(l2)
    l3 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l3)
    l3 = LeakyReLU(alpha = 0.2)(l3)
    
    
    l4 = Conv3D(
        filters = 16 * 16, kernel_size = 4, strides = (2, 2, 2),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l3)
    l4 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l4)
    l4 = LeakyReLU(alpha = 0.2)(l4)
    
    
    l5 = Conv3D(
        filters = 16 * 16, kernel_size = 4, strides = (2, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz, 
        )(l4)
    l5 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l5)
    l5 = LeakyReLU(alpha = 0.2)(l5)
    
    
    l6 = Conv3D(
        filters = 16 * 16, kernel_size = 1, strides = (1, 1, 1), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(l5)
    l6 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l6)
    l6 = LeakyReLU(alpha = 0.2)(l6)
    
    
    l7 = Conv3D(
        filters = 16 * 8, kernel_size = 1, strides = (1, 1, 1),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l6)
    l7 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l7)
    l7 = LeakyReLU(alpha = 0.2)(l7)
    
    
    l8 = Conv3D(
        filters = 16 * 2, kernel_size = 1, strides = (1, 1, 1), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(l7)
    l8 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l8)
    l8 = LeakyReLU(alpha = 0.2)(l8)
    
    
    l9 = Conv3D(
        filters = 16 * 2, kernel_size = 3, strides = (1, 1, 1),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l8)
    l9 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l9)
    l9 = LeakyReLU(alpha = 0.2)(l9)
    
    
    l10 = Conv3D(
        filters = 16 * 8, kernel_size = 3, strides = (1, 1, 1),
        padding = 'same', kernel_initializer = kernel_initz,
        )(l9)
    l10 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(l10)
    l10 = LeakyReLU(alpha = 0.2)(l10)
    
    
    l11 = Add()([l7,l10])
    l11 = LeakyReLU(alpha = 0.2)(l11)
    
    
    out = Conv3D(
        filters = 1, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = 'he_normal',
        )(l11)
    
    
    model = Model(inputs = inp, outputs = out)
    return model

In [11]:
def resden(
    x, fil_lay, fil_end, beta, 
    gamma_init, kernel_initz, trainable,
):   
    
    x1 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(x)
    x1 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x1)
    x1 = LeakyReLU(alpha = 0.2)(x1)
    x1=Concatenate(axis=-1)([x, x1])
    
    
    x2 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(x1)
    x2 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x2)
    x2 = LeakyReLU(alpha = 0.2)(x2)
    x2 = Concatenate(axis = -1)([x1, x2])
     
        
    x3 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(x2)
    x3 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x3)
    x3 = LeakyReLU(alpha = 0.2)(x3)
    x3=Concatenate(axis = -1)([x2, x3])
    
    
    x4 = Conv3D(
        filters = fil_lay, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(x3)
    x4 = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(x4)
    x4 = LeakyReLU(alpha = 0.2)(x4)
    x4 = Concatenate(axis = -1)([x3, x4])
    
    
    x5 = Conv3D(
        filters = fil_end, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz,
        )(x4)
    
    x5 = Lambda(lambda x: x * beta)(x5)
    
    xout = Add()([x5,x])
    return xout

def resresden(x, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable):
    
    x1 = resden(x,  fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable,)
    x2 = resden(x1, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable,)
    x3 = resden(x2, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable,)
    
    x3 = Lambda(lambda x : x * beta)(x3)
    
    xout = Add()([x3,x])
    return xout

In [17]:
def generator(inp_shape, kernel_initz, trainable = True,):
    gamma_init = tf.random_normal_initializer(1., 0.02)
    
    fil_lay = 32
    fil_end = 512
    rrd_count = 12
    beta = 0.2

    
    inp_usamp_imag = Input(inp_shape) # (32, 256, 256, 2)
    
    
    lay_1dn = Conv3D(
        filters = 64, kernel_size = 4, strides = (1, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(inp_usamp_imag)
    lay_1dn = LeakyReLU(alpha = 0.2)(lay_1dn)

    
    lay_2dn = Conv3D(
        filters = 128, kernel_size = 4, strides = (1, 2, 2),
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_1dn)
    lay_2dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_2dn)
    lay_2dn = LeakyReLU(alpha = 0.2)(lay_2dn)

    
    lay_3dn = Conv3D(
        filters = 256, kernel_size = 4, strides = (2, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_2dn)
    lay_3dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_3dn)
    lay_3dn = LeakyReLU(alpha = 0.2)(lay_3dn)

    
    lay_4dn = Conv3D(
        filters = 512, kernel_size = 4, strides = (2, 2, 2),
        padding = 'same', kernel_initializer = kernel_initz, 
        )(lay_3dn)
    lay_4dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_4dn)
    lay_4dn = LeakyReLU(alpha = 0.2)(lay_4dn)  

    
    lay_5dn = Conv3D(
        filters = 512, kernel_size = 4, strides = (2, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_4dn)
    lay_5dn = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_5dn)
    lay_5dn = LeakyReLU(alpha = 0.2)(lay_5dn)


    c1 = Conv3D(
        filters = fil_end, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(lay_5dn)
    
    xrrd = c1
    for _ in range(rrd_count):
        xrrd = resresden(xrrd, fil_lay, fil_end, beta, gamma_init, kernel_initz, trainable)

    c2 = Conv3D(
        filters = fil_end, kernel_size = 3, strides = 1,
        padding = 'same', kernel_initializer = kernel_initz, 
        )(xrrd)
    
    
    lay_5upc=Add()([c1, c2])

    
    lay_4up = Conv3DTranspose(
        filters = 1024, kernel_size = 4, strides = (2, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_5upc)
    lay_4up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_4up)
    lay_4up = Activation('relu')(lay_4up) 

    lay_4upc = Concatenate(axis = -1)([lay_4up,lay_4dn])

    
    
    lay_3up = Conv3DTranspose(
        filters = 256, kernel_size = 4, strides = (2, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_4upc) 
    lay_3up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_3up)
    lay_3up = Activation('relu')(lay_3up)

    lay_3upc = Concatenate(axis = -1)([lay_3up,lay_3dn])

    
    lay_2up = Conv3DTranspose(
        filters = 128, kernel_size = 4, strides = (2, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_3upc)
    lay_2up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_2up)
    lay_2up = Activation('relu')(lay_2up)

    lay_2upc = Concatenate(axis = -1)([lay_2up, lay_2dn])

    
    lay_1up = Conv3DTranspose(
        filters = 64, kernel_size = 4, strides = (1, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_2upc)
    lay_1up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_1up)
    lay_1up = Activation('relu')(lay_1up) 

    lay_1upc = Concatenate(axis = -1)([lay_1up,lay_1dn])

    lay_256up = Conv3DTranspose(
        filters = 64, kernel_size = 4, strides = (1, 2, 2), 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_1upc)
    lay_256up = BatchNormalization(gamma_initializer = gamma_init, trainable = trainable)(lay_256up)
    lay_256up = Activation('relu')(lay_256up)

    out = Conv3D(
        filters = 1, kernel_size = 1, strides = (1, 1, 1), activation = 'tanh', 
        padding = 'same', kernel_initializer = kernel_initz,
        )(lay_256up)

    model = Model(inputs = inp_usamp_imag, outputs = out)

    return model

In [18]:
def define_gan_model(gen_model, dis_model, inp_shape):
        
    dis_model.trainable = False
    inp = Input(shape = inp_shape)
    out_g = gen_model(inp)
    out_dis = dis_model(out_g)
    model = Model(inputs = inp, outputs = [out_dis, out_g])
    model.summary()
    return model

In [19]:
def train(
    g_model, d_model, gan_model, 
    dataset, 
    n_epochs, n_batch, n_critic, 
    clip_val, n_patch, 
    f,
):
    
    bat_per_epo = int(fsamp_data.shape[0]/n_batch)
    half_batch = int(n_batch / 2)
    
    for i in range(n_epochs):
        for j in range(bat_per_epo):
            
            # training the discriminator
            for k in range(n_critic):
                X = np.ones((32, 256, 256, 1)) #[h, l, w, c]
                y = np.ones((half_batch, n_patch, n_patch, n_patch, 1))

                for usamp_data, fsamp_data in dataset.take(half_batch):
                    X_real = fsamp_data
                    X_fake = g_model.predict(usamp_data)

                    y_real = np.ones((half_batch, n_patch, n_patch, n_patch, 1))
                    y_fake = -np.ones((half_batch, n_patch, n_patch, n_patch, 1))

                    X, y = np.vstack((X, X_real)), np.vstack((y, y_real))
                    X, y = np.vstack((X, X_fake)), np.vstack((y, y_fake))

                X, y = X[1:], y[1:] # take out first np.ones
                d_loss, accuracy = d_model.train_on_batch(X, y)

                for l in d_model.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -clip_val, clip_val) for w in weights]
                    l.set_weights(weights)

                    
            # training the generator
            X_usamps = np.ones((32, 256, 256, 2)) #[h, l, w, c]
            X_fsamps = np.ones((32, 256, 256, 1))

            for usamp_data, fsamp_data in dataset.take(n_batch):
                X_usamp = usamp_data
                X_fsamp = fsamp_data
                
                X_usamps = np.vstack((X_usamps, X_usamp))
                X_fsamps = np.vstack((X_fsamps, X_fsamp)
                
            X_usamps, X_fsamps = X_usamps[1:], X_fsamps[1:] # take out first np.ones
            y_gan = np.ones((n_batch, n_patch, n_patch, n_patch, 1))
            g_loss = gan_model.train_on_batch ([X_usamps], [y_gan, X_fsamps])
            
            
            f.write(f'>epoch: {i+1}, batch: {(j+1)/bat_per_epo}, discriminator loss: {d_loss}, acc: {accuracy},  wasserstein: {g_loss[1]},  mae: {g_loss[2]},  mssim: {g_loss[3]}, generator_loss: {g_loss[0]}')
            f.write('\n')
            print (f'>epoch: {i+1}, batch: {(j+1)/bat_per_epo}, discriminator loss: {d_loss}, acc: {accuracy}, generator_loss: {g_loss[0]}')
        
        h5_filename = '~/bebi205/models/gen_weights_1_%04d.h5' % (i+1)
        g_model.save(filename)

    f.close() 

SyntaxError: invalid syntax (<ipython-input-19-a4ed2dfadce0>, line 50)

In [21]:
#hyperparameters       
n_epochs = 300
n_batch = 4
n_critic = 3
clip_val = 0.05
in_shape_gen = (32, 256, 256, 2)
in_shape_dis = (32, 256, 256, 1)


d_model = discriminator (inp_shape = in_shape_dis, kernel_initz = 'he_normal', trainable = True)
# d_model.summary()
opt = Adam(lr = 0.0002, beta_1 = 0.5)
d_model.compile(loss = wloss, optimizer = opt, metrics = [accm])


g_model = generator(inp_shape = in_shape_gen, kernel_initz = 'he_normal', trainable = True)
# g_model.summary()

gan_model = define_gan_model(g_model, d_model, in_shape_gen)
opt1 = Adam(lr = 0.0001, beta_1 = 0.5)
gan_model.compile(loss = [wloss, 'mae', mssim], optimizer = opt1, loss_weights = [0.01, 20.0, 1.0])
n_patch = d_model.output_shape[1]

# data_path='/home/cs-mri-gan/training_gt_aug.pickle' #Ground truth
# usam_path='/home/cs-mri-gan/training_usamp_1dg_a5_aug.pickle' #Zero-filled reconstructions

# df = open(data_path,'rb')
# uf = open(usam_path,'rb')

# fsamp_data = pickle.load(df)
# usamp_data = pickle.load(uf)

# fsamp_data = np.expand_dims(fsamp_data, axis = -1)
# usamp_data = np.expand_dims(usamp_data, axis = -1)

# usamp_data_real = usamp_data.real
# usamp_data_imag = usamp_data.imag


# usamp_data_2c = np.concatenate((usamp_data_real, usamp_data_imag), axis = -1)

# f = open('/home/cs-mri-gan/log_a5.txt', 'x')
# f = open('/home/cs-mri-gan/log_a5.txt', 'a') 

# train(g_par, d_par, gan_model, fsamp_data, usamp_data_2c, n_epochs, n_batch, n_critic, clip_val, n_patch, f)


    
  
  
       

    
 

Model: "model_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_10 (InputLayer)        [(None, 32, 256, 256, 2)] 0         
_________________________________________________________________
model_6 (Functional)         (None, 32, 256, 256, 1)   494673921 
_________________________________________________________________
model_5 (Functional)         (None, 4, 4, 4, 1)        7231281   
Total params: 501,905,202
Trainable params: 494,658,817
Non-trainable params: 7,246,385
_________________________________________________________________
