# Ginkgo clustering environment

## Setup

In [1]:
%matplotlib inline

import copy
import sys
import os
import numpy as np
from matplotlib import pyplot as plt
import gym
import logging
from tqdm import trange

from stable_baselines.common.policies import MlpPolicy, MlpLnLstmPolicy
from stable_baselines.deepq import MlpPolicy as DQNMlpPolicy
from stable_baselines import PPO2, ACER, DQN
from stable_baselines.bench import Monitor
from stable_baselines import results_plotter
from stable_baselines.common.env_checker import check_env
from stable_baselines.common.callbacks import BaseCallback

sys.path.append("../")
from ginkgo_rl import GinkgoLikelihoodEnv, GinkgoLikelihood1DEnv, GinkgoLikelihoodShuffledEnv, GinkgoEvaluator
from ginkgo_rl import BatchedACERAgent, RandomMCTSAgent, MCTSAgent


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
# Logging setup
logging.basicConfig(
    format='%(asctime)-5.5s %(name)-20.20s %(levelname)-7.7s %(message)s',
    datefmt='%H:%M',
    level=logging.DEBUG
)

for key in logging.Logger.manager.loggerDict:
    if "ginkgo_rl" not in key:
        logging.getLogger(key).setLevel(logging.ERROR)
        
def set_output(on=True):
    for key in logging.Logger.manager.loggerDict:
        if "ginkgo_rl" in key:
            logging.getLogger(key).setLevel(logging.DEBUG if on else logging.ERROR)


## Let's play a round of clustering manually

In [None]:
set_output(True)
env = GinkgoLikelihoodEnv(n_max=6, illegal_reward=0., min_reward=None)
state = env.reset()
env.render()

In [11]:
for i, (children, p, log_likelihood_split, delta) in enumerate(zip(
    env.jet['tree'], env.jet['content'], env.jet['logLH'], env.jet['deltas']
)):
    if children[0] < 0 and children[1] < 0:
        print(f"{i}: {p}. Leaf.")
    else:
        print(f"{i}: {p}. {i} -> ({children[0]}, {children[1]}) with log p = {log_likelihood_split}, delta = {delta}")

0: [407.92156109 230.94010768 230.94010768 230.94010768]. 0 -> (1, 2) with log p = -14.680485725402832, delta = 6400.0
1: [309.75645614 166.06642975 164.32836866 203.35401195]. Leaf.
2: [98.16510495 64.87367792 66.61173901 27.58609573]. 2 -> (3, 8) with log p = -8.415014266967773, delta = 229.67758178710938
3: [67.73644979 46.7227897  43.13286949 21.7859088 ]. 3 -> (4, 7) with log p = -6.781652450561523, delta = 70.13741302490234
4: [52.2916298  35.95941174 33.94140284 15.70927786]. 4 -> (5, 6) with log p = -3.7083871364593506, delta = 42.535091400146484
5: [33.95171902 23.17600461 22.85206443  9.2338885 ]. Leaf.
6: [18.33990696 12.7834045  11.08933593  6.47538821]. Leaf.
7: [15.44481806 10.76337663  9.19146543  6.07663032]. Leaf.
8: [30.42865825 18.15089026 23.47887162  5.8001878 ]. Leaf.


In [15]:
# Merge two particles
action = 0, 1

state, reward, done, info = env.step(action)
env.render()

print(f"Reward: {reward}")
print(f"Done: {done}")
print(f"Info: {info}")

# Repeat this cell as often as you feel like


14:23 ginkgo_rl.envs.ginkg DEBUG   Environment step. Action: (0, 1)
14:23 ginkgo_rl.envs.ginkg DEBUG   Computing log likelihood of action (0, 1): ti = 0.0, tj = 229.67728176139576, t_cut = 16.0, lam = 3.0 -> log likelihood = -14.680485725402832
14:23 ginkgo_rl.envs.ginkg DEBUG   Merging particles 0 and 1. New state has 1 particles.
14:23 ginkgo_rl.envs.ginkg DEBUG   Episode is done.
14:23 ginkgo_rl.envs.ginkg DEBUG   Sampling new jet with 4 leaves
14:23 ginkgo_rl.envs.ginkg INFO    4 particles:
14:23 ginkgo_rl.envs.ginkg INFO      p[ 0] = (  0.8,   0.4,   0.5,   0.5)
14:23 ginkgo_rl.envs.ginkg INFO      p[ 1] = (  0.6,   0.3,   0.3,   0.4)
14:23 ginkgo_rl.envs.ginkg INFO      p[ 2] = (  1.7,   0.9,   1.0,   1.1)
14:23 ginkgo_rl.envs.ginkg INFO      p[ 3] = (  1.0,   0.8,   0.5,   0.3)


