In [1]:
import numpy as np
import gym
from go_ai import policies, game, metrics, data
from go_ai.models import value_model
import os
import random
import torch

# Hyperparameters

In [2]:
BOARD_SIZE = 4

In [3]:
ITERATIONS = 256
EPISODES_PER_ITERATION = 128
NUM_EVAL_GAMES = 128

In [4]:
INIT_TEMP = 1
TEMP_DECAY = 3/4
MIN_TEMP = 1/64

In [5]:
BATCH_SIZE = 32

In [6]:
LOAD_SAVED_MODELS = False

# Data Parameters

In [7]:
EPISODES_DIR = 'episodes/'

In [8]:
CHECKPOINT_PATH = 'checkpoints/checkpoint_{}x{}.pt'.format(BOARD_SIZE, BOARD_SIZE)

In [9]:
DEMO_TRAJECTORY_PATH = 'logs/a_trajectory.png'

# Go Environment
Train on a small board for fast training and efficient debugging

In [10]:
go_env = gym.make('gym_go:go-v0', size=BOARD_SIZE)

# Model

In [11]:
curr_model = value_model.ValueNet(BOARD_SIZE)
checkpoint_model = value_model.ValueNet(BOARD_SIZE)

if LOAD_SAVED_MODELS:
    assert os.path.exists(CHECKPOINT_PATH)
    print("Starting from checkpoint")
else:
    torch.save(curr_model.state_dict(), CHECKPOINT_PATH)
    print("Initialized checkpoint") 

curr_model.load_state_dict(torch.load(CHECKPOINT_PATH))
checkpoint_model.load_state_dict(torch.load(CHECKPOINT_PATH))

curr_model

Initialized checkpoint


ValueNet(
  (convs): Sequential(
    (0): Conv2d(6, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ReLU()
  )
  (fcs): Sequential(
    (0): Linear(in_features=16, out_features=16, bias=True)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=1, bias=True)
  )
  (criterion): BCEWithLogitsLoss()
)

# Policies

In [12]:
curr_policy = policies.QTempPolicy('Current', curr_model, INIT_TEMP)
checkpoint_policy = policies.QTempPolicy('Checkpoint', checkpoint_model, INIT_TEMP)

random_policy = policies.RandomPolicy()
greedy_policy = policies.QTempPolicy('Greedy', policies.greedy_val_func, temp=0)
human_policy = policies.HumanPolicy()

In [13]:
def decay_temps(policies, temp_decay, min_temp):
    for policy in policies:
        assert hasattr(policy, 'temp')
        policy.temp *= temp_decay
        if policy.temp < min_temp:
            policy.temp = min_temp
        print(f"{policy.name} temp decayed to {policy.temp}")

# Demo and Time Games

Symmetries

In [14]:
%%time
go_env.reset()
action = (1, 1)
next_state, _, _, _ = go_env.step(action)
metrics.plot_symmetries(next_state, 'logs/symmetries.jpg')

CPU times: user 338 ms, sys: 32.6 ms, total: 371 ms
Wall time: 189 ms


With replay memory

In [15]:
%%time
go_env.reset()
_,_ = game.pit(go_env, curr_policy, curr_policy, get_traj=True)

CPU times: user 504 ms, sys: 63.3 ms, total: 567 ms
Wall time: 193 ms


In [16]:
%%time
metrics.gen_traj_fig(go_env, curr_policy, DEMO_TRAJECTORY_PATH)

CPU times: user 2.26 s, sys: 99.4 ms, total: 2.36 s
Wall time: 1.98 s


# Train

