In [53]:
import jax
import jax.numpy as jnp
from jax import random, lax
import functools

Helper: GELU, LayerNorm, conv2d & depthwise conv

In [54]:
def gelu(x):
    return 0.5 * x * (1.0 + jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3)))


def layer_norm(x, gamma, beta, eps=1e-5):
    """
    x: (..., C)  (channels-last)
    gamma, beta: (C,)
    """
    mean = jnp.mean(x, axis=-1, keepdims=True)
    var = jnp.var(x, axis=-1, keepdims=True)
    x_hat = (x - mean) / jnp.sqrt(var + eps)
    return gamma * x_hat + beta


def conv2d(x, w, stride=(1, 1), padding="SAME", feature_group_count=1):
    """
    x: (N, H, W, C_in)
    w: (KH, KW, C_in, C_out)
    """
    return lax.conv_general_dilated(
        lhs=x,
        rhs=w,
        window_strides=stride,
        padding=padding,
        dimension_numbers=("NHWC", "HWIO", "NHWC"),
        feature_group_count=feature_group_count,
    )


def depthwise_conv2d(x, w, stride=(1, 1), padding="SAME"):
    """
    Depthwise conv:
    x: (N, H, W, C)
    w: (KH, KW, 1, C)
    feature_group_count = C
    """
    C = x.shape[-1]
    return conv2d(x, w, stride=stride, padding=padding, feature_group_count=C)


**Penjelasan:**

* `def gelu(x):`
  Mendefinisikan fungsi `gelu` yang menerima tensor `x` (bentuk arbitrer, biasanya `(N, ..., C)`).

* `return 0.5 * x * (1.0 + jnp.tanh(...))`
  Ini implementasi aproksimasi GELU yang populer (Hendrycks & Gimpel).
  Di dalamnya:

  * `x**3` → pangkat tiga tiap elemen `x`.
  * `0.044715 * x**3` → konstanta empiris untuk akurasi aproksimasi.
  * `(x + 0.044715 * x**3)` → argumen untuk fungsi `tanh`.
  * `jnp.sqrt(2.0 / jnp.pi)` → konstanta (\sqrt{2/\pi}).
  * `jnp.tanh(jnp.sqrt(2.0 / jnp.pi) * (...))` → tanh dari argumen tadi.
  * `1.0 + tanh(...)` → shift supaya dalam range kira-kira [0, 2].
  * `0.5 * x * (1 + tanh(...))` → skema akhir GELU aproksimasi.


**LayerNorm:**

* `def layer_norm(x, gamma, beta, eps=1e-5):`
  Mendefinisikan LayerNorm generik:

  * `x`: tensor dengan channel di dimensi terakhir, misalnya `(N, H, W, C)` atau `(N, C)`.
  * `gamma`, `beta`: parameter skala dan bias per channel, shape `(C,)`.
  * `eps`: konstanta kecil untuk menghindari pembagian nol.

* `""" x: (..., C) ... """`
  Docstring: menjelaskan bahwa `x` punya dimensi terakhir = channels.

* `mean = jnp.mean(x, axis=-1, keepdims=True)`

  * Menghitung rata-rata (`mean`) sepanjang dimensi channel (dimensi terakhir).
  * `axis=-1` → pakai dimensi terakhir.
  * `keepdims=True` → supaya hasilnya bisa di-*broadcast* kembali ke `x` (shape jadi `(..., 1)`).

* `var = jnp.var(x, axis=-1, keepdims=True)`

  * Menghitung varians per posisi (untuk LN) sepanjang channel.

* `x_hat = (x - mean) / jnp.sqrt(var + eps)`

  * Normalisasi: kurangi mean dan bagi akar varians + eps.
  * `x_hat` sekarang punya mean 0 dan var 1 di setiap posisi (per sample/pixel).

* `return gamma * x_hat + beta`

  * `gamma` dan `beta` shape `(C,)`, akan otomatis *broadcast* ke `(..., C)`.
  * Menghasilkan output LN yang sudah distandarisasi dan di-skala/bias.



