In [None]:
import os
from glob import glob
import numpy as np
import tensorflow as tf
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['image.cmap'] = 'gray'

## 모델 불러오기

In [None]:
from models import UNetGenerator, Discriminator
# from nested_unet_model import nUNetGenerator, Discriminator

generator = UNetGenerator()
# generator = nUNetGenerator()
discriminator = Discriminator()

## Loss function, Optimizer

In [None]:
from keras import losses

bce = losses.BinaryCrossentropy(from_logits=False)
mae = losses.MeanAbsoluteError()

def get_gene_loss(fake_output, real_output, fake_disc):
    l1_loss = mae(real_output, fake_output)
    gene_loss = bce(tf.ones_like(fake_disc), fake_disc)
    return gene_loss, l1_loss

def get_disc_loss(fake_disc, real_disc):
    return bce(tf.zeros_like(fake_disc), fake_disc) + bce(tf.ones_like(real_disc), real_disc)

from keras import optimizers

gene_opt = optimizers.Adam(2e-4, beta_1=.5, beta_2=.999)
disc_opt = optimizers.Adam(2e-4, beta_1=.5, beta_2=.999)

## Train step

In [None]:
@tf.function
def train_step(inp, tar):
    with tf.GradientTape() as gene_tape, tf.GradientTape() as disc_tape:
        # Generator 예측 = fake
        gen_inp = generator(inp, training=True)
        # Discriminator 예측
        fake_disc = discriminator(inp, gen_inp, training=True)
        real_disc = discriminator(inp, tar, training=True)
        # Generator 손실 계산
        gene_loss, l1_loss = get_gene_loss(gen_inp, tar, fake_disc)
        gene_total_loss = gene_loss + (100 * l1_loss) ## <===== L1 손실 반영 λ=100
        # Discrminator 손실 계산
        disc_loss = get_disc_loss(fake_disc, real_disc)
                
    gene_gradient = gene_tape.gradient(gene_total_loss, generator.trainable_variables)
    disc_gradient = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    
    gene_opt.apply_gradients(zip(gene_gradient, generator.trainable_variables))
    disc_opt.apply_gradients(zip(disc_gradient, discriminator.trainable_variables))
    return gene_loss, l1_loss, disc_loss

## Checkpoint

In [None]:
## 1. 저장할 폴더와 형식을 선택
ckpt_directory = "체크포인트 디렉토리 경로"
#### epoch로 파일이름 저장
ckpt_path = ckpt_directory+"/체크포인트_파일명.ckpt" 

checkpoint_prefix = os.path.join(ckpt_path)
checkpoint = tf.train.Checkpoint(unet_optimizer=gene_opt,
                                 discriminator_optimizer=disc_opt,
                                 unet=generator,
                                 discriminator=discriminator)

- `data.Dataset.list_files` : https://soki.tistory.com/m/20

## 데이터셋 업로드 및 학습

In [None]:
EPOCHS = 200   
n_batch = 10

from numpy import load
# load and prepare training images
def load_real_samples(filename):
    # load compressed arrays
    data = load(filename)
    # unpack arrays
    X1, X2 = data['arr_0'], data['arr_1']
    return [X1, X2]

dataset = load_real_samples('데이터셋_파일명')
print(dataset[0].shape)

from numpy.random import randint
from numpy import zeros
from numpy import ones
# select a batch of random samples, returns images and target
def generate_real_samples(dataset, n_samples):
	# unpack dataset
	trainA, trainB = dataset
	# choose random instances
	ix = randint(0, trainA.shape[0], n_samples)
	# retrieve selected images
	X1, X2 = trainA[ix], trainB[ix]
	# generate 'real' class labels (1)
	y = ones((n_samples, 16,16,1))
	return [X1, X2], y

import time
## 1에폭당 학습 스텝
n_steps = int(len(dataset[0]) / n_batch)

