# Conditional Generative Adversarial Network (cGAN) dengan JAX/Flax

Notebook ini mendemonstrasikan implementasi **Conditional Generative Adversarial Network (cGAN)** menggunakan framework **JAX** dan library **Flax (NNX)**. Kita akan melatih model ini pada dataset **CIFAR-10** untuk menghasilkan gambar berdasarkan label kelas tertentu.

## 1. Perbedaan antara Conditional GAN (cGAN) dan DCGAN

Penting untuk memahami perbedaan mendasar antara model cGAN yang kita buat ini dengan model **DCGAN (Deep Convolutional GAN)** yang mungkin sudah Anda pelajari sebelumnya:

| Fitur | DCGAN | Conditional GAN (cGAN) |
|---|---|---|
| **Input Generator** | Hanya noise acak ($z$). | Noise acak ($z$) DAN informasi tambahan (misalnya label kelas $y$). |
| **Input Discriminator** | Hanya gambar ($x$, baik asli maupun palsu). | Gambar ($x$) DAN informasi tambahan (label kelas $y$). |
| **Kontrol Output** | Tidak ada kontrol langsung. Generator membuat gambar apa saja dari distribusi data. | Kita bisa mengontrol kelas gambar yang ingin dihasilkan (misal: "buatkan saya gambar kucing"). |
| **Struktur Data** | Proses training bersifat *unsupervised* atau *self-supervised*. | Proses training memerlukan label (*supervised*) untuk mengkondisikan model. |

**Intinya:**
- **DCGAN** berfokus pada penggunaan arsitektur *Convolutional* untuk stabilitas training, tetapi ia menghasilkan gambar secara acak.
- **cGAN** menambahkan mekanisme "pengkondisian" (conditioning) yang memungkinkan kita memandu proses pembangkitan gambar dengan memberikan label sebagai input tambahan baik pada Generator maupun Discriminator.

## 2. Persiapan Lingkungan dan Import

Pertama, kita siapkan library yang dibutuhkan.

In [1]:
import jax
import jax.numpy as jnp
from flax import nnx
import matplotlib.pyplot as plt
import sys, os
import numpy as np
import time as timer
from tqdm import tqdm
import grain.python as grain
import optax
import urllib.request
import tarfile
import pickle

# Konfigurasi path untuk utils (jika dijalankan di struktur folder dl-jax101)
parent_dir = os.path.abspath("..")
if parent_dir not in sys.path:
    sys.path.append(parent_dir)

import viz_utils as vu
import model_utils as mu

print(f"JAX Device: {jax.devices()}")

JAX Device: [CpuDevice(id=0)]


## 3. Hyperparameters dan Direktori

Kita definisikan konstanta yang akan digunakan selama eksperimen.

In [2]:
DATA_DIR = os.path.join(parent_dir, "data")
MODEL_DIR = os.path.join(parent_dir, "models") 

BATCH_SIZE = 64
NUM_EPOCH = 50
IMAGE_SIZE = 64
NC = 3        # Channel warna RGB
NZ = 100      # Dimensi latent vector (noise)
NGF = 64      # Generator feature size
NDF = 64      # Discriminator feature size
LR = 2e-4 
BETA1 = 0.5 
NVIZ = 64     # Jumlah sampel untuk visualisasi
NUM_CLASSES = 10

DATASET = 'cifar10'
checkpoint_dir = os.path.join(MODEL_DIR, f"cond-gan_{DATASET}_z{NZ}")
sample_dir = os.path.join(checkpoint_dir, "samples")

for d in [checkpoint_dir, sample_dir, DATA_DIR]:
    if not os.path.exists(d):
        os.makedirs(d)

## 4. Dataset: CIFAR-10

Fungsi-fungsi di bawah ini digunakan untuk mengunduh dan memuat dataset CIFAR-10 secara lokal.

In [3]:
def download_and_extract_cifar10(dest_dir):
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = os.path.join(dest_dir, "cifar-10-python.tar.gz")
    extract_path = os.path.join(dest_dir, "cifar-10-batches-py")
    
    if not os.path.exists(extract_path):
        if not os.path.exists(filename):
            print(f"Downloading {url}...")
            urllib.request.urlretrieve(url, filename)
        
        with tarfile.open(filename, "r:gz") as tar:
            tar.extractall(path=dest_dir)
    return extract_path

