# Ginkgo clustering environment

## Setup

In [1]:
%matplotlib inline

import copy
import sys
import os
import numpy as np
from matplotlib import pyplot as plt
import gym
import logging
from tqdm import trange

from stable_baselines.common.policies import MlpPolicy, MlpLnLstmPolicy
from stable_baselines.deepq import MlpPolicy as DQNMlpPolicy
from stable_baselines import PPO2, ACER, DQN
from stable_baselines.bench import Monitor
from stable_baselines import results_plotter
from stable_baselines.common.env_checker import check_env
from stable_baselines.common.callbacks import BaseCallback

sys.path.append("../")
from ginkgo_rl import GinkgoLikelihoodEnv, GinkgoLikelihood1DEnv, GinkgoLikelihoodShuffledEnv, GinkgoEvaluator
from ginkgo_rl import BatchedACERAgent, RandomMCTSAgent, MCTSAgent


  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
# Logging setup
logging.basicConfig(
    format='%(asctime)-5.5s %(name)-20.20s %(levelname)-7.7s %(message)s',
    datefmt='%H:%M',
    level=logging.DEBUG
)

for key in logging.Logger.manager.loggerDict:
    if "ginkgo_rl" not in key:
        logging.getLogger(key).setLevel(logging.ERROR)
        
def set_output(on=True):
    for key in logging.Logger.manager.loggerDict:
        if "ginkgo_rl" in key:
            logging.getLogger(key).setLevel(logging.DEBUG if on else logging.ERROR)


## Let's play a round of clustering manually

In [3]:
set_output(True)
env = GinkgoLikelihoodEnv(n_max=6, illegal_reward=0., min_reward=None)
state = env.reset()
env.render()

10:50 ginkgo_rl.envs.ginkg DEBUG   Initializing environment
10:50 ginkgo_rl.envs.ginkg DEBUG   Resetting environment
10:50 ginkgo_rl.envs.ginkg DEBUG   Sampling new jet with 5 leaves
10:50 ginkgo_rl.envs.ginkg INFO    5 particles:
10:50 ginkgo_rl.envs.ginkg INFO      p[ 0] = (  1.2,   0.6,   0.7,   0.8)
10:50 ginkgo_rl.envs.ginkg INFO      p[ 1] = (  1.1,   0.8,   0.6,   0.4)
10:50 ginkgo_rl.envs.ginkg INFO      p[ 2] = (  0.9,   0.4,   0.5,   0.6)
10:50 ginkgo_rl.envs.ginkg INFO      p[ 3] = (  0.5,   0.2,   0.3,   0.3)
10:50 ginkgo_rl.envs.ginkg INFO      p[ 4] = (  0.4,   0.3,   0.2,   0.1)


In [4]:
for i, (children, p, log_likelihood_split, delta) in enumerate(zip(
    env.jet['tree'], env.jet['content'], env.jet['logLH'], env.jet['deltas']
)):
    if children[0] < 0 and children[1] < 0:
        print(f"{i}: {p}. Leaf.")
    else:
        print(f"{i}: {p}. {i} -> ({children[0]}, {children[1]}) with log p = {log_likelihood_split}, delta = {delta}")

0: [407.92156109 230.94010768 230.94010768 230.94010768]. 0 -> (1, 4) with log p = -17.501615524291992, delta = 6400.0
1: [150.63956705 108.53830782  88.01694999  55.98104365]. 1 -> (2, 3) with log p = -3.256178379058838, delta = 30.854320526123047
2: [108.42051582  77.84475298  63.08221104  41.36759101]. Leaf.
3: [42.21905123 30.69355484 24.93473895 14.61345264]. Leaf.
4: [257.28199404 122.40179986 142.92315768 174.95906402]. 4 -> (5, 8) with log p = -7.785614967346191, delta = 174.1205596923828
5: [135.43621138  64.69534027  75.43769792  91.5893318 ]. 5 -> (6, 7) with log p = -4.683798789978027, delta = 78.02823638916016
6: [48.41943841 23.01838007 29.83651693 30.39613273]. Leaf.
7: [87.01677663 41.67696195 45.60118302 61.19320154]. Leaf.
8: [121.84578266  57.70645958  67.48545977  83.36973222]. Leaf.


In [5]:
# Merge two particles
action = 0, 1

state, reward, done, info = env.step(action)
env.render()

print(f"Reward: {reward}")
print(f"Done: {done}")
print(f"Info: {info}")

# Repeat this cell as often as you feel like


