In [3]:
import time
import tensorflow as tf

bar = tf.keras.utils.Progbar(target=10, stateful_metrics=['train_loss', 'val_loss']) 
step = 0
train_loss = 1
val_loss = 1
for x in range(10):
    time.sleep(0.5)
    step += 1
    train_loss += 1
    bar.update(step, values=[('train_loss', train_loss)], finalize=True)
    

 1/10 [==>...........................] - 1s 501ms/step - train_loss: 2.0000
 2/10 [=====>........................] - 1s 502ms/step - train_loss: 3.0000


KeyboardInterrupt: 

In [1]:
from tensorflow.keras.utils import Progbar
import tensorflow as tf

import math

from models import SMD_Unet

import numpy as np
from data_generator import DR_Generator

tf.config.run_functions_eagerly(True)

class Trainer:
    def __init__(self, model, epochs, optimizer, for_recons, alpha, beta=None):
        '''
        for_recons : bool, 학습 단계 구분하기 위함
        alpha : recons loss에 곱해줄 가중치
        beta : [] , mask loss에 곱해줄 가중치 리스트
        '''
        self.model = model
        self.epochs = epochs
        self.optimizer = optimizer
        self.for_recons = for_recons
        self.alpha = tf.cast(alpha, dtype=tf.float64)
        self.beta = beta 

        if beta!=None:
            self.b1, self.b2, self.b3, self.b4 = beta
            self.b1 = tf.cast(self.b1, dtype=tf.float64)
            self.b2 = tf.cast(self.b2, dtype=tf.float64)
            self.b3 = tf.cast(self.b3, dtype=tf.float64)
            self.b4 = tf.cast(self.b4, dtype=tf.float64)
        
        # reconstruction만 학습하는거면 안쓰는 decoder trainable=False로 해주기
        if self.for_recons:
            self.model.HardExudate.trainable=False
            self.model.Hemohedge.trainable=False
            self.model.Microane.trainable=False
            self.model.SoftExudates.trainable=False
        else:
            self.model.HardExudate.trainable=True
            self.model.Hemohedge.trainable=True
            self.model.Microane.trainable=True
            self.model.SoftExudates.trainable=True

    # loss 함수 계산하는 부분 
    # return 값이 텐서여야 하는건가? -> 아마도 그런 것 같다.
    def dice_loss(self, inputs, targets, smooth = 1.):
        dice_losses = []
        
        for input, target in zip(inputs, targets): 
            input_flat = tf.reshape(input, [-1])
            target_flat = tf.reshape(target, [-1])
            
            input_flat = tf.cast(input_flat, dtype=tf.float64)
            target_flat = tf.cast(target_flat, dtype=tf.float64) 
            
            intersection = tf.reduce_sum(input_flat * target_flat)
            dice_coef = (2. * intersection + smooth) / (tf.reduce_sum(input_flat) + tf.reduce_sum(target_flat) + smooth)

            dice_losses.append(1. - dice_coef)
            
        result = tf.reduce_mean(dice_losses) 
        return result
    
    def mean_square_error(self, input_hats, inputs):        
        mses = []
        
        for input_hat, input in zip(input_hats, inputs):
            mses.append(tf.reduce_mean(tf.square(input_hat - input)))
            
        result = tf.reduce_mean(mses) # 배치 나눠서 계산하고 평균해주기
        return result

    @tf.function
    def train_on_batch(self, x_batch_train, y_batch_train):
        with tf.GradientTape() as tape:
            preds = self.model(x_batch_train, only_recons=self.for_recons)    # 모델이 예측한 결과
#             input_hat, ex_hat, he_hat, ma_hat, se_hat = preds
            
