# Generative Adversarial Networks (GANs) menggunakan JAX dan Flax

Generative Adversarial Networks (GANs) adalah kelas kerangka kerja pembelajaran mesin yang dirancang oleh Ian Goodfellow dan rekan-rekannya pada tahun 2014. GAN terdiri dari dua jaringan saraf, yaitu **Generator** dan **Discriminator**, yang bersaing satu sama lain dalam permainan zero-sum.

## 1. Latar Belakang Matematis

Proses pelatihan GAN sering digambarkan sebagai permainan minimax di mana Generator ($G$) mencoba menghasilkan sampel realistis untuk menipu Discriminator ($D$), dan Discriminator mencoba membedakan antara data asli dan sampel palsu yang dihasilkan oleh Generator.

Fungsi objektif utamanya adalah **Value Function** $V(D, G)$:

$$\min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{\text{data}}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_z(z)}[\log(1 - D(G(z)))]$$

Di mana:
- $x$: sebuah gambar asli dari data pelatihan.
- $z$: sebuah vektor noise acak dari distribusi prior (misalnya, Gaussian).
- $G(z)$: representasi data dari generator.
- $D(x)$: estimasi discriminator tentang probabilitas bahwa contoh data asli $x$ adalah asli.

### Langkah Pelatihan
1. **Perbarui Discriminator**: Maksimalkan probabilitas penempatan label yang benar baik untuk contoh pelatihan maupun sampel dari $G$.
2. **Perbarui Generator**: Minimalkan $\log(1 - D(G(z)))$. Dalam praktiknya, kita sering kali memaksimalkan $\log D(G(z))$ untuk memberikan gradien yang lebih baik di awal pelatihan.

## 2. Persiapan dan Impor

Pertama, mari kita impor pustaka yang diperlukan. Kita akan menggunakan JAX untuk komputasi backend, Flax NNX untuk definisi model, dan Optax untuk optimasi.

In [1]:
import jax
import jax.numpy as jnp
from flax import nnx
import matplotlib.pyplot as plt
import sys, os
import numpy as np
import time as timer
from tqdm import tqdm
import grain.python as grain
from sklearn.datasets import fetch_openml
import optax

# Add parent directory to path to import utils
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
import viz_utils as vu
import model_utils as mu

## 3. Konfigurasi dan Hiperparameter

In [2]:
# Define constants
DATA_DIR = "/Users/mghifary/Work/Code/AI/data"
MODEL_DIR = "/Users/mghifary/Work/Code/AI/IF5281/jax/models" 

BATCH_SIZE = 128
NUM_EPOCH = 50
NC = 1 # num channels
NZ = 100 # num latent variables
LR = 2e-4 # learning rate
BETA1 = 0.5 # beta1 for Adam optimizer
NVIZ = 64

DATASET = 'mnist'
mname = "gan"
checkpoint_dir = os.path.join(MODEL_DIR, f"{mname}_{DATASET}_z{NZ}")
sample_dir = os.path.join(checkpoint_dir, "samples")

for d in [checkpoint_dir, sample_dir]:
    if not os.path.exists(d):
        os.makedirs(d)
        print(f'The new directory {d} has been created')

## 4. Pemuatan Data

Kita memuat dataset MNIST menggunakan `fetch_openml` dari Scikit-Learn dan menggunakan `grain` dari Google untuk pemuatan data yang efisien.

In [3]:
print("Loading MNIST via OpenML (may take a minute)...Found locally if cache exists")
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='liac-arff')
X_all, y_all = mnist.data, mnist.target.astype(np.int32)

# Split into train and test (60k / 10k)
X_train_all, X_test_all = X_all[:60000], X_all[60000:]
y_train_all, y_test_all = y_all[:60000], y_all[60000:]

class MNISTSource(grain.RandomAccessDataSource):
    def __init__(self, images, labels):
        self._images = images
        self._labels = labels
        
    def __len__(self):
        return len(self._images)
        
    def __getitem__(self, index):
        # Reshape to (1, 28, 28) and normalize to [-1, 1]
        image = (self._images[index].reshape(1, 28, 28).astype(np.float32) / 255.0) * 2.0 - 1.0
        label = self._labels[index]
        return {'image': image, 'label': label}

