In [54]:
from tensorflow.keras.utils import Progbar
import tensorflow as tf

import math

from models import SMD_Unet

import numpy as np
from data_generator import DR_Generator

class Trainer:
    def __init__(self, model, epochs, optimizer, for_recons, alpha, beta):
        '''
        for_recons : bool, 학습 단계 구분하기 위함
        alpha : recons loss에 곱해줄 가중치
        beta : [] , mask loss에 곱해줄 가중치 리스트
        '''
        self.model = model
        self.epochs = epochs
        self.optimizer = optimizer
        self.for_recons = for_recons
        self.alpha = alpha
        self.beta = beta 
        if not beta:
            self.b1, self.b2, self.b3, self.b4 = beta
        
        # reconstruction만 학습하는거면 안쓰는 decoder trainable=False로 해주기
        if self.for_recons:
            self.model.HardExudate.trainable=False
            self.model.Hemohedge.trainable=False
            self.model.Microane.trainable=False
            self.model.SoftExudates.trainable=False
        else:
            self.model.HardExudate.trainable=True
            self.model.Hemohedge.trainable=True
            self.model.Microane.trainable=True
            self.model.SoftExudates.trainable=True

    # loss 함수 계산하는 부분 
    # return 값이 텐서여야 하는건가? -> 아마도 그런 것 같다.
    def dice_loss(self, inputs, targets, smooth = 1.):
        inputs = inputs
        targets = targets.numpy()

        dice_losses = []
        for input, target in zip(inputs, targets): 
            input_flat = input.flatten()
            target_flat = target.flatten()
            
            intersection = np.sum(input_flat * target_flat)
            dice_coef = (2. * intersection + smooth) / (np.sum(input_flat) + np.sum(target_flat) + smooth)

            dice_losses.append(1. - dice_coef)
            
        result = tf.reduce_mean(dice_losses) 
        return tf.convert_to_tensor(result) # tensor로 바꿔주기
    
    def mean_square_error(self, input_hats, inputs):
        inputs = inputs
        input_hats = input_hats.numpy()
        
        mses = []
        for input_hat, input in zip(input_hats, inputs):
            mses.append(tf.reduce_mean(tf.square(input_hat - input)).numpy())
            
        result = np.mean(mses) # 배치 나눠서 계산하고 평균해주기
        return tf.convert_to_tensor(result) # tensor로 바꿔주기

    @tf.function
    def train_on_batch(self, x_batch_train, y_batch_train):
        with tf.GradientTape() as tape:
            preds = self.model(x_batch_train, only_recons=self.for_recons)    # 모델이 예측한 결과
            input_hat, ex_hat, he_hat, ma_hat, se_hat = preds
            
            ex, he, ma, se = y_batch_train
            
            # loss 계산하기
            # reconstruction
            loss_recons = self.mean_square_error(input_hat, x_batch_train)
            
            if not self.for_recons:
            # ex, he, ma, se
                ex_loss = self.dice_loss(ex, ex_hat)
                he_loss = self.dice_loss(he, he_hat)
                ma_loss = self.dice_loss(ma, ma_hat)
                se_loss = self.dice_loss(se, se_hat)            
                # loss 가중합 해주기
                train_loss = self.b1 * ex_loss + self.b2 * he_loss + self.b3 * ma_loss + self.b4 * se_loss + self.alpha * loss_recons
            else:     
                train_loss = loss_recons 
        
        grads = tape.gradient(train_loss, self.model.trainable_weights)  # gradient 계산
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_weights))  # Otimizer에게 처리된 그라데이션 적용을 요청

        return train_loss, preds

    def train(self, train_dataset, val_dataset):
        metrics_names = ['train_loss', 'val_loss']

        for epoch in range(self.epochs):
            print("\nEpoch {}/{}".format(epoch+1, self.epochs))

            # train_dataset = train_dataset.take(steps_per_epoch)
            # val_dataset = val_dataset.take(val_step)

            progBar = Progbar(len(train_dataset), stateful_metrics=metrics_names)

            # 데이터 집합의 배치에 대해 반복합니다
            for step_train, (x_batch_train, y_batch_train) in enumerate(train_dataset):
                train_loss, logits = self.train_on_batch(x_batch_train, y_batch_train)

                # train metric(mean, auc, accuracy 등) 업데이트
                # acc_metric.update_state(y_batch_train, logits)

                # train_acc = self.compute_acc(logits, y_batch_train)
                values = [('train_loss', train_loss)]
                # print('{}'.format((step_train + 1) * self.batch))
                progBar.update((step_train + 1) * self.batch, values=values)

            for step, (x_batch_val, y_batch_val) in enumerate(val_dataset):
                preds = self.model(x_batch_val, only_recons=self.for_recons)    # 모델이 예측한 결과
                input_hat, ex_hat, he_hat, ma_hat, se_hat = preds
                
                ex, he, ma, se = y_batch_val
                
                # loss 계산하기
                # reconstruction
                loss_recons = self.mean_square_error(input_hat, x_batch_val)
                
                if not self.for_recons:
                # ex, he, ma, se
                    ex_loss = self.dice_loss(ex, ex_hat)
                    he_loss = self.dice_loss(he, he_hat)
                    ma_loss = self.dice_loss(ma, ma_hat)
                    se_loss = self.dice_loss(se, se_hat)            
                    # loss 가중합 해주기
                    val_loss = self.b1 * ex_loss + self.b2 * he_loss + self.b3 * ma_loss + self.b4 * se_loss + self.alpha * loss_recons
                else:     
                    val_loss = loss_recons 
                
            values = [('train_loss', train_loss), ('val_loss', val_loss)]
            progBar.update((step_train + 1) * self.batch, values=values, finalize=True)


