In [None]:
import glob
import os
import sys

import cv2
from keras.preprocessing.image import load_img
from tensorflow.keras.utils import Sequence
import numpy as np
from utils import config

def get_img_for_model(img_filepath):
    img = load_img(img_filepath, target_size=(config.IMG_HEIGHT, config.IMG_WIDTH))
    img = np.array(img)
    img = (img - 127.5) / 127.5
    return img

class DataGenerator(Sequence):
    def __init__(self, source_dir, target_dir, batch_size, is_training):
        self.source_dir = source_dir
        self.target_dir = target_dir
        self.batch_size = batch_size
        self.is_training = is_training

        self.img_filenames = self._get_img_filenames(source_dir)
        if self.is_training:
            np.random.shuffle(self.img_filenames)
        else:
            self.img_filenames = np.sort(self.img_filenames)

    def __getitem__(self, batch_num):
        n_imgs = len(self.img_filenames)
        idx_start = batch_num * self.batch_size
        idx_end = min((batch_num+1) * self.batch_size, n_imgs)
        img_filenames_batch = self.img_filenames[idx_start:idx_end]
        imgs_source, imgs_target, discriminator_labels_real, discriminator_labels_fake = self._get_batch(img_filenames_batch)

        return imgs_source, imgs_target, discriminator_labels_real, discriminator_labels_fake

    def __len__(self):
        return int(np.ceil(len(self.img_filenames) / self.batch_size))

    def on_epoch_end(self):
        if self.is_training:
            np.random.shuffle(self.img_filenames)

    def _draw_color_circles_on_src_img(self, img_src, img_target):
        non_white_coords = self._get_non_white_coordinates(img_target)
        for center_y, center_x in non_white_coords:
            self._draw_color_circle_on_src_img(img_src, img_target, center_y, center_x)

    def _draw_color_circle_on_src_img(self, img_src, img_target, center_y, center_x):
        y0, y1, x0, x1 = self._get_color_point_bbox_coords(center_y, center_x)
        color = np.mean(img_target[y0:y1, x0:x1], axis=(0, 1))
        img_src[y0:y1, x0:x1] = color

    def _get_batch(self, img_filenames_batch):
        batch_size = len(img_filenames_batch)
        batch_shape = (batch_size,) + config.IMG_SHAPE
        img_sources = np.empty(batch_shape)
        img_targets = np.empty(batch_shape)

        for idx, img_filename in enumerate(img_filenames_batch):
            img_source, img_target = self._get_img_source_and_img_target(img_filename)
            img_sources[idx] = img_source
            img_targets[idx] = img_target

        discriminator_labels_real = self._get_discriminator_labels_real(batch_size)
        discriminator_labels_fake = self._get_discriminator_labels_fake(batch_size)

        return img_sources, img_targets, discriminator_labels_real, discriminator_labels_fake

    def _get_color_point_bbox_coords(self, center_y, center_x):
        radius = config.USER_COLOR_POINTS_RADIUS
        y0 = max(0, center_y-radius+1)
        y1 = min(config.IMG_HEIGHT, center_y+radius)
        x0 = max(0, center_x-radius+1)
        x1 = min(config.IMG_WIDTH, center_x+radius)

        return y0, y1, x0, x1

    def _get_discriminator_labels_fake(self, batch_size):
        return np.zeros((batch_size, config.IMG_PATCH_HEIGHT, config.IMG_PATCH_WIDTH, 1))

    def _get_discriminator_labels_real(self, batch_size):
        return np.ones((batch_size, config.IMG_PATCH_HEIGHT, config.IMG_PATCH_WIDTH, 1))

    def _get_img_filenames(self, directory):
        return [os.path.basename(fp) for fp in glob.glob(f'{directory}/bottomwear_pants*.png')] # jpg > png 수정

    def _get_img_source_and_img_target(self, img_filename):
        img_source = get_img_for_model(os.path.join(self.source_dir, img_filename))
        img_target = get_img_for_model(os.path.join(self.target_dir, img_filename))

        if self.is_training:
            self._draw_color_circles_on_src_img(img_source, img_target)
            # data augmentation
            if np.random.random_sample() > 0.5:
                img_source = np.fliplr(img_source)
                img_target = np.fliplr(img_target)

        return img_source, img_target

    def _get_non_white_coordinates(self, img):
        non_white_mask = np.sum(img, axis=-1) < 2.75
        non_white_y, non_white_x = np.nonzero(non_white_mask)

        # randomly sample non-white coordinates
        n_non_white = len(non_white_y)
        n_color_points = min(n_non_white, config.USER_COLOR_POINTS_PER_IMG)
        idxs = np.random.choice(n_non_white, n_color_points, replace=False)
        non_white_coords = zip(non_white_y[idxs], non_white_x[idxs])

        return non_white_coords

