# Zoning Game AlphaZero Demo

## Individually construct bits and pieces

In [1]:
from nsai_experiments.general_az_1p.game import Game
from nsai_experiments.general_az_1p.policy_value_net import PolicyValueNet

from nsai_experiments.general_az_1p.zoning_game.zoning_game_az_impl import ZoningGameGame
from nsai_experiments.general_az_1p.zoning_game.zoning_game_az_impl import ZoningGamePolicyValueNet

### The `Game`

In [2]:
mygame = ZoningGameGame()
assert isinstance(mygame, Game)
mygame.reset_wrapper(seed=47)
print(mygame.render().read())  # type: ignore[union-attr]

Tile grid:
[[0 0 0 5 1 0]
 [0 4 0 0 0 0]
 [0 3 0 3 2 4]
 [0 0 0 0 0 0]
 [2 0 0 0 0 0]
 [0 0 0 0 3 1]]
Tile queue (leftmost next): [1 4 2 1 5 2 3 3 2 3 1 1 1 4 2 2 1 5 5 2 1 5 3 2 5 1 0 0 0 0 0 0 0 0 0 0]
where 0 = EMPTY, 1 = RESIDENTIAL, 2 = COMMERCIAL, 3 = INDUSTRIAL, 4 = DOWNTOWN, 5 = PARK.
After 0 moves, current grid score is 3; terminated = False, truncated = False.



### The `PolicyValueNet`

In [3]:
import torch
from nsai_experiments.zoning_game.notebook_utils import get_zg_data
from nsai_experiments.zoning_game.zg_policy import create_policy_indiv_greedy

torch.manual_seed(47)
n_games = 20_000
savedir = "../../zoning_game/zg_data"
valid_frac = 0.15
test_frac = 0.15

states_tensor, values_tensor, moves_tensor = get_zg_data(create_policy_indiv_greedy, n_games = n_games, savedir = savedir)
indices = torch.randperm(len(values_tensor))
full_dataset_3 = torch.utils.data.TensorDataset(states_tensor[indices], moves_tensor[indices], values_tensor[indices])

valid_size_3 = int(valid_frac * len(full_dataset_3))
test_size_3 = int(test_frac * len(full_dataset_3))
train_size_3 = len(full_dataset_3) - valid_size_3 - test_size_3
train_dataset_3, valid_dataset_3, test_dataset_3 = torch.utils.data.random_split(full_dataset_3, [train_size_3, valid_size_3, test_size_3])
print("Done loading, shuffling, splitting data")

Loading data from disk: ../../zoning_game/zg_data/create_policy_indiv_greedy__20000
Done loading, shuffling, splitting data


In [None]:
mynet = ZoningGamePolicyValueNet(training_params={"epochs": 3})
assert isinstance(mynet, PolicyValueNet)
mynet.train(train_dataset_3, needs_reshape=False)
mynet.predict(mygame.state)

Skipping reshape of `examples`.
Epoch 1/3, Train Loss: 225.8652
Epoch 2/3, Train Loss: 58.3128
Epoch 3/3, Train Loss: 35.9788


(array([0.01218951, 0.01753684, 0.03265317, 0.03207622, 0.01752278,
        0.01161393, 0.02130278, 0.02662969, 0.03313527, 0.03251413,
        0.0260546 , 0.0192566 , 0.02537464, 0.03309538, 0.04804223,
        0.05170261, 0.03571438, 0.03096365, 0.02715765, 0.03155551,
        0.04565817, 0.04267368, 0.0378471 , 0.02587227, 0.01670748,
        0.02475507, 0.02818183, 0.03365848, 0.03014457, 0.01869849,
        0.01263633, 0.019414  , 0.03090652, 0.03224783, 0.0209248 ,
        0.01358179], dtype=float32),
 1.7862025499343872)