# Ginkgo clustering environment

## Setup

In [1]:
%matplotlib inline

import sys
import os
import numpy as np
from matplotlib import pyplot as plt
import gym
import logging
from stable_baselines.common.policies import MlpPolicy, MlpLnLstmPolicy
from stable_baselines.deepq import MlpPolicy as DQNMlpPolicy
from stable_baselines import PPO2, ACER, DQN
from stable_baselines.bench import Monitor
from stable_baselines import results_plotter
from stable_baselines.common.env_checker import check_env
from stable_baselines.common.callbacks import BaseCallback

sys.path.append("../")
from ginkgo_rl import GinkgoLikelihoodEnv, GinkgoLikelihood1DEnv, GinkgoLikelihoodShuffledEnv


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
# Logging setup
logging.basicConfig(
    format='%(asctime)-5.5s %(name)-20.20s %(levelname)-7.7s %(message)s',
    datefmt='%H:%M',
    level=logging.DEBUG
)

for key in logging.Logger.manager.loggerDict:
    if "ginkgo_rl" not in key:
        logging.getLogger(key).setLevel(logging.ERROR)
        
def set_output(on=True):
    for key in logging.Logger.manager.loggerDict:
        if "ginkgo_rl" in key:
            logging.getLogger(key).setLevel(logging.DEBUG if on else logging.ERROR)


## Let's play a round of clustering manually

In [3]:
set_output(True)
env = gym.make("GinkgoLikelihoodShuffled-v0")
state = env.reset()
env.render()

17:52 ginkgo_rl.envs.ginkg DEBUG   Initializing environment
17:52 ginkgo_rl.envs.ginkg DEBUG   Sampling new jet with 7 leaves
17:52 ginkgo_rl.envs.ginkg DEBUG   Resetting environment
17:52 ginkgo_rl.envs.ginkg DEBUG   Sampling new jet with 9 leaves
17:52 ginkgo_rl.envs.ginkg INFO    9 particles:
17:52 ginkgo_rl.envs.ginkg INFO      p[ 0] = (  0.2,   0.1,   0.2,   0.1)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 1] = (  1.5,   1.0,   0.6,   0.9)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 3] = (  0.2,   0.1,   0.1,   0.1)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 4] = (  0.3,   0.2,   0.2,   0.1)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 5] = (  0.3,   0.2,   0.2,   0.2)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 6] = (  0.3,   0.2,   0.1,   0.2)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 7] = (  0.5,   0.2,   0.3,   0.3)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 8] = (  0.3,   0.1,   0.2,   0.2)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 9] = (  0.5,   0.2,   0.3,   0.3)


In [4]:
# Merge two particles
action = 0, 5

state, reward, done, info = env.step(action)
env.render()

print(f"Reward: {reward}")
print(f"Done: {done}")
print(f"Info: {info}")

# Repeat this cell as often as you feel like


17:52 ginkgo_rl.envs.ginkg DEBUG   Environment step. Action: (3, 2)
17:52 ginkgo_rl.envs.ginkg DEBUG   Computing log likelihood of action (3, 2): -9.241791725158691
17:52 ginkgo_rl.envs.ginkg DEBUG   Merging particles 3 and 2. New state has 8 particles.
17:52 ginkgo_rl.envs.ginkg INFO    8 particles:
17:52 ginkgo_rl.envs.ginkg INFO      p[ 0] = (  0.3,   0.1,   0.2,   0.2)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 1] = (  0.5,   0.3,   0.3,   0.3)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 2] = (  0.5,   0.2,   0.3,   0.3)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 3] = (  0.3,   0.2,   0.1,   0.2)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 4] = (  1.5,   1.0,   0.6,   0.9)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 6] = (  0.3,   0.2,   0.2,   0.1)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 7] = (  0.5,   0.2,   0.3,   0.3)
17:52 ginkgo_rl.envs.ginkg INFO      p[ 9] = (  0.2,   0.1,   0.1,   0.1)


Reward: -9.241791725158691
Done: False
Info: {'legal': True, 'illegal_action_counter': 0, 'replace_illegal_action': False, 'i': 3, 'j': 2}


In [5]:
env.close()

## Env checker

In [6]:
set_output(False)
env = gym.make("GinkgoLikelihoodShuffled-v0")

check_env(env)

for i in range(env.n_max):
    assert i == env.inverse_permutation[env.permutation[i]] == env.permutation[env.inverse_permutation[i]]
    
env.close()



## Reward distribution for random actions

In [None]:
n_steps = 4000

