In [1]:
import numpy as np
import gym
from go_ai import policies, game, metrics, data
from go_ai.models import value_model
import os
import random
import torch

# Hyperparameters

In [2]:
BOARD_SIZE = 4

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 256
NUM_EVAL_GAMES = 128

In [4]:
INIT_TEMP = 1/64
TEMP_DECAY = 3/4
MIN_TEMP = 1/64

In [5]:
BATCH_SIZE = 32

In [6]:
LOAD_SAVED_MODELS = True

# Data Parameters

In [7]:
EPISODES_DIR = 'episodes/'

In [8]:
CHECKPOINT_PATH = 'checkpoints/checkpoint_{}x{}.pt'.format(BOARD_SIZE, BOARD_SIZE)

In [9]:
DEMO_TRAJECTORY_PATH = 'logs/a_trajectory.png'

# Go Environment
Train on a small board for fast training and efficient debugging

In [10]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Model

In [11]:
curr_model = value_model.ValueNet(BOARD_SIZE)
checkpoint_model = value_model.ValueNet(BOARD_SIZE)

if LOAD_SAVED_MODELS:
    assert os.path.exists(CHECKPOINT_PATH)
    print("Starting from checkpoint")
else:
    torch.save(curr_model.state_dict(), CHECKPOINT_PATH)
    print("Initialized checkpoint") 

curr_model.load_state_dict(torch.load(CHECKPOINT_PATH))
checkpoint_model.load_state_dict(torch.load(CHECKPOINT_PATH))

optim = torch.optim.Adam(curr_model.parameters(), 1e-3)
curr_model

Starting from checkpoint


