### IU Sketch Baseline 코드  
 - python모듈 설치 코드는 처음 한번 실행해주세요.     


In [None]:
# !pip install imageio
# !pip install imageio --upgrade
# !pip install einops

In [None]:
# 모듈 path 설정.
import os,sys
sys.path.insert(1, os.path.join(os.getcwd()  , '..'))

베이스 라인 코드.  

In [None]:
import os, glob, random
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.layers import ThresholdedReLU

from tensorflow.keras import layers, Input, Model
from tensorflow.keras import losses
from tensorflow.keras import optimizers

import config_env as cfg

from classes.image_frame import ImgFrame
from classes.video_clip import VideoClip
from models.dataset_generator import DataSetGenerator


import warnings
warnings.filterwarnings("ignore")

# from models.layer_conv import Conv2Plus1D, TConv2Plus1D
# from models.layer_encoder import Encoder5D, Decoder5D
# from models.layer_lstm import ConvLstmSeries

In [None]:
# 필요한 디렉토리 없으면 생성.

# 학습용 raw_clip(gif) 파일 위치.
if not os.path.exists(cfg.RAW_CLIP_PATH):
    os.mkdir(cfg.RAW_CLIP_PATH)

# 모델 저장 위치
if not os.path.exists(cfg.MODEL_SAVE_PATH):
    os.mkdir(cfg.MODEL_SAVE_PATH)

# 임시 데이터 저장 위치
if not os.path.exists(cfg.TEMP_DATA_PATH):
    os.mkdir(cfg.TEMP_DATA_PATH)

In [None]:
img_w, img_h = 64, 64 #cfg.DATA_IMG_W, cfg.DATA_IMG_H
batch_size = 4 #cfg.DATA_BATCH_SIZE
time_steps = 8 # cfg.DATA_TIME_STEP
enc_blk_count = 6  # 3 - 7
disc_blk_count = 3 # 
EPOCHS = 2


# dataset 설정.
data_seq_type = 'all'  # 'all', 'rest', 'arandom', 'aforward', 'forward', 'reverse', 'random'
data_label_type = '1step'   # 'all', 'rest', 'same', '1step'
stakced = False
overlap = False

# 전체 raw_clip 랜덤한 이미지 목록을 가져옴.
img_list = glob.glob(os.path.join(cfg.RAW_CLIP_PATH, "*.gif"))
random.shuffle(img_list)

# 이미지 목록을 train/validation용으로 9:1로 나눔.
train_val_ratio = 0.9
train_img_cnt = int(len(img_list) * train_val_ratio)
train_img_list = img_list[:train_img_cnt]
val_img_list = img_list[train_img_cnt:]

# train/validation용 generator를 생성.
tdgen = DataSetGenerator(imgs=train_img_list, batch_size=batch_size, 
                         time_step=time_steps, imgw=img_w, imgh=img_h, 
                         seq_type=data_seq_type, label_type=data_label_type,
                         stacked=stakced, overlap=overlap)

vdgen = DataSetGenerator(imgs=val_img_list, batch_size=batch_size, 
                         time_step=time_steps, imgw=img_w, imgh=img_h, 
                         seq_type=data_seq_type, label_type=data_label_type,
                         stacked=stakced, overlap=overlap)


In [None]:
class EncodeBlock(layers.Layer):
    def __init__(self, n_filters, use_bn=True):
        super(EncodeBlock, self).__init__()
        self.use_bn = use_bn       
        self.conv = layers.Conv3D(filters=n_filters,
                    kernel_size=(1, 4, 4),
                    strides=(1, 2, 2),
                    padding="same")
        self.batchnorm = layers.BatchNormalization()
        self.lrelu= layers.LeakyReLU(0.2)

    def call(self, x):
        x = self.conv(x)
        if self.use_bn:
            x = self.batchnorm(x)
        return self.lrelu(x)


class Encoder(layers.Layer):
    def __init__(self, blk_cnt=5):
        super(Encoder, self).__init__()
        filters = [64,128,256,512,512,512,512,512]
        
        self.blocks = []
        for i in range(blk_cnt):
            f = filters[i]
            if i == 0:
                self.blocks.append(EncodeBlock(f, use_bn=False))
            else:
                self.blocks.append(EncodeBlock(f))
    
    def call(self, x):
        for block in self.blocks:
            x = block(x)
        return x
    
    def get_summary(self, input_shape=(None, img_w, img_h, 1)):
        inputs = Input(input_shape)
        return Model(inputs, self.call(inputs)).summary()


In [None]:
class DecodeBlock(layers.Layer):
    def __init__(self, f, dropout=True):
        super(DecodeBlock, self).__init__()
        self.dropout = dropout
        self.Transconv = layers.Conv3DTranspose(filters=f,
                    kernel_size=(1, 4, 4),
                    strides=(1, 2, 2),
                    padding="same")
        self.batchnorm = layers.BatchNormalization()
        self.relu = layers.ReLU()
        
    def call(self, x):
        x = self.Transconv(x)
        x = self.batchnorm(x)
        if self.dropout:
            x = layers.Dropout(.5)(x)
        return self.relu(x)

    
