# Conditional Generative Adversarial Network (cGAN) dengan JAX/Flax (GCS Edition)

Notebook ini mendemonstrasikan implementasi **Conditional Generative Adversarial Network (cGAN)** menggunakan framework **JAX** dan library **Flax (NNX)**. Kita akan melatih model ini pada dataset **CIFAR-10** dan menyimpan hasilnya langsung ke **Google Cloud Storage (GCS)**.

## 1. Perbedaan antara Conditional GAN (cGAN) dan DCGAN

| Fitur | DCGAN | Conditional GAN (cGAN) |
|---|---|---|
| **Input Generator** | Noise acak ($z$). | Noise acak ($z$) + Label kelas ($y$). |
| **Input Discriminator** | Gambar ($x$). | Gambar ($x$) + Label kelas ($y$). |
| **Kontrol Output** | Acak. | Terkendali (berdasarkan label). |

**cGAN** memungkinkan kita untuk memandu proses pembangkitan gambar dengan memberikan label sebagai input tambahan.

## 2. Persiapan Lingkungan, Import, dan GCS Auth

Kita siapkan library dan autentikasi ke Google Cloud Storage.

In [None]:
import jax
import jax.numpy as jnp
from flax import nnx
import matplotlib.pyplot as plt
import sys, os
import numpy as np
import time as timer
from tqdm import tqdm
import grain.python as grain
import optax
import urllib.request
import tarfile
import pickle
import safetensors
from safetensors.flax import save_file

from google.colab import auth
from google.cloud import storage

# Autentikasi GCS
auth.authenticate_user()
storage_client = storage.Client()
BUCKET_NAME = 'dljax'
bucket = storage_client.bucket(BUCKET_NAME)

print(f"JAX Device: {jax.devices()}")

JAX Device: [CudaDevice(id=0)]


In [None]:
!nvidia-smi

In [None]:
# GCS Helpers
def upload_to_gcs(local_path, gcs_path):
    blob = bucket.blob(gcs_path)
    blob.upload_from_filename(local_path)
    # print(f"Uploaded {local_path} to gs://{BUCKET_NAME}/{gcs_path}")

# Visualization helpers
def set_grid(D, num_cells=1):
    if len(D.shape) == 3:
        n, h, w = D.shape
        D = D[:, jnp.newaxis, :, :]
    
    if D.shape[1] in [1, 3] and D.shape[3] not in [1, 3]:
        n, c, d1, d2 = D.shape
    elif D.shape[3] in [1, 3]:
        D = jnp.transpose(D, (0, 3, 1, 2))
        n, c, d1, d2 = D.shape
    else:
        n, c, d1, d2 = D.shape
    
    grid_size = int(jnp.ceil(jnp.sqrt(num_cells)))
    grid = jnp.zeros((c, grid_size * d1, grid_size * d2))
    
    for i in range(num_cells):
        if i >= n: break
        r = i // grid_size
        col = i % grid_size
        grid = grid.at[:, r*d1:(r+1)*d1, col*d2:(col+1)*d2].set(D[i])
        
    return grid

def normalize(x, new_min=0, new_max=255):
    old_min = np.min(x)
    old_max = np.max(x)
    xn = (x - old_min) * ((new_max - new_min) / (old_max - old_min)) + new_min
    return xn

# Checkpoint helper with GCS Auto-Upload
def save_checkpoint(model: nnx.Module, epoch: int, filedir: str = "checkpoint", gcs_prefix: str = "models"):
    _, state = nnx.split(model)
    flat_state = state.to_pure_dict()
    
    def flatten_dict(d, parent_key='', sep='.'):
        items = []
        for k, v in d.items():
            if v is None: continue
            new_key = f"{parent_key}{sep}{k}" if parent_key else k
            if isinstance(v, dict):
                items.extend(flatten_dict(v, new_key, sep=sep).items())
            else:
                items.append((new_key, v))
        return dict(items)
    
    flat_params = flatten_dict(flat_state)
    
    filename = f"epoch_{epoch}.safetensors"
    local_dir = os.path.join("/content", filedir)
    os.makedirs(local_dir, exist_ok=True)
    local_path = os.path.join(local_dir, filename)
    save_file(flat_params, local_path)
    
    # Upload to GCS
    gcs_path = f"{gcs_prefix}/{filedir}/{filename}"
    upload_to_gcs(local_path, gcs_path)
    print(f"Model saved locally to {local_path} and uploaded to gs://{BUCKET_NAME}/{gcs_path}")

