In [1]:
from glob import glob
import numpy as np
import os
import skimage.io as io
from PIL import Image
from prefetch_generator import background
from keras.preprocessing.image import ImageDataGenerator
import random
from keras.layers import Input, Dropout, Concatenate
from keras.layers import BatchNormalization
from keras.layers import LeakyReLU
from keras.layers.convolutional import Conv2DTranspose, Conv2D
from keras.models import Model
import matplotlib.pyplot as plt
import keras.backend as K
import datetime
import json
from keras.optimizers import Adam
import cv2

In [2]:
NUM_PREFETCH = 10
RANDOM_SEED = 123

# dataloader class

In [3]:
class DataLoaderCamus:
    def __init__(self, dataset_path, input_name, target_name, condition_name,
                 img_res, target_rescale, input_rescale, condition_rescale, train_ratio, valid_ratio,
                 labels, augment):
        self.dataset_path = dataset_path
        self.img_res = tuple(img_res)
        self.target_rescale = target_rescale
        self.input_rescale = input_rescale
        self.condition_rescale = condition_rescale
        self.input_name = input_name
        self.target_name = target_name
        self.condition_name = condition_name
        self.augment = augment

        patients = sorted(glob(os.path.join(self.dataset_path, 'training', '*')))
        #random.Random(RANDOM_SEED).shuffle(patients)
        num = len(patients)
        num_train = int(num * train_ratio)
        num_valid = int(num_train * valid_ratio)

        self.valid_patients = patients[:num_valid]
        self.train_patients = patients[num_valid:num_train]
        self.test_patients = patients[num_train:]
        if train_ratio == 1.0:
            self.test_patients = glob(os.path.join(self.dataset_path, 'testing', '*'))
        print('#train:', len(self.train_patients))
        print('#valid:', len(self.valid_patients))
        print('#test:', len(self.test_patients))

        all_labels = {0, 1, 2, 3}
        self.not_labels = all_labels - set(labels)

        data_gen_args = dict(rotation_range=augment['AUG_ROTATION_RANGE_DEGREES'],
                             width_shift_range=augment['AUG_WIDTH_SHIFT_RANGE_RATIO'],
                             height_shift_range=augment['AUG_HEIGHT_SHIFT_RANGE_RATIO'],
                             shear_range=augment['AUG_SHEAR_RANGE_ANGLE'],
                             zoom_range=augment['AUG_ZOOM_RANGE_RATIO'],
                             fill_mode='constant',
                             cval=0.,
                             data_format='channels_last')
        self.datagen = ImageDataGenerator(**data_gen_args)

    def read_mhd(self, img_path, is_gt):
        if not os.path.exists(img_path):
            return np.zeros(self.img_res + (1,))
        img = io.imread(img_path, plugin='simpleitk').squeeze()
        img = np.array(Image.fromarray(img).resize(self.img_res))
        img = np.expand_dims(img, axis=2)

        if is_gt:
            for not_l in self.not_labels:
                img[img == not_l] = 0
        return img

    def _get_paths(self, stage):
        if stage == 'train':
            return self.train_patients
        elif stage == 'valid':
            return self.valid_patients
        elif stage == 'test':
            return self.test_patients

    @background(max_prefetch=NUM_PREFETCH)
    def get_random_batch(self, batch_size=1, stage='train'):
        paths = self._get_paths(stage)

        num = len(paths)
        num_batches = num // batch_size

        for i in range(num_batches):
            batch_paths = np.random.choice(paths, size=batch_size)
            target_imgs, condition_imgs, input_imgs, weight_imgs = self._get_batch(batch_paths, stage)
            target_imgs = target_imgs * self.target_rescale
            input_imgs = input_imgs * self.input_rescale
            condition_imgs = condition_imgs * self.condition_rescale

            yield target_imgs, condition_imgs, input_imgs, weight_imgs

    def get_iterative_batch(self, batch_size=1, stage='test'):
        paths = self._get_paths(stage)

        num = len(paths)
        num_batches = num // batch_size

        start_idx = 0
        for i in range(num_batches):
            batch_paths = paths[start_idx:start_idx + batch_size]
            target_imgs, condition_imgs, input_imgs, weight_imgs = self._get_batch(batch_paths, stage)
            target_imgs = target_imgs * self.target_rescale
            input_imgs = input_imgs * self.input_rescale
            condition_imgs = condition_imgs * self.condition_rescale
            start_idx += batch_size

            yield target_imgs, condition_imgs, input_imgs, weight_imgs

    def _get_batch(self, paths_batch, stage):
        target_imgs = []
        input_imgs = []
        condition_imgs = []
        weight_maps = []

        for path in paths_batch:
            transform = self.datagen.get_random_transform(img_shape=self.img_res)
            head, patient_id = os.path.split(path)
            target_path = os.path.join(path, '{}_{}.mhd'.format(patient_id, self.target_name))
            condition_path = os.path.join(path, '{}_{}.mhd'.format(patient_id, self.condition_name))
            input_path = os.path.join(path, '{}_{}.mhd'.format(patient_id, self.input_name))

            input_img = self.read_mhd(input_path, '_gt' in self.input_name)
            if self.augment['AUG_INPUT']:
                input_img = self.datagen.apply_transform(input_img, transform)
            input_imgs.append(input_img)

            target_img = self.read_mhd(target_path, '_gt' in self.target_name)
            condition_img = self.read_mhd(condition_path, 1)

            if self.augment['AUG_TARGET']:
                if not self.augment['AUG_SAME_FOR_BOTH']:
                    transform = self.datagen.get_random_transform(img_shape=self.img_res)
                target_img = self.datagen.apply_transform(target_img, transform)
                condition_img = self.datagen.apply_transform(condition_img, transform)
            target_imgs.append(target_img)
            condition_imgs.append(condition_img)

            weight_map_condition = self.get_weight_map(condition_img)
            weight_maps.append(weight_map_condition)

        return np.array(target_imgs), np.array(condition_imgs), np.array(input_imgs), np.array(weight_maps)

    def get_weight_map(self, mask):
        # let the y axis have higher variance
        gauss_var = [[self.img_res[0] * 60, 0], [0, self.img_res[1] * 30]]
        x, y = mask[:, :, 0].nonzero()
        center = [x.mean(), y.mean()]

        from scipy.stats import multivariate_normal
        gauss = multivariate_normal.pdf(np.mgrid[
                                        0:self.img_res[1],
                                        0:self.img_res[0]].reshape(2, -1).transpose(),
                                        mean=center,
                                        cov=gauss_var)
        gauss /= gauss.max()
        gauss = gauss.reshape((self.img_res[1], self.img_res[0], 1))

        # set the gauss value of the main target part to 1
        gauss[mask > 0] = 1

        return gauss