Reward: -14.680485725402832
Done: True
Info: {'legal': True, 'illegal_action_counter': 0, 'replace_illegal_action': False, 'i': 0, 'j': 1}


In [None]:
env.close()

## Env checker

In [3]:
set_output(False)
env = gym.make("GinkgoLikelihoodShuffled-v0")

check_env(env)

for i in range(env.n_max):
    assert i == env.inverse_permutation[env.permutation[i]] == env.permutation[env.inverse_permutation[i]]
    
env.close()



## Evaluation routine and baselines

In [3]:
set_output(False)
evaluator = GinkgoEvaluator(n_jets = 4)


In [4]:
# True log likelihoods
set_output(False)
evaluator.eval_true("Truth")


In [5]:
# MLE (Trellis)
set_output(False)
evaluator.eval_exact_trellis("ML (Trellis)")


In [None]:
# Random clusterings
set_output(False)
evaluator.eval_random("Random", env_name="GinkgoLikelihood1D-v0")


  P = np.sqrt(tp)/2 * np.sqrt( 1 - 2 * (t_child+t_sib)/tp + (t_child - t_sib)**2 / tp**2 )
 20%|██        | 82/400 [00:28<02:38,  2.01it/s]

## Training functions

In [None]:
log_dirs = []
models = {}
env_names = {}

def train_baseline(algorithm, algo_class, policy_class, env_name="GinkgoLikelihood1D-v0", n_steps=10000):
    log_dir = f"./logs/{algorithm}/"
    log_dirs.append(log_dir)

    set_output(False)
    env = gym.make(env_name)
    os.makedirs(log_dir, exist_ok=True)
    env = Monitor(env, log_dir)

    model = algo_class(policy_class, env, verbose=0)
    model.learn(total_timesteps=n_steps)
    
    models[algorithm] = model
    env_names[algorithm] = env_name
    env.close()

def train_own(algorithm, algo_class, env_name="GinkgoLikelihood1D-v0", n_steps=10000):
    log_dir = f"./logs/{algorithm}/"
    log_dirs.append(log_dir)

    set_output(False)
    env = gym.make(env_name)
    # os.makedirs(log_dir, exist_ok=True)
    # env = Monitor(env, log_dir)
    _ = env.reset()

    model = algo_class(env)
    model.learn(total_timesteps=n_steps)
    
    models[algorithm] = model
    env_names[algorithm] = env_name
    env.close()


## Train agents

In [None]:
train_own("mcts_untrained", MCTSAgent, n_steps=1)

In [None]:
# train_own("mcts_1k", MCTSAgent, n_steps=1000)

In [None]:
# train_own("acer", BatchedACERAgent)

In [None]:
# train_baseline("ppo", PPO2, MlpPolicy, env_name="GinkgoLikelihood-v0")

In [None]:
# train_baseline("dqn", DQN, DQNMlpPolicy)

## Plot training progress

In [None]:
# results_plotter.plot_results(log_dirs[:], 1e4, results_plotter.X_TIMESTEPS, "Ginkgo")


## Evaluate

In [None]:
evaluator.eval("MCTS (before training)", models["mcts_untrained"], "GinkgoLikelihood1D-v0", n_repeats=25)

In [None]:
# evaluator.eval("MCTS (after training)", models["mcts"], "GinkgoLikelihood1D-v0")

In [None]:
# evaluator.eval("Random MCTS", models["random_mcts"], "GinkgoLikelihood1D-v0")

In [None]:
# evaluator.eval("ACER", models["acer"], "GinkgoLikelihood1D-v0")

In [None]:
# evaluator.eval("PPO", models["ppo"], "GinkgoLikelihood-v0")

In [None]:
# evaluator.eval("DQN", models["dqn"], "GinkgoLikelihood1D-v0")

In [None]:
_ = evaluator.plot_log_likelihoods()