**Conv2D:**

* `def conv2d(x, w, stride=(1, 1), padding="SAME", feature_group_count=1):`

  * Wrapper sederhana untuk konvolusi 2D JAX.
  * `x`: input tensor bentuk `(N, H, W, C_in)` (NHWC).
  * `w`: weight kernel `(KH, KW, C_in, C_out)`.
  * `stride`: tuple `(stride_h, stride_w)`.
  * `padding`: `"SAME"` atau `"VALID"`.
  * `feature_group_count`: untuk group conv (dipakai di depthwise).

* Docstring menjelaskan shape `x` dan `w`.

* `return lax.conv_general_dilated(...`

  * `lhs=x`: input (left-hand side).
  * `rhs=w`: filter/kernel.
  * `window_strides=stride`: stride.
  * `padding=padding`: scheme padding (string).
  * `dimension_numbers=("NHWC", "HWIO", "NHWC")`:

    * Format input: `"NHWC"`
    * Format kernel: `"HWIO"` (Height, Width, InputChan, OutputChan)
    * Format output: `"NHWC"`.
  * `feature_group_count=feature_group_count`:

    * Jika 1 → conv standard.
    * Jika `C_in` → depthwise (tiap channel di-*group* sendiri).


**Penjelasan:**

* `def depthwise_conv2d(x, w, stride=(1, 1), padding="SAME"):`
  Fungsi khusus untuk depthwise convolution:

* Docstring:

  * `x` shape `(N, H, W, C)`.
  * `w` shape `(KH, KW, C, 1)` → satu filter per channel input, menghasilkan 1 channel output per channel → total tetap `C`.

* `C = x.shape[-1]`
  Mengambil jumlah channel dari dimensi terakhir tensor `x`.

* `return conv2d(... feature_group_count=C)`

  * Menggunakan `conv2d` umum, tapi `feature_group_count=C`.
  * Ini berarti setiap channel diproses oleh grup terpisah → **depthwise convolution**.



Init weight Simplified

In [55]:
def init_conv_weight(key, ksize, in_ch, out_ch):
    kh, kw = ksize
    fan_in = kh * kw * in_ch
    std = 1.0 / jnp.sqrt(fan_in)
    w = std * random.normal(key, (kh, kw, in_ch, out_ch))
    return w


def init_depthwise_weight(key, ksize, channels):
    kh, kw = ksize
    fan_in = kh * kw  # per-channel, in_channels_per_group = 1
    std = 1.0 / jnp.sqrt(fan_in)
    # (KH, KW, in_channels_per_group=1, out_channels_per_group=1) for each of C groups
    # Tapi karena feature_group_count = C, total out_channels = C * 1 = C
    w = std * random.normal(key, (kh, kw, 1, channels))
    return w



def init_dense_weight(key, in_dim, out_dim):
    fan_in = in_dim
    std = 1.0 / jnp.sqrt(fan_in)
    w = std * random.normal(key, (in_dim, out_dim))
    b = jnp.zeros((out_dim,))
    return w, b


def init_layer_norm_params(dim):
    gamma = jnp.ones((dim,))
    beta = jnp.zeros((dim,))
    return {"gamma": gamma, "beta": beta}

**Penjelasan:**

* `def init_conv_weight(key, ksize, in_ch, out_ch):`

  * Fungsi untuk inisialisasi kernel conv biasa.
  * `key`: PRNGKey JAX untuk random.
  * `ksize`: tuple `(kh, kw)`.
  * `in_ch`, `out_ch`: jumlah channel.

* `kh, kw = ksize`
  Pecah ukuran kernel menjadi tinggi dan lebar.

* `fan_in = kh * kw * in_ch`

  * Fan-in = banyaknya input per neuron (per output channel per posisi).

* `std = 1.0 / jnp.sqrt(fan_in)`

  * Standar deviasi untuk inisialisasi (mirip Xavier/He sederhana).

