In [20]:
from tensorflow.keras.utils import Progbar
import tensorflow as tf

import math

from models import SMD_Unet

import numpy as np
from data_generator import DR_Generator

class Trainer:
    def __init__(self, model, epochs, optimizer, for_recons, alpha, beta=None):
        '''
        for_recons : bool, 학습 단계 구분하기 위함
        alpha : recons loss에 곱해줄 가중치
        beta : [] , mask loss에 곱해줄 가중치 리스트
        '''
        self.model = model
        self.epochs = epochs
        self.optimizer = optimizer
        self.for_recons = for_recons
        self.alpha = alpha
        self.beta = beta 
        if beta!=None:
            self.b1, self.b2, self.b3, self.b4 = beta
        
        # reconstruction만 학습하는거면 안쓰는 decoder trainable=False로 해주기
        if self.for_recons:
            self.model.HardExudate.trainable=False
            self.model.Hemohedge.trainable=False
            self.model.Microane.trainable=False
            self.model.SoftExudates.trainable=False
        else:
            self.model.HardExudate.trainable=True
            self.model.Hemohedge.trainable=True
            self.model.Microane.trainable=True
            self.model.SoftExudates.trainable=True

    # loss 함수 계산하는 부분 
    # return 값이 텐서여야 하는건가? -> 아마도 그런 것 같다.
    def dice_loss(self, inputs, targets, smooth = 1.):
        inputs = inputs
        targets = targets.numpy()

        dice_losses = []
        for input, target in zip(inputs, targets): 
            input_flat = input.flatten()
            target_flat = target.flatten()
            
            intersection = np.sum(input_flat * target_flat)
            dice_coef = (2. * intersection + smooth) / (np.sum(input_flat) + np.sum(target_flat) + smooth)

            dice_losses.append(1. - dice_coef)
            
        result = tf.reduce_mean(dice_losses) 
        return tf.convert_to_tensor(result) # tensor로 바꿔주기
    
    def mean_square_error(self, input_hats, inputs):
        inputs = inputs
        input_hats = input_hats.numpy()
        
        mses = []
        for input_hat, input in zip(input_hats, inputs):
            mses.append(tf.reduce_mean(tf.square(input_hat - input)).numpy())
            
        result = np.mean(mses) # 배치 나눠서 계산하고 평균해주기
        return tf.convert_to_tensor(result) # tensor로 바꿔주기

    @tf.function
    def train_on_batch(self, x_batch_train, y_batch_train):
        with tf.GradientTape() as tape:
            preds = self.model(x_batch_train, only_recons=self.for_recons)    # 모델이 예측한 결과
#             input_hat, ex_hat, he_hat, ma_hat, se_hat = preds
            
#             ex, he, ma, se = y_batch_train
            
            # loss 계산하기
            # reconstruction
            loss_recons = self.mean_square_error(preds[0], x_batch_train)
            
            if not self.for_recons:
            # ex, he, ma, se
                ex_loss = self.dice_loss(y_batch_train[0], preds[1])
                he_loss = self.dice_loss(y_batch_train[1], preds[2])
                ma_loss = self.dice_loss(y_batch_train[2], preds[3])
                se_loss = self.dice_loss(y_batch_train[3], preds[4])            
                # loss 가중합 해주기
                train_loss = self.b1 * ex_loss + self.b2 * he_loss + self.b3 * ma_loss + self.b4 * se_loss + self.alpha * loss_recons
            else:     
                train_loss = loss_recons 
            
        grads = tape.gradient(train_loss, self.model.trainable_weights)  # gradient 계산
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))  # Otimizer에게 처리된 그라데이션 적용을 요청
        
        del preds
        
        return train_loss

    def train(self, train_dataset, val_dataset):
        metrics_names = ['train_loss', 'val_loss']

        for epoch in range(self.epochs):
            print("\nEpoch {}/{}".format(epoch+1, self.epochs))

            # train_dataset = train_dataset.take(steps_per_epoch)
            # val_dataset = val_dataset.take(val_step)

            progBar = Progbar(len(train_dataset), stateful_metrics=metrics_names)

            # 데이터 집합의 배치에 대해 반복합니다
            for step_train, (x_batch_train, y_batch_train) in enumerate(train_dataset):
                train_loss = self.train_on_batch(x_batch_train, y_batch_train)

                # train metric(mean, auc, accuracy 등) 업데이트
                # acc_metric.update_state(y_batch_train, logits)

                # train_acc = self.compute_acc(logits, y_batch_train)
                values = [('train_loss', train_loss)]
                # print('{}'.format((step_train + 1) * self.batch))
                progBar.update((step_train + 1) * self.batch, values=values)
                
                del train_loss
                del x_batch_train
                del y_batch_train

            for step, (x_batch_val, y_batch_val) in enumerate(val_dataset):
                preds = self.model(x_batch_val, only_recons=self.for_recons)    # 모델이 예측한 결과