In [17]:
for iteration in range(ITERATIONS):
    print(f"Iteration {iteration}")
    
    # Make and write out the episode data
    _, replay_data = game.play_games(go_env, curr_policy, curr_policy, True, EPISODES_PER_ITERATION)
        
    # Process the data
    random.shuffle(replay_data)
    replay_data = data.replaylist_to_numpy(replay_data)

    # Optimize
    curr_model.optimize(replay_data, BATCH_SIZE)
    
    # Evaluate against checkpoint model and other baselines
    opp_winrate, _ = game.play_games(go_env, curr_policy, checkpoint_policy, False, EPISODES_PER_ITERATION)

    if opp_winrate > 0.6:
        # New parameters are significantly better. Accept it
        torch.save(curr_model.state_dict(), CHECKPOINT_PATH)
        checkpoint_model.load_state_dict(torch.load(CHECKPOINT_PATH))
        print(f"{100*opp_winrate:.1f}% Accepted new model")
        
        # Plot samples of states and response heatmaps
        metrics.gen_traj_fig(go_env, curr_policy, DEMO_TRAJECTORY_PATH)
        print("Plotted sample trajectory")
        
        rand_winrate, _ = game.play_games(go_env, curr_policy, random_policy, False, NUM_EVAL_GAMES)
        greed_winrate, _ = game.play_games(go_env, curr_policy, greedy_policy, False, NUM_EVAL_GAMES)

    elif opp_winrate >= 0.4:
        # Keep trying
        print(f"{100*opp_winrate:.1f}% Continuing to train current weights")
    else:
        # New parameters are significantly worse. Reject it.
        curr_model.load_state_dict(torch.load(CHECKPOINT_PATH))
        print(f"{100*opp_winrate:.1f}% Rejected new model")
        
    # Decay the temperatures if any
    decay_temps([curr_policy, checkpoint_policy], TEMP_DECAY, MIN_TEMP)

Current vs. Current:   0%|          | 0/128 [00:00<?, ?it/s]

Iteration 0


Current vs. Current: 100%|██████████| 128/128 [00:25<00:00,  5.04it/s, 59.0%]
Optimizing: 82it [00:00, 175.68it/s, 52.4%, 0.692L]
Current vs. Checkpoint: 100%|██████████| 128/128 [00:24<00:00,  5.17it/s, 60.2%]


60.2% Accepted new model


Current vs. Random:   2%|▏         | 2/128 [00:00<00:11, 10.74it/s, 100.0%]

Plotted sample trajectory


Current vs. Random: 100%|██████████| 128/128 [00:15<00:00,  8.39it/s, 52.0%]
Current vs. Greedy: 100%|██████████| 128/128 [00:26<00:00,  4.85it/s, 4.7%]
Current vs. Current:   1%|          | 1/128 [00:00<00:21,  5.88it/s, 100.0%]

Current temp decayed to 0.75
Checkpoint temp decayed to 0.75
Iteration 1


Current vs. Current: 100%|██████████| 128/128 [00:27<00:00,  4.62it/s, 52.3%]
Optimizing: 91it [00:00, 182.48it/s, 52.4%, 0.693L]
Current vs. Checkpoint: 100%|██████████| 128/128 [00:26<00:00,  4.92it/s, 57.0%]
Current vs. Current:   0%|          | 0/128 [00:00<?, ?it/s]

57.0% Continuing to train current weights
Current temp decayed to 0.5625
Checkpoint temp decayed to 0.5625
Iteration 2


Current vs. Current: 100%|██████████| 128/128 [00:25<00:00,  5.02it/s, 55.9%]
Optimizing: 81it [00:00, 168.01it/s, 52.5%, 0.693L]
Current vs. Checkpoint: 100%|██████████| 128/128 [00:24<00:00,  5.31it/s, 57.4%]
Current vs. Current:   1%|          | 1/128 [00:00<00:22,  5.57it/s, 0.0%]

57.4% Continuing to train current weights
Current temp decayed to 0.421875
Checkpoint temp decayed to 0.421875
Iteration 3


Current vs. Current: 100%|██████████| 128/128 [00:23<00:00,  5.37it/s, 60.5%]
Optimizing: 80it [00:00, 184.48it/s, 51.7%, 0.693L]
Current vs. Checkpoint: 100%|██████████| 128/128 [00:24<00:00,  5.30it/s, 58.6%]
Current vs. Current:   0%|          | 0/128 [00:00<?, ?it/s]

58.6% Continuing to train current weights
Current temp decayed to 0.31640625
Checkpoint temp decayed to 0.31640625
Iteration 4


Current vs. Current: 100%|██████████| 128/128 [00:23<00:00,  5.49it/s, 64.1%]
Optimizing: 76it [00:00, 181.74it/s, 49.3%, 0.694L]
Current vs. Checkpoint:  15%|█▍        | 19/128 [00:03<00:21,  4.99it/s, 47.4%]

KeyboardInterrupt: 

# Evaluate

Play against our AI

In [None]:
set_temps([curr_policy, checkpoint_policy], 0)

In [None]:
game.pit(go_env, human_policy, checkpoint_policy, False)