In [None]:
import tensorflow as tf

In [None]:
import os
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output

AUTOTUNE = tf.data.AUTOTUNE

In [None]:
wr_dir = os.getcwd()
wr_dir

In [None]:
os.chdir(wr_dir)

In [None]:
import tensorflow as tf
import glob
import matplotlib.pyplot as plt

# 로컬 이미지 파일 경로 수집
train_after_paths = glob.glob("C:/Users/user/Desktop/drug_proj/cycle/after/train/*.png")
train_before_paths = glob.glob('C:/Users/user/Desktop/drug_proj/cycle/before/train/*.png')
test_after_paths = glob.glob('C:/Users/user/Desktop/drug_proj/cycle/after/test/*.png')
test_before_paths = glob.glob('C:/Users/user/Desktop/drug_proj/cycle/before/test/*.png')

# tf.data.Dataset으로 데이터셋 생성
train_after_ds = tf.data.Dataset.from_tensor_slices(train_after_paths)
train_before_ds = tf.data.Dataset.from_tensor_slices(train_before_paths)
test_after_ds = tf.data.Dataset.from_tensor_slices(test_after_paths)
test_before_ds = tf.data.Dataset.from_tensor_slices(test_before_paths)

# 이미지를 읽고 전처리하는 함수 추가
def load_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)  # PNG 이미지용 디코드 함수 사용
    return image

# 이미지 정규화 함수
def normalize(image):
    image = tf.cast(image, tf.float32)
    image = (image / 127.5) - 1
    return image

def random_jitter(image):
    # 이미지에 배치 차원을 추가하여 4차원 텐서로 변환
    image = tf.expand_dims(image, 0)
    
    # 이미지를 286 x 286 x 3으로 리사이징
    image = tf.image.resize(image, [286, 286], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    
    # 이미지를 랜덤하게 256 x 256 x 3으로 크롭
    image = tf.image.random_crop(image, size=[1, 256, 256, 3])
    
    # 배치 차원을 제거하고 원본 차원으로 복원
    image = tf.squeeze(image, [0])
    
    # 이미지를 랜덤하게 좌우 반전
    image = tf.image.random_flip_left_right(image)
    
    # 이미지에 다시 배치 차원을 추가
    image = tf.expand_dims(image, 0)

    return image




# 전처리 함수 적용
AUTOTUNE = tf.data.experimental.AUTOTUNE
BUFFER_SIZE = 1000
BATCH_SIZE = 1

def preprocess_image_train(image_path):
    image = load_image(image_path)
    image = random_jitter(image)
    image = normalize(image)
    return image

def preprocess_image_test(image_path):
    image = load_image(image_path)
    image = random_jitter(image)
    image = normalize(image)
    return image

train_after = train_after_ds.map(preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
train_before = train_before_ds.map(preprocess_image_train, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_after = test_after_ds.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_before = test_before_ds.map(preprocess_image_test, num_parallel_calls=AUTOTUNE).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# 샘플 이미지 시각화
plt.figure(figsize=(8, 8))

sample_after = next(iter(train_after))
sample_before = next(iter(train_before))

plt.subplot(221)
plt.title('After')
plt.imshow(sample_after[0][0] * 0.5 + 0.5)

plt.subplot(222)
plt.title('After with random jitter')
plt.imshow(tf.squeeze(random_jitter(sample_after[0][0])) * 0.5 + 0.5)

plt.subplot(223)
plt.title('Before')
plt.imshow(sample_before[0][0] * 0.5 + 0.5)

plt.subplot(224)
plt.title('Before with random jitter')
plt.imshow(tf.squeeze(random_jitter(sample_before[0][0])) * 0.5 + 0.5)

plt.show()


In [None]:
import tensorflow_addons as tfa
from tensorflow_examples.models.pix2pix import pix2pix


In [None]:
OUTPUT_CHANNELS = 3

generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
generator_f = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

discriminator_x = pix2pix.discriminator(norm_type='instancenorm', target=False)
discriminator_y = pix2pix.discriminator(norm_type='instancenorm', target=False)

In [None]:
to_after = generator_g(tf.squeeze(sample_before, [1]))
to_before = generator_f(tf.squeeze(sample_after, [1]))

plt.figure(figsize=(8, 8))
contrast = 8

imgs = [sample_before, to_after, sample_after, to_before]
title = ['before', 'to_after', 'after', 'to_before']

for i in range(len(imgs)):
  plt.subplot(2, 2, i+1)
  plt.title(title[i])
  if i % 2 == 0:
    plt.imshow(tf.squeeze(imgs[i][0]) * 0.5 + 0.5)

  else:
    plt.imshow(tf.squeeze(imgs[i][0]) * 0.5 * contrast + 0.5)
plt.show()



In [None]:
plt.figure(figsize=(8, 8))

plt.subplot(121)
plt.title('Is a real after?')
plt.imshow(discriminator_y(tf.squeeze(sample_after, [1]))[0, ..., -1], cmap='RdBu_r')

plt.subplot(122)
plt.title('Is a real before?')
plt.imshow(discriminator_x(tf.squeeze(sample_before, [1]))[0, ..., -1], cmap='RdBu_r')

plt.show()


In [None]:
LAMBDA = 90

In [None]:
loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

In [None]:
def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)

  generated_loss = loss_obj(tf.zeros_like(generated), generated)

  total_disc_loss = real_loss + generated_loss

  return total_disc_loss * 0.5

In [None]:
def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

In [None]:
def calc_cycle_loss(real_image, cycled_image):
  loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
  
  return LAMBDA * loss1

In [None]:
def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

In [None]:
# 학습률 스케줄
num_steps_per_epoch = 100
lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    [float(100 * num_steps_per_epoch)], [0.0002, 0.])



In [None]:
# Optimizer
generator_g_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
generator_f_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
discriminator_x_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)
discriminator_y_optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