## 3. Hyperparameters dan Direktori

Kita definisikan konstanta yang akan digunakan selama eksperimen.

In [None]:
BATCH_SIZE = 128
NUM_EPOCH = 50
IMAGE_SIZE = 64
NC = 3
NZ = 100
NGF = 64
NDF = 64
LR = 2e-4 
BETA1 = 0.5 
NVIZ = 64
NUM_CLASSES = 10

DATA_DIR = "/content/data"
SAMPLE_DIR = "samples"
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(SAMPLE_DIR, exist_ok=True)

## 4. Dataset: CIFAR-10

Fungsi-fungsi di bawah ini digunakan untuk mengunduh dan memuat dataset CIFAR-10 secara lokal.

In [None]:
def download_and_extract_cifar10(dest_dir):
    url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    filename = os.path.join(dest_dir, "cifar-10-python.tar.gz")
    extract_path = os.path.join(dest_dir, "cifar-10-batches-py")
    
    if not os.path.exists(extract_path):
        if not os.path.exists(filename):
            print(f"Downloading {url}...")
            urllib.request.urlretrieve(url, filename)
        
        with tarfile.open(filename, "r:gz") as tar:
            tar.extractall(path=dest_dir)
    return extract_path

def load_cifar10_local(data_dir):
    def unpickle(file):
        with open(file, 'rb') as fo:
            d = pickle.load(fo, encoding='bytes')
        return d

    images, labels = [], []
    for i in range(1, 6):
        batch = unpickle(os.path.join(data_dir, f"data_batch_{i}"))
        images.append(batch[b'data'])
        labels.append(batch[b'labels'])
    
    X_train = np.vstack(images)
    y_train = np.hstack(labels).astype(np.int32)
    return X_train, y_train

print("Loading dataset...")
cifar_path = download_and_extract_cifar10(DATA_DIR)
X_train_all, y_train_all = load_cifar10_local(cifar_path)
print(f"Data loaded: {X_train_all.shape} images.")

### Pipeline Data dengan Google Grain

In [None]:
class CIFARSource(grain.RandomAccessDataSource):
    def __init__(self, images, labels):
        self._images = images
        self._labels = labels
        
    def __len__(self):
        return len(self._images)
        
    def __getitem__(self, index):
        from PIL import Image
        img = self._images[index].reshape(3, 32, 32).transpose(1, 2, 0).astype(np.uint8)
        img = Image.fromarray(img).resize((IMAGE_SIZE, IMAGE_SIZE), Image.BILINEAR)
        img = np.array(img).astype(np.float32)
        image = (img / 255.0) * 2.0 - 1.0
        return {'image': image, 'label': self._labels[index]}

def create_loader(data_source, batch_size, shuffle=False, seed=0):
    sampler = grain.IndexSampler(num_records=len(data_source), shard_options=grain.NoSharding(), shuffle=shuffle, num_epochs=1, seed=seed)
    dataloader = grain.DataLoader(data_source=data_source, sampler=sampler, worker_count=0)
    
    class BatchIterator:
        def __init__(self, loader, batch_size, num_records): self.loader, self.batch_size, self.num_records = loader, batch_size, num_records
        def __len__(self): return (self.num_records + self.batch_size - 1) // self.batch_size
        def __iter__(self):
            batch_images, batch_labels = [], []
            for record in self.loader:
                batch_images.append(record['image'])
                batch_labels.append(record['label'])
                if len(batch_images) == self.batch_size:
                    yield np.stack(batch_images), np.array(batch_labels)
                    batch_images, batch_labels = [], []
            if batch_images: yield np.stack(batch_images), np.array(batch_labels)
    return BatchIterator(dataloader, batch_size, len(data_source))