def create_loader(data_source, batch_size, shuffle=False, seed=0):
    sampler = grain.IndexSampler(
        num_records=len(data_source),
        shard_options=grain.NoSharding(),
        shuffle=shuffle,
        num_epochs=1,
        seed=seed,
    )
    dataloader = grain.DataLoader(
        data_source=data_source,
        sampler=sampler,
        worker_count=0,
    )
    
    class BatchIterator:
        def __init__(self, loader, batch_size, num_records):
            self.loader = loader
            self.batch_size = batch_size
            self.num_records = num_records
        
        def __len__(self):
            return (self.num_records + self.batch_size - 1) // self.batch_size

        def __iter__(self):
            batch_images, batch_labels = [], []
            for record in self.loader:
                batch_images.append(record['image'])
                batch_labels.append(record['label'])
                if len(batch_images) == self.batch_size:
                    yield np.stack(batch_images), np.array(batch_labels)
                    batch_images, batch_labels = [], []
            if batch_images:
                 yield np.stack(batch_images), np.array(batch_labels)
    
    return BatchIterator(dataloader, batch_size, len(data_source))

train_loader = create_loader(MNISTSource(X_train_all, y_train_all), BATCH_SIZE, shuffle=True, seed=42)

Loading MNIST via OpenML (may take a minute)...Found locally if cache exists


## 5. Arsitektur Model

Baik Generator maupun Discriminator adalah Multi-Layer Perceptrons (MLPs) sebagaimana didefinisikan dalam makalah asli GAN untuk dataset sederhana seperti MNIST.

In [4]:
class Generator(nnx.Module):
    def __init__(self, input_size=100, output_size=784, rngs: nnx.Rngs = None):
        self.fc1 = nnx.Linear(input_size, 128, rngs=rngs)
        self.fc2 = nnx.Linear(128, 256, rngs=rngs)
        self.bn2 = nnx.BatchNorm(256, rngs=rngs)
        self.fc3 = nnx.Linear(256, 512, rngs=rngs)
        self.bn3 = nnx.BatchNorm(512, rngs=rngs)
        self.fc4 = nnx.Linear(512, 1024, rngs=rngs)
        self.bn4 = nnx.BatchNorm(1024, rngs=rngs)
        self.fc5 = nnx.Linear(1024, output_size, rngs=rngs)
    
    def __call__(self, z):
        h = nnx.leaky_relu(self.fc1(z), negative_slope=0.2)
        h = nnx.leaky_relu(self.bn2(self.fc2(h)), negative_slope=0.2)
        h = nnx.leaky_relu(self.bn3(self.fc3(h)), negative_slope=0.2)
        h = nnx.leaky_relu(self.bn4(self.fc4(h)), negative_slope=0.2)
        h = nnx.tanh(self.fc5(h))
        return h.reshape(h.shape[0], NC, 28, 28)

class Discriminator(nnx.Module):
    def __init__(self, input_size=784, num_classes=1, rngs: nnx.Rngs = None):
        self.fc1 = nnx.Linear(input_size, 512, rngs=rngs)
        self.fc2 = nnx.Linear(512, 256, rngs=rngs)
        self.fc3 = nnx.Linear(256, num_classes, rngs=rngs)
        
    def __call__(self, x):
        x = x.reshape(x.shape[0], -1) 
        h = nnx.leaky_relu(self.fc1(x), negative_slope=0.2)
        h = nnx.leaky_relu(self.fc2(h), negative_slope=0.2)
        return nnx.sigmoid(self.fc3(h)).flatten()

## 6. Fungsi Pelatihan

Kita mendefinisikan fungsi loss dan langkah pelatihan yang dikompilasi JIT.

In [None]:
def binary_cross_entropy(logits, labels):
    epsilon = 1e-12
    logits = jnp.clip(logits, epsilon, 1.0 - epsilon)
    return -jnp.mean(labels * jnp.log(logits) + (1 - labels) * jnp.log(1 - logits))

@nnx.jit
def train_step_D(netD, netG, optimizerD, real_x, noise):
    fake_x = netG(noise)
    def loss_fn(netD):
        real_pred = netD(real_x)
        fake_pred = netD(fake_x)
        errD = binary_cross_entropy(real_pred, jnp.ones_like(real_pred)) + \
               binary_cross_entropy(fake_pred, jnp.zeros_like(fake_pred))
        return errD, (real_pred, fake_pred)
    (loss, (real_p, fake_p)), grads = nnx.value_and_grad(loss_fn, has_aux=True)(netD)
    optimizerD.update(netD, grads)
    return loss, jnp.mean(real_p), jnp.mean(fake_p)

@nnx.jit
def train_step_G(netD, netG, optimizerG, noise):
    def loss_fn(netG):
        fake_x = netG(noise)
        preds = netD(fake_x)
        return binary_cross_entropy(preds, jnp.ones_like(preds)), preds
    (loss, preds), grads = nnx.value_and_grad(loss_fn, has_aux=True)(netG)
    optimizerG.update(netG, grads)
    return loss, jnp.mean(preds)