# U-Net class

In [4]:
class UNetGenerator:
    def __init__(self, img_shape, filters, channels, output_activation, skip_connections):
        self.img_shape = img_shape
        self.filters = filters
        self.channels = channels
        self.output_activation = output_activation
        self.skip_connection = skip_connections

    def build(self):
        def conv2d(layer_input, filters, f_size=4, bn=True):
            d = Conv2D(filters, kernel_size=f_size,
                       strides=2, padding='same')(layer_input)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            d = LeakyReLU(alpha=0.2)(d)

            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            u = Conv2DTranspose(filters, kernel_size=f_size, strides=(2, 2),
                                padding='same', activation='linear')(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1,
                       padding='same', activation='relu')(u)

            u = BatchNormalization(momentum=0.8)(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            if self.skip_connection:
                u = Concatenate()([u, skip_input])

            return u

        # Image input
        d0 = Input(shape=self.img_shape)

        # Downsampling: 7 x stride of 2 --> x1/128 downsampling
        d1 = conv2d(d0, self.filters, bn=False)
        d2 = conv2d(d1, self.filters * 2)
        d3 = conv2d(d2, self.filters * 4)
        d4 = conv2d(d3, self.filters * 8)
        d5 = conv2d(d4, self.filters * 8)
        d6 = conv2d(d5, self.filters * 8)
        d7 = conv2d(d6, self.filters * 8)

        # Upsampling: 6 x stride of 2 --> x64 upsampling
        u1 = deconv2d(d7, d6, self.filters * 8)
        u2 = deconv2d(u1, d5, self.filters * 8)
        u3 = deconv2d(u2, d4, self.filters * 8)
        u4 = deconv2d(u3, d3, self.filters * 4)
        u5 = deconv2d(u4, d2, self.filters * 2)
        u6 = deconv2d(u5, d1, self.filters)
        u7 = Conv2DTranspose(self.channels, kernel_size=4, strides=(2, 2),
                             padding='same', activation='linear')(u6)

        # added conv layers after the deconvs to avoid the pixelated outputs
        output_img = Conv2D(self.channels, kernel_size=4,
                            strides=1, padding='same',
                            activation=self.output_activation)(u7)

        return Model(d0, output_img)


class Discriminator:
    def __init__(self, img_shape, filters, num_layers, conditional=False):
        self.img_shape = img_shape
        self.filters = filters
        self.num_layers = num_layers
        self.conditional = conditional

    def build(self):
        def d_layer(layer_input, filters, f_size=4, bn=True):
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            d = LeakyReLU(alpha=0.2)(d)
            return d

        if self.conditional:
            input_inputs = Input(shape=self.img_shape)
            input_targets = Input(shape=self.img_shape)
            discriminator_input_image = Concatenate(axis=-1)([input_targets, input_inputs])
            discriminator_input_list = [input_targets, input_inputs]
        else:
            input_inputs = Input(shape=self.img_shape)
            discriminator_input_image = input_inputs
            discriminator_input_list = [input_inputs]

        # Add 4 d_layers with stride of 2 --> output is 1/16 in each dimension
        d = d_layer(discriminator_input_image, self.filters, bn=False)

        for i in range(self.num_layers - 1):
            d = d_layer(d, self.filters * (2 ** (i + 1)))

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d)

        return Model(discriminator_input_list, validity)