In [None]:
checkpoint_path = "./checkpoints/train_90_600"

ckpt = tf.train.Checkpoint(generator_g=generator_g,
                           generator_f=generator_f,
                           discriminator_x=discriminator_x,
                           discriminator_y=discriminator_y,
                           generator_g_optimizer=generator_g_optimizer,
                           generator_f_optimizer=generator_f_optimizer,
                           discriminator_x_optimizer=discriminator_x_optimizer,
                           discriminator_y_optimizer=discriminator_y_optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
  ckpt.restore(ckpt_manager.latest_checkpoint)
  print ('Latest checkpoint restored!!')

In [None]:
EPOCHS = 600

In [None]:
def generate_images(model, test_input):
  test_input = tf.squeeze(test_input, [1])  # Remove the image number dimension
  prediction = model(test_input)
    
  plt.figure(figsize=(12, 12))

  display_list = [test_input[0], prediction[0]]
  title = ['Input Image', 'Predicted Image']

  for i in range(2):
    plt.subplot(1, 2, i+1)
    plt.title(title[i])
    # getting the pixel values between [0, 1] to plot it.
    plt.imshow(display_list[i] * 0.5 + 0.5)
    plt.axis('off')
  plt.show()

In [None]:
@tf.function
def train_step(real_x, real_y):
    real_x = tf.squeeze(real_x, axis=[1])  # Squeeze to remove the extra dimension
    real_y = tf.squeeze(real_y, axis=[1])  # Squeeze to remove the extra dimension

    with tf.GradientTape(persistent=True) as tape:
        # Generator G translates X -> Y
        # Generator F translates Y -> X.

        fake_y = generator_g(real_x, training=True)
        cycled_x = generator_f(fake_y, training=True)

        fake_x = generator_f(real_y, training=True)
        cycled_y = generator_g(fake_x, training=True)

        # same_x and same_y are used for identity loss.
        same_x = generator_f(real_x, training=True)
        same_y = generator_g(real_y, training=True)

        disc_real_x = discriminator_x(real_x, training=True)
        disc_real_y = discriminator_y(real_y, training=True)

        disc_fake_x = discriminator_x(fake_x, training=True)
        disc_fake_y = discriminator_y(fake_y, training=True)

        # calculate the loss
        gen_g_loss = generator_loss(disc_fake_y)
        gen_f_loss = generator_loss(disc_fake_x)

        total_cycle_loss = calc_cycle_loss(real_x, cycled_x) + calc_cycle_loss(real_y, cycled_y)

        # Total generator loss = adversarial loss + cycle loss
        total_gen_g_loss = gen_g_loss + total_cycle_loss + identity_loss(real_y, same_y)
        total_gen_f_loss = gen_f_loss + total_cycle_loss + identity_loss(real_x, same_x)

        disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)
        disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)

    # Calculate the gradients for generator and discriminator
    generator_g_gradients = tape.gradient(total_gen_g_loss, generator_g.trainable_variables)
    generator_f_gradients = tape.gradient(total_gen_f_loss, generator_f.trainable_variables)

    discriminator_x_gradients = tape.gradient(disc_x_loss, discriminator_x.trainable_variables)
    discriminator_y_gradients = tape.gradient(disc_y_loss, discriminator_y.trainable_variables)

    # Apply the gradients to the optimizer
    generator_g_optimizer.apply_gradients(zip(generator_g_gradients, generator_g.trainable_variables))
    generator_f_optimizer.apply_gradients(zip(generator_f_gradients, generator_f.trainable_variables))
    discriminator_x_optimizer.apply_gradients(zip(discriminator_x_gradients, discriminator_x.trainable_variables))
    discriminator_y_optimizer.apply_gradients(zip(discriminator_y_gradients, discriminator_y.trainable_variables))