def load_cifar10_local(data_dir):
    def unpickle(file):
        with open(file, 'rb') as fo:
            d = pickle.load(fo, encoding='bytes')
        return d

    images, labels = [], []
    for i in range(1, 6):
        batch = unpickle(os.path.join(data_dir, f"data_batch_{i}"))
        images.append(batch[b'data'])
        labels.append(batch[b'labels'])
    
    X_train = np.vstack(images)
    y_train = np.hstack(labels).astype(np.int32)
    return X_train, y_train

print("Loading dataset...")
cifar_path = download_and_extract_cifar10(DATA_DIR)
X_train_all, y_train_all = load_cifar10_local(cifar_path)
print(f"Data loaded: {X_train_all.shape} images.")

Loading dataset...
Data loaded: (50000, 3072) images.


### Pipeline Data dengan Google Grain

Kita gunakan Grain untuk melakukan batching dan preprocessing gambar (resizing ke 64x64 dan normalisasi ke range [-1, 1]).

In [8]:
class CIFARSource(grain.RandomAccessDataSource):
    def __init__(self, images, labels):
        self._images = images
        self._labels = labels
        
    def __len__(self):
        return len(self._images)
        
    def __getitem__(self, index):
        from PIL import Image
        # CIFAR-10 raw (3, 32, 32) -> (32, 32, 3)
        img = self._images[index].reshape(3, 32, 32).transpose(1, 2, 0).astype(np.uint8)
        img = Image.fromarray(img).resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR)
        img = np.array(img).astype(np.float32)
        # Normalize to [-1, 1]
        image = (img / 255.0) * 2.0 - 1.0
        return {'image': image, 'label': self._labels[index]}

def create_loader(data_source, batch_size, shuffle=False, seed=0):
    sampler = grain.IndexSampler(
        num_records=len(data_source),
        shard_options=grain.NoSharding(),
        shuffle=shuffle,
        num_epochs=1,
        seed=seed,
    )
    dataloader = grain.DataLoader(data_source=data_source, sampler=sampler, worker_count=0)
    
    class BatchIterator:
        def __init__(self, loader, batch_size, num_records):
            self.loader = loader
            self.batch_size = batch_size
            self.num_records = num_records
        def __len__(self):
            return (self.num_records + self.batch_size - 1) // self.batch_size
        def __iter__(self):
            batch_images, batch_labels = [], []
            for record in self.loader:
                batch_images.append(record['image'])
                batch_labels.append(record['label'])
                if len(batch_images) == self.batch_size:
                    yield np.stack(batch_images), np.array(batch_labels)
                    batch_images, batch_labels = [], []
            if batch_images:
                 yield np.stack(batch_images), np.array(batch_labels)
    return BatchIterator(dataloader, batch_size, len(data_source))

train_loader = create_loader(CIFARSource(X_train_all, y_train_all), BATCH_SIZE, shuffle=True, seed=42)

## 5. Arsitektur Model

### Generator
Generator menerima noise acak dan label kelas (one-hot encoded) dan menghasilkan gambar. Di sini, label digabungkan (*concatenate*) dengan noise latent sebelum masuk ke lapisan konvolusi.