train_loader = create_loader(CIFARSource(X_train_all, y_train_all), BATCH_SIZE, shuffle=True, seed=42)

## 5. Arsitektur Model

In [None]:
class Generator(nnx.Module):
    def __init__(self, nz, ngf, nc, num_classes, rngs: nnx.Rngs):
        normal_init = nnx.initializers.normal(0.02)
        self.num_classes = num_classes
        self.convt1 = nnx.ConvTranspose(nz + num_classes, ngf * 8, kernel_size=(4, 4), strides=(1, 1), padding='VALID', use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn1 = nnx.BatchNorm(ngf * 8, rngs=rngs)
        self.convt2 = nnx.ConvTranspose(ngf * 8, ngf * 4, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn2 = nnx.BatchNorm(ngf * 4, rngs=rngs)
        self.convt3 = nnx.ConvTranspose(ngf * 4, ngf * 2, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn3 = nnx.BatchNorm(ngf * 2, rngs=rngs)
        self.convt4 = nnx.ConvTranspose(ngf * 2, ngf, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn4 = nnx.BatchNorm(ngf, rngs=rngs)
        self.convt5 = nnx.ConvTranspose(ngf, nc, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, rngs=rngs, kernel_init=normal_init)

    def __call__(self, z, c, train: bool = True, use_running_average: bool = None):
        if use_running_average is None: use_running_average = not train
        h = jnp.concatenate([z, c], axis=1).reshape(-1, 1, 1, z.shape[1] + c.shape[1])
        h = nnx.relu(self.bn1(self.convt1(h), use_running_average=use_running_average))
        h = nnx.relu(self.bn2(self.convt2(h), use_running_average=use_running_average))
        h = nnx.relu(self.bn3(self.convt3(h), use_running_average=use_running_average))
        h = nnx.relu(self.bn4(self.convt4(h), use_running_average=use_running_average))
        return nnx.tanh(self.convt5(h))

In [None]:
class Discriminator(nnx.Module):
    def __init__(self, nc, ndf, num_classes, rngs: nnx.Rngs):
        normal_init = nnx.initializers.normal(0.02)
        self.conv1 = nnx.Conv(nc + num_classes, ndf, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.conv2 = nnx.Conv(ndf, ndf * 2, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn2 = nnx.BatchNorm(ndf * 2, rngs=rngs)
        self.conv3 = nnx.Conv(ndf * 2, ndf * 4, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn3 = nnx.BatchNorm(ndf * 4, rngs=rngs)
        self.conv4 = nnx.Conv(ndf * 4, ndf * 8, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, rngs=rngs, kernel_init=normal_init)
        self.bn4 = nnx.BatchNorm(ndf * 8, rngs=rngs)
        self.conv5 = nnx.Conv(ndf * 8, 1, kernel_size=(4, 4), strides=(1, 1), padding='VALID', use_bias=False, rngs=rngs, kernel_init=normal_init)

    def __call__(self, x, c, train: bool = True, use_running_average: bool = None):
        if use_running_average is None: use_running_average = not train
        c_spatial = jnp.broadcast_to(c[:, None, None, :], (x.shape[0], x.shape[1], x.shape[2], c.shape[1]))
        h = jnp.concatenate([x, c_spatial], axis=-1)
        h = nnx.leaky_relu(self.conv1(h), negative_slope=0.2)
        h = nnx.leaky_relu(self.bn2(self.conv2(h), use_running_average=use_running_average), negative_slope=0.2)
        h = nnx.leaky_relu(self.bn3(self.conv3(h), use_running_average=use_running_average), negative_slope=0.2)
        h = nnx.leaky_relu(self.bn4(self.conv4(h), use_running_average=use_running_average), negative_slope=0.2)
        return self.conv5(h).flatten()

## 6. Logic Training

In [None]:
class CGAN(nnx.Module):
    def __init__(self, nz, ngf, nc, ndf, num_classes, rngs: nnx.Rngs):
        self.netG = Generator(nz, ngf, nc, num_classes, rngs)
        self.netD = Discriminator(nc, ndf, num_classes, rngs)

rngs = nnx.Rngs(0)
model = CGAN(NZ, NGF, NC, NDF, NUM_CLASSES, rngs=rngs)
optimizerG = nnx.Optimizer(model.netG, optax.adam(LR, b1=BETA1), wrt=nnx.Param)
optimizerD = nnx.Optimizer(model.netD, optax.adam(LR, b1=BETA1), wrt=nnx.Param)

def loss_bce(logits, labels): return jnp.mean(optax.sigmoid_binary_cross_entropy(logits, labels))

@nnx.jit
def train_step_D(model, optimizerD, real_x, c_vec, noise):
    fake_x = jax.lax.stop_gradient(model.netG(noise, c_vec, train=True))
    def loss_fn(model):
        real_logits, fake_logits = model.netD(real_x, c_vec, train=True), model.netD(fake_x, c_vec, train=True)
        loss = loss_bce(real_logits, jnp.ones_like(real_logits)) + loss_bce(fake_logits, jnp.zeros_like(fake_logits))
        return loss, (nnx.sigmoid(real_logits), nnx.sigmoid(fake_logits))
    (loss, (real_p, fake_p)), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizerD.update(model.netD, grads.netD)
    return loss, jnp.mean(real_p), jnp.mean(fake_p)

@nnx.jit
def train_step_G(model, optimizerG, c_vec, noise):
    def loss_fn(model):
        fake_logits = model.netD(model.netG(noise, c_vec, train=True), c_vec, train=True)
        loss = loss_bce(fake_logits, jnp.ones_like(fake_logits))
        return loss, nnx.sigmoid(fake_logits)
    (loss, outD), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model)
    optimizerG.update(model.netG, grads.netG)
    return loss, jnp.mean(outD)

## 7. Loop Pelatihan (Main Loop)

In [None]:
print("Starting Training Loop...")
step_rng = jax.random.PRNGKey(0)
fixed_latent = jax.random.normal(jax.random.PRNGKey(42), (NVIZ, NZ))
fixed_y = jnp.array([i % NUM_CLASSES for i in range(NVIZ)])
fixed_cvec = jax.nn.one_hot(fixed_y, NUM_CLASSES)

for epoch in range(NUM_EPOCH):
    with tqdm(train_loader, unit="batch", desc=f"Epoch {epoch+1}") as tepoch:
        for batch_idx, (real_x, y) in enumerate(tepoch):
            c_vec = jax.nn.one_hot(y, NUM_CLASSES)
            step_rng, rng_d, rng_g = jax.random.split(step_rng, 3)
            errD, D_x, D_G_z1 = train_step_D(model, optimizerD, real_x, c_vec, jax.random.normal(rng_d, (real_x.shape[0], NZ)))
            errG, D_G_z2 = train_step_G(model, optimizerG, c_vec, jax.random.normal(rng_g, (real_x.shape[0], NZ)))
            if batch_idx % 10 == 0: tepoch.set_postfix(Loss_D=f"{errD:.4f}", Loss_G=f"{errG:.4f}", Dx=f"{D_x:.4f}", Dgz=f"{D_G_z2:.4f}")

    # Visualisasi dan simpan ke GCS
    fake_samples = model.netG(fixed_latent, fixed_cvec, train=False)
    grid = set_grid(fake_samples, num_cells=NVIZ)
    plt.figure(figsize=(8, 8))
    plt.imshow(np.transpose(np.array(normalize(grid, 0, 1)), (1, 2, 0)))
    plt.axis('off')
    
    sample_name = f'samples_epoch_{epoch+1}.png'
    sample_path = os.path.join(SAMPLE_DIR, sample_name)
    plt.savefig(sample_path)
    upload_to_gcs(sample_path, f"samples/{sample_name}")
    plt.show()

    # Checkpointing ke GCS
    save_checkpoint(model.netG, epoch + 1, filedir="generator", gcs_prefix="models")
    save_checkpoint(model.netD, epoch + 1, filedir="discriminator", gcs_prefix="models")