In [1]:
import jax, jax.numpy as jp
from flax.struct import dataclass, field, PyTreeNode
from functools import partial

nonpytree_node = partial(field, pytree_node=False)

In [2]:
@partial(dataclass, frozen=False)
class DT:
    a: int = nonpytree_node()
    b: jp.ndarray = field(default=None)
    
    @classmethod
    def create(cls, a: int):
        instance = cls(a=a)
        instance.b = instance.make_b()
        return instance
        
    def make_b(self):
        return jp.full((3, 3), self.a)


class DT2(DT):
    def make_b(self):
        return jp.full((6, 6), self.a)
    
tmp = DT.create(3)
tmp2 = DT2.create(3)
print(tmp)
print(tmp2)

2024-08-09 19:28:07.503737: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.6.20). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


DT(a=3, b=Array([[3, 3, 3],
       [3, 3, 3],
       [3, 3, 3]], dtype=int32, weak_type=True))
DT2(a=3, b=Array([[3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3]], dtype=int32, weak_type=True))


In [15]:
tmp2.a = 4
tmp2

DT2(a=4, b=Array([[3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3]], dtype=int32, weak_type=True))

In [16]:
class PN(PyTreeNode):
    a: int = nonpytree_node()
    b: jp.ndarray = field(default=None)
    
    @classmethod
    def create(cls, a: int):
        instance = cls(a=a)
        instance.b = instance.make_b()
        return instance
    
    def make_b(self):
        return jp.full((3, 3), self.a)
    
    
class PN2(PN):
    def make_b(self):
        return jp.full((6, 6), self.a)
    
tmp = PN.create(3)
tmp2 = PN2.create(3)
print(tmp)
print(tmp2)

FrozenInstanceError: cannot assign to field 'b'

In [18]:
class PN(PyTreeNode):
    a: int = nonpytree_node()
    b: jp.ndarray = field(default=None)
    
    @classmethod
    def create(cls, a: int):
        instance = cls(a=a)
        instance = instance.replace(b=instance.make_b())
        return instance
    
    def make_b(self):
        return jp.full((3, 3), self.a)
    
    
class PN2(PN):
    def make_b(self):
        return jp.full((6, 6), self.a)
    
tmp = PN.create(3)
tmp2 = PN2.create(3)
print(tmp)
print(tmp2)

PN(a=3, b=Array([[3, 3, 3],
       [3, 3, 3],
       [3, 3, 3]], dtype=int32, weak_type=True))
PN2(a=3, b=Array([[3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3],
       [3, 3, 3, 3, 3, 3]], dtype=int32, weak_type=True))


In [3]:
from pathlib import Path

import optax
import flax.linen as nn
from flax.training.train_state import TrainState

In [4]:
from dataclasses import dataclass as py_dataclass

@py_dataclass
class Config:
    lr: float = 0.01
    momentum: float = 0.9
    weight_decay: float = 0.01
    num_classes: int = 10
    num_epochs: int = 10

In [45]:
from torchvision.datasets import MNIST
from torchvision import transforms as T
from torch.utils.data import DataLoader

trns = T.Compose([
    T.RandomRotation(20),
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,)),
    T.Lambda(lambda x: x.permute(1, 2, 0))
])

test_trns = T.Compose([
    T.ToTensor(),
    T.Normalize((0.1307,), (0.3081,)),
    T.Lambda(lambda x: x.permute(1, 2, 0))
])

train_set = MNIST(root=Path.home() / 'Datasets', train=True, download=True, transform=trns)
test_set = MNIST(root=Path.home() / 'Datasets', train=False, download=True, transform=test_trns)

In [46]:
import numpy as np

def numpy_collate(batch):
    if isinstance(batch, np.ndarray):
        return batch
    elif isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        # batchify the list of single data pairs (e.g. [(x1, y1), (x2, y2), ...])
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)
    
train_loader = DataLoader(train_set, batch_size=32, shuffle=True, collate_fn=numpy_collate)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False, collate_fn=numpy_collate)

@py_dataclass
class DataModule:
    train_loader: DataLoader
    test_loader: DataLoader

    def train_dataloader(self):
        return self.train_loader

    def test_dataloader(self):
        return self.test_loader

In [47]:
out = next(iter(train_loader))
out[0].shape

(32, 28, 28, 1)

In [48]:
class Network(nn.Module):
    features: int = 64
    num_classes: int = 10
    
    @nn.compact
    def __call__(self, x: jp.ndarray, train: bool = True):
        x = nn.Conv(self.features, (3, 3), (1, 1), padding='SAME')(x)
        x = nn.gelu(x)
        x = nn.Conv(self.features * 2, (3, 3), (1, 1), padding='SAME')(x)
        x = nn.gelu(x)
        x = nn.avg_pool(x, (2, 2), (2, 2))
        x = nn.Conv(self.features * 4, (3, 3), (1, 1), padding='SAME')(x)
        x = nn.gelu(x)
        x = nn.avg_pool(x, (2, 2), (2, 2))
        x = nn.Conv(self.features * 8, (3, 3), (1, 1), padding='SAME')(x)
        x = nn.gelu(x)
        x = nn.avg_pool(x, (2, 2), (2, 2))
        
        x = x.reshape((x.shape[0], -1))
        x = nn.Dropout(0.1)(x, deterministic=not train)
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.1)(x, deterministic=not train)
        x = nn.Dense(self.num_classes)(x)
        return x

In [49]:
import time