## 7. Loop Pelatihan

Sekarang kita menginisialisasi model dan menjalankan loop pelatihan.

In [None]:
rngs = nnx.Rngs(0)
netG = Generator(input_size=NZ, output_size=28*28, rngs=rngs)
netD = Discriminator(input_size=28*28, num_classes=1, rngs=rngs)

optimizerG = nnx.Optimizer(netG, optax.adam(LR, b1=BETA1), wrt=nnx.Param)
optimizerD = nnx.Optimizer(netD, optax.adam(LR, b1=BETA1), wrt=nnx.Param)

fixed_latent = jax.random.normal(jax.random.PRNGKey(42), (64, NZ))
step_rng = jax.random.PRNGKey(0)

for epoch in range(NUM_EPOCH):
    start_t = timer.time()
    with tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}") as tepoch:
        for batch_idx, (real_x, _) in enumerate(tepoch):
            step_rng, rng_d, rng_g = jax.random.split(step_rng, 3)
            
            noise_d = jax.random.normal(rng_d, (real_x.shape[0], NZ))
            errD, D_x, D_G_z1 = train_step_D(netD, netG, optimizerD, real_x, noise_d)
            
            noise_g = jax.random.normal(rng_g, (real_x.shape[0], NZ))
            errG, D_G_z2 = train_step_G(netD, netG, optimizerG, noise_g)
            
            if batch_idx % 100 == 0:
                tepoch.set_postfix(Loss_D=f"{errD:.4f}", Loss_G=f"{errG:.4f}", Dx=f"{D_x:.4f}")
    
    print(f'Epoch {epoch+1} finished in {timer.time() - start_t:.2f}s')

    # Save Real Samples (Only once)
    if epoch == 0:
        vutils_real = real_x[:NVIZ] if real_x.shape[0] >= NVIZ else real_x
        # real_x is (B, 1, 28, 28)
        grid_real = vu.set_grid(vutils_real, num_cells=64) # NVIZ used in original was 512, let's use 64
        plt.figure(figsize=(10, 10))
        plt.imshow(np.transpose(np.array(vu.normalize(grid_real, 0, 1)), (1, 2, 0)), cmap='gray')
        plt.axis('off')
        plt.show()
        plt.savefig(os.path.join(sample_dir, 'real_samples.jpg'), bbox_inches='tight')
        plt.close()

    if epoch % 10 != 0:
        continue

    # Save Fake Samples
    fake_samples = netG(fixed_latent)
    # fake_samples is (B, 1, 28, 28)
    grid_fake = vu.set_grid(fake_samples, num_cells=64)
    plt.figure(figsize=(10, 10))
    plt.imshow(np.transpose(np.array(vu.normalize(grid_fake, 0, 1)), (1, 2, 0)), cmap='gray')
    plt.axis('off')
    plt.savefig(os.path.join(sample_dir, f'fake_samples_epoch-{epoch+1}.jpg'), bbox_inches='tight')
    plt.show()
    plt.close()
    
    # Checkpointing (Save mostly locally or periodically)
    mu.save_checkpoint(netD, epoch + 1, filedir=checkpoint_dir) # Might need to differentiate dicts if saving both in one dir
    # wait, model_utils save_checkpoint uses 'epoch_{epoch}.safetensors'. 
    # If we run this for both, they overwrite.
    # We should probably modify save_checkpoint to accept prefix or handle separate folders.
    # For now, let's save G and D in separate subfolders or just rename manually/hackily?
    # Or just use the model_utils standard and assume single model.
    # Let's save them as:
    
    # Saving Generator
    path_g = os.path.join(checkpoint_dir, "generator")
    if not os.path.exists(path_g): os.makedirs(path_g)
    mu.save_checkpoint(netG, epoch + 1, filedir=path_g)
    
    # Saving Discriminator
    path_d = os.path.join(checkpoint_dir, "discriminator")
    if not os.path.exists(path_d): os.makedirs(path_d)
    mu.save_checkpoint(netD, epoch + 1, filedir=path_d)
    
    print(f" --- Models stored ---")


print(" -- Storing final checkpoints --")
mu.save_checkpoint(netD, epoch + 1, filedir=checkpoint_dir)
mu.save_checkpoint(netG, epoch + 1, filedir=path_g)
mu.save_checkpoint(netD, epoch + 1, filedir=path_d)