In [1]:
import jax
from jax import numpy as jnp
from jax import grad, pmap, value_and_grad
import flax
from flax import linen as nn
from typing import Sequence, Any

# 创建 PRNGKey (PRNG State)
key = jax.random.PRNGKey(0)

# 创建模型

In [2]:
class MLP(nn.Module):
    
    def setup(self):
        self.layer1 = nn.Dense(features=512)
        self.dropout1 = nn.Dropout(rate=0.3)
        self.norm1 = nn.BatchNorm()
        
        self.layer2 = nn.Dense(features=512)
        self.dropout2 = nn.Dropout(rate=0.4)
        self.norm2 = nn.BatchNorm()
        
        self.layer3 = nn.Dense(features=10)
        
    
    def __call__(self, x, train:bool = True):
        x = nn.relu(self.layer1(x))
        x = self.dropout1(x, deterministic=not train)
        x = self.norm1(x, use_running_average=not train)
        x = nn.relu(self.layer2(x))
        x = self.dropout2(x, deterministic=not train)
        x = self.norm2(x, use_running_average=not train)
        
        x = self.layer3(x)

        return x

In [3]:
# 创建模型
model = MLP()

# 使用`init`和dummy_x来创建模型参数
key, init_key = jax.random.split(key)
dummy_x = jax.random.uniform(init_key, (784, ))

key, init_key, drop_key = jax.random.split(key, 3)

variables = model.init({"params": init_key, "dropout": drop_key}, dummy_x, train=True)

In [4]:
y, non_trainable_params = model.apply(variables, dummy_x, train=True, rngs={"dropout": drop_key},
                                      mutable=['batch_stats']) 

In [5]:
y

DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [6]:
non_trainable_params.keys()

frozen_dict_keys(['batch_stats'])

# 数据读取

In [7]:
import numpy as np
from torch.utils.data import DataLoader
from torch.utils.data import Sampler, SequentialSampler
from torchvision.datasets import MNIST


class FlattenAndCast(object):  
    def __call__(self, pic):
        return np.ravel(np.array(pic, dtype=jnp.float32))


# DataLoader返回numpy array，而不是torch Tensor
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

class JAXRandomSampler(Sampler):
    def __init__(self, data_source, rng_key):
        self.data_source = data_source
        self.rng_key = rng_key
        
    def __len__(self):
        return len(self.data_source)
    
    def __iter__(self):
        self.rng_key, current_rng = jax.random.split(self.rng_key)
        return iter(jax.random.permutation(current_rng, jnp.arange(len(self))).tolist())

In [8]:
class NumpyLoader(DataLoader):
    def __init__(self, dataset, rng_key=None, batch_size=1,
                 shuffle=False, **kwargs):
        if shuffle:
            sampler = JAXRandomSampler(dataset, rng_key)
        else:
            sampler = SequentialSampler(dataset)
        
        super().__init__(dataset, batch_size, sampler=sampler, **kwargs)

In [9]:
# 借助于torchvision和NumpyLoader
mnist_dataset_train = MNIST('/tmp/mnist/', download=True, transform=FlattenAndCast())
key, loader_key = jax.random.split(key)
train_loader = NumpyLoader(mnist_dataset_train, loader_key, batch_size=128 * jax.local_device_count(), shuffle=True,
                           num_workers=0, collate_fn=numpy_collate, drop_last=True)

mnist_dataset_test = MNIST('/tmp/mnist/', download=True, train=False, transform=FlattenAndCast())
eval_loader = NumpyLoader(mnist_dataset_test, batch_size=128 * jax.local_device_count(), shuffle=False, num_workers=0,
                          collate_fn=numpy_collate, drop_last=False)

# 优化器和学习率调度算法

In [10]:
import optax

lr = 1e-3
lr_decay_fn = optax.linear_schedule(
        init_value=lr,
        end_value=1e-5,
        transition_steps=200,
)

optimizer = optax.adam(
            learning_rate=lr_decay_fn,
)

# TrainState

将训练过程中的状态封装为一个类，统一管理

In [11]:
from flax.training import train_state

class CustomTrainState(train_state.TrainState):
    batch_stats: flax.core.FrozenDict[str, Any]