* `w = std * random.normal(key, (kh, kw, in_ch, out_ch))`

  * Generate tensor normal N(0,1) lalu skala dengan `std`.
  * Shape sesuai kernel conv: `(KH, KW, C_in, C_out)`.

* `return w`

  * Mengembalikan weight conv.

* Sama seperti `init_conv_weight`, tapi khusus depthwise:

  * `channels`: jumlah channel input.
  * `fan_in = kh * kw * 1` karena per channel bekerja sendiri.
  * Shape output: `(KH, KW, C, 1)` sesuai depthwise conv.

---

**Penjelasan:**

* `def init_dense_weight(key, in_dim, out_dim):`

  * Inisialisasi matrix weight dan bias untuk layer dense/linear.

* `fan_in = in_dim` → jumlah input per neuron.

* `std = 1.0 / jnp.sqrt(fan_in)` → scaling.

* `w = std * random.normal(key, (in_dim, out_dim))`

  * Weight matrix, shape `(in_dim, out_dim)`.

* `b = jnp.zeros((out_dim,))`

  * Bias vector, awalnya semua nol.

* `return w, b`

  * Mengembalikan pasangan weight & bias.




ConvNeXt Block (DW 7×7, LN, MLP 4×, LayerScale, residual)

In [56]:
def init_convnext_block_params(key, dim, mlp_ratio=4, layer_scale_init=1e-6):
    """
    ConvNeXt block:
    - DWConv 7x7
    - LayerNorm (channels-last)
    - MLP: dim -> 4*dim -> dim
    - LayerScale (gamma)
    - Residual
    """
    k1, k2, k3 = random.split(key, 3)
    # depthwise conv
    dw_w = init_depthwise_weight(k1, (7, 7), dim)

    # layernorm params
    ln = init_layer_norm_params(dim)

    # MLP
    hidden_dim = dim * mlp_ratio
    w1, b1 = init_dense_weight(k2, dim, hidden_dim)
    w2, b2 = init_dense_weight(k3, hidden_dim, dim)

    # layer scale (per-channel)
    layer_scale = jnp.ones((dim,)) * layer_scale_init

    return {
        "dw_conv": {"w": dw_w},
        "ln": ln,
        "mlp": {
            "w1": w1,
            "b1": b1,
            "w2": w2,
            "b2": b2,
        },
        "layer_scale": layer_scale,
    }

Forward block

In [57]:
def convnext_block_forward(params, x):
    """
    x: (N, H, W, C)
    """
    shortcut = x
    dim = x.shape[-1]

    # 1) Depthwise conv 7x7
    w_dw = params["dw_conv"]["w"]
    x = depthwise_conv2d(x, w_dw, stride=(1, 1), padding="SAME")

    # 2) LayerNorm (channels-last)
    ln_params = params["ln"]
    x = layer_norm(x, ln_params["gamma"], ln_params["beta"])

    # 3) MLP: (N,H,W,C) -> (N*H*W, C)
    N, H, W, C = x.shape
    x_flat = x.reshape(N * H * W, C)

    mlp = params["mlp"]
    x_flat = jnp.dot(x_flat, mlp["w1"]) + mlp["b1"]
    x_flat = gelu(x_flat)
    x_flat = jnp.dot(x_flat, mlp["w2"]) + mlp["b2"]

    x = x_flat.reshape(N, H, W, C)

    # 4) LayerScale + residual
    gamma = params["layer_scale"]  # (C,)
    x = shortcut + gamma * x

    return x

**Forward:**

* `def convnext_block_forward(params, x):`

  * Fungsi forward untuk satu block.
  * `params`: dict dari `init_convnext_block_params`.
  * `x`: feature map `(N, H, W, C)`.

* Docstring menjelaskan shape `x`.

* `shortcut = x`

  * Menyimpan input original untuk jalur residual.

