# Cosine Decay Learning Rate

## Import Library

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt

## Cek Data

In [None]:
train_samples = np.load('data/train_samples.npz')
ori_samples = np.load('data/ori_train.npz')

test_samples = np.load('data/test_samples.npz')


print("Keys in train_samples:", list(train_samples.keys()))
print("Keys in ori_samples:", list(ori_samples.keys()))
print("Keys in test_samples:", list(test_samples.keys()))

In [None]:
x_train = train_samples['train_images']
y_train = train_samples['train_labels']
x_ori = ori_samples['ori_train']
x_test = test_samples['test_images']
y_test = test_samples['test_labels']

# cek shape dari data train dan test
print(f"x_train shape => {x_train.shape} || x_test shape => {x_test.shape} || x_ori shape => {x_ori.shape}")

# cek type dari data train dan test
print(f"x_train type => {x_train.dtype} || x_test type => {x_test.dtype} || x_ori type => {x_ori.dtype}")

#  cek jumlah label yang ada di data train dan test
print(f"Jumlah label train : {len(np.unique(y_train))} => {np.unique(y_train)}")
print(f"Jumlah label test : {len(np.unique(y_test))} => {np.unique(y_test)}")

In [None]:
# Mapping label ke deskripsi
label_descriptions = {
	'hole': "Berlubang",
	'bleed': "Tinta Tembus",
	'stain': "Bercak",
	'missing': "Teks Hilang"
}

# Modifikasi fungsi visualisasi
def visualize_images(images, labels, indices, label_descriptions):
	num_images = len(indices)
	fig, axes = plt.subplots(1, num_images, figsize=(15, 5))

	for i, idx in enumerate(indices):
		axes[i].imshow(images[idx].squeeze(), cmap='gray')
		axes[i].set_title(f"{label_descriptions[labels[idx]]}")
		axes[i].axis('off')

	plt.show()

In [None]:
# index_train
train_indices = [5, 25, 45, 65]

# visualisasi data train
visualize_images(x_train, y_train, train_indices, label_descriptions)

## Normalisasi Data

In [None]:
# Fungsi normalisasi data
def normalize_data(img):
	# Mengilangkan channel warna jika perlu
	if img.shape[-1] == 3:
		img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
	
	# dtype float32
	img = img.astype('float32')
	
	# Normalisasi data Tanh
	img = (img - 127.5) / 127.5

	# Ubah channel ke 1
	img = np.expand_dims(img, axis=-1)
	
	return img

# Normalisasi data train dan test
x_train = np.array([normalize_data(img) for img in x_train])
x_test = np.array([normalize_data(img) for img in x_test])
x_ori = np.array([normalize_data(img) for img in x_ori])

# cek shape dari data train dan test
print(f"x_train shape => {x_train.shape} || x_test shape => {x_test.shape} || x_ori shape => {x_ori.shape}")

# cek type dari data train dan test
print(f"x_train type => {x_train.dtype} || x_test type => {x_test.dtype} || x_ori type => {x_ori.dtype}")

# Cek normalisasi citra
print(f"Nilai pixel train : min => {x_train.min()} & max => {x_train.max()}")
print(f"Nilai pixel test : min => {x_test.min()} & max => {x_test.max()}")
print(f"Nilai pixel ori : min => {x_ori.min()} & max => {x_ori.max()}")

# Cek rata-rata pixel normalisasi min dan max
print(f"Rata-rata pixel train  min : {np.mean(x_train.min())}")
print(f"Rata-rata pixel train  max : {np.mean(x_train.max())}")

print(f"Rata-rata pixel ori min : {np.mean(x_ori.min())}")
print(f"Rata-rata pixel ori max : {np.mean(x_ori.max())}")

print(f"Rata-rata pixel test min : {np.mean(x_test.min())}")
print(f"Rata-rata pixel test max : {np.mean(x_test.max())}")

In [None]:
# index_train
train_indices = [11, 31, 51, 71]

# visualisasi data train
visualize_images(x_train, y_train, train_indices, label_descriptions)

## Train Model

### Import

In [None]:
import tensorflow as tf

# import keras
from keras.optimizers import Adam
from keras.optimizers.schedules import ExponentialDecay

# loss_function
from utils.loss_function import generator_loss, discriminator_loss

# import models
from models import buildGen, buildDisc

# import matrics
from utils.metrics import psnr, ssim

In [None]:
damaged_images = tf.data.Dataset.from_tensor_slices(x_train).batch(16)
real_images = tf.data.Dataset.from_tensor_slices(x_ori).batch(16)

# Melihat shape
for x_img, y_img in zip(damaged_images.take(1), real_images.take(1)):
	print(x_img.shape)
	print(y_img.shape)

### Hyperparameter

In [None]:
epochs = 50

# Cosine decay steps
batch_size = 16
batches = len(x_train) / batch_size
decay_steps = batches * epochs

# Learning rate decay
fix_lr = 0.001
exp_lr = ExponentialDecay(
	initial_learning_rate=fix_lr, decay_steps=decay_steps, decay_steps=0.9
)

