### IU Sketch Baseline 코드  
 - python모듈 설치 코드는 처음 한번 실행해주세요.     


In [None]:
# !pip install imageio
# !pip install imageio --upgrade
# !pip install einops

In [None]:
# 모듈 path 설정.
import os,sys
sys.path.insert(1, os.path.join(os.getcwd()  , '..'))

베이스 라인 코드.  

In [None]:
import os, glob, random
import numpy as np
import matplotlib.pyplot as plt

import tensorflow.keras as keras
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.layers import ThresholdedReLU


import config_env as cfg

from classes.image_frame import ImgFrame
from classes.video_clip import VideoClip
from models.dataset_generator import DataSetGenerator
from models.layer_conv import Conv2Plus1D, TConv2Plus1D
from models.layer_encoder import Encoder5D, Decoder5D
from models.layer_lstm import ConvLstmSeries

In [None]:
# 필요한 디렉토리 없으면 생성.

# 학습용 raw_clip(gif) 파일 위치.
if not os.path.exists(cfg.RAW_CLIP_PATH):
    os.mkdir(cfg.RAW_CLIP_PATH)

# 모델 저장 위치
if not os.path.exists(cfg.MODEL_SAVE_PATH):
    os.mkdir(cfg.MODEL_SAVE_PATH)

# 임시 데이터 저장 위치
if not os.path.exists(cfg.TEMP_DATA_PATH):
    os.mkdir(cfg.TEMP_DATA_PATH)

In [None]:
img_w, img_h = 128, 128 #cfg.DATA_IMG_W, cfg.DATA_IMG_H
batch_size = 4 #cfg.DATA_BATCH_SIZE
time_steps = cfg.DATA_TIME_STEP

# encoder-decoder 모델 사용시
is_autoenc_model = False

# 전체 raw_clip 랜덤한 이미지 목록을 가져옴.
img_list = glob.glob(os.path.join(cfg.RAW_CLIP_PATH, "*.gif"))
random.shuffle(img_list)

# 이미지 목록을 train/validation용으로 9:1로 나눔.
train_val_ratio = 0.9
train_img_cnt = int(len(img_list) * train_val_ratio)
train_img_list = img_list[:train_img_cnt]
val_img_list = img_list[train_img_cnt:]

# train/validation용 generator를 생성.
tdgen = DataSetGenerator(imgs=train_img_list, batch_size=batch_size, time_step=time_steps, imgw=img_w, imgh=img_h, is_train=False, for_enc=is_autoenc_model)
vdgen = DataSetGenerator(imgs=val_img_list, batch_size=batch_size, time_step=time_steps, imgw=img_w, imgh=img_h, is_train=False, for_enc=is_autoenc_model)


# encoder - lstms - decoder - retina(0) 모델을 생성.
enc_in_filters = 128
enc_conv_count = 3
lstm_count = 3
dec_conv_count = 3
retina_conv_count = 3

enc_out_filters = enc_in_filters*2**(enc_conv_count-1)
dec_in_filters = enc_out_filters // 2
retina_in_filters = dec_in_filters // 2**(dec_conv_count)

encoder = Encoder5D(enc_conv_count, enc_in_filters, (1, 3, 3), 2, "same")
decoder = Decoder5D(dec_conv_count, dec_in_filters, (1, 3, 3), 2, "same")
lstms = ConvLstmSeries(enc_out_filters, 0, [(3, 3), (3, 3), (3, 3)])
retina = Encoder5D(retina_conv_count, retina_in_filters, (1, 3, 3), 1, "same", out_channel=1)
retina0 = Encoder5D(0, retina_in_filters, (1, 3, 3), 1, "same", out_channel=1)
# threshold_relu = ThresholdedReLU(theta=0.5)

inputs = layers.Input(shape=(None, img_w, img_h, 1))

if is_autoenc_model:
    x = encoder(inputs)
    x = decoder(x)
    x = retina0(x)

else:
    x = encoder(inputs)
    x = lstms(x)
    x = decoder(x)
    x = retina(x)

# x = threshold_relu(x)
outputs = x

model = Model(inputs=inputs, outputs=outputs, name='sketcher')
model.compile(optimizer = keras.optimizers.Adam(1e-4), loss = 'binary_crossentropy')
model.summary()