ValueNet(
  (convs): Sequential(
    (0): Conv2d(6, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
    (6): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): Conv2d(256, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU()
  )
  (fcs): Sequential(
    (0): Linear(in_features=16, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=1, bias=True)
  )
  (criterion): BCEWithLogitsLoss()
)

# Policies

In [12]:
curr_policy = policies.QTempPolicy('Current', curr_model, INIT_TEMP)
checkpoint_policy = policies.QTempPolicy('Checkpoint', checkpoint_model, INIT_TEMP)

random_policy = policies.RandomPolicy()
greedy_policy = policies.QTempPolicy('Greedy', policies.greedy_val_func, temp=0)
human_policy = policies.HumanPolicy()

In [13]:
def decay_temps(policies, temp_decay, min_temp):
    for policy in policies:
        assert hasattr(policy, 'temp')
        policy.temp *= temp_decay
        if policy.temp < min_temp:
            policy.temp = min_temp
        print(f"{policy.name} temp decayed to {policy.temp}")

# Demo and Time Games

Symmetries

In [14]:
%%time
go_env.reset()
action = (1, 1)
next_state, _, _, _ = go_env.step(action)
metrics.plot_symmetries(next_state, 'logs/symmetries.jpg')

CPU times: user 323 ms, sys: 28.2 ms, total: 352 ms
Wall time: 172 ms


With replay memory

In [15]:
%%time
go_env.reset()
_,_ = game.pit(go_env, curr_policy, curr_policy, get_traj=True)

CPU times: user 675 ms, sys: 65.1 ms, total: 741 ms
Wall time: 358 ms


In [16]:
%%time
metrics.gen_traj_fig(go_env, curr_policy, DEMO_TRAJECTORY_PATH)

CPU times: user 3.4 s, sys: 126 ms, total: 3.53 s
Wall time: 3.14 s


# Train

In [17]:
for iteration in range(ITERATIONS):
    print(f"Iteration {iteration}")
    
    # Make and write out the episode data
    _, replay_data = game.play_games(go_env, curr_policy, curr_policy, True, EPISODES_PER_ITERATION)
        
    # Process the data
    random.shuffle(replay_data)
    replay_data = data.replaylist_to_numpy(replay_data)

    # Optimize
    curr_model.optimize(replay_data, optim, BATCH_SIZE)
    
    # Evaluate against checkpoint model and other baselines
    opp_winrate, _ = game.play_games(go_env, curr_policy, checkpoint_policy, False, NUM_EVAL_GAMES)

    if opp_winrate > 0.6:
        # New parameters are significantly better. Accept it
        torch.save(curr_model.state_dict(), CHECKPOINT_PATH)
        checkpoint_model.load_state_dict(torch.load(CHECKPOINT_PATH))
        print(f"{100*opp_winrate:.1f}% Accepted new model")
        
        # Plot samples of states and response heatmaps
        metrics.gen_traj_fig(go_env, curr_policy, DEMO_TRAJECTORY_PATH)
        print("Plotted sample trajectory")
        
        rand_winrate, _ = game.play_games(go_env, curr_policy, random_policy, False, NUM_EVAL_GAMES)
        greed_winrate, _ = game.play_games(go_env, curr_policy, greedy_policy, False, NUM_EVAL_GAMES)

    elif opp_winrate >= 0.4:
        # Keep trying
        print(f"{100*opp_winrate:.1f}% Continuing to train current weights")
    else:
        # New parameters are significantly worse. Reject it.
        curr_model.load_state_dict(torch.load(CHECKPOINT_PATH))
        print(f"{100*opp_winrate:.1f}% Rejected new model")
        
    # Decay the temperatures if any
    decay_temps([curr_policy, checkpoint_policy], TEMP_DECAY, MIN_TEMP)

Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Iteration 0


Current vs. Current: 100%|██████████| 256/256 [02:02<00:00,  2.09it/s, 46.1%]
Optimizing: 182it [00:28,  6.40it/s, 82.8%, 0.381L]
Current vs. Checkpoint: 100%|██████████| 256/256 [02:14<00:00,  1.90it/s, 61.5%]


61.5% Accepted new model


Current vs. Random:   0%|          | 0/128 [00:00<?, ?it/s]

Plotted sample trajectory


Current vs. Random: 100%|██████████| 128/128 [00:36<00:00,  3.55it/s, 93.0%]
Current vs. Greedy: 100%|██████████| 128/128 [00:49<00:00,  2.59it/s, 72.7%]
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

Current temp decayed to 0.015625
Checkpoint temp decayed to 0.015625
Iteration 1


Current vs. Current: 100%|██████████| 256/256 [02:21<00:00,  1.80it/s, 50.4%]
Optimizing: 223it [00:34,  6.49it/s, 68.8%, 0.585L]
Current vs. Checkpoint: 100%|██████████| 256/256 [02:20<00:00,  1.82it/s, 39.8%]
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

39.8% Rejected new model
Current temp decayed to 0.015625
Checkpoint temp decayed to 0.015625
Iteration 2


Current vs. Current: 100%|██████████| 256/256 [02:30<00:00,  1.70it/s, 43.4%]
Optimizing: 227it [00:36,  6.18it/s, 64.5%, 0.607L]
Current vs. Checkpoint: 100%|██████████| 256/256 [02:22<00:00,  1.79it/s, 42.2%]
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

42.2% Continuing to train current weights
Current temp decayed to 0.015625
Checkpoint temp decayed to 0.015625
Iteration 3


Current vs. Current: 100%|██████████| 256/256 [02:19<00:00,  1.83it/s, 49.4%]
Optimizing: 211it [00:33,  6.34it/s, 71.5%, 0.483L]
Current vs. Checkpoint: 100%|██████████| 256/256 [02:10<00:00,  1.97it/s, 42.2%]
Current vs. Current:   0%|          | 0/256 [00:00<?, ?it/s]

42.2% Continuing to train current weights
Current temp decayed to 0.015625
Checkpoint temp decayed to 0.015625
Iteration 4


Current vs. Current: 100%|██████████| 256/256 [02:00<00:00,  2.12it/s, 48.4%]
Optimizing: 174it [00:27,  6.29it/s, 65.4%, 0.584L]
Current vs. Checkpoint:  40%|███▉      | 102/256 [00:51<01:11,  2.17it/s, 25.0%]

KeyboardInterrupt: 

# Evaluate

Play against our AI

In [None]:
set_temps([curr_policy, checkpoint_policy], 0)

In [None]:
game.pit(go_env, human_policy, checkpoint_policy, False)