In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import tensorflow as tf

import muzero as mz
import gym
from time import time, sleep
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

In [2]:
env = gym.make("Breakout-ram-v0")

In [3]:
MAX_MOVES = 27000       # Maximum number of moves in each game
DISCOUNT = 0.95
DIRICHLET_ALPHA = 0.25  # paper/stream
NUM_SIMULATIONS = 50    # Number of mcts sims 
BATCH_SIZE = 128       # arbitrary
TD_STEPS = 10           # arbitrary
NUM_ACTORS = 10         # game-specific

config = mz.MuZeroConfig(
  env, MAX_MOVES, DISCOUNT, DIRICHLET_ALPHA, NUM_SIMULATIONS, BATCH_SIZE, 
  TD_STEPS, NUM_ACTORS
)

storage = mz.SharedStorage()
replay_buffer = mz.ReplayBuffer(config)

**Load the pretrained network**:

In [4]:
network = storage.latest_network(config)
network.load_pretrained('data/model')

storage._networks[0] = network

**Load previous game history**:

In [5]:
replay_buffer.load_buffer('data/buffer.pkl')

In [6]:
print(max([len(g.history) for g in replay_buffer.buffer]))

10000


Play and train:

In [9]:
NUM_LOOPS = 2
for i in range(NUM_LOOPS):
  for _ in tqdm(range(config.num_actors)):
    mz.run_selfplay(config, storage, replay_buffer)
  mz.train_network(config, storage, replay_buffer)

  # Print some statistics
  latest_network = storage.latest_network(config)
  print(f'Latest loss: {latest_network.losses[-1][0]}')
  max_reward = max(
    [sum(g.rewards) for g in replay_buffer.buffer[-config.num_actors:]]
  )
  print(f'Maximum reward: {max_reward}')
  print()

100%|██████████| 10/10 [11:45<00:00, 70.52s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Latest loss: 6266.70654296875
Maximum reward: 4.0



100%|██████████| 10/10 [10:53<00:00, 65.35s/it]


Latest loss: 5344.46923828125
Maximum reward: 5.0



## Play a game

In [10]:
network = storage.latest_network(config)

In [12]:
game = mz.play_game(config, network)

In [13]:
print(sum(game.rewards))

0.0


In [14]:
obs = env.reset()
for i, action in enumerate(game.history):
    env.render()
    sleep(.1)
    obs, r, d, info = env.step(action)
#     env.render()
    if i == 125:
      print(i)
    if r > 0:
      print(f'Reward baby: {r}')
    if d:
      print('Done')
      break
env.close()

Reward baby: 1.0
125
Reward baby: 1.0


### Saving

**Saving** the model:

In [15]:
network.save('data/model')

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:tensorflow:Assets written to: data/model/assets


**Saving** the game history:

In [16]:
replay_buffer.save_buffer('data/buffer.pkl')

/

In [5]:
for _ in tqdm(range(5)):
  run_selfplay(config, storage, replay_buffer)

train_network(config, storage, replay_buffer)

for _ in tqdm(range(5)):
  run_selfplay(config, storage, replay_buffer)
  
train_network(config, storage, replay_buffer)

100%|██████████| 5/5 [05:02<00:00, 60.47s/it]
100%|██████████| 5/5 [05:36<00:00, 67.25s/it]