In [None]:
from tqdm import tqdm

def train(gen_model, d_model, gan_model, training_generator, validation_generator=None,
          epochs = 70, initial_epoch=0, ck_pt_freq=5, output_dir='output', save_models=True):
    for epoch_num in tqdm(range(initial_epoch, epochs)):
        for imgs_source, imgs_target_real, d_labels_real, d_labels_fake in training_generator:
            imgs_target_fake = gen_model.predict(imgs_source)

            # update discriminator
            d_loss_real = d_model.train_on_batch([imgs_source, imgs_target_real], d_labels_real)
            d_loss_fake = d_model.train_on_batch([imgs_source, imgs_target_fake], d_labels_fake)

            # update generator
            g_loss, _, _ = gan_model.train_on_batch(imgs_source, [d_labels_real, imgs_target_real])

        if epoch_num % 3 == 0:
            gen_model.save_weights('/content/drive/MyDrive/졸업프로젝트/model_save/gen_bottom_' + str(epoch_num) + '.h5')
            dis_model.save_weights('/content/drive/MyDrive/졸업프로젝트/model_save/dis_bottom_' + str(epoch_num) + '.h5')
            gan_model.save_weights('/content/drive/MyDrive/졸업프로젝트/model_save/gan_bottom_' + str(epoch_num) + '.h5')

        training_generator.on_epoch_end()

In [None]:
source_dir = '/content/drive/MyDrive/졸업프로젝트/bottom_edge'
target_dir = '/content/drive/MyDrive/졸업프로젝트/bottom_data'

training_generator = DataGenerator(source_dir, target_dir, 4, is_training = True)

In [None]:
gen_model = get_generator_model()
dis_model = get_discriminator_model()
gan_model = get_gan_model(gen_model, dis_model)

In [None]:
train(gen_model, dis_model, gan_model, training_generator, epochs = 10)

In [None]:
gen_model.save('/content/drive/MyDrive/졸업프로젝트/model_save/gen_model_bottomwear.h5')
dis_model.save('/content/drive/MyDrive/졸업프로젝트/model_save/dis_model_bottomwear.h5')
gan_model.save('/content/drive/MyDrive/졸업프로젝트/model_save/gan_model_bottomwear.h5')

In [None]:
def get_img_for_model(img_filepath):
    img = load_img(img_filepath, target_size=(config.IMG_HEIGHT, config.IMG_WIDTH))
    img = np.array(img)
    img = (img - 127.5) / 127.5
    return img

In [None]:
import matplotlib.pyplot as plt

img_source = get_img_for_model('/content/drive/MyDrive/test_bottom.png')
imgs_source = np.array([img_source])
imgs_target_fake = gen_model_bottom.predict(imgs_source)
imgs_target_fake = (imgs_target_fake + 1) / 2.0
img_target_fake = imgs_target_fake[0]

plt.figure(figsize=(5,5))
plt.axis('off')
plt.imshow((img_target_fake*255).astype(np.uint8))