In [9]:
class Generator(nnx.Module):
    def __init__(self, nz, ngf, nc, num_classes, rngs: nnx.Rngs):
        normal_init = nnx.initializers.normal(0.02)
        self.num_classes = num_classes
        
        # Input ke convt1: (N, 1, 1, nz + num_classes)
        self.convt1 = nnx.ConvTranspose(nz + num_classes, ngf * 8, kernel_size=(4, 4), strides=(1, 1), padding='VALID', 
                                        use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn1 = nnx.BatchNorm(ngf * 8, rngs=rngs)
        
        self.convt2 = nnx.ConvTranspose(ngf * 8, ngf * 4, kernel_size=(4, 4), strides=(2, 2), padding='SAME', 
                                        use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn2 = nnx.BatchNorm(ngf * 4, rngs=rngs)
        
        self.convt3 = nnx.ConvTranspose(ngf * 4, ngf * 2, kernel_size=(4, 4), strides=(2, 2), padding='SAME', 
                                        use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn3 = nnx.BatchNorm(ngf * 2, rngs=rngs)

        self.convt4 = nnx.ConvTranspose(ngf * 2, ngf, kernel_size=(4, 4), strides=(2, 2), padding='SAME', 
                                        use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn4 = nnx.BatchNorm(ngf, rngs=rngs)
        
        self.convt5 = nnx.ConvTranspose(ngf, nc, kernel_size=(4, 4), strides=(2, 2), padding='SAME', 
                                        use_bias=False, rngs=rngs, kernel_init=normal_init)

    def __call__(self, z, c, train: bool = True, use_running_average: bool = None):
        if use_running_average is None: use_running_average = not train
        
        # z: (N, NZ), c: (N, num_classes)
        # Concatenate noise dan label
        h = jnp.concatenate([z, c], axis=1)
        h = h.reshape(h.shape[0], 1, 1, -1)
        
        h = nnx.relu(self.bn1(self.convt1(h), use_running_average=use_running_average))
        h = nnx.relu(self.bn2(self.convt2(h), use_running_average=use_running_average))
        h = nnx.relu(self.bn3(self.convt3(h), use_running_average=use_running_average))
        h = nnx.relu(self.bn4(self.convt4(h), use_running_average=use_running_average))
        return nnx.tanh(self.convt5(h))

### Discriminator
Discriminator menerima gambar dan label kelas. Label dalam bentuk one-hot diputar (*broadcast*) sehingga memiliki dimensi spasial yang sama dengan gambar (misal 64x64), lalu digabungkan sebagai channel tambahan.

In [6]:
class Discriminator(nnx.Module):
    def __init__(self, nc, ndf, num_classes, rngs: nnx.Rngs):
        normal_init = nnx.initializers.normal(0.02)
        
        # Input ke conv1: (N, 64, 64, nc + num_classes)
        self.conv1 = nnx.Conv(nc + num_classes, ndf, kernel_size=(4, 4), strides=(2, 2), padding='SAME', 
                              use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.conv2 = nnx.Conv(ndf, ndf * 2, kernel_size=(4, 4), strides=(2, 2), padding='SAME', 
                              use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn2 = nnx.BatchNorm(ndf * 2, rngs=rngs)
        self.conv3 = nnx.Conv(ndf * 2, ndf * 4, kernel_size=(4, 4), strides=(2, 2), padding='SAME', 
                              use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn3 = nnx.BatchNorm(ndf * 4, rngs=rngs)
        self.conv4 = nnx.Conv(ndf * 4, ndf * 8, kernel_size=(4, 4), strides=(2, 2), padding='SAME', 
                              use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn4 = nnx.BatchNorm(ndf * 8, rngs=rngs)
        self.conv5 = nnx.Conv(ndf * 8, 1, kernel_size=(4, 4), strides=(1, 1), padding='VALID', 
                              use_bias=False, rngs=rngs, kernel_init=normal_init)

    def __call__(self, x, c, train: bool = True, use_running_average: bool = None):
        if use_running_average is None: use_running_average = not train
            
        # x: (N, 64, 64, 3), c: (N, 10)
        # Broadcast c (label) ke dimensi spasial
        c_spatial = jnp.broadcast_to(c[:, None, None, :], (x.shape[0], x.shape[1], x.shape[2], c.shape[1]))
        h = jnp.concatenate([x, c_spatial], axis=-1)
        
        h = nnx.leaky_relu(self.conv1(h), negative_slope=0.2)
        h = nnx.leaky_relu(self.bn2(self.conv2(h), use_running_average=use_running_average), negative_slope=0.2)
        h = nnx.leaky_relu(self.bn3(self.conv3(h), use_running_average=use_running_average), negative_slope=0.2)
        h = nnx.leaky_relu(self.bn4(self.conv4(h), use_running_average=use_running_average), negative_slope=0.2)
        return self.conv5(h).flatten()

## 6. Logic Training

Pembungkus model CGAN, optimizer, dan fungsi loss.

In [7]:
class CGAN(nnx.Module):
    def __init__(self, nz, ngf, nc, ndf, num_classes, rngs: nnx.Rngs):
        self.netG = Generator(nz, ngf, nc, num_classes, rngs)
        self.netD = Discriminator(nc, ndf, num_classes, rngs)

# Inisialisasi Model, RNGs dan Optimizer
rngs = nnx.Rngs(0)
model = CGAN(NZ, NGF, NC, NDF, NUM_CLASSES, rngs=rngs)
optimizerG = nnx.Optimizer(model.netG, optax.adam(LR, b1=BETA1), wrt=nnx.Param)
optimizerD = nnx.Optimizer(model.netD, optax.adam(LR, b1=BETA1), wrt=nnx.Param)

def loss_bce(logits, labels):
    return jnp.mean(optax.sigmoid_binary_cross_entropy(logits, labels))

### Fungsi Step Training

Kita membagi training menjadi dua langkah: `train_step_D` untuk Discriminator dan `train_step_G` untuk Generator.

In [10]:
@nnx.jit
def train_step_D(model, optimizerD, real_x, c_vec, noise):
    # Hasil generate dari G (detach/stop_gradient karena kita hanya update D)
    fake_x = model.netG(noise, c_vec, train=True)
    fake_x = jax.lax.stop_gradient(fake_x)
    
    def loss_fn(model):
        real_logits = model.netD(real_x, c_vec, train=True)
        fake_logits = model.netD(fake_x, c_vec, train=True)
        
        errD_real = loss_bce(real_logits, jnp.ones_like(real_logits))
        errD_fake = loss_bce(fake_logits, jnp.zeros_like(fake_logits))
        return errD_real + errD_fake, (nnx.sigmoid(real_logits), nnx.sigmoid(fake_logits))
        
    (loss, (real_p, fake_p)), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizerD.update(model.netD, grads.netD)
    return loss, jnp.mean(real_p), jnp.mean(fake_p)

@nnx.jit
def train_step_G(model, optimizerG, c_vec, noise):
    def loss_fn(model):
        fake_x = model.netG(noise, c_vec, train=True)
        fake_logits = model.netD(fake_x, c_vec, train=True)
        # Generator ingin Discriminator mengira gambar palsu ini real (label=1)
        errG = loss_bce(fake_logits, jnp.ones_like(fake_logits))
        return errG, nnx.sigmoid(fake_logits)
        
    (loss, outD), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizerG.update(model.netG, grads.netG)
    return loss, jnp.mean(outD)

## 7. Loop Pelatihan (Main Loop)

Sekarang kita jalankan proses pelatihannya. Di setiap epoch, kita akan mengambil noise yang tetap (`fixed_latent`) dan label yang tetap (`fixed_cvec`) untuk memantau perkembangan Generator.

In [None]:
print("Starting Training Loop...")
step_rng = jax.random.PRNGKey(0)

# Noise tetap untuk evaluasi visual
fixed_latent = jax.random.normal(jax.random.PRNGKey(42), (NVIZ, NZ))
fixed_y = jnp.array([i % NUM_CLASSES for i in range(NVIZ)])
fixed_cvec = jax.nn.one_hot(fixed_y, NUM_CLASSES)

for epoch in range(NUM_EPOCH):
    start_t = timer.time()
    with tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}") as tepoch:
        for batch_idx, (real_x, y) in enumerate(tepoch):
            batch_size = real_x.shape[0]
            c_vec = jax.nn.one_hot(y, NUM_CLASSES)
            
            step_rng, rng_d, rng_g = jax.random.split(step_rng, 3)
            noise_d = jax.random.normal(rng_d, (batch_size, NZ))
            errD, D_x, D_G_z1 = train_step_D(model, optimizerD, real_x, c_vec, noise_d)
            
            noise_g = jax.random.normal(rng_g, (batch_size, NZ))
            errG, D_G_z2 = train_step_G(model, optimizerG, c_vec, noise_g)
            
            if batch_idx % 10 == 0:
                tepoch.set_postfix(Loss_D=f"{errD:.4f}", Loss_G=f"{errG:.4f}", Dx=f"{D_x:.4f}", Dgz=f"{D_G_z2:.4f}")

    # Visualisasi hasil setiap epoch
    fake_samples = model.netG(fixed_latent, fixed_cvec, train=False)
    grid = vu.set_grid(fake_samples, num_cells=NVIZ)
    plt.figure(figsize=(8, 8))
    plt.imshow(np.transpose(np.array(vu.normalize(grid, 0, 1)), (1, 2, 0)))
    
    plt.title(f"Generated Samples Epoch {epoch+1}")
    plt.axis('off')
    plt.savefig(os.path.join(sample_dir, f'samples_epoch_{epoch+1}.png'))
    plt.show()

Starting Training Loop...


Epoch 1:   7%|â–‹         | 57/782 [05:11<1:06:05,  5.47s/batch, Dgz=0.0005, Dx=0.9543, Loss_D=0.2194, Loss_G=8.7065] 


KeyboardInterrupt: 