#             ex, he, ma, se = y_batch_train
            
            # loss 계산하기
            # reconstruction
            loss_recons = tf.cast(self.mean_square_error(preds[0], x_batch_train), dtype=tf.float64)

            if not self.for_recons:
            # ex, he, ma, se
                ex_loss = self.dice_loss(y_batch_train[0], preds[1])
                he_loss = self.dice_loss(y_batch_train[1], preds[2])
                ma_loss = self.dice_loss(y_batch_train[2], preds[3])
                se_loss = self.dice_loss(y_batch_train[3], preds[4])            
                # loss 가중합 해주기
                train_loss = self.b1 * ex_loss + self.b2 * he_loss + self.b3 * ma_loss + self.b4 * se_loss + self.alpha * loss_recons
            else:     
                train_loss = loss_recons 
            
        grads = tape.gradient(train_loss, self.model.trainable_weights)  # gradient 계산
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))  # Otimizer에게 처리된 그라데이션 적용을 요청
        
        del preds
        
        return train_loss

    def train(self, train_dataset, val_dataset):
        metrics_names = ['train_loss', 'val_loss']

        for epoch in range(self.epochs):
            print("\nEpoch {}/{}".format(epoch+1, self.epochs))

            # train_dataset = train_dataset.take(steps_per_epoch)
            # val_dataset = val_dataset.take(val_step)

            tr_progBar = Progbar(target=len(train_dataset) * train_dataset.batch_size, stateful_metrics=['train_loss'])
            
            # 데이터 집합의 배치에 대해 반복합니다
            for step_train, (x_batch_train, y_batch_train) in enumerate(train_dataset):
                train_loss = self.train_on_batch(x_batch_train, y_batch_train)

                # train metric(mean, auc, accuracy 등) 업데이트
                # acc_metric.update_state(y_batch_train, logits)

                values = [('train_loss', train_loss.numpy())]
                tr_progBar.update((step_train + 1) * train_dataset.batch_size, values=values)
                
                del train_loss
                del x_batch_train
                del y_batch_train
            
            val_progBar = Progbar(target=len(val_dataset) * val_dataset.batch_size, stateful_metrics=['val_loss'])
            
            for step_val, (x_batch_val, y_batch_val) in enumerate(val_dataset):
                preds = self.model(x_batch_val, only_recons=self.for_recons)    # 모델이 예측한 결과
#                 input_hat, ex_hat, he_hat, ma_hat, se_hat = preds
                
#                 ex, he, ma, se = y_batch_val
                
                # loss 계산하기
                # reconstruction
                loss_recons = tf.cast(self.mean_square_error(preds[0], x_batch_val), dtype=tf.float64)
                
                if not self.for_recons:
                # ex, he, ma, se
                    ex_loss = self.dice_loss(y_batch_val[0], preds[1])
                    he_loss = self.dice_loss(y_batch_val[1], preds[2])
                    ma_loss = self.dice_loss(y_batch_val[2], preds[3])
                    se_loss = self.dice_loss(y_batch_val[3], preds[4])            
                    # loss 가중합 해주기
                    val_loss = self.b1 * ex_loss + self.b2 * he_loss + self.b3 * ma_loss + self.b4 * se_loss + self.alpha * loss_recons
                else:     
                    val_loss = loss_recons
                    
                values = [('val_loss', val_loss.numpy())]
                val_progBar.update((step_val + 1) * val_dataset.batch_size, values=values)
                
                del val_loss
                del x_batch_val
                del y_batch_val
                del preds


In [2]:
import os

masks = ['HardExudate_Masks', 'Hemohedge_Masks', 'Microaneurysms_Masks', 'SoftExudate_Masks']
mask_dir = '../data/Seg-set'
mask_paths = [os.path.join(mask_dir, mask) for mask in masks]

generator_args = {
  'dir_path':'../data/Seg-set/Original_Images/',
  'mask_path':mask_paths,
  'use_mask':True,
  'img_size':(512, 512),  
  'batch_size':4, # 8로 하면 바로 OOM 뜸
  'dataset':'FGADR', # FGADR or EyePacks
  'is_train':True
}

tr_eyepacks_gen = DR_Generator(start_end_index=(0, 500), **generator_args)
val_eyepacks_gen = DR_Generator(start_end_index=(500, 600), **generator_args)

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001) 

model = SMD_Unet(filters=[32, 64, 128, 256, 512])

trainer_args = {
    'model':model,
    'epochs':1,
    'optimizer':optimizer,
    'for_recons':False,
    'alpha':1.0,
    'beta':[1.0, 1.0, 1.0, 1.0]
}
trainer = Trainer(**trainer_args)

trainer.train(train_dataset=tr_eyepacks_gen,
              val_dataset=val_eyepacks_gen
              )


Epoch 1/1

In [6]:
# model = SMD_Unet(filters=[32, 64, 128, 256, 512])
# model.build(input_shape=(None, 512, 512, 1))
# model.summary()

