In [17]:
import time
import pickle

import jax 
from jax import numpy as jnp 
jax.config.update('jax_platform_name', 'cpu')
from mctx import gumbel_muzero_policy

import muax
from muax import nn 

from envs import ClanClassicEnv
import cv2
import numpy as np


In [18]:
class Buffer():
    def __init__(self):
        self.games_rgb = []
        self.games_actions = []
        self.games_results = []
        
    def add_game(self, rgb_frames, actions, result):
        self.games_rgb.append(rgb_frames)
        self.games_actions.append(actions)
        self.games_results.append(result)

    def buffer_full(self):
        return len(self.games_results) == 64
    

rng_key = jax.random.PRNGKey(42)

embedding_size = 128
num_actions = 2305
support_size = 2305
full_support_size = 4611
discount = 0.99
num_simulations = 1
output_init_scale = 1.0

num_trajectory = 32
sample_per_trajectory = 1
k_steps = 10

repr_fn = nn._init_ez_representation_func(nn.EZRepresentation, embedding_size)
pred_fn = nn._init_ez_prediction_func(nn.EZPrediction, num_actions, full_support_size, output_init_scale)
dy_fn = nn._init_ez_dynamic_func(nn.EZDynamic, embedding_size, num_actions, full_support_size, output_init_scale)



In [19]:

gradient_transform = muax.model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.002, warmup_steps=5000, transition_steps=5000)

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

sample_input = jnp.expand_dims(jnp.zeros((240, 135, 3)), axis=0)
#sample_input = cv2.resize(sample_input, dsize=(400, 225), interpolation=cv2.INTER_CUBIC)
rng_key, subkey = jax.random.split(rng_key)
model.init(subkey, sample_input)


buffer = Buffer()


In [11]:

while True: 
    try:
        env = ClanClassicEnv(serial="RFCWC04A2VY", host=True)
        break
    except:
        time.sleep(1)


Client Initialized.


In [None]:
for i in range(64):
    rgb_frames = []
    actions = []
    #env.reset()
    print(f"Game {i} began.")
    while True: #env.in_game():
        obs = env.get_observation()
        obs = cv2.resize(obs, dsize=(135, 240), interpolation=cv2.INTER_CUBIC)
        rng_key, subkey = jax.random.split(rng_key)
        a, pi, v = model.act(subkey, obs, 
                        with_pi=True, 
                        with_value=True, 
                        obs_from_batch=False,
                        num_simulations=num_simulations,
                        temperature=1.0)
        print(a)


In [21]:
import pickle
import muax

#model.load('networks/initial_model_uniform_v0.2.npy')

with open('buffers/game_buffer_S1_uniform_0403-0856_H.pkl', 'rb') as file_handle: 
  buffer = pickle.load(file_handle)

train_loss = 0
for ep in range(1):
  for _ in range(50):
    transition_batch = buffer.sample(num_trajectory=32,
                                      sample_per_trajectory=sample_per_trajectory,
                                      k_steps=k_steps)
    loss_metric = model.update(transition_batch)
    train_loss += loss_metric['loss']
    print(f"Loss: {loss_metric['loss']}")

  train_loss /= 50
  print(f'epoch: {ep:04d}, loss: {train_loss:.8f}')


Loss: 12.700068473815918
Loss: 11.437064170837402
Loss: 10.445414543151855
Loss: 9.770403861999512
Loss: 9.374824523925781
Loss: 9.19653034210205
Loss: 9.12408447265625
Loss: 9.059216499328613
Loss: 9.00328254699707
Loss: 8.968838691711426
Loss: 8.93356704711914
Loss: 8.887943267822266
Loss: 8.853827476501465
Loss: 8.825200080871582
Loss: 8.789388656616211
Loss: 8.752554893493652
Loss: 8.72392749786377
Loss: 8.693243026733398
Loss: 8.659235000610352
Loss: 8.638696670532227
Loss: 8.617006301879883
Loss: 8.585643768310547
Loss: 8.551481246948242
Loss: 8.546408653259277
Loss: 8.526519775390625
Loss: 8.509419441223145
Loss: 8.492897987365723
Loss: 8.478777885437012
Loss: 8.460850715637207
Loss: 8.442511558532715
Loss: 8.439632415771484
Loss: 8.429237365722656
Loss: 8.418272018432617
Loss: 8.400196075439453
Loss: 8.40036392211914
Loss: 8.394373893737793
Loss: 8.386743545532227
Loss: 8.38058090209961
Loss: 8.372041702270508
Loss: 8.363141059875488
Loss: 8.364082336425781
Loss: 8.356446266174

In [14]:
model.save('initial_model_uniform_v0.2.npy')