* `dim = x.shape[-1]`

  * Mengambil jumlah channel C (meski tidak dipakai eksplisit, cuma referensi).

---

**1) Depthwise Conv**

* `w_dw = params["dw_conv"]["w"]`

  * Ambil kernel depthwise dari dict params.

* `x = depthwise_conv2d(x, w_dw, stride=(1, 1), padding="SAME")`

  * Terapkan depthwise conv 7×7 dengan stride 1, padding “SAME”.
  * Setelah ini, shape `x` tetap `(N, H, W, C)`.

---

**2) LayerNorm**

* `ln_params = params["ln"]`

  * Ambil param LN `{gamma, beta}`.

* `x = layer_norm(x, ln_params["gamma"], ln_params["beta"])`

  * Terapkan LayerNorm di dimensi channel.

---

**3) MLP (per posisi pixel)**

* `N, H, W, C = x.shape`

  * Unpack dim, supaya bisa reshape.

* `x_flat = x.reshape(N * H * W, C)`

  * Flatten semua posisi `(H,W)` dan batch `N` ke satu dimensi besar.
  * Kita akan menganggap setiap posisi `(n,h,w)` sebagai satu vektor `C` dan melewati MLP sama untuk semua posisi.

* `mlp = params["mlp"]`

  * Shortcut ke dict MLP `{w1, b1, w2, b2}`.

* `x_flat = jnp.dot(x_flat, mlp["w1"]) + mlp["b1"]`

  * Linear pertama: `(N*H*W, C) · (C, 4C) → (N*H*W, 4C)`.

* `x_flat = gelu(x_flat)`

  * Aktivasi GELU di hidden.

* `x_flat = jnp.dot(x_flat, mlp["w2"]) + mlp["b2"]`

  * Linear kedua: `(N*H*W, 4C) · (4C, C) → (N*H*W, C)`.

* `x = x_flat.reshape(N, H, W, C)`

  * Kembalikan ke bentuk spasial semula.

---

**4) LayerScale + Residual**

* `gamma = params["layer_scale"]  # (C,)`

  * Ambil vektor gamma per channel.

* `x = shortcut + gamma * x`

  * `gamma * x` → meng-scale output block per channel.
  * Tambah `shortcut` → residual connection.

* `return x`

  * Output block: `(N, H, W, C)` dengan fitur yang sudah diproses.


Downsample layer & stem & head

In [58]:
def init_downsample_layer_params(key, in_dim, out_dim):
    k1, k2 = random.split(key)
    ln = init_layer_norm_params(in_dim)
    w = init_conv_weight(k2, (2, 2), in_dim, out_dim)
    b = jnp.zeros((out_dim,))
    return {
        "ln": ln,
        "conv": {"w": w, "b": b},
    }


def downsample_forward(params, x):
    """
    x: (N, H, W, C_in)
    """
    ln = params["ln"]
    x = layer_norm(x, ln["gamma"], ln["beta"])
    w = params["conv"]["w"]
    b = params["conv"]["b"]
    x = conv2d(x, w, stride=(2, 2), padding="SAME")
    x = x + b  # broadcast
    return x


**Downsample:**

* `def init_downsample_layer_params(key, in_dim, out_dim):`

  * Inisialisasi layer downsample antara dua stage.
  * `in_dim`: C_in.
  * `out_dim`: C_out.

* `k1, k2 = random.split(key)`

  * `k1` → LN (meski LN tidak random).
  * `k2` → conv 2×2.

* `ln = init_layer_norm_params(in_dim)`

  * LN untuk channel `in_dim`.

* `w = init_conv_weight(k2, (2, 2), in_dim, out_dim)`

  * Kernel conv 2×2 yang mengubah channel `in_dim → out_dim`.

* `b = jnp.zeros((out_dim,))`

  * Bias conv layer.

* `return { "ln": ln, "conv": {"w": w, "b": b} }`

  * Struktur params downsample: LN + conv.