ep_list, g_list, d_list = [], [], []
for epoch in range(1, EPOCHS+1):    
    t = time.time()
    for i in range(n_steps):
        [inp, tar], y_real = generate_real_samples(dataset, n_batch)
        g_loss, l1_loss, d_loss = train_step(inp, tar)

        # 10회 반복마다 손실을 출력합니다.        
        if (i+1) % 10 == 0:
            print(f"EPOCH[{epoch}] - STEP[{i+1}] \
                    \nGenerator_loss:{g_loss.numpy():.4f}, L1_loss:{l1_loss.numpy():.4f} \
                    \nDiscriminator_loss:{d_loss.numpy():.4f}, ", (time.time()-t)/60, " minute", end="\n\n")
        if (i+1) % n_steps == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
            ep_list.append(epoch)
            g_list.append(g_loss)
            d_list.append(d_loss)

## 학습 진행에 따른 generator/discriminator loss

In [None]:
gloss_dict = dict(zip(ep_list, g_list))
# print(gloss_dict.keys())
dloss_dict = dict(zip(ep_list, d_list))
fig = plt.figure(figsize=(10,5))
ax = fig.add_subplot(1,1,1)
#ax.set_xticks(list(range(EPOCHS)), gloss_dict.keys())
xticks = [5*i for i in range(1,int(EPOCHS/5)+1)]
ax.set_xticks(xticks)
plt.xticks(rotation=45, fontsize=8)
plt.plot(gloss_dict.values(), label='generator loss')
plt.plot(dloss_dict.values(), label='discriminator loss')
plt.legend()
plt.show()

## 학습 후 source/ generated/ ground truth 랜덤 이미지 비교

In [None]:
from numpy.random import randint
ix = randint(0, len(dataset[0]), 1)
inp, tar = dataset[0][ix], dataset[1][ix] #.reshape(512,512,1)
print(inp.shape)

pred = generator(inp)

title = ['Source', 'Generated', 'Expected']
fig, ax = plt.subplots(1, 3, figsize=(10*3, 10))
ax[0].imshow(np.squeeze(inp))
ax[1].imshow(np.squeeze(pred))
ax[2].imshow(np.squeeze(tar))
ax[0].set_title(title[0], fontsize=30)
ax[1].set_title(title[1], fontsize=30)
ax[2].set_title(title[2], fontsize=30)

## 검증 데이터에 대한 예측 이미지 데이터 저장

In [None]:
valid_dt = load_real_samples('valid_a.npz')
print(valid_dt[0].shape)

pred_img = []
for s in valid_dt[0]:
    pred = generator(s.reshape(1,512,512,1))
    pred_img.append(pred)

In [None]:
## 예측이미지 데이터 shape 확인
pred_img = np.array(pred_img).reshape(800,512,512,1)
print(pred_img.shape)

In [None]:
# save files : (source, ground truth, predicted)
filename = '200ep-10b-pred-a.npz'
np.savez_compressed(filename, valid_dt[0], valid_dt[1], pred_img)
print('Saved dataset: ', filename)

## 전체 source/ generated/ ground truth 이미지 비교

In [None]:
# three images
def get_3_image(one, two, three, vmin, vmax):
    def current_slice(idx):
        title = ['Source', 'Generated', 'Expected']
        fig, ax = plt.subplots(1, 3, figsize=(15*3, 15))
        ax = ax.flatten()
        ax[0].imshow(one[idx, ...], vmin=vmin, vmax=vmax)
        ax[1].imshow(two[idx, ...], vmin=vmin, vmax=vmax)
        ax[2].imshow(three[idx, ...], vmin=vmin, vmax=vmax)
        ax[0].set_title(title[0], fontsize=50)
        ax[1].set_title(title[1], fontsize=50)
        ax[2].set_title(title[2], fontsize=50)
        plt.show()
    return current_slice

def sliceimageview_3(one, two, three):
    from ipywidgets import IntSlider, interact
    current_slice = get_3_image(one, two, three, vmin=-0.01, vmax=0.5)
    num_slices = one.shape[0]
    num_slices = two.shape[0]
    num_slices = three.shape[0]
    step_slider = IntSlider(min=0, max=num_slices-1, value=num_slices//2)
    interact(current_slice, idx=step_slider)

In [None]:
sliceimageview_3(valid_dt[0], pred_img, valid_dt[1])