In [1]:
import jax
import jax.numpy as jnp
import jraph
import evojax
from evojax.task.base import TaskState, VectorizedTask
from evojax.policy.base import PolicyState, PolicyNetwork

from torchvision import datasets
from flax import linen as nn
from flax.struct import dataclass
import numpy as np

In [2]:
SEED = np.random.randint(0, 100000)
NUM = 8
RANDOM_KEY = jax.random.PRNGKey(SEED)

def GET_RANDOM_KEY(): return jax.random.PRNGKey(np.random.randint(0, 100000))
SEED, RANDOM_KEY

(40628, DeviceArray([    0, 40628], dtype=uint32))

In [3]:
class CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=8, kernel_size=(5, 5), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=8, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=10)(x)
        x = nn.log_softmax(x)
        return x

In [4]:
def loss(prediction: jnp.ndarray, target: jnp.ndarray) -> jnp.float32:
    target = jax.nn.one_hot(target, 10)
    return -jnp.mean(jnp.sum(prediction * target, axis=1))


def accuracy(prediction: jnp.ndarray, target: jnp.ndarray) -> jnp.float32:
    predicted_class = jnp.argmax(prediction, axis=2)
    print(prediction.shape, predicted_class.shape, target.shape)
    return jnp.mean(predicted_class == target)


@dataclass
class VisionState(TaskState): # obs : batch_data (must) + args
    obs : jnp.ndarray
    labels: jnp.ndarray

class VisionPolicy(PolicyNetwork):
    def __init__(self):
        model = CNN()
        params = model.init(RANDOM_KEY, jnp.zeros([1, 28, 28, 1]))
        self.num_params, fmt_fn = evojax.util.get_params_format_fn(params)
        self._fmt_fn = jax.vmap(fmt_fn)
        
        self._forward_fn = jax.vmap(model.apply)
        
    def get_actions(self, t_states: VisionState, params: jnp.ndarray, p_states: VisionState) -> tuple[jnp.ndarray, VisionState]:
        params = self._fmt_fn(params)
        return self._forward_fn(params, t_states.obs), p_states
   
    
class MNIST_Task(VectorizedTask):
    def __init__(
        self, is_test: bool = True, batch_size: int = 32
    ):

        self.max_steps = 1
        self.obs_shape = (28, 28, 1)
        self.act_shape = (10, )
        
        dataset = datasets.MNIST("./ignore_dir/data", train=not is_test, download=True)
        data = np.expand_dims(dataset.data.numpy() / 255, axis=-1)
        labels = dataset.targets.numpy()
        
        self.is_test = is_test
        
        def f(key):
            if self.is_test:
                bd, bl = data, labels
            else:
                ix = random.choice(key=key, a=data.shape[0], shape=(batch_size,), replace=False)
                bd, bl = (jnp.take(data, indices=ix, axis=0), jnp.take(labels, indices=ix, axis=0))
            return VisionState(obs=bd, labels=bl)
        
        self._reset_fn = jax.jit(jax.vmap(f))
    
    def _step(self, current_state: VisionState, result):
        if self.is_test:
            reward = accuracy(result, current_state.labels)
        else:
            reward = -loss(result, current_state.labels)
        return current_state, reward, jnp.ones((), dtype=jnp.int32)

    def reset(self, key: jnp.ndarray) -> TaskState:
        return self._reset_fn(key)
    
    def step(self, current_state: VisionState, result: jnp.ndarray) -> tuple[VisionState, jnp.ndarray, jnp.ndarray]:
        return self._step(current_state, result)

In [5]:
policy = VisionPolicy()
train_task = MNIST_Task()
test_task = MNIST_Task(is_test=True)

In [6]:
solver = evojax.algo.PGPE(
    pop_size=8,
    param_size=policy.num_params,
    optimizer="adam",
    seed=SEED
)
solver.ask().shape

(8, 4722)

In [7]:
trainer = evojax.Trainer(
    policy=policy,
    solver=solver,
    train_task=train_task,
    test_task=test_task,
    n_evaluations=1,
    log_dir="./ignore_dir/logs"
)

In [12]:
trainer.run()

DeviceArray(0.0984, dtype=float32)

In [13]:
trainer.solver.ask()

DeviceArray([[-1.8963573 ,  5.1905236 ,  1.7493045 , ..., -0.49119452,
              -6.135384  ,  1.2280372 ],
             [-1.8829391 ,  5.173755  ,  1.7294123 , ..., -0.82921517,
              -6.1479974 ,  1.4862982 ],
             [-1.809455  ,  5.1927958 ,  1.7385601 , ..., -0.7132671 ,
              -6.147729  ,  1.3206091 ],
             ...,
             [-1.9608529 ,  5.184592  ,  1.6793004 , ..., -0.31550005,
              -6.1160927 ,  1.1364937 ],
             [-1.9503118 ,  5.1861587 ,  1.8137469 , ..., -0.6894172 ,
              -6.1302357 ,  1.4721737 ],
             [-1.8289846 ,  5.17812   ,  1.6649699 , ..., -0.6309925 ,
              -6.153146  ,  1.2421618 ]], dtype=float32)