In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
import datetime
import json
import os
import pickle
import threading
import time

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras

from tqdm.notebook import trange

In [4]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU')]

In [5]:
for gpu in tf.config.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)

In [6]:
import solution.constants as const
import solution.model as model
import solution.utils as utils

In [7]:
def get_replay_data(delete_processed_files=False):
    with tf.device('CPU:0'):
        current_chunk = []
        for file_name in os.listdir(const.REPLAY_PATH):
            file_path = os.path.join(const.REPLAY_PATH, file_name)
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
                current_chunk += data
            
            if delete_processed_files:
                os.unlink(file_path)
    return current_chunk


MAX_REPLAY_BUFFER_SIZE = 50000
replay_buffer = []


def replay_buffer_job():
    global replay_buffer
    
    clear_dir(const.REPLAY_PATH)
    
    while replay_data_needed:
        chunk = get_replay_data(True)
        if not chunk:
            time.sleep(1.0)
            continue
        
        with buffer_lock:
            if len(chunk) + len(replay_buffer) > MAX_REPLAY_BUFFER_SIZE:
                start_index = len(chunk)
                replay_buffer = replay_buffer[start_index:]
            replay_buffer += chunk
            print('New buffer length: %s' % len(replay_buffer))

In [8]:
def games_data_job():
    clear_dir(const.GAMES_DATA_PATH)
    step = 0
    
    while games_data_needed:
        for file_name in os.listdir(const.GAMES_DATA_PATH):
            file_path = os.path.join(const.GAMES_DATA_PATH, file_name)
            with open(file_path) as f:
                game_data = json.load(f)
                
                with tf_writer_lock, tf_writer.as_default():
                    for k, v in game_data.items():
                        if v is None:
                            continue
                        tf.summary.scalar(k, v, step)
                step += 1
                
            os.unlink(file_path)

        time.sleep(1)

In [9]:
def soft_update(source_variables, target_variables, tau):
    for v_s, v_t in zip(source_variables, target_variables):
        v_t.assign((1 - tau) * v_t + tau * v_s)



def save_model(epoch):
    if not os.path.exists(const.MODELS_PATH):
        os.mkdir(const.MODELS_PATH)
        
    file_path = os.path.join(const.MODELS_PATH, '%s_%s' % (current_time, epoch))
    network.save_weights(file_path)
    
    tmp_link_path = '%s_tmp' % const.LATEST_MODEL_SYMLINK_PATH
    os.symlink(file_path, tmp_link_path)
    os.rename(tmp_link_path, const.LATEST_MODEL_SYMLINK_PATH)
    

def clear_dir(path):
    if not os.path.exists(path):
        os.mkdir(path)
    
    for file_name in os.listdir(path):
        file_path = os.path.join(path, file_name)
        os.unlink(file_path)

In [10]:
input_shape = [(1,) + tuple(i.shape) for i in utils.STATE_SPEC]

network = model.Model()
network.build(input_shape)
network.summary()

target_network = model.Model()
target_network.build(input_shape)
target_network.trainable = False


optimizer = keras.optimizers.Adam()
loss_fn = keras.losses.MeanSquaredError()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential (Sequential)      (1, 128)                  99232     
_________________________________________________________________
sequential_1 (Sequential)    (1, 32)                   2912      
_________________________________________________________________
sequential_2 (Sequential)    (1, 1, 1, 32)             15424     
_________________________________________________________________
sequential_3 (Sequential)    (1, 32)                   1280      
_________________________________________________________________
sequential_4 (Sequential)    (1, 64)                   8320      
_________________________________________________________________
sequential_5 (Sequential)    (1, 5)                    58885     
_________________________________________________________________
sequential_6 (Sequential)    (1, 1)                    57857 

