# Ginkgo clustering environment

## Setup

In [None]:
%matplotlib inline

import sys
import numpy as np
from matplotlib import pyplot as plt
import gym
import logging
from stable_baselines.common.policies import MlpPolicy
from stable_baselines import PPO2

sys.path.append("../")
from ginkgo_rl import GinkgoLikelihoodEnv, GinkgoLikelihood1DWrapper


In [None]:
# Logging setup
logging.basicConfig(
    format='%(asctime)-5.5s %(name)-20.20s %(levelname)-7.7s %(message)s',
    datefmt='%H:%M',
    level=logging.DEBUG
)

# Ginkgo likes to output a lot of logging info, we don't really want that
for key in logging.Logger.manager.loggerDict:
    if "ginkgo_rl" not in key:
        logging.getLogger(key).setLevel(logging.WARNING)


## Let's play a round of clustering manually

In [None]:
# Initial setup
env = gym.make("GinkgoLikelihood-v0")
state = env.reset()
env.render()

In [None]:
# Merge two particles
action = 0, 1

state, reward, done, info = env.step(action)
env.render()

print(f"Reward: {reward}")
print(f"Done: {done}")
print(f"Info: {info}")

# Repeat this cell as often as you feel like


## Same but for the version of the env with 1D action spaces

In [None]:
# Initial setup
env1d = gym.make("GinkgoLikelihood1D-v0")
state = env1d.reset()
env1d.render()


In [None]:
# Merge two particles
action = 1, 10
action_wrapped = env1d.wrap_action(action)

state, reward, done, info = env1d.step(action_wrapped)
env1d.render()

print(f"Reward: {reward}")
print(f"Done: {done}")
print(f"Info: {info}")

# Repeat this cell as often as you feel like


## Let's let some RL agents loose

In [None]:
n_steps = 1000
n_eval = 100

In [None]:
env1d = gym.make("GinkgoLikelihood1D-v0")

model = PPO2(MlpPolicy, env1d, verbose=1)
model.learn(total_timesteps=n_steps)

state = env1d.reset()
for i in range(n_eval):
    action, _states = model.predict(state)
    obs, rewards, dones, info = env1d.step(action)
    env.render()

env.close()
