# Debug trained MCTS agent

In [1]:
%matplotlib inline

import numpy as np
import sys
import logging
from matplotlib import pyplot as plt
import torch

sys.path.append("../../")
from ginkgo_rl import GinkgoLikelihood1DEnv, MCTSAgent


In [2]:
logging.basicConfig(
    format='%(message)s',
    datefmt='%H:%M',
    level=logging.DEBUG
)

for key in logging.Logger.manager.loggerDict:
    if "ginkgo_rl" not in key:
        logging.getLogger(key).setLevel(logging.ERROR)


In [3]:
env = GinkgoLikelihood1DEnv()
agent = MCTSAgent(env, verbose=1)
agent.load_state_dict(torch.load("../data/runs/mcts_20200901_174303/model.pty"))

Initializing environment
Creating linear layer: 100->100
Creating linear head layer: 100->1


<All keys matched successfully>

## Let's play an episode

In [4]:
# Initialize episode
state = env.reset()
done = False
log_likelihood = 0.
errors = 0
reward = 0.0
agent.set_env(env)
agent.eval()

# Render initial state
env.render()

while not done:
    # Agent step
    action, agent_info = agent.predict(state)
    
    # Environment step
    next_state, next_reward, done, info = env.step(action)
    env.render()

    # Book keeping
    log_likelihood += next_reward
    errors += int(info["legal"])
    agent.update(state, reward, action, done, next_state, next_reward=reward, num_episode=0, **agent_info)
    reward, state = next_reward, next_state
    

Resetting environment
Sampling new jet with 9 leaves
9 particles:
  p[ 0] = (  1.3,   0.9,   0.7,   0.7)
  p[ 1] = (  0.8,   0.3,   0.5,   0.5)
  p[ 2] = (  0.6,   0.3,   0.3,   0.3)
  p[ 3] = (  0.5,   0.4,   0.3,   0.2)
  p[ 4] = (  0.3,   0.2,   0.2,   0.2)
  p[ 5] = (  0.2,   0.1,   0.1,   0.1)
  p[ 6] = (  0.2,   0.1,   0.1,   0.1)
  p[ 7] = (  0.1,   0.0,   0.1,   0.1)
  p[ 8] = (  0.0,   0.0,   0.0,   0.0)
Starting MCTS with 100 trajectories
MCTS results:
     0: log likelihood =  -8.4, policy = 0.00, n =  0, mean =   0.0 [0.93], max =  -inf [0.00]
     1: log likelihood =  -5.3, policy = 0.00, n =  0, mean =   0.0 [0.93], max =  -inf [0.00]
     2: log likelihood =  -4.6, policy = 0.00, n =  0, mean =   0.0 [0.93], max =  -inf [0.00]
     3: log likelihood =  -8.5, policy = 0.00, n =  0, mean =   0.0 [0.93], max =  -inf [0.00]
     4: log likelihood =  -9.7, policy = 0.00, n =  0, mean =   0.0 [0.93], max =  -inf [0.00]
     5: log likelihood =  -8.1, policy = 0.00, n =  0, mea