set_output(False)
env = gym.make("GinkgoLikelihoodShuffled1D-v0")
state = env.reset()
rewards, legal = [], []

for _ in range(n_steps):
    action = env.action_space.sample()
    legal.append(env.check_legality(action))
    _, reward, _, _ = env.step(action)
    rewards.append(reward)


  P = np.sqrt(tp)/2 * np.sqrt( 1 - 2 * (t_child+t_sib)/tp + (t_child - t_sib)**2 / tp**2 )


In [None]:
range_ = (-20.1, -3)

rewards = np.asarray(rewards)
legal = np.asarray(legal)

fig = plt.figure(figsize=(5,5))
plt.hist(rewards, range=range_, bins=50, histtype="step")
plt.hist(rewards[legal], range=range_, bins=50, histtype="step")
plt.tight_layout()
plt.yscale("log")
plt.show()


## Let's let some RL agents loose! First, define the training and eval procedures

In [None]:
n_test = 5
set_output(False)
env = gym.make("GinkgoLikelihood-v0")

test_internal_states = []
test_log_likelihoods = []

for _ in range(n_test):
    env.reset()
    test_internal_states.append(env.get_internal_state())
    test_log_likelihoods.append(sum(env.jet["logLH"]))
    
env.close()


class GinkgoEvalCallback(BaseCallback):
    def __init__(self, eval_env, eval_freq=100, verbose=0):
        super(GinkgoEvalCallback, self).__init__(verbose)
        
        self.eval_env = eval_env
        self.eval_env.min_reward = -1000.0
        self.eval_freq = eval_freq
        
        self.steps = []
        self.log_likelihoods = []
        self.errors = []
        
    def _on_step(self):
        if self.eval_freq > 0 and self.n_calls % self.eval_freq == 0:
            log_likelihood = 0.
            errors = 0.

            for internal_state in test_internal_states:
                _ = self.eval_env.reset()
                self.eval_env.set_internal_state(internal_state)
                state = self.eval_env.get_state()
                done = False
                steps = 0
                
                while not done and steps < int(1.e5):
                    action, _ = self.model.predict(state)
                    state, reward, done, info = self.eval_env.step(action)

                    steps += 1
                    if info["legal"]:
                        log_likelihood += reward / n_test
                    else:
                        errors += 1. / n_test
                    
            self.steps.append(self.n_calls)
            self.log_likelihoods.append(log_likelihood)
            self.errors.append(errors)
            print(log_likelihood, errors)
        return True


In [None]:
log_dirs = []
models = {}
wrap1ds = {}
callbacks = {}

def train(algorithm, algo_class, policy_class, wrap1d=False, n_steps=100000, eval_freq=5000):
    log_dir = f"./logs/{algorithm}/"
    log_dirs.append(log_dir)

    set_output(False)
    env = gym.make("GinkgoLikelihoodShuffled1D-v0" if wrap1d else "GinkgoLikelihoodShuffled-v0")
    os.makedirs(log_dir, exist_ok=True)
    env = Monitor(env, log_dir)
    
    eval_env = gym.make("GinkgoLikelihoodShuffled1D-v0" if wrap1d else "GinkgoLikelihoodShuffled-v0")
    callback = GinkgoEvalCallback(eval_env, eval_freq)

    model = algo_class(policy_class, env, verbose=0)
    model.learn(total_timesteps=n_steps, callback=callback)
    
    models[algorithm] = model
    wrap1ds[algorithm] = wrap1d
    callbacks[algorithm] = callback
    
    env.close()
    eval_env.close()


In [None]:
def run(algorithm):
    model = models[algorithm]
    
    set_output(True)
    env = gym.make("GinkgoLikelihoodShuffled1D-v0" if wrap1ds[algorithm] else "GinkgoLikelihoodShuffled-v0")
    
    state = env.reset()
    done = False
    steps = 0

    while not done and steps < int(1.e5):
        action, _states = model.predict(state)
        state, reward, done, info = env.step(action)
        env.render()
        steps += 1
    
    env.close()


## PPO

In [None]:
model = train("ppo", PPO2, MlpPolicy)

In [None]:
callbacks["ppo"].log_likelihoods

In [None]:
callbacks["ppo"].errors

In [None]:
run("ppo")

## DQN

In [None]:
model = train("dqn", DQN, DQNMlpPolicy, wrap1d=True)

In [None]:
run("dqn")

## Results

In [None]:
results_plotter.plot_results(log_dirs, 1e5, results_plotter.X_TIMESTEPS, "Ginkgo")