plot_model(model, show_shapes=True, expand_nested=False, show_dtype=False)

In [None]:
# plot_model(encoder.seq, show_shapes=True, expand_nested=True, show_dtype=False)

In [None]:
# plot_model(lstms.seq, show_shapes=True, expand_nested=True, show_dtype=False)

In [None]:
# plot_model(decoder.seq, show_shapes=True, expand_nested=True, show_dtype=False)

In [None]:
# plot_model(retina.seq, show_shapes=True, expand_nested=True, show_dtype=False)

In [None]:
def save_all(prefix='base_200'):
    model.save(os.path.join(cfg.MODEL_SAVE_PATH, f"sketcher_{prefix}"))
    encoder.seq.save(os.path.join(cfg.MODEL_SAVE_PATH, f"enc_{prefix}"))
    lstms.seq.save(os.path.join(cfg.MODEL_SAVE_PATH, f"lstm_{prefix}"))
    decoder.seq.save(os.path.join(cfg.MODEL_SAVE_PATH, f"dec_{prefix}"))
    retina.seq.save(os.path.join(cfg.MODEL_SAVE_PATH, f"retina_{prefix}"))

In [None]:
def load_all(prefix='base_200'):
    encoder.seq = keras.models.load_model(os.path.join(cfg.MODEL_SAVE_PATH, f"enc_{prefix}"))
    lstms.seq = keras.models.load_model(os.path.join(cfg.MODEL_SAVE_PATH, f"lstm_{prefix}"))
    decoder.seq = keras.models.load_model(os.path.join(cfg.MODEL_SAVE_PATH, f"dec_{prefix}"))
    retina.seq = keras.models.load_model(os.path.join(cfg.MODEL_SAVE_PATH, f"retina_{prefix}"))
#     model = keras.models.load_model(os.path.join(cfg.MODEL_SAVE_PATH, f"sketcher_{prefix}"))

In [None]:
# load_all(prefix='base_300')

In [None]:
# encoder.trainable = False
# decoder.trainable = False
# lstms.trainable = False
# retina.trainable = False

In [None]:
epoch_cnt = 400
steps_per_epoch = 50
val_steps = 2

history = model.fit(
            tdgen,
            validation_data = vdgen,
            steps_per_epoch = steps_per_epoch,
            validation_steps = val_steps,
            epochs = epoch_cnt,
            batch_size = batch_size,
            verbose = 1,
        )

In [None]:
def plot_history(history):
    plt.plot(history.history['loss'], label='train')
    plt.plot(history.history['val_loss'], label='val')
    plt.legend()
    plt.show()

In [None]:
save_all(prefix='base_400')

In [None]:
plot_history(history)

In [None]:
def arry5d_to_img(arry5d, save_as=''):
    frm = ImgFrame(img=arry5d[0][-1][:, :, :], do_norm=False)

    # 예측 결과 표시.
    # frm.arry = frm.arry * 255
    img = frm.to_image(save_file=save_as)
    plt.imshow(img, cmap='gray')    

In [None]:
# dataset중 하나만 뽑아서 예측에 입력
it = iter(vdgen)
x, y = next(it)
in_x = x[:1, :, :, :, :]
print(x.shape, y.shape, in_x.shape)

In [None]:
# x 이미지 한개 표시
arry5d_to_img(in_x)

In [None]:
# y 이미지 한개 표시.
arry5d_to_img(y)

In [None]:
# 예측하여 이미지 표시.
pred = model.predict(in_x)

file_name = os.path.join(cfg.TEMP_DATA_PATH, 'result.gif')
arry5d_to_img(pred, save_as=file_name)

In [None]:
# user가 그린 임의의 그림 예측.
user_file_name = os.path.join(cfg.TEMP_DATA_PATH, 'user_draw.gif')
user_draw = VideoClip(gif_path=user_file_name)
user_draw.resize(128, 128, inplace=True)
arry5d = user_draw.to_array(expand=True)
print(arry5d.shape)

pred = model.predict(arry5d)

file_name = os.path.join(cfg.TEMP_DATA_PATH, 'result.gif')
arry5d_to_img(pred, save_as=file_name)