In [None]:
for epoch in range(EPOCHS):
  start = time.time()

  n = 0
  for image_x, image_y in tf.data.Dataset.zip((train_before, train_after)):
    train_step(image_x, image_y)
    if n % 10 == 0:
      print ('.', end='')
    n += 1

  clear_output(wait=True)
  # is clearly visible.
  generate_images(generator_g, sample_before)

  if (epoch + 1) % 5 == 0:
    ckpt_save_path = ckpt_manager.save()
    print ('Saving checkpoint for epoch {} at {}'.format(epoch+1,
                                                         ckpt_save_path))

  print ('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
                                                      time.time()-start))

In [None]:
# Run the trained model on the test dataset
for inp in test_before.take(5):
  generate_images(generator_g, inp)

In [None]:

new_image_path = "C:/Users/user/Desktop/drug_proj/preprocessed_img1.png"

def load_image(image_path):
    image = tf.io.read_file(image_path)
    image = tf.image.decode_png(image, channels=3)  # PNG 이미지용 디코드 함수 사용
    image = tf.image.resize(image, [256, 256])  # 이미지 크기를 256x256으로 조정
    return image

# 이미지 로드 및 전처리
new_image = load_image(new_image_path)
#new_image = random_jitter(new_image)
new_image = normalize(new_image)

# 이미지 차원 확장 (모델이 배치 입력을 기대하기 때문)
new_image = tf.expand_dims(new_image, 0)

# 이미지를 모델에 적용
output = generator_g(new_image)


In [None]:
# 첫 번째 이미지 선택 및 범위 변환
output_image = output[0] * 0.5 + 0.5

# 이미지 시각화
plt.imshow(output_image)
plt.show()

In [None]:
OUTPUT_CHANNELS = 3
generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')

checkpoint_path = 'C:/Users/user/Desktop/drug_proj/checkpoints/train_90_600/ckpt-120'
checkpoint = tf.train.Checkpoint(generator_g=generator_g)
checkpoint.restore(checkpoint_path)

generator_g.save('C:/Users/user/Desktop/drug_proj/model/generator_g_train_90_600.h5')



In [None]:
generator_g

In [None]:
# 모델 사용
output = generator_g(new_image, training=Falses)

In [None]:
# 첫 번째 이미지 선택 및 범위 변환
output_image = output[0] * 0.5 + 0.5

# 이미지 시각화
plt.imshow(output_image)
plt.show()