#                 input_hat, ex_hat, he_hat, ma_hat, se_hat = preds
                
#                 ex, he, ma, se = y_batch_val
                
                # loss 계산하기
                # reconstruction
                loss_recons = self.mean_square_error(preds[0], x_batch_val)
                
                if not self.for_recons:
                # ex, he, ma, se
                    ex_loss = self.dice_loss(y_batch_val[0], preds[1])
                    he_loss = self.dice_loss(y_batch_val[1], preds[2])
                    ma_loss = self.dice_loss(y_batch_val[2], preds[3])
                    se_loss = self.dice_loss(y_batch_val[3], preds[4])            
                    # loss 가중합 해주기
                    val_loss = self.b1 * ex_loss + self.b2 * he_loss + self.b3 * ma_loss + self.b4 * se_loss + self.alpha * loss_recons
                else:     
                    val_loss = loss_recons
                    
                del val_loss
                del x_batch_val
                del y_batch_val
                del preds
                
            values = [('train_loss', train_loss), ('val_loss', val_loss)]
            progBar.update((step_train + 1) * self.batch, values=values, finalize=True)


In [16]:
model = SMD_Unet()
model.build(input_shape=(None, 512, 512, 1))
model.summary()

Model: "smd__unet_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
encoder_block_7 (EncoderBloc multiple                  18857920  
_________________________________________________________________
decoder_block_35 (DecoderBlo multiple                  12197325  
_________________________________________________________________
decoder_block_36 (DecoderBlo multiple                  12197325  
_________________________________________________________________
decoder_block_37 (DecoderBlo multiple                  12197325  
_________________________________________________________________
decoder_block_38 (DecoderBlo multiple                  12197325  
_________________________________________________________________
decoder_block_39 (DecoderBlo multiple                  12197325  
Total params: 79,844,545
Trainable params: 79,817,389
Non-trainable params: 27,156
______________________________________

In [12]:
import os

masks = ['HardExudate_Masks', 'Hemohedge_Masks', 'Microaneurysms_Masks', 'SoftExudate_Masks']
mask_dir = '../data/Seg-set'
mask_paths = [os.path.join(mask_dir, mask) for mask in masks]

generator_args = {
  'dir_path':'../data/Seg-set/Original_Images/',
  'mask_path':mask_paths,
  'use_mask':True,
  'img_size':(256, 256),  
  'batch_size':2, # 8로 하면 바로 OOM 뜸
  'dataset':'FGADR', # FGADR or EyePacks
  'is_train':True
}

tr_eyepacks_gen = DR_Generator(start_end_index=(0, 500), **generator_args)
val_eyepacks_gen = DR_Generator(start_end_index=(500, 600), **generator_args)

In [18]:
tf.config.run_functions_eagerly(True)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001) 

model = SMD_Unet()

trainer_args = {
    'model':model,
    'epochs':1,
    'optimizer':optimizer,
    'for_recons':True,
    'alpha':1.0,
    'beta':[1.0, 1.0, 1.0, 1.0]
}
trainer = Trainer(**trainer_args)

trainer.train(train_dataset=tr_eyepacks_gen,
              val_dataset=val_eyepacks_gen
              )

NameError: name 'Unet' is not defined

In [22]:
from tensorflow import keras
import tensorflow as tf

def ConvBlock(x, n_filters):
    x = keras.layers.Conv2D(n_filters, 3, padding='same', activation='relu', kernel_initializer='he_normal')(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Activation('relu')(x)

    return x

def UpsampleBlock(x, skip, n_filters):
    x = keras.layers.Conv2DTranspose(n_filters, (2, 2), strides=2, padding='same')(x)
    x = keras.layers.Concatenate()([x, skip]) 
    x = ConvBlock(x, n_filters)

    return x

def Encoder(x, filters):
    skips = []

    for f in filters:
        x = ConvBlock(x, f)
        x = ConvBlock(x, f)
        # 맨 마지막 층을 제외하고는 skip connection, downsampling을 진행
        if f != 1024:
            skips.append(x)
            x = keras.layers.MaxPooling2D(2)(x)
      
    return x, skips

def Decoder(x, filters, skips):
    for f, skip in zip(filters, skips):
        x = keras.layers.Conv2DTranspose(f, (2, 2), strides=2, padding='same')(x)
        x = ConvBlock(x, f)
        x = keras.layers.Concatenate()([x, skip]) 
        x = ConvBlock(x, f)

        x = ConvBlock(x, 2)
        x = keras.layers.Conv2D(filters=1, kernel_size=1, padding='same', activation='linear')(x)

    return x

def Unet(img_size):
    inputs = keras.Input(shape=img_size + (1,))

    # 축소 경로
    filters = [64, 128, 256, 512, 1024]

    x, skips = Encoder(inputs, filters)

    # 확장 경로
    x = Decoder(x, filters[::-1][1:], skips[::-1])

    # loss = mse  
    model = keras.Model(inputs, x)

    return model

In [25]:
import os

masks = ['HardExudate_Masks', 'Hemohedge_Masks', 'Microaneurysms_Masks', 'SoftExudate_Masks']
mask_dir = '../data/Seg-set'
mask_paths = [os.path.join(mask_dir, mask) for mask in masks]

generator_args = {
  'dir_path':'../data/Seg-set/Original_Images/',
  'mask_path':mask_paths,
  'use_mask':True,
  'img_size':(512, 512),  
  'batch_size':4, # 8로 하면 바로 OOM 뜸
  'dataset':'FGADR', # FGADR or EyePacks
  'is_train':True
}

tr_eyepacks_gen = DR_Generator(start_end_index=(0, 500), **generator_args)
val_eyepacks_gen = DR_Generator(start_end_index=(500, 600), **generator_args)

In [26]:
model = Unet(img_size=(512, 512))
# model.summary()

In [31]:
# model.compile(loss='mean_squared_error')
# model.fit(tr_eyepacks_gen, epochs=1)
# unet은 문제없음

In [32]:
def SMD_Unet(img_size):
    inputs = keras.Input(shape=img_size + (1,))

    # 축소 경로
    filters = [64, 128, 256, 512, 1024]

    x, skips = Encoder(inputs, filters)

    # 확장 경로
    # mask : HardExudate, Hemohedge, Microane, SoftExudates
    input_hat = Decoder(x, filters[::-1][1:], skips[::-1]) # 원본 이미지 추정
    ex = Decoder(x, filters[::-1][1:], skips[::-1])
    he = Decoder(x, filters[::-1][1:], skips[::-1])
    ma = Decoder(x, filters[::-1][1:], skips[::-1])
    se = Decoder(x, filters[::-1][1:], skips[::-1])

    model = keras.Model(inputs, outputs=[input_hat, ex, he, ma, se])

    return model

In [36]:
# model = SMD_Unet((512, 512))
# model.compile(loss='mean_squared_error')
# model.fit(tr_eyepacks_gen, epochs=1)

# trainer class 문제는 아님