#### 참고코드 : https://github.com/wikibook/tf2

In [1]:
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from tqdm import tqdm
import tensorflow as tf

In [None]:
dataset = tfds.load(name='imagenet_v2')

In [None]:
images = []
labels = []
for i in tqdm(dataset['test']):
    images.append(i['image'])
    labels.append(i['label'])

In [None]:
plt.figure(figsize=[12,12])
for i in range(9):
    plt.subplot(3,3,i+1)
    plt.imshow(images[i])
    plt.title(f'index number :{labels[i].numpy()}')

In [None]:
# 원본 이미지에서 조각을 추출하고 입력, 출력 데이터를 반환하는 함수 정의
def get_hr_and_lr(img):
    img = tf.image.convert_image_dtype(img, tf.float32)
    y = tf.image.random_crop(img, [50, 50, 3])
    x = tf.image.resize(y, [25, 25])
    x = tf.image.resize(x, [50, 50])
    return x, y

In [None]:
X_data = []
Y_data = []
for img in tqdm(images):
    X,Y = get_hr_and_lr(img)
    X_data.append(X)
    Y_data.append(Y)

In [None]:
plt.figure(figsize=[12,12])
for i in [1,3]:
    plt.subplot(2,2,i)
    plt.imshow(X_data[i])
    plt.subplot(2,2,i+1)
    plt.imshow(Y_data[i+1])

In [None]:
from sklearn.model_selection import train_test_split
import numpy as np

In [None]:
train_X, valid_X, train_Y, valid_Y = train_test_split(X_data,Y_data,train_size=0.8)
valid_X, test_X, valid_Y, test_Y = train_test_split(valid_X,valid_Y,train_size=0.5)
train_X = np.array(train_X)
valid_X = np.array(valid_X)
test_X = np.array(test_X)
train_Y = np.array(train_Y)
valid_Y = np.array(valid_Y)
test_Y = np.array(test_Y)

print(train_X.shape)
print(valid_X.shape)
print(test_X.shape)
print(train_Y.shape)
print(valid_Y.shape)
print(test_Y.shape)

In [None]:
# tf.keras의 함수형 API를 사용한 REDNet 네트워크 정의
def REDNet(num_layers):
    conv_layers = []
    deconv_layers = []
    residual_layers = []

    inputs = tf.keras.layers.Input(shape=(None, None, 3))
    conv_layers.append(tf.keras.layers.Conv2D(3, kernel_size=3, padding='same', activation='relu'))

    for i in range(num_layers-1):
        conv_layers.append(tf.keras.layers.Conv2D(64, kernel_size=3, padding='same', activation='relu'))
        deconv_layers.append(tf.keras.layers.Conv2DTranspose(64, kernel_size=3, padding='same', activation='relu'))

    deconv_layers.append(tf.keras.layers.Conv2DTranspose(3, kernel_size=3, padding='same'))

    # 인코더 시작
    x = conv_layers[0](inputs)

    for i in range(num_layers-1):
        x = conv_layers[i+1](x)
        if i % 2 == 0:
            residual_layers.append(x)

    # 디코더 시작
    for i in range(num_layers-1):
        if i % 2 == 1:
            x = tf.keras.layers.Add()([x, residual_layers.pop()])
            x = tf.keras.layers.Activation('relu')(x)
        x = deconv_layers[i](x) 

    x = deconv_layers[-1](x)
    
    model = tf.keras.Model(inputs=inputs, outputs=x)
    return model

In [None]:
# PSNR 정의
def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)

In [None]:
# REDNet-30 네트워크 초기화 및 컴파일
model = REDNet(15)
model.compile(optimizer=tf.optimizers.Adam(0.0001), loss='mse', metrics=[psnr_metric])

In [None]:
# REDNet-30 네트워크 학습
history = model.fit(train_X,train_Y,
                    epochs=1000, 
                    validation_data=[valid_X,valid_Y])