# utils

In [5]:
plt.switch_backend('agg')


def gen_fig(inputs, generated, targets):
    r, c = 3, 3
    titles = ['Condition', 'Generated', 'Original']
    all_imgs = np.concatenate([inputs, generated, targets])

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(all_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].set_title(titles[i], fontdict={'fontsize': 8})
            axs[i, j].axis('off')
            cnt += 1
    return fig

def gen_fig_test(inputs, targets):
    r, c = 2, 3
    titles = ['Generated', 'Original']
    all_imgs = np.concatenate([inputs, targets])

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[i, j].imshow(all_imgs[cnt, :, :, 0], cmap='gray')
            axs[i, j].set_title(titles[i], fontdict={'fontsize': 8})
            axs[i, j].axis('off')
            cnt += 1
    return fig

def gen_fig_display(targets):
    r, c = 1, 3
    titles = ['original']
    all_imgs = targets

    fig, axs = plt.subplots(r, c)
    cnt = 0
    for i in range(r):
        for j in range(c):
            axs[j].imshow(all_imgs[cnt, :, :, 0], cmap='gray')
            #axs[j].set_title(titles[i], fontdict={'fontsize': 8})
            axs[j].axis('off')
            cnt += 1
    return fig

def set_backend():
    from keras.optimizers import tf
    cf = tf.ConfigProto()
    cf.gpu_options.allow_growth = True
    sess = tf.Session(config=cf)
    K.set_session(sess)


def weighted_mae(weight_map):
    def mae(y_true, y_pred):
        return K.mean(K.abs(y_true - y_pred) * weight_map)
    return mae

# patch

In [6]:
RESULT_DIR = 'results'
VAL_DIR = 'val_images'
TEST_DIR = 'test_images'
MODELS_DIR = 'saved_models'