In [55]:
import os

masks = ['HardExudate_Masks', 'Hemohedge_Masks', 'Microaneurysms_Masks', 'SoftExudate_Masks']
mask_dir = '../data/FGADR-Seg-set_Release/Seg-set'
mask_paths = [os.path.join(mask_dir, mask) for mask in masks]

generator_args = {
  'dir_path':'../data/FGADR-Seg-set_Release/Seg-set/Original_Images/',
  'mask_path':mask_paths,
  'use_mask':True,
  'img_size':(512, 512),  
  'batch_size':4, # 8로 하면 바로 OOM 뜸
  'dataset':'FGADR', # FGADR or EyePacks
  'is_train':True
}

tr_eyepacks_gen = DR_Generator(start_end_index=(0, 500), **generator_args)
val_eyepacks_gen = DR_Generator(start_end_index=(500, 600), **generator_args)

In [56]:
tf.config.run_functions_eagerly(True)

optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001) 

model = SMD_Unet()

trainer_args = {
    'model':model,
    'epochs':1,
    'optimizer':optimizer,
    'for_recons':False,
    'alpha':1.0,
    'beta':[1.0, 1.0, 1.0, 1.0]
}
trainer = Trainer(**trainer_args)

trainer.train(train_dataset=tr_eyepacks_gen,
              val_dataset=val_eyepacks_gen
              )


Epoch 1/1


ResourceExhaustedError: Exception encountered when calling layer "activation_894" "                 f"(type Activation).

{{function_node __wrapped__Relu_device_/job:localhost/replica:0/task:0/device:CPU:0}} OOM when allocating tensor with shape[4,256,256,128] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator mklcpu [Op:Relu]

Call arguments received by layer "activation_894" "                 f"(type Activation):
  • inputs=tf.Tensor(shape=(4, 256, 256, 128), dtype=float32)

In [None]:
# 참고 코드
# https://pyimagesearch.com/2018/06/04/keras-multiple-outputs-and-multiple-losses/

# initialize our FashionNet multi-output network
model = FashionNet.build(96, 96,
	numCategories=len(categoryLB.classes_),
	numColors=len(colorLB.classes_),
	finalAct="softmax")
# define two dictionaries: one that specifies the loss method for
# each output of the network along with a second dictionary that
# specifies the weight per loss
losses = {
	"category_output": "categorical_crossentropy",
	"color_output": "categorical_crossentropy",
}
lossWeights = {"category_output": 1.0, "color_output": 1.0}
# initialize the optimizer and compile the model
print("[INFO] compiling model...")
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)
model.compile(optimizer=opt, loss=losses, loss_weights=lossWeights,
	metrics=["accuracy"])

H = model.fit(x=trainX,
	y={"category_output": trainCategoryY, "color_output": trainColorY},
	validation_data=(testX,
		{"category_output": testCategoryY, "color_output": testColorY}),
	epochs=EPOCHS,
	verbose=1)