@partial(dataclass, frozen=False)
class TrainerDT:
    config: Config = nonpytree_node()
    model_def: nn.Module = nonpytree_node()
    data_module: DataModule = nonpytree_node()
    sample_input: jp.ndarray = nonpytree_node()
    state: TrainState = field(default=None)
    step: int = nonpytree_node(default=0)
    
    train_step: callable = nonpytree_node(default=None)
    eval_step: callable = nonpytree_node(default=None)
    
    @classmethod
    def create(cls, config, model_def, data_module, sample_input):
        trainer = cls(
            config=config,
            model_def=model_def,
            data_module=data_module,
            sample_input=sample_input,
            step=0
        )
        
        trainer.state = trainer.make_state()
        trainer.train_step, trainer.eval_step = trainer.make_jit_fn()
        return trainer
    
    def get_model_rngs(self, rng, train=True):
        if train:
            rng = jax.random.split(rng, 2)
            return {"params": rng[0], "dropout": rng[1]}
        else:
            return {"params": rng}
    
    def batch_to_input(self, batch):
        return batch[0]
    
    def init(self, rng):
        rngs = self.get_model_rngs(rng)
        sample_input = self.batch_to_input(self.sample_input)
        variables = self.model_def.init(rngs, sample_input, train=True)
        return variables
    
    def make_state(self):
        params = self.init(rng=jax.random.PRNGKey(0))
        tx = optax.chain(
            optax.clip_by_global_norm(1.0),
            optax.scale_by_adam(b1=self.config.momentum),
            optax.add_decayed_weights(self.config.weight_decay),
            optax.scale_by_schedule(optax.constant_schedule(self.config.lr))
        )
        
        state = TrainState.create(apply_fn=self.model_def.apply, params=params, tx=tx)
        return state
    
    def make_train_step(self):
        def train_step(state, batch, rng):
            rng, step_rng = jax.random.split(rng)
            loss_fn = partial(self.loss_fn, model=self.model_def, batch=batch)
            loss, grad = jax.value_and_grad(loss_fn)(state.params)
            state = state.apply_gradients(grads=grad)
            return state, {"loss": loss, "rng": rng}
        return train_step
    
    def make_eval_step(self):
        def eval_step(state, batch):
            x, y = self.batch_to_input(batch)
            logits = self.model_def.apply(state.params, x, train=False)
            acc = jp.mean(jp.argmax(logits, axis=-1) == y)
            return acc
        return eval_step
    
    def make_jit_fn(self):
        train_step = jax.jit(self.make_train_step())
        eval_step = jax.jit(self.make_eval_step())
        return train_step, eval_step
    
    def train_model(self):
        num_epochs = self.config.num_epochs
        mean_step_time = 0
        rng = jax.random.PRNGKey(0)
        for epoch in range(num_epochs):
            for i, batch in enumerate(self.data_module.train_dataloader()):
                st = time.time()
                self.state, info = self.train_step(self.state, batch, rng)
                rng = info["rng"]
                eta = (time.time() - st)
                mean_step_time += (eta - mean_step_time) / (i + 1)
                if i % 10 == 0:
                    print(f"Epoch {epoch}, Step {i}, "
                          f"Loss: {info['loss'].item():.4f}, mean_eta: {mean_step_time:.4f}")
            
            mean_acc = 0
            for j, batch in enumerate(self.data_module.test_dataloader()):
                acc = self.eval_step(self.state, batch)
                mean_acc += (acc - mean_acc) / (j + 1)
            
            print(f"Epoch {epoch}, Mean Accuracy: {mean_acc:.4f}")
            print("=" * 80)
    
    def loss_fn(self, params, model, batch):
        x, y = batch
        rngs = self.get_model_rngs(jax.random.PRNGKey(0), train=True)
        logits = model.apply(params, x, rngs=rngs, train=True)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)
        return loss.mean()

In [51]:
out[1].shape

(32,)

In [52]:
trainer = TrainerDT.create(
    config=Config(),
    model_def=Network(),
    data_module=DataModule(train_loader=train_loader, test_loader=test_loader),
    sample_input=next(iter(train_loader))
)

In [53]:
trainer.train_model()

Epoch 0, Step 0, Loss: 2.2984, mean_eta: 3.9333
Epoch 0, Step 10, Loss: 268508352.0000, mean_eta: 0.3593
Epoch 0, Step 20, Loss: 21468979200.0000, mean_eta: 0.1887
Epoch 0, Step 30, Loss: 284918054912.0000, mean_eta: 0.1281
Epoch 0, Step 40, Loss: 1681682595840.0000, mean_eta: 0.0970
Epoch 0, Step 50, Loss: 6895111241728.0000, mean_eta: 0.0782
Epoch 0, Step 60, Loss: 21667417948160.0000, mean_eta: 0.0655
Epoch 0, Step 70, Loss: 62148621369344.0000, mean_eta: 0.0564
Epoch 0, Step 80, Loss: 121174012985344.0000, mean_eta: 0.0496
Epoch 0, Step 90, Loss: 308723322978304.0000, mean_eta: 0.0442
Epoch 0, Step 100, Loss: 657716796194816.0000, mean_eta: 0.0399
Epoch 0, Step 110, Loss: 1135320380735488.0000, mean_eta: 0.0364
Epoch 0, Step 120, Loss: 2044224312705024.0000, mean_eta: 0.0334
Epoch 0, Step 130, Loss: 3337135365029888.0000, mean_eta: 0.0309
Epoch 0, Step 140, Loss: 5178363148763136.0000, mean_eta: 0.0288
Epoch 0, Step 150, Loss: 9362075222540288.0000, mean_eta: 0.0269
Epoch 0, Step 1

ValueError: too many values to unpack (expected 2)