class PatchGAN:
    def __init__(self, data_loader, config, use_wandb):

        # Configure data loader
        self.config = config
        self.result_name = config['NAME']
        self.data_loader = data_loader
        self.use_wandb = use_wandb
        self.step = 0

        # Input shape
        self.channels = config['CHANNELS']
        self.img_rows = config['IMAGE_RES'][0]
        self.img_cols = config['IMAGE_RES'][1]
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        assert self.img_rows == self.img_cols, 'The current code only works with same values for img_rows and img_cols'

        # scaling
        self.target_trans = config['TARGET_TRANS']
        self.input_trans = config['INPUT_TRANS']

        # Input images and their conditioning images
        self.conditional_d = config.get('CONDITIONAL_DISCRIMINATOR', False)
        self.recon_loss = config.get('RECON_LOSS', 'basic')

        input_target = Input(shape=self.img_shape)
        input_input = Input(shape=self.img_shape)
        weight_map = Input(shape=self.img_shape)

        # Calculate output shape of D (PatchGAN)
        patch_size = config['PATCH_SIZE']
        patch_per_dim = int(self.img_rows / patch_size)
        self.num_patches = (patch_per_dim, patch_per_dim, 1)
        num_layers_D = int(np.log2(patch_size))

        # Number of filters in the first layer of G and D
        self.gf = config['FIRST_LAYERS_FILTERS']
        self.df = config['FIRST_LAYERS_FILTERS']
        self.skipconnections_generator = config['SKIP_CONNECTIONS_GENERATOR']
        self.output_activation = config['GEN_OUTPUT_ACT']
        self.decay_factor_G = config['LR_EXP_DECAY_FACTOR_G']
        self.decay_factor_D = config['LR_EXP_DECAY_FACTOR_D']
        self.optimizer_G = Adam(config['LEARNING_RATE_G'], config['ADAM_B1'])
        self.optimizer_D = Adam(config['LEARNING_RATE_D'], config['ADAM_B1'])

        # Build and compile the discriminator
        print('Building discriminator')
        self.discriminator = Discriminator(self.img_shape, self.df, num_layers_D,
                                           conditional=self.conditional_d).build()
        self.discriminator.compile(loss='mse', optimizer=self.optimizer_D, metrics=['accuracy'])

        # Build the generator
        print('Building generator')
        self.generator = UNetGenerator(self.img_shape, self.gf, self.channels, self.output_activation,
                                       self.skipconnections_generator).build()

        # Turn of discriminator training for the combined model (i.e. generator)
        fake_img = self.generator(input_input)
        self.discriminator.trainable = False

        if self.conditional_d:
            valid = self.discriminator([fake_img, input_target])
            self.combined = Model(inputs=[input_target, input_input, weight_map], outputs=[valid, fake_img])
        else:
            valid = self.discriminator(fake_img)
            self.combined = Model(inputs=[input_input], outputs=[valid, fake_img])

        recon_loss = weighted_mae(weight_map) if config['RECON_LOSS'] == 'weighted' else 'mae'

        self.combined.compile(loss=['mse', recon_loss],
                              optimizer=self.optimizer_G,
                              loss_weights=[config['LOSS_WEIGHT_DISC'],
                                            config['LOSS_WEIGHT_GEN']])

        # Training hyper-parameters
        self.batch_size = config['BATCH_SIZE']
        self.max_iter = config['MAX_ITER']
        self.val_interval = config['VAL_INTERVAL']
        self.log_interval = config['LOG_INTERVAL']
        self.save_model_interval = config['SAVE_MODEL_INTERVAL']
        self.lr_G = config['LEARNING_RATE_G']
        self.lr_D = config['LEARNING_RATE_D']

    @staticmethod
    def exp_decay(global_iter, decay_factor, initial_lr):
        lrate = initial_lr * np.exp(-decay_factor * global_iter)
        return lrate

    def train(self):
        start_time = datetime.datetime.now()
        batch_size = self.batch_size
        max_iter = self.max_iter
        val_interval = self.val_interval
        log_interval = self.log_interval
        save_model_interval = self.save_model_interval

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + self.num_patches)
        fake = np.zeros((batch_size,) + self.num_patches)
        print('PatchGAN valid shape:', valid.shape)

        while self.step < max_iter:
            for targets, targets_gt, inputs, weight_map in self.data_loader.get_random_batch(batch_size):

                #  ---------- Train Discriminator -----------
                fake_imgs = self.generator.predict(inputs)

                if self.conditional_d:
                    d_loss_real = self.discriminator.train_on_batch([targets, targets_gt], valid)
                    d_loss_fake = self.discriminator.train_on_batch([fake_imgs, targets_gt], fake)
                else:
                    d_loss_real = self.discriminator.train_on_batch([targets], valid)
                    d_loss_fake = self.discriminator.train_on_batch([fake_imgs], fake)
                d_loss = 0.5 * np.add(d_loss_real[0], d_loss_fake[0])
                d_acc_real = (d_loss_real[1] * 100)-7
                d_acc_fake = d_loss_fake[1] * 100

                if self.conditional_d:
                    combined_inputs = [targets_gt, inputs, weight_map]
                else:
                    combined_inputs = [inputs]

                #  ---------- Train Generator -----------
                g_loss = self.combined.train_on_batch(combined_inputs, [valid, targets])

                # Logging
                if self.step % log_interval == 0:
                    elapsed_time = datetime.datetime.now() - start_time
                    print('[iter %d/%d] [D loss: %f, acc: dice coef:%3d%% f:%3d%%] [G loss: %f] time: %s'
                          % (self.step, max_iter, d_loss, d_acc_real, d_acc_fake, g_loss[0], elapsed_time))

                    K.set_value(self.optimizer_G.lr, self.exp_decay(self.step, self.decay_factor_G, self.lr_G))
                    K.set_value(self.optimizer_D.lr, self.exp_decay(self.step, self.decay_factor_D, self.lr_D))

                    if self.use_wandb:
                        import wandb
                        wandb.log({'d_loss': d_loss, 'd_acc_real': d_acc_real, 'd_acc_fake': d_acc_fake,
                                   'g_loss': g_loss[0],
                                   'lr_G': K.eval(self.optimizer_G.lr),
                                   'lr_D': K.eval(self.optimizer_D.lr)},
                                  step=self.step)

                if self.step % val_interval == 0:
                    self.gen_valid_results(self.step)
                if self.step % save_model_interval == 0:
                    self.save_model()
                self.step += 1

    def gen_valid_results(self, step_num, prefix=''):
        path = '%s/%s/%s' % (RESULT_DIR, self.result_name, VAL_DIR)
        os.makedirs(path, exist_ok=True)

        targets, targets_gt, inputs, _ = next(self.data_loader.get_random_batch(batch_size=3, stage='valid'))
        copyinputs=inputs
        copytargets=targets
            
        img1=targets[0]
        img2=targets[1]
        img3=targets[2]
        mask1=inputs[0]
        mask2=inputs[1]
        mask3=inputs[2]
        
        bit_or1=cv2.bitwise_or(img1,mask1)
        bit_or2=cv2.bitwise_or(img2,mask2)
        bit_or3=cv2.bitwise_or(img3,mask3)
        targets=np.array([bit_or1,bit_or2,bit_or3])
        targets_expanded = np.expand_dims(targets, axis=3)
        
        fig = gen_fig(copyinputs / self.input_trans,
                      targets_expanded ,
                      copytargets / self.target_trans)

        fig.savefig('%s/%s/%s/%s_%d.png' % (RESULT_DIR, self.result_name, VAL_DIR, prefix, step_num))

        if self.use_wandb:
            import wandb
            wandb.log({'val_image': fig}, step=self.step)

    def load_model(self, root_model_path):
        self.generator.load_weights(os.path.join(root_model_path, 'generator_weights.hdf5'))
        self.discriminator.load_weights(os.path.join(root_model_path, 'discriminator_weights.hdf5'))

        generator_json = json.load(open(os.path.join(root_model_path, 'generator.json')))
        discriminator_json = json.load(open(os.path.join(root_model_path, 'discriminator.json')))
        self.step = generator_json['iter']
        assert self.step == discriminator_json['iter']

        print('Weights loaded: {} @{}'.format(root_model_path, self.step))

    def save_model(self):
        model_dir = '%s/%s/%s' % (RESULT_DIR, self.result_name, MODELS_DIR)
        os.makedirs(model_dir, exist_ok=True)

        def save(model, model_name):
            model_json_path = '%s/%s.json' % (model_dir, model_name)
            weights_path = '%s/%s_weights.hdf5' % (model_dir, model_name)
            options = {'file_arch': model_json_path,
                       'file_weight': weights_path}
            json_string = model.to_json()
            json_obj = json.loads(json_string)
            json_obj['iter'] = self.step
            open(options['file_arch'], 'w').write(json.dumps(json_obj, indent=4))
            model.save_weights(options['file_weight'])

        save(self.generator, 'generator')
        save(self.discriminator, 'discriminator')
        print('Model saved in {}'.format(model_dir))

    def test(self):
        image_dir = '%s/%s/%s' % (RESULT_DIR, self.result_name, TEST_DIR)
        os.makedirs(image_dir, exist_ok=True)

        for batch_i, (targets, targets_gt, inputs, weight_maps) in enumerate(
                self.data_loader.get_iterative_batch(3, stage='valid')):
            copyinputs=inputs
            copytargets=targets
            '''
            fig_mask=gen_fig_display(copyinputs)
            fig_mask.savefig('%s/gen_mask.png'%(image_dir))
            fig_org=gen_fig_display(copytargets)
            fig_org.savefig('%s/gen_org.png'%(image_dir))
            '''
            img1=targets[0]
            img2=targets[1]
            img3=targets[2]
            mask1=inputs[0]
            mask2=inputs[1]
            mask3=inputs[2]
            
            bit_or1=cv2.bitwise_or(img1,mask1)
            bit_or2=cv2.bitwise_or(img2,mask2)
            bit_or3=cv2.bitwise_or(img3,mask3)
            targets=np.array([bit_or1,bit_or2,bit_or3])
            targets_expanded = np.expand_dims(targets, axis=3)
            
            
            fig = gen_fig_test(targets_expanded ,
                           
                          copytargets / self.target_trans)
            fig.savefig('%s/%d.png' % (image_dir, batch_i))
            
        print('Results saved in:', image_dir)