In [59]:
def init_stem_params(key, in_ch=3, dim=96):
    w = init_conv_weight(key, (4, 4), in_ch, dim)
    b = jnp.zeros((dim,))
    return {"w": w, "b": b}


def stem_forward(params, x):
    """
    x: (N, H, W, 3), pixel [0,1] or [0,255] (silakan normalisasi sendiri)
    """
    w = params["w"]
    b = params["b"]
    x = conv2d(x, w, stride=(4, 4), padding="SAME")
    x = x + b
    return x


global avg pool + LN + FC ke num_classes

In [60]:
def init_head_params(key, dim, num_classes):
    k1, k2 = random.split(key)
    ln = init_layer_norm_params(dim)
    w, b = init_dense_weight(k2, dim, num_classes)
    return {"ln": ln, "fc": {"w": w, "b": b}}


def head_forward(params, x):
    """
    x: (N, H, W, C)
    """
    # Global average pooling
    x = jnp.mean(x, axis=(1, 2))  # (N, C)

    # LN over channels
    ln = params["ln"]
    x = layer_norm(x, ln["gamma"], ln["beta"])

    # Linear classifier
    w = params["fc"]["w"]
    b = params["fc"]["b"]
    logits = jnp.dot(x, w) + b
    return logits


Full ConvNeXt-Tiny params & forward
Konfigurasi ConvNeXt-Tiny

Depths: [3, 3, 9, 3]

Dims: [96, 192, 384, 768]

In [61]:
CONVNEXT_TINY_DEPTHS = [3, 3, 9, 3]
CONVNEXT_TINY_DIMS   = [96, 192, 384, 768]

In [62]:
def init_convnext_tiny_params(key, num_classes=1000, in_ch=3):
    """
    Return: nested dict: {'stem', 'stages', 'head'}
    """
    depths = CONVNEXT_TINY_DEPTHS
    dims = CONVNEXT_TINY_DIMS

    keys = random.split(key, 2 + sum(depths) + (len(dims) - 1) + 1)
    k_idx = 0

    # Stem
    stem_key = keys[k_idx]; k_idx += 1
    stem = init_stem_params(stem_key, in_ch=in_ch, dim=dims[0])

    stages = []
    block_keys_needed = sum(depths)
    block_keys = keys[k_idx: k_idx + block_keys_needed]
    k_idx += block_keys_needed

    # Downsample layers (between stages)
    downsample_keys = keys[k_idx: k_idx + (len(dims) - 1)]
    k_idx += (len(dims) - 1)

    # Build stages
    bk = 0
    for stage_idx, (depth, dim) in enumerate(zip(depths, dims)):
        # Blocks in this stage
        blocks = []
        for _ in range(depth):
            blocks.append(
                init_convnext_block_params(block_keys[bk], dim=dim, mlp_ratio=4, layer_scale_init=1e-6)
            )
            bk += 1

        # Downsample from previous stage to this stage (except first)
        if stage_idx == 0:
            downsample = None
        else:
            in_dim = dims[stage_idx - 1]
            out_dim = dim
            ds_params = init_downsample_layer_params(downsample_keys[stage_idx - 1], in_dim, out_dim)
            downsample = ds_params

        stages.append(
            {
                "blocks": blocks,
                "downsample": downsample,
            }
        )

    # Head
    head_key = keys[k_idx]; k_idx += 1
    head = init_head_params(head_key, dim=dims[-1], num_classes=num_classes)

    return {
        "stem": stem,
        "stages": stages,
        "head": head,
    }


Forward ConvNeXt-Tiny

In [63]:
def convnext_tiny_forward(params, x, train=False):
    """
    x: (N, H, W, 3)
    return: logits (N, num_classes)
    train flag currently unused (no dropout/stochastic depth),
           tapi kamu bisa extend nanti.
    """
    # Stem
    x = stem_forward(params["stem"], x)

    # Stages
    stages = params["stages"]
    for i, stage in enumerate(stages):
        if stage["downsample"] is not None:
            x = downsample_forward(stage["downsample"], x)

        for block_params in stage["blocks"]:
            x = convnext_block_forward(block_params, x)

    # Head
    logits = head_forward(params["head"], x)
    return logits