# Optimizer
optimizer_g_cosine = Adam(learning_rate=exp_lr)
optimizer_d = Adam(learning_rate=fix_lr)

### Generator

In [None]:
# Training generator
def training_generator(generator, discriminator, real_images, damaged_images, optimizer):
	with tf.GradientTape() as tape:
		fake_images = generator(damaged_images, training=True)
		fake_output = discriminator([fake_images, damaged_images], training=False)
		gen_loss = generator_loss(real_images, fake_images, fake_output)

	gradients = tape.gradient(gen_loss, generator.trainable_variables)
	optimizer.apply_gradients(zip(gradients, generator.trainable_variables))

	# psnr_score = psnr(real_images, fake_images)
	# ssim_score = ssim(real_images, fake_images)
	
	# print(f"Generator Loss: {gen_loss.numpy()}")
	# print(f"PSNR: {psnr_score.numpy()} || SSIM: {ssim_score.numpy()}")
	
	return gen_loss

# Build Generator
generator = buildGen()
generator.summary()

### Discriminator

In [None]:
# Train Discriminator
def training_discriminator(discriminator, generator, real_images, damaged_images, optimizer):
	with tf.GradientTape() as tape:
		fake_images = generator(damaged_images, training=True)
		real_output = discriminator([real_images, damaged_images], training=True)
		fake_output = discriminator([fake_images, damaged_images], training=True)
		disc_loss = discriminator_loss(real_output, fake_output)

	gradients = tape.gradient(disc_loss, discriminator.trainable_variables)
	optimizer.apply_gradients(zip(gradients, discriminator.trainable_variables))

	# print(f"Discriminator Loss: {disc_loss.numpy()}")
	return disc_loss

# Build Discriminator
discriminator = buildDisc()
discriminator.summary()

### Train Loop

In [None]:
from tqdm import tqdm
import time

In [None]:
def visualize_images(damaged_images, real_images, generated_images, num_samples=2):
	plt.figure(figsize=(15, 5))
	for i in range(num_samples):
		# Damaged
		plt.subplot(3, num_samples, i + 1)
		plt.imshow(np.squeeze(damaged_images[i]), cmap='gray')
		plt.title("Damaged")
		plt.axis("off")

		# Real
		plt.subplot(3, num_samples, i + 1 + num_samples)
		plt.imshow(np.squeeze(real_images[i]), cmap='gray')
		plt.title("Real")
		plt.axis("off")

		# Generated
		plt.subplot(3, num_samples, i + 1 + 2 * num_samples)
		plt.imshow(np.squeeze(generated_images[i]), cmap='gray')
		plt.title("Generated")
		plt.axis("off")
	plt.tight_layout()
	plt.show()


def plot_loss(generator_losses, discriminator_losses):
	plt.figure(figsize=(10, 5))
	plt.plot(generator_losses, label='Generator Loss', color='blue')
	plt.plot(discriminator_losses, label='Discriminator Loss', color='red')
	plt.xlabel('Epochs')
	plt.ylabel('Loss')
	plt.title('Training Losses')
	plt.legend()
	plt.show()

In [None]:
def train_model(generator, discriminator, damaged_dataset, real_dataset, epochs):
	# Gabungkan dataset menjadi satu
	train_dataset = tf.data.Dataset.zip((damaged_dataset, real_dataset))

	for epoch in range(epochs):
		start_time = time.time()
		gen_loss_list = []
		disc_loss_list = []

		with tqdm(total=batches , desc=f"Epoch {epoch+1}/{epochs}", unit="batch") as pbar:
			for damaged_batch, real_batch in train_dataset:
				# Latih discriminator
				disc_loss = training_discriminator(discriminator, generator, real_batch, damaged_batch, optimizer_d)
				# Latih generator
				gen_loss = training_generator(generator, discriminator, real_batch, damaged_batch, optimizer_g_cosine)

				gen_loss_list.append(gen_loss.numpy())
				disc_loss_list.append(disc_loss.numpy())

				# Update progress bar
				pbar.set_postfix(Gen_Loss=gen_loss.numpy(), Disc_Loss=disc_loss.numpy())
				pbar.update(1)

				# visualisasi hasil setiap 10 epoch
				if epoch % 10 == 0:
					generated_images = generator(damaged_batch, training=False)
					visualize_images(damaged_batch, real_batch, generated_images, num_samples=2)

		# Hitung waktu epoch dan rata-rata loss
		epoch_duration = time.time() - start_time
		avg_gen_loss = sum(gen_loss_list) / len(gen_loss_list)
		avg_disc_loss = sum(disc_loss_list) / len(disc_loss_list)

		print(f"Epoch {epoch+1}/{epochs} completed in {epoch_duration:.2f}s")
		print(f"Average Generator Loss: {avg_gen_loss:.4f}, Average Discriminator Loss: {avg_disc_loss:.4f}\n")

train_model(generator, discriminator, damaged_images, real_images, epochs)