In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import tensorflow as tf

from muzero import Game, Node, MuZeroConfig, ReplayBuffer, SharedStorage
from muzero.play import play_game, run_selfplay
from muzero.model import train_network
import gym
from time import time, sleep
from tqdm import tqdm

import warnings
warnings.filterwarnings('ignore')

In [None]:
env = gym.make("Breakout-ram-v0")

In [3]:
MAX_MOVES = 27000       # Maximum number of moves in each game
DISCOUNT = 0.95
DIRICHLET_ALPHA = 0.25  # paper/stream
NUM_SIMULATIONS = 50    # Number of mcts sims 
BATCH_SIZE = 1024       # arbitrary
TD_STEPS = 10           # arbitrary
NUM_ACTORS = 10         # game-specific

config = MuZeroConfig(env, MAX_MOVES, DISCOUNT, DIRICHLET_ALPHA, 
                      NUM_SIMULATIONS, BATCH_SIZE, TD_STEPS, NUM_ACTORS)

storage = SharedStorage()
replay_buffer = ReplayBuffer(config)

**Load the pretrained network**:

In [4]:
network = storage.latest_network(config)
network.load_pretrained('data/model')

storage._networks[0] = network

**Load previous game history**:

In [5]:
replay_buffer.load_buffer('data/buffer.pkl')

Play and train:

In [1]:
NUM_LOOPS = 2
for i in range(NUM_LOOPS):
  for _ in tqdm(range(config.num_actors)):
    run_selfplay(config, storage, replay_buffer)
  train_network(config, storage, replay_buffer)

  # Print some statistics
  latest_network = storage.latest_network(config)
  print(f'Latest loss: {latest_network.losses[-1][0]}')
  max_reward = max([sum(g.rewards) for g in replay_buffer.buffer[-config.num_actors:]])
  print(f'Maximum reward: {max_reward}')
  print()

## Play a game

In [103]:
network = storage.latest_network(config)

In [28]:
game = play_game(config, network)

In [29]:
print(sum(game.rewards))

1.0


In [101]:
obs = env.reset()
for i, action in enumerate(game.history):
    env.render()
    sleep(.1)
    obs, r, d, info = env.step(action)
#     env.render()
    if i == 125:
      print(i)
    if r > 0:
      print(f'Reward baby: {r}')
    if d:
      print('Done')
      break
env.close()

125
Done


### Saving

**Saving** the model:

In [104]:
network.save('data/model')

INFO:tensorflow:Assets written to: data/model/assets


**Saving** the game history:

In [40]:
replay_buffer.save_buffer('data/buffer.pkl')

/

In [5]:
for _ in tqdm(range(5)):
  run_selfplay(config, storage, replay_buffer)

train_network(config, storage, replay_buffer)

for _ in tqdm(range(5)):
  run_selfplay(config, storage, replay_buffer)
  
train_network(config, storage, replay_buffer)

100%|██████████| 5/5 [05:02<00:00, 60.47s/it]
100%|██████████| 5/5 [05:36<00:00, 67.25s/it]