Training dengan autograd (AdamW manual)

In [64]:
def cross_entropy_loss(logits, labels):
    """
    logits: (N, num_classes)
    labels: (N,) int32
    """
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    nll = -log_probs[jnp.arange(labels.shape[0]), labels]
    return jnp.mean(nll)


In [65]:
def loss_fn(params, x, y):
    logits = convnext_tiny_forward(params, x, train=True)
    return cross_entropy_loss(logits, y)


AdamW

In [66]:
def init_adamw_state(params, lr=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.05):
    m = jax.tree.map(jnp.zeros_like, params)
    v = jax.tree.map(jnp.zeros_like, params)
    return {
        "m": m,
        "v": v,
        "t": jnp.array(0, dtype=jnp.int32),
        "lr": lr,
        "beta1": beta1,
        "beta2": beta2,
        "eps": eps,
        "weight_decay": weight_decay,
    }


def adamw_update(params, grads, opt_state):
    beta1 = opt_state["beta1"]
    beta2 = opt_state["beta2"]
    eps = opt_state["eps"]
    lr = opt_state["lr"]
    wd = opt_state["weight_decay"]

    t = opt_state["t"] + 1

    m = jax.tree.map(lambda m, g: beta1 * m + (1 - beta1) * g, opt_state["m"], grads)
    v = jax.tree.map(lambda v, g: beta2 * v + (1 - beta2) * (g * g), opt_state["v"], grads)

    m_hat = jax.tree.map(lambda m: m / (1 - beta1**t), m)
    v_hat = jax.tree.map(lambda v: v / (1 - beta2**t), v)

    def update_param(p, mh, vh):
        # Adam step + decoupled weight decay
        adam_step = mh / (jnp.sqrt(vh) + eps)
        return p - lr * (adam_step + wd * p)

    new_params = jax.tree.map(update_param, params, m_hat, v_hat)

    new_opt_state = {
        "m": m,
        "v": v,
        "t": t,
        "lr": lr,
        "beta1": beta1,
        "beta2": beta2,
        "eps": eps,
        "weight_decay": wd,
    }
    return new_params, new_opt_state


Train step (jit + autograd)

In [67]:
@jax.jit
def train_step(params, opt_state, x, y):
    """
    x: (N, H, W, 3) float32
    y: (N,) int32 labels
    """
    loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
    new_params, new_opt_state = adamw_update(params, grads, opt_state)
    return new_params, new_opt_state, loss

Inference

In [68]:
@jax.jit
def predict(params, x):
    """
    x: (N, H, W, 3)
    return: predicted class (N,)
    """
    logits = convnext_tiny_forward(params, x, train=False)
    return jnp.argmax(logits, axis=-1)


Ex:

In [52]:
key = random.PRNGKey(42)

# 1) Init params ConvNeXt-Tiny untuk, misalnya, CIFAR-10 (10 kelas)
num_classes = 10
params = init_convnext_tiny_params(key, num_classes=num_classes, in_ch=3)

# 2) Init optimizer state
opt_state = init_adamw_state(params, lr=1e-3, weight_decay=0.05)

# 3) Dummy batch (misal sudah di-resize ke 224x224 dan dinormalisasi)
batch_size = 4
x_dummy = random.normal(key, (batch_size, 224, 224, 3))  # contoh
y_dummy = jnp.array([0, 1, 2, 3], dtype=jnp.int32)

# 4) Satu step training
params, opt_state, loss_val = train_step(params, opt_state, x_dummy, y_dummy)
print("loss:", float(loss_val))

# 5) Inference
preds = predict(params, x_dummy)
print("preds:", preds)

loss: 2.5652852058410645
preds: [0 1 2 3]
