In [18]:
import chex
import shinrl
import gym
import jax.numpy as jnp
import distrax
import jax
from copy import deepcopy

# Create custom Solver

This tutorial demonstrates how to create a custom solver.
We are going to implement a very simple solver that tries to maximize the one-step reward.

You need to implement two classes, 
1. a config class inheriting `shinrl.SolverConfig` and 
2. a solver class inheriting `shinrl.BaseSolver`.

## 1. Config

The config class is a dataclass inheriting `shinrl.SolverConfig`.
It holds hyperparameters of a solver.

In [31]:
@chex.dataclass
class ExampleConfig(shinrl.SolverConfig):
    seed: int = 0
    discount: float = 0.99
    eval_trials: int = 10
    eval_interval: int = 100
    add_interval: int = 100
    steps_per_epoch: int = 1000

## 2. Solver

The main solver class must inherit `shinrl.BaseSolver`.
You need to implement three functions (See details in [shinrl/solvers/base/solver.py](../shinrl/solvers/base/solver.py)):

* **make_mixins** (staticmethod): Make a list of mixins from env and config. A solver is instantiated by mixing generated mixins.
* **evaluate** (function): Evaluate the solver and return the dict of results. Called every self.config.eval_interval steps.
* **step** (function): Execute the solver by one step and return the dict of results.

The following code implements `evaluate` and `step` functions through mixins:

In [36]:

class ExampleStepMixIn:
    def initialize(self, env, config=None) -> None:
        super().initialize(env, config)
        
        dA = self.env.action_space.n
        policy = jnp.ones(dA)
        policy = policy / policy.sum()
        
        # Any jittable object (e.g., network parameters, Q-table, etc.)
        # should be stored in this `solver.data` dictionary
        self.data["Policy"] = policy

    def step(self):
        policy = self.data["Policy"]
        dist = distrax.Greedy(policy)
        act = dist.sample(seed=self.key).item()
        self.env.obs, rew, done, _ = self.env.step(act)
        
        # Update policy
        policy = policy.at[act].add(rew)
        self.data["Policy"] = policy / policy.sum()
        
        # Return any scalar data you want to record
        return {"Rew": rew}

    
class ExampleEvalMixIn:
    def initialize(self, env, config=None) -> None:
        super().initialize(env, config)
        self._eval_env = deepcopy(self.env)

    def evaluate(self):
        self._eval_env.reset()
        
        policy = self.data["Policy"]
        dist = distrax.Greedy(policy)
        ret = 0
        done = False
        while not done:
            act = dist.sample(seed=self.key).item()       
            self._eval_env.obs, rew, done, _ = self._eval_env.step(act)
            ret += rew
            
        # Return any scalar data you want to record
        return {"Return": ret}


class ExampleSolver(shinrl.BaseSolver):
    DefaultConfig = ExampleConfig
    @staticmethod
    def make_mixins(env, config):
        return [ExampleStepMixIn, ExampleEvalMixIn, ExampleSolver]

In [37]:
env = gym.make("CartPole-v0")
config = ExampleSolver.DefaultConfig(add_interval=5, steps_per_epoch=20, eval_interval=10)
mixins = ExampleSolver.make_mixins(env, config)
solver = ExampleSolver.factory(env, config, mixins)
solver.run()