In [11]:
N_STEP = 5
@tf.function
def learn_from_batch(states, actions, rewards, is_not_done, *, gamma):
    # states is tuple of 5 state_element, state_element has (B, N, *) shape
    batch_size = actions.shape[0]
    steps = actions.shape[1]
    
    assert steps == N_STEP + 1
    
    first_state = tuple([state[:, 0] for state in states])
    last_state = tuple([state[:, -1] for state in states])
    
    # Get the best action index according to online network
    online_q_values = network(last_state)
    best_online_action_index = tf.argmax(online_q_values, -1)
    
    # But get the value of the action according to target network
    target_q_values = target_network(last_state)
    last_state_value = tf.gather(target_q_values, best_online_action_index, axis=-1, batch_dims=1)
    
    # Set last state values to zeros if they are past the end
    last_state_value = tf.where(is_not_done[:, -1], last_state_value, tf.zeros(batch_size,))
    
    cumulative_returns = last_state_value
    for n in reversed(range(N_STEP)):
        cumulative_returns = rewards[:, n] + gamma * cumulative_returns
        
    first_action = actions[:, 0]
    with tf.GradientTape() as tape:
        q_values = network(first_state, training=True)
        state_value = tf.gather(q_values, first_action, axis=-1, batch_dims=1)
        loss = loss_fn(cumulative_returns, state_value)
    
    gradients = tape.gradient(loss, network.trainable_weights)
    optimizer.apply_gradients(zip(gradients, network.trainable_weights))
    
    return loss

In [None]:
def stop_threads():
    if 'replay_buffer_thread' in globals():
        print('Stopping replay thread')
        global replay_data_needed
        replay_data_needed = False
        if replay_buffer_thread.is_alive():
            replay_buffer_thread.join()
    
    if 'games_data_thread' in globals():
        print('Stopping games data thread')
        global games_data_needed
        games_data_needed = False
        if games_data_thread.is_alive():
            games_data_thread.join()


stop_threads()

# Setup tensorboard
TB_LOGS_DIR = '/tmp/tb_logs/'
if not os.path.exists(TB_LOGS_DIR):
    os.mkdir(TB_LOGS_DIR)
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
current_logs_dir = os.path.join(TB_LOGS_DIR, current_time)
tf_writer = tf.summary.create_file_writer(current_logs_dir)
tf_writer_lock = threading.Lock()

#tf.random.set_seed(1)

EPOCHS = 1000000
RANDOM_SIZE = 100000
TARGET_NETWORK_UPDATE_EPISODES = 2500
MODEL_SAVE_PERIOD = 100
GAMMA = 0.99
BATCH_SIZE = 64


steps = 0

buffer_lock = threading.Lock()
replay_data_needed = True
replay_buffer_thread = threading.Thread(target=replay_buffer_job)
replay_buffer_thread.start()

games_data_needed = True
games_data_thread = threading.Thread(target=games_data_job)
games_data_thread.start()


save_model(0)


while len(replay_buffer) < BATCH_SIZE * 10:
    print('Waiting for data:', len(replay_buffer))
    time.sleep(1)


for epoch in trange(EPOCHS):
    with buffer_lock, tf.device('CPU:0'):
        indices = tf.random.uniform((BATCH_SIZE,), 0, len(replay_buffer), dtype=tf.int64)
        #print(replay_buffer[0]['actions'].device)
        batch_actions = tf.stack([replay_buffer[i]['actions'] for i in indices])
        #print(batch_actions.device, batch_actions[0].device)
        batch_rewards = tf.stack([replay_buffer[i]['rewards'] for i in indices])
        batch_is_not_done = tf.stack([replay_buffer[i]['is_not_done'] for i in indices])
        batch_states = []
        for k in range(len(replay_buffer[0]['states'])):
            batch_states.append(
                tf.stack([replay_buffer[i]['states'][k] for i in indices])
            )
        
        model_versions = [int(replay_buffer[i]['additional_data'][0]['model_version']) for i in indices]
    
    avg_model_version = np.mean(model_versions)


    if steps % TARGET_NETWORK_UPDATE_EPISODES == 0:
        print('Updating network')
        target_network.set_weights(network.get_weights())
    
    if steps % MODEL_SAVE_PERIOD == 0:
        print('Saving model')
        save_model(epoch)
        
    if steps == 1:
        network.summary()

    steps += 1
    
    loss = learn_from_batch(batch_states, batch_actions, batch_rewards, batch_is_not_done, gamma=GAMMA)

    with tf_writer_lock, tf_writer.as_default():
        tf.summary.scalar('loss', loss, step=steps)
        tf.summary.scalar('avg_version', avg_model_version, step=steps)

stop_threads()

Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0


Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0
Waiting for data: 0


In [None]:
stop_threads()