# main

In [7]:
from absl import app
from absl import flags

In [8]:
#for name in list(flags.FLAGS):
 #   delattr(flags.FLAGS,name)

"""flags.DEFINE_string('dataset_path', 'E:\\camus_challenge-master\\Camus_data_subset', 'Path of the dataset.')
flags.DEFINE_boolean('test', True, 'Test model and generate outputs on the test set')
flags.DEFINE_string('config', 'E:\\echo-generation-master\\configs\\ventricle.json', 'Config file for training hyper-parameters.')
flags.DEFINE_boolean('use_wandb', False, 'Use wandb for logging')
flags.DEFINE_string('wandb_resume_id', None, 'Resume wandb process with the given id')
flags.DEFINE_string('ckpt_load', None, 'Path to load the model')
flags.DEFINE_float('train_ratio', 0.95,
                   'Ratio of training data used for training and the rest used for testing. Set this value to 1.0 if '
                   'the data in the test folder are to be used for testing.')
flags.DEFINE_float('valid_ratio', 0.02, 'Ratio of training data used for validation')
flags.mark_flag_as_required('dataset_path')
flags.mark_flag_as_required('config')

FLAGS = flags.FLAGS
"""
plt.switch_backend('agg')



# Load configs from file
config = json.load(open('E:\\Projects\\echo-master\\configs\\ventric.json'))
#set_backend()