class Decoder(layers.Layer):
    def __init__(self, blk_cnt=4):
        super(Decoder, self).__init__()
        filters = [512,512,512,512,256,128,64]
        
        self.blocks = []
        for i in range(blk_cnt):
            f = filters[i]
            if i < 3:
                self.blocks.append(DecodeBlock(f))
            else:
                self.blocks.append(DecodeBlock(f, dropout=False))
                
        self.blocks.append(layers.Conv3DTranspose(filters=1,
                    kernel_size=(1, 4, 4),
                    strides=(1, 2, 2),
                    padding="same"))
        
    def call(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [None]:
class EncoderDecoderGenerator(Model):
    def __init__(self, blk_cnt=4):
        super(EncoderDecoderGenerator, self).__init__()
        self.encoder = Encoder(blk_cnt)
        self.decoder = Decoder(blk_cnt-1)
    
    def call(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x
   
    def get_summary(self, input_shape=(None, img_w, img_h, 1)):
        inputs = Input(input_shape)
        return Model(inputs, self.call(inputs)).summary()

In [None]:
class DiscBlock(layers.Layer):
    def __init__(self, n_filters, stride=2, custom_pad=0, use_bn=True, act=True):
        super(DiscBlock, self).__init__()
        self.custom_pad = custom_pad
        self.use_bn = use_bn
        self.act = act
        
        # outputsize = (w - f + 2*p) / s + 1
        if custom_pad > 0:
            self.padding = layers.ZeroPadding3D(padding=(0,custom_pad,custom_pad))
            self.conv = layers.Conv3D(filters=n_filters,
                    kernel_size=(1, 4, 4),
                    strides=(1, stride, stride),
                    padding="valid")
        else:
            self.conv = layers.Conv3D(filters=n_filters,
                    kernel_size=(1, 4, 4),
                    strides=(1, stride, stride),
                    padding="same")
        
        self.batchnorm = layers.BatchNormalization() if use_bn else None
        self.lrelu = layers.LeakyReLU(0.2) if act else None
        
    def call(self, x):
        if self.custom_pad:
            x = self.padding(x)
            x = self.conv(x)
        else:
            x = self.conv(x)
                
        if self.use_bn:
            x = self.batchnorm(x)
            
        if self.act:
            x = self.lrelu(x)
        return x 


In [None]:
class Discriminator(Model):
    def __init__(self, blk_cnt=3):
        super(Discriminator, self).__init__()

        self.concat = layers.Concatenate()

        filters = [64,128,256,512,512,512]
        self.blocks = []
        for i in range(blk_cnt):
            f = filters[i]
            self.blocks.append(DiscBlock(
                n_filters=f,
                stride=2,
                custom_pad=0,
                use_bn=False if i==0 else True,
                act=True
            ))

        self.blocks.append(DiscBlock(n_filters=512, stride=1, custom_pad=1, use_bn=True, act=True))
        self.blocks.append(DiscBlock(n_filters=1, stride=1, custom_pad=1, use_bn=False, act=False))
        self.sigmoid = layers.Activation("sigmoid")


    def call(self, x, y):
        out = self.concat([x, y])
        
        for block in self.blocks:
            out = block(out)
        return self.sigmoid(out)
    
    def get_summary(self, x_shape=(None, img_w, img_h, 1), y_shape=(None, img_w, img_h, 1)):
        x, y = Input(x_shape), Input(y_shape) 
        return Model((x, y), self.call(x, y)).summary()


In [None]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    '''
    모델학습 초기에 learning rate를 급격히 높였다가, 
    서서히 낮추어 가면서 안정적으로 수렴하게 하는 고급 기법
    학습 초기에는 learning_rate가 step_num에 비례해서 증가하다가 이후로는 감소
    '''
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps**-1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)


In [None]:
bce = losses.BinaryCrossentropy(from_logits=False)
mae = losses.MeanAbsoluteError()

def get_gene_loss(fake_output, real_output, fake_disc):
    l1_loss = mae(real_output, fake_output)
    gene_loss = bce(tf.ones_like(fake_disc), fake_disc)
    return gene_loss, l1_loss

def get_disc_loss(fake_disc, real_disc):
    return bce(tf.zeros_like(fake_disc), fake_disc) + bce(tf.ones_like(real_disc), real_disc)

In [None]:
gene_learning_rate = CustomSchedule(d_model=64, warmup_steps=40)
disc_learning_rate = CustomSchedule(d_model=64, warmup_steps=40)

gene_opt = optimizers.Adam(2e-4, beta_1=.5, beta_2=.999)
disc_opt = optimizers.Adam(2e-4, beta_1=.5, beta_2=.999)

In [None]:
generator = EncoderDecoderGenerator(blk_cnt=enc_blk_count)
discriminator = Discriminator(blk_cnt=disc_blk_count)

history = {'gen_loss':[], 'disc_loss':[], 'real_accuracy':[], 'fake_accuracy':[], 'l1_loss':[]}

In [None]:
@tf.function
def train_step(sketch, label):
    with tf.GradientTape() as gene_tape, tf.GradientTape() as disc_tape:
        # Generator 예측
        fake_label = generator(sketch, training=True)
        # Discriminator 예측
        fake_disc = discriminator(sketch, fake_label, training=True)
        real_disc = discriminator(sketch, label, training=True)
        # Generator 손실 계산
        gene_loss, l1_loss = get_gene_loss(fake_label, label, fake_disc)
        gene_total_loss = gene_loss + (100 * l1_loss) ## <===== L1 손실 반영 λ=100
        # Discrminator 손실 계산
        disc_loss = get_disc_loss(fake_disc, real_disc)
                
    gene_gradient = gene_tape.gradient(gene_total_loss, generator.trainable_variables)
    disc_gradient = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    gene_opt.apply_gradients(zip(gene_gradient, generator.trainable_variables))
    disc_opt.apply_gradients(zip(disc_gradient, discriminator.trainable_variables))
    return gene_loss, l1_loss, disc_loss


In [None]:
# EPOCHS = 2

it = iter(tdgen)

for epoch in range(EPOCHS):

    x, y = next(it)
    g_loss, l1_loss, d_loss = train_step(x, y)

    history['gen_loss'].append(g_loss)
    history['disc_loss'].append(d_loss)
    history['l1_loss'].append(l1_loss)

    print(f"{epoch}  Gloss:{g_loss.numpy():.4f}  L1:{l1_loss.numpy():.4f}  Dloss:{d_loss.numpy():.4f}")


In [None]:
def plot_history(history):
    # summarize history for loss  
    plt.subplot(211)  
    plt.plot(history['gen_loss'])  
    plt.plot(history['disc_loss'])  
    plt.plot(history['l1_loss'])  
    plt.title('model loss')  
    plt.ylabel('loss')  
    plt.xlabel('batch iters')  
    plt.legend(['gen_loss', 'disc_loss', 'l1_loss'], loc='upper left')  

    # summarize history for accuracy  
    # plt.subplot(212)  
    # plt.plot(history['fake_accuracy'])  
    # plt.plot(history['real_accuracy'])  
    # plt.title('discriminator accuracy')  
    # plt.ylabel('accuracy')  
    # plt.xlabel('batch iters')  
    # plt.legend(['fake_accuracy', 'real_accuracy'], loc='upper left')  
    
    # training_history 디렉토리에 epoch별로 그래프를 이미지 파일로 저장합니다.
    # plt.savefig(os.path.join(history_path, 'train_history_{:04d}.png'.format(epoch)))
    plt.show()


In [None]:
plot_history(history)

In [None]:
def arry5d_to_img(arry5d, save_as='', threshold=0.0):
    frmimg_cnt = arry5d.shape[1]
    fig, axes = plt.subplots(nrows = 1, ncols = frmimg_cnt, figsize=(15, 3))

    for idx, num in enumerate(range(0, frmimg_cnt)):
        frm = ImgFrame(img=arry5d[0][idx][:, :, :], do_norm=False)
        if threshold > 0.0:
            frm.threshold(threshold=threshold)
        img = frm.to_image()
        axes[idx].imshow(img, cmap='gray')

    plt.show()

In [None]:
# dataset중 하나만 뽑아서 예측에 입력
it = iter(vdgen)
x, y = next(it)

In [None]:
# x 이미지 한개 표시
arry5d_to_img(x)

In [None]:
# y 이미지 한개 표시.
arry5d_to_img(y)

In [None]:
# 예측하여 이미지 표시.
in_x = x[:1, :, :, :, :]
pred = generator(in_x)

file_name = os.path.join(cfg.TEMP_DATA_PATH, 'result.gif')
arry5d_to_img(pred, save_as=file_name)

In [None]:
# user가 그린 임의의 그림 예측.
# user_file_name = os.path.join(cfg.TEMP_DATA_PATH, 'user_draw.gif')
# user_draw = VideoClip(gif_path=user_file_name)
# user_draw.resize(img_w, img_h, inplace=True)
# arry5d = user_draw.to_array(expand=True)
# print(arry5d.shape)
# arry5d_to_img(arry5d)

In [None]:
# pred = generator(arry5d)

# file_name = os.path.join(cfg.TEMP_DATA_PATH, 'result.gif')
# arry5d_to_img(pred, save_as=file_name)