Model: "smd__unet_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_block_2 (EncoderBloc multiple                  4719584   
_________________________________________________________________
decoder_block_10 (DecoderBlo multiple                  3052269   
_________________________________________________________________
decoder_block_11 (DecoderBlo multiple                  3052269   
_________________________________________________________________
decoder_block_12 (DecoderBlo multiple                  3052269   
_________________________________________________________________
decoder_block_13 (DecoderBlo multiple                  3052269   
_________________________________________________________________
decoder_block_14 (DecoderBlo multiple                  3052269   
Total params: 19,980,929
Trainable params: 19,967,341
Non-trainable params: 13,588
______________________________________

In [None]:
from tensorflow import keras
import tensorflow as tf

def ConvBlock(x, n_filters):
    x = keras.layers.Conv2D(n_filters, 3, padding='same', activation='relu', kernel_initializer='he_normal')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation('relu')(x)

    return x

def UpsampleBlock(x, skip, n_filters):
    x = keras.layers.Conv2DTranspose(n_filters, (2, 2), strides=2, padding='same')(x)
    x = keras.layers.Concatenate()([x, skip]) 
    x = ConvBlock(x, n_filters)

    return x

def Encoder(x, filters):
    skips = []

    for f in filters:
        x = ConvBlock(x, f)
        x = ConvBlock(x, f)
        # 맨 마지막 층을 제외하고는 skip connection, downsampling을 진행
        if f != filters[-1]:
            skips.append(x)
            x = keras.layers.MaxPooling2D(2)(x)
      
    return x, skips

def Decoder(x, filters, skips):
    for f, skip in zip(filters, skips):
        x = keras.layers.Conv2DTranspose(f, (2, 2), strides=2, padding='same')(x)
        x = ConvBlock(x, f)
        x = keras.layers.Concatenate()([x, skip]) 
        x = ConvBlock(x, f)

        x = ConvBlock(x, 2)
        x = keras.layers.Conv2D(filters=1, kernel_size=1, padding='same', activation='linear')(x)

    return x

def Unet(img_size):
    inputs = keras.Input(shape=img_size + (1,))

    # 축소 경로
    filters = [64, 128, 256, 512, 1024]

    x, skips = Encoder(inputs, filters)

    # 확장 경로
    x = Decoder(x, filters[::-1][1:], skips[::-1])

    # loss = mse  
    model = keras.Model(inputs, x)

    return model

In [1]:
import os
from data_generator import DR_Generator

masks = ['HardExudate_Masks', 'Hemohedge_Masks', 'Microaneurysms_Masks', 'SoftExudate_Masks']
mask_dir = '../data/Seg-set'
mask_paths = [os.path.join(mask_dir, mask) for mask in masks]

generator_args = {
  'dir_path':'../data/Seg-set/Original_Images/',
  'mask_path':mask_paths,
  'use_mask':True,
  'img_size':(512, 512),  
  'batch_size':4, # 8로 하면 바로 OOM 뜸
  'dataset':'FGADR', # FGADR or EyePacks
  'is_train':True
}

tr_eyepacks_gen = DR_Generator(start_end_index=(0, 500), **generator_args)
val_eyepacks_gen = DR_Generator(start_end_index=(500, 600), **generator_args)

In [26]:
model = Unet(img_size=(512, 512))
# model.summary()

In [31]:
# model.compile(loss='mean_squared_error')
# model.fit(tr_eyepacks_gen, epochs=1)
# unet은 문제없음

In [3]:
def SMD_Unet(img_size, filters):
    inputs = keras.Input(shape=img_size + (1,))

    # 축소 경로
    x, skips = Encoder(inputs, filters)

    # 확장 경로
    # mask : HardExudate, Hemohedge, Microane, SoftExudates
    input_hat = Decoder(x, filters[::-1][1:], skips[::-1]) # 원본 이미지 추정
    ex = Decoder(x, filters[::-1][1:], skips[::-1])
    he = Decoder(x, filters[::-1][1:], skips[::-1])
    ma = Decoder(x, filters[::-1][1:], skips[::-1])
    se = Decoder(x, filters[::-1][1:], skips[::-1])

    model = keras.Model(inputs, outputs=[input_hat, ex, he, ma, se])

    return model

In [4]:
filters = [32, 64, 128, 256, 512]

model = SMD_Unet((512, 512), filters)
model.compile(loss='mean_squared_error')
model.fit(tr_eyepacks_gen, epochs=1)

# trainer class 문제는 아님



KeyboardInterrupt: 