In [203]:
import jax
import jax.numpy as jnp
import evojax
from evojax.task.base import TaskState, VectorizedTask
from evojax.policy.base import PolicyState, PolicyNetwork

from flax.struct import dataclass
import numpy as np

In [None]:
SEED = np.random.randint(0, 100000)
NUM = 8
RANDOM_KEY = jax.random.PRNGKey(SEED)

def GET_RANDOM_KEY(): return jax.random.PRNGKey(np.random.randint(0, 100000))
SEED, RANDOM_KEY

In [289]:
solver = evojax.algo.PGPE(8, 1)
solver.ask().shape

(8, 1)

In [290]:
class Problem:
    def __init__(self):
        self.target = jnp.array(range(NUM))

@dataclass
class MyTaskInitState(TaskState):
    obs : jnp.ndarray = jnp.array([0 for i in range(NUM)])

        
@dataclass
class MyPolicyState:
    keys: jnp.ndarray

    
class MyTask(VectorizedTask):
    def __init__(
        self,
        problems: tuple[Problem, ...] = (Problem(),), 
        is_test: bool = True):
        self.problems = problems
        
        self.max_steps = 1
        self.obs_shape = (NUM, )
        self.act_shape = (1, )
        
        self.is_test = is_test
    
    def __step__(self, state: TaskState, result: jnp.ndarray):
        if self.is_test:
            reward = NUM - jnp.sum(result)
        else:
            reward = self.calc_reward(result)

        return state, reward, jnp.ones((), dtype=jnp.int32)
    
    def calc_reward(self, result: jnp.ndarray):
        return - jnp.sum((result - 1) ** 2)
    
    def reset(self, key: jnp.ndarray) -> TaskState:
        return MyTaskInitState()
    
    def step(self, state: TaskState, result: jnp.ndarray) -> tuple[TaskState, jnp.ndarray, jnp.ndarray]:
        return self.__step__(state, result)

class MyPolicy(PolicyNetwork):
    def get_actions(
        self, 
        t_states: TaskState, 
        params: jnp.ndarray,
        p_states: PolicyState,
    ) -> tuple[jnp.ndarray, PolicyState]:
        return jax.random.randint(GET_RANDOM_KEY(), (NUM, ), 0, 2), p_states
   

In [291]:
init_state = MyTaskInitState()
train_task = MyTask()
test_task = MyTask(is_test=True)

In [322]:
trainer = evojax.Trainer(
    policy=MyPolicy(),
    solver=solver,
    train_task=train_task,
    test_task=test_task,
    n_evaluations=NUM,
    log_dir="./logs"
)

In [323]:
trainer.run(demo_mode=False)

DeviceArray(8., dtype=float32)

In [325]:
trainer.run(demo_mode=False)

DeviceArray(8., dtype=float32)

In [329]:
trainer.solver.ask()

DeviceArray([[-0.00436163],
             [-0.00541603],
             [-0.00548816],
             [-0.00428949],
             [-0.00460761],
             [-0.00517004],
             [-0.00517295],
             [-0.00460471]], dtype=float32)