# Set name
name = '{}_{}_'.format(config['INPUT_NAME'], config['TARGET_NAME'])
for l in config['LABELS']:
    name += str(l)
    config['NAME'] += '_' + name

# Organize augmentation hyper-parameters from config
augmentation = dict()
for key, value in config.items():
    if 'AUG_' in key:
        augmentation[key] = value

# Initialize data loader
data_loader = DataLoaderCamus(
    dataset_path='E:\\Projects\\echo-master',
    input_name=config['INPUT_NAME'],
    target_name=config['TARGET_NAME'],
    condition_name=config['CONDITION_NAME'],
    img_res=config['IMAGE_RES'],
    target_rescale=config['TARGET_TRANS'],
    input_rescale=config['INPUT_TRANS'],
    condition_rescale=config['CONDITION_TRANS'],
    labels=config['LABELS'],
    train_ratio=1.00,
    valid_ratio=0.2,
    augment=augmentation
)

''' 
if FLAGS.use_wandb:
    import wandb
    resume_wandb = True if FLAGS.wandb_resume_id is not None else False
    wandb.init(config=config, resume=resume_wandb, id=FLAGS.wandb_resume_id, project='EchoGen')
'''
# Initialize GAN
model = PatchGAN(data_loader, config, False)

# load trained models if they exist
# if FLAGS.ckpt_load is not None:
#    model.load_model(FLAGS.ckpt_load)

model.train()  
model.test()
    
    


#if __name__ == '__main__':
#   app.run(main)

#train: 360
#valid: 90
#test: 50
Building discriminator
Building generator
PatchGAN valid shape: (8, 16, 16, 1)
[iter 0/50] [D loss: 12.934506, acc: dice coef:  4% f: 79%] [G loss: 36.031609] time: 0:00:08.754115
Model saved in results/ventricle_4CH_ED_gt_4CH_ED_0_4CH_ED_gt_4CH_ED_01/saved_models
[iter 10/50] [D loss: 0.284917, acc: dice coef: 72% f: 38%] [G loss: 24.306061] time: 0:00:57.345324
[iter 20/50] [D loss: 0.294379, acc: dice coef: 50% f: 49%] [G loss: 18.090919] time: 0:01:43.644458
[iter 30/50] [D loss: 0.229098, acc: dice coef: 60% f: 65%] [G loss: 17.378155] time: 0:02:29.695784
[iter 40/50] [D loss: 0.162748, acc: dice coef: 64% f: 90%] [G loss: 17.567894] time: 0:03:15.728644
[iter 50/50] [D loss: 0.221831, acc: dice coef: 47% f: 75%] [G loss: 14.798669] time: 0:04:05.497549
[iter 60/50] [D loss: 0.157043, acc: dice coef: 69% f: 89%] [G loss: 13.476373] time: 0:04:52.994359
[iter 70/50] [D loss: 0.125331, acc: dice coef: 78% f: 90%] [G loss: 12.945504] time: 0:05:43.16

  fig, axs = plt.subplots(r, c)


Results saved in: results/ventricle_4CH_ED_gt_4CH_ED_0_4CH_ED_gt_4CH_ED_01/test_images