[2m2021-12-27T04:12:19.974884Z[0m [1mset_config is called.         [0m [36mconfig[0m=[35m{'seed': 0, 'discount': 0.99, 'eval_trials': 10, 'eval_interval': 10, 'add_interval': 5, 'steps_per_epoch': 20}[0m [36menv_id[0m=[35mNone[0m [36msolver_id[0m=[35mMixedSolver-11[0m
[2m2021-12-27T04:12:19.978362Z[0m [1mset_env is called.            [0m [36menv_id[0m=[35m0[0m [36msolver_id[0m=[35mMixedSolver-11[0m
[2m2021-12-27T04:12:19.983226Z[0m [1mSolver is initialized.        [0m [36menv_id[0m=[35m0[0m [36mmethods[0m=[35m['BaseSolver.__init__', 'History.add_scalar', 'ExampleEvalMixIn.evaluate', 'History.init_history', 'ExampleStepMixIn.initialize', 'History.load', 'History.recent_summary', 'BaseSolver.run', 'History.save', 'BaseSolver.seed', 'History.set_config', 'BaseSolver.set_env', 'ExampleStepMixIn.step'][0m [36mmixins[0m=[35m[<class '__main__.ExampleStepMixIn'>, <class '__main__.ExampleEvalMixIn'>, <class '__main__.ExampleSolver'>][0m [36msolver_id

In [38]:
solver.data["Policy"]

DeviceArray([4.8828125e-04, 9.9951172e-01], dtype=float32)

### solver.scalars

The results from `step` and `evaluate` functions are stored in solver.scalars:

In [39]:
solver.scalars

{'Return': {'x': [0, 10], 'y': [10.0, 9.0]},
 'Rew': {'x': [0, 5, 10, 15], 'y': [1.0, 1.0, 0.0, 0.0]}}

# Useful MixIns

For ease of implementation, we provide the following base mixins (See details in [shinrl/solvers/base/base_mixin.py](../shinrl/solvers/base/base_mixin.py):

* **BaseGymEvalMixIn**: Base mixin for gym.Env evaluation. `explore` function is implemented. Need to implement `eval_act` function.
* **BaseGymExploreMixIn**: Base mixin for gym.Env exploration. `evaluate` function is implemented. Need to implement `explore_act` function.
* **BaseShinEvalMixIn**: Base mixin for ShinEnv evaluation. `explore` function is implemented. solver.data need to have `EvaluatePolicy` table.
* **BaseShinExploreMixIn**: Base mixin for ShinEnv exploration. `evaluate` function is implemented. solver.data need to have `ExplorePolicy` table.

### Gym Solver Example

`BaseGymEvalMixIn` and `BaseGymExploreMixIn` conduct **sampling-based** evaluation and exploration.

You need to implement three functions: 
* `step` 
* `eval_act`
* `explore_act`

Here we implement the step function in `GymStepMixIn` and the act functions in `GymActMixIn`:

In [7]:
@chex.dataclass
class GymConfig(shinrl.SolverConfig):
    seed: int = 0
    discount: float = 0.99
    eval_trials: int = 10
    eval_interval: int = 100
    add_interval: int = 100
    steps_per_epoch: int = 1000
    num_samples: int = 10
        

class GymStepMixIn:
    def step(self):
        samples = self.explore()
        dummy_loss = (samples.rew).mean()
        return {"DummyLoss": dummy_loss.item()}

    
class GymActMixIn:
    def eval_act(self, key, obs):
        new_key = jax.random.split(self.key)
        act = self._eval_env.action_space.sample()
        log_prob = 0.0
        return new_key, act, log_prob

    def explore_act(self, key, obs):
        new_key = jax.random.split(self.key)
        act = self.env.action_space.sample()
        log_prob = 0.0
        return new_key, act, log_prob


class GymSolver(shinrl.BaseSolver):
    DefaultConfig = GymConfig
    @staticmethod
    def make_mixins(env, config):
        return [GymStepMixIn, GymActMixIn, shinrl.BaseGymExploreMixIn, shinrl.BaseGymEvalMixIn, GymSolver]

In [8]:
env = gym.make("CartPole-v0")
config = GymSolver.DefaultConfig(add_interval=5, steps_per_epoch=20, eval_interval=10)
mixins = GymSolver.make_mixins(env, config)
solver = GymSolver.factory(env, config, mixins)
solver.run()

[2m2021-12-27T03:56:39.003904Z[0m [1mset_config is called.         [0m [36mconfig[0m=[35m{'seed': 0, 'discount': 0.99, 'eval_trials': 10, 'eval_interval': 10, 'add_interval': 5, 'steps_per_epoch': 20, 'num_samples': 10}[0m [36menv_id[0m=[35mNone[0m [36msolver_id[0m=[35mMixedSolver-1[0m
[2m2021-12-27T03:56:39.008038Z[0m [1mset_env is called.            [0m [36menv_id[0m=[35m0[0m [36msolver_id[0m=[35mMixedSolver-1[0m
[2m2021-12-27T03:56:39.015721Z[0m [1mSolver is initialized.        [0m [36menv_id[0m=[35m0[0m [36mmethods[0m=[35m['BaseSolver.__init__', 'History.add_scalar', 'GymActMixIn.eval_act', 'BaseGymEvalMixIn.evaluate', 'BaseGymExploreMixIn.explore', 'GymActMixIn.explore_act', 'History.init_history', 'BaseGymEvalMixIn.initialize', 'History.load', 'History.recent_summary', 'BaseSolver.run', 'History.save', 'BaseSolver.seed', 'History.set_config', 'BaseSolver.set_env', 'GymStepMixIn.step'][0m [36mmixins[0m=[35m[<class '__main__.GymStepMixIn'>

In [9]:
solver.scalars

{'Return': {'x': [0, 10], 'y': [20.4, 26.1]},
 'DummyLoss': {'x': [0, 5, 10, 15], 'y': [1.0, 1.0, 1.0, 1.0]}}

### ShinEnv Solver Example

`BaseShinEvalMixIn` and `BaseShinExploreMixIn` conduct **oracle** evaluation and exploration.

You need to set two arrays to solver.data:

* `ExplorePolicy`: dS x dA probability array
* `EvaluatePolicy`: dS x dA probability array

Here we implement them in `BuildTableMixIn`.

In [10]:
@chex.dataclass
class ShinConfig(shinrl.SolverConfig):
    seed: int = 0
    discount: float = 0.99
    eval_trials: int = 10
    eval_interval: int = 100
    add_interval: int = 100
    steps_per_epoch: int = 1000
    num_samples: int = 10
        
        
class BuildTableMixIn:
    def initialize(self, env, config=None) -> None:
        # build tables
        super().initialize(env, config)
        self.data["Q"] = jnp.zeros((self.dS, self.dA))
        self.data["ExplorePolicy"] = jnp.ones((self.dS, self.dA)) / self.dA
        self.data["EvaluatePolicy"] = jnp.ones((self.dS, self.dA)) / self.dA


class ShinStepMixIn:
    def step(self):
        samples = self.explore()
        dummy_loss = (samples.rew).mean()
        return {"DummyLoss": dummy_loss.item()}

    
class ShinSolver(shinrl.BaseSolver):
    DefaultConfig = ShinConfig
    @staticmethod
    def make_mixins(env, config):
        return [ShinStepMixIn, BuildTableMixIn, shinrl.BaseShinExploreMixIn, shinrl.BaseShinEvalMixIn, ShinSolver]

In [11]:
env = gym.make("ShinMountainCar-v0")
config = ShinSolver.DefaultConfig(add_interval=5, steps_per_epoch=20, eval_interval=10)
mixins = ShinSolver.make_mixins(env, config)
solver = ShinSolver.factory(env, config, mixins)
solver.run()

[2m2021-12-27T03:56:41.292676Z[0m [1mset_config is called.         [0m [36mconfig[0m=[35m{'seed': 0, 'discount': 0.99, 'eval_trials': 10, 'eval_interval': 10, 'add_interval': 5, 'steps_per_epoch': 20, 'num_samples': 10}[0m [36menv_id[0m=[35mNone[0m [36msolver_id[0m=[35mMixedSolver-2[0m
[2m2021-12-27T03:56:41.293170Z[0m [1mset_env is called.            [0m [36menv_id[0m=[35m0[0m [36msolver_id[0m=[35mMixedSolver-2[0m
[2m2021-12-27T03:56:41.295118Z[0m [1mSolver is initialized.        [0m [36menv_id[0m=[35m0[0m [36mmethods[0m=[35m['BaseSolver.__init__', 'History.add_scalar', 'BaseShinEvalMixIn.evaluate', 'BaseShinExploreMixIn.explore', 'History.init_history', 'BuildTableMixIn.initialize', 'History.load', 'History.recent_summary', 'BaseSolver.run', 'History.save', 'BaseSolver.seed', 'History.set_config', 'BaseSolver.set_env', 'ShinStepMixIn.step'][0m [36mmixins[0m=[35m[<class '__main__.ShinStepMixIn'>, <class '__main__.BuildTableMixIn'>, <class 'sh

In [12]:
solver.scalars

{'Return': {'x': [0, 10], 'y': [-199.546630859375, -199.546630859375]},
 'DummyLoss': {'x': [0, 5, 10, 15], 'y': [-1.0, -1.0, -1.0, -1.0]}}

### Gym & ShinEnv Solver Example

A solver can support both gym.Env & ShinEnv by modifing the `make_mixin` function:

In [13]:
@chex.dataclass
class GymAndShinConfig(shinrl.SolverConfig):
    seed: int = 0
    discount: float = 0.99
    eval_trials: int = 10
    eval_interval: int = 100
    add_interval: int = 100
    steps_per_epoch: int = 1000
    num_samples: int = 10
        
        
class GymAndShinSolver(shinrl.BaseSolver):
    DefaultConfig = GymAndShinConfig
    @staticmethod
    def make_mixins(env, config):
        is_shin_env = isinstance(env, shinrl.ShinEnv)
        
        if is_shin_env:
            mixin_list = [ShinStepMixIn, BuildTableMixIn, shinrl.BaseShinExploreMixIn, shinrl.BaseShinEvalMixIn, GymAndShinSolver]
        else:
            mixin_list = [GymStepMixIn, GymActMixIn, shinrl.BaseGymExploreMixIn, shinrl.BaseGymEvalMixIn, GymAndShinSolver]
        
        return mixin_list

In [14]:
# GymEnv
env = gym.make("MountainCar-v0")
config = GymAndShinSolver.DefaultConfig(add_interval=5, steps_per_epoch=20, eval_interval=10)
mixins = GymAndShinSolver.make_mixins(env, config)
solver = GymAndShinSolver.factory(env, config, mixins)
solver.run()

[2m2021-12-27T03:56:41.880841Z[0m [1mset_config is called.         [0m [36mconfig[0m=[35m{'seed': 0, 'discount': 0.99, 'eval_trials': 10, 'eval_interval': 10, 'add_interval': 5, 'steps_per_epoch': 20, 'num_samples': 10}[0m [36menv_id[0m=[35mNone[0m [36msolver_id[0m=[35mMixedSolver-3[0m
[2m2021-12-27T03:56:41.881617Z[0m [1mset_env is called.            [0m [36menv_id[0m=[35m0[0m [36msolver_id[0m=[35mMixedSolver-3[0m
[2m2021-12-27T03:56:41.882815Z[0m [1mSolver is initialized.        [0m [36menv_id[0m=[35m0[0m [36mmethods[0m=[35m['BaseSolver.__init__', 'History.add_scalar', 'GymActMixIn.eval_act', 'BaseGymEvalMixIn.evaluate', 'BaseGymExploreMixIn.explore', 'GymActMixIn.explore_act', 'History.init_history', 'BaseGymEvalMixIn.initialize', 'History.load', 'History.recent_summary', 'BaseSolver.run', 'History.save', 'BaseSolver.seed', 'History.set_config', 'BaseSolver.set_env', 'GymStepMixIn.step'][0m [36mmixins[0m=[35m[<class '__main__.GymStepMixIn'>

In [15]:
# ShinEnv
env = gym.make("ShinMountainCar-v0")
config = GymAndShinSolver.DefaultConfig(add_interval=5, steps_per_epoch=20, eval_interval=10)
mixins = GymAndShinSolver.make_mixins(env, config)
solver = GymAndShinSolver.factory(env, config, mixins)
solver.run()

[2m2021-12-27T03:56:42.579567Z[0m [1mset_config is called.         [0m [36mconfig[0m=[35m{'seed': 0, 'discount': 0.99, 'eval_trials': 10, 'eval_interval': 10, 'add_interval': 5, 'steps_per_epoch': 20, 'num_samples': 10}[0m [36menv_id[0m=[35mNone[0m [36msolver_id[0m=[35mMixedSolver-4[0m
[2m2021-12-27T03:56:42.580077Z[0m [1mset_env is called.            [0m [36menv_id[0m=[35m0[0m [36msolver_id[0m=[35mMixedSolver-4[0m
[2m2021-12-27T03:56:42.582145Z[0m [1mSolver is initialized.        [0m [36menv_id[0m=[35m0[0m [36mmethods[0m=[35m['BaseSolver.__init__', 'History.add_scalar', 'BaseShinEvalMixIn.evaluate', 'BaseShinExploreMixIn.explore', 'History.init_history', 'BuildTableMixIn.initialize', 'History.load', 'History.recent_summary', 'BaseSolver.run', 'History.save', 'BaseSolver.seed', 'History.set_config', 'BaseSolver.set_env', 'ShinStepMixIn.step'][0m [36mmixins[0m=[35m[<class '__main__.ShinStepMixIn'>, <class '__main__.BuildTableMixIn'>, <class 'sh