state = CustomTrainState.create(
    apply_fn=model.apply,
    params=variables['params'],
    tx=optimizer,
    batch_stats=variables['batch_stats'],
)

# 训练流程

In [12]:
def train_step(state, x, y, dropout_key):
    """Computes gradients and loss for a single batch."""
    def loss_fn(params):
        logits, new_state = state.apply_fn({"params": params, "batch_stats": state.batch_stats},
                                           x, train=True, rngs={"dropout": dropout_key}, mutable=["batch_stats"])
        
        one_hot = jax.nn.one_hot(y, 10)
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, new_state

    grad_fn = value_and_grad(loss_fn, has_aux=True)  # `value_and_grad`在进行grad同时返回loss
    (loss, new_state), grads = grad_fn(state.params)
    grads = jax.lax.pmean(grads, "batch")  # pmean计算所有device上的梯度均值
    loss = jax.lax.pmean(loss, "batch")
    batch_stats = jax.lax.pmean(new_state["batch_stats"], "batch")
    new_state = state.apply_gradients(grads=grads, batch_stats=batch_stats)
    
    return new_state, loss

p_train_step = pmap(train_step, "batch", donate_argnums=(0,))  # donate_argnums用于buffer复用，这里指的是输入和输出的state buffer复用


def apply_model(state, x):
    """Computes gradients and loss for a single batch."""
    
    logits = state.apply_fn({"params":state.params, "batch_stats": state.batch_stats},
                            x, train=False)
    return jnp.argmax(logits, -1)


In [13]:
def eval_model(state, loader):
    total_acc = 0.
    total_num = 0.
    for xs, ys in loader:
        xs = jax.tree_map(
            lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
        ys = jax.tree_map(
            lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), ys)
        y_pred = pmap(apply_model)(state, xs)  # 验证时没有跨设备通信操作，不需要设置axis_name
        total_num += ys.size
        total_acc += jnp.sum(y_pred == ys)
    return total_acc / total_num

In [14]:
devices = jax.local_devices()
state = jax.device_put_replicated(state, devices)  # 或者 state = flax.jax_utils.replicate(state)

In [15]:
for epoch in range(5):
    for idx, (xs, ys) in enumerate(train_loader):
        xs = jax.tree_map(
            lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), xs)
        ys = jax.tree_map(
            lambda x: x.reshape((jax.local_device_count(), -1) + x.shape[1:]), ys)
        
        key, dropout_key = jax.random.split(key)
        dropout_key = jax.random.split(drop_key, jax.local_device_count())
        state, loss = p_train_step(state, xs, ys, dropout_key)
        
        if idx % 100 == 0:  # evaluation
            train_acc = eval_model(state, train_loader)
            eval_acc = eval_model(state, eval_loader)
            print("Epoch {} - batch_idx {}, loss {}, Training set acc {}, eval set accuracy {}".format(
              epoch, idx, jax.tree_map(lambda x: x[0], loss), train_acc, eval_acc))

Epoch 0 - batch_idx 0, loss 2.666304588317871, Training set acc 0.4732118844985962, eval set accuracy 0.4770999848842621
Epoch 1 - batch_idx 0, loss 0.20269665122032166, Training set acc 0.9528387784957886, eval set accuracy 0.9491999745368958
Epoch 2 - batch_idx 0, loss 0.1484561264514923, Training set acc 0.9685648083686829, eval set accuracy 0.964199960231781
Epoch 3 - batch_idx 0, loss 0.12777751684188843, Training set acc 0.9755185842514038, eval set accuracy 0.9678999781608582
Epoch 4 - batch_idx 0, loss 0.10267762094736099, Training set acc 0.9769666194915771, eval set accuracy 0.9679999947547913


## 使用单设备验证

如果在模型验证时不想使用多设备，取其中一份state，jax.tree_util.tree_map(lambda x: x[0], state

In [16]:
def eval_model_single(state, loader):
    total_acc = 0.
    total_num = 0.
    for xs, ys in loader:
        y_pred = jax.jit(apply_model)(state, xs)
        total_num += ys.size
        total_acc += jnp.sum(y_pred == ys)
    return total_acc / total_num

In [17]:
eval_model_single(jax.tree_util.tree_map(lambda x: x[0], state), eval_loader)

DeviceArray(0.96809995, dtype=float32, weak_type=True)