10:50 ginkgo_rl.envs.ginkg DEBUG   Environment step. Action: (0, 1)
10:50 ginkgo_rl.envs.ginkg DEBUG   Computing log likelihood of action (0, 1): ti = 0.0, tj = 0.0, t_cut = 16.0, lam = 1.5 -> log likelihood = -10.924210548400879
10:50 ginkgo_rl.envs.ginkg DEBUG   Merging particles 0 and 1. New state has 4 particles.
10:50 ginkgo_rl.envs.ginkg INFO    4 particles:
10:50 ginkgo_rl.envs.ginkg INFO      p[ 0] = (  2.3,   1.4,   1.3,   1.2)
10:50 ginkgo_rl.envs.ginkg INFO      p[ 1] = (  0.9,   0.4,   0.5,   0.6)
10:50 ginkgo_rl.envs.ginkg INFO      p[ 2] = (  0.5,   0.2,   0.3,   0.3)
10:50 ginkgo_rl.envs.ginkg INFO      p[ 3] = (  0.4,   0.3,   0.2,   0.1)


Reward: -10.924210548400879
Done: False
Info: {'legal': True, 'illegal_action_counter': 0, 'replace_illegal_action': False, 'i': 0, 'j': 1}


In [6]:
env.close()

## Env checker

In [7]:
set_output(False)
env = gym.make("GinkgoLikelihoodShuffled-v0")

check_env(env)

for i in range(env.n_max):
    assert i == env.inverse_permutation[env.permutation[i]] == env.permutation[env.inverse_permutation[i]]
    
env.close()



## Evaluation routine and baselines

In [8]:
set_output(False)
evaluator = GinkgoEvaluator()


In [9]:
# True log likelihoods
set_output(False)
evaluator.eval_true("Truth")


In [10]:
# MLE (Trellis)
set_output(False)
evaluator.eval_exact_trellis("ML (Trellis)")


In [11]:
# # Random clusterings
# set_output(False)
# evaluator.eval_random("Random", env_name="GinkgoLikelihood-v0")


## Training functions

In [12]:
log_dirs = []
models = {}
env_names = {}

def train_baseline(algorithm, algo_class, policy_class, env_name="GinkgoLikelihood1D-v0", n_steps=10000):
    log_dir = f"./logs/{algorithm}/"
    log_dirs.append(log_dir)

    set_output(False)
    env = gym.make(env_name)
    os.makedirs(log_dir, exist_ok=True)
    env = Monitor(env, log_dir)

    model = algo_class(policy_class, env, verbose=0)
    model.learn(total_timesteps=n_steps)
    
    models[algorithm] = model
    env_names[algorithm] = env_name
    env.close()

def train_own(algorithm, algo_class, env_name="GinkgoLikelihood1D-v0", n_steps=1000):
    log_dir = f"./logs/{algorithm}/"
    log_dirs.append(log_dir)

    set_output(False)
    env = gym.make(env_name)
    # os.makedirs(log_dir, exist_ok=True)
    # env = Monitor(env, log_dir)
    _ = env.reset()

    model = algo_class(env)
    model.learn(total_timesteps=n_steps)
    
    models[algorithm] = model
    env_names[algorithm] = env_name
    env.close()


## Train agents

In [None]:
train_own("mcts", MCTSAgent)

  logger.debug(f"pR inv mass from p^2 in lab  frame: {np.sqrt(pR_mu[0] ** 2 - np.linalg.norm(pR_mu[1::]) ** 2)}")
  P = np.sqrt(tp)/2 * np.sqrt( 1 - 2 * (t_child+t_sib)/tp + (t_child - t_sib)**2 / tp**2 )
 32%|███▏      | 320/1000 [39:19<20:13,  1.78s/it]  

In [None]:
# train_own("random_mcts", RandomMCTSAgent)

In [None]:
# train_own("acer", BatchedACERAgent)

In [None]:
# train_baseline("ppo", PPO2, MlpPolicy, env_name="GinkgoLikelihood-v0")

In [None]:
# train_baseline("dqn", DQN, DQNMlpPolicy)

## Plot training progress

In [None]:
results_plotter.plot_results(log_dirs[:], 1e4, results_plotter.X_TIMESTEPS, "Ginkgo")


## Evaluate

In [None]:
evaluator.eval("MCTS", models["mcts"], "GinkgoLikelihood1D-v0")

In [None]:
# evaluator.eval("Random MCTS", models["random_mcts"], "GinkgoLikelihood1D-v0")

In [None]:
# evaluator.eval("ACER", models["acer"], "GinkgoLikelihood1D-v0")

In [None]:
# evaluator.eval("PPO", models["ppo"], "GinkgoLikelihood-v0")

In [None]:
# evaluator.eval("DQN", models["dqn"], "GinkgoLikelihood1D-v0")

In [None]:
_ = evaluator.plot_log_likelihoods()