In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import collections
import datetime
import gc
import json
import os
import pickle
import threading
import time

import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.keras as keras

from tqdm.notebook import trange

In [3]:
for gpu in tf.config.list_physical_devices('GPU'):
    tf.config.experimental.set_memory_growth(gpu, True)

In [4]:
import solution.constants as const
import solution.model as model
import solution.utils as utils
from solution.replay_buffer import ReplayBuffer
from solution.rl import n_step_return

In [19]:
def get_replay_data(delete_processed_files=False):
    with tf.device('CPU:0'):
        current_chunk = []
        for file_name in os.listdir(const.REPLAY_PATH):
            file_path = os.path.join(const.REPLAY_PATH, file_name)
            print('kuddai file_path', file_path)
            with open(file_path, 'rb') as f:
                data = pickle.load(f)
                current_chunk += data
            
            if delete_processed_files:
                os.unlink(file_path)
    return current_chunk


MAX_REPLAY_BUFFER_SIZE = 5000
replay_buffer = ReplayBuffer(MAX_REPLAY_BUFFER_SIZE)


def replay_buffer_job():
    clear_dir(const.REPLAY_PATH)
    step = 0
    total_sequences = 0
    
    while replay_data_needed:
        chunk = get_replay_data(True)
        if not chunk:
            time.sleep(1.0)
            continue
            
        if step == 0:
            start_time = time.time()
            
        total_sequences += len(chunk)
        
        with tf.device('CPU:0'):
            add_start_time = time.time()
            replay_buffer.add(chunk)
            with tf_writer_lock, tf_writer.as_default(), tf.name_scope('replay_buffer'):
                tf.summary.scalar('buffer size', len(replay_buffer), step)
                tf.summary.scalar('add chunk time', time.time() - add_start_time, step)
                tf.summary.scalar('sequences per second', total_sequences / (time.time() - start_time), step)

        step += 1

In [18]:
get_replay_data(False)

kuddai file_path /tmp/lux_ai/replays/1bdae962-37be-45ec-ac15-1915dffbad0f
kuddai file_path /tmp/lux_ai/replays/4c3ba141-e33c-4e0c-b999-af63c04846e4
kuddai file_path /tmp/lux_ai/replays/d7db9abd-dce7-4990-9c5b-b2edfcf70b1e
kuddai file_path /tmp/lux_ai/replays/ae28f34e-8f34-445b-b355-888ff463944b
kuddai file_path /tmp/lux_ai/replays/26bb0e13-7818-4024-9c63-ff5c5a058724
kuddai file_path /tmp/lux_ai/replays/e83a5959-02a8-4e30-be2c-bea8ae0ca7b0
kuddai file_path /tmp/lux_ai/replays/16434de8-606f-4b42-b502-ecbbb4d87c42
kuddai file_path /tmp/lux_ai/replays/5344bf49-bdc1-4ef5-889f-c633eb9469e0
kuddai file_path /tmp/lux_ai/replays/94409386-d497-46d0-9595-d788e792bbc5
kuddai file_path /tmp/lux_ai/replays/5200b7c9-79ec-4828-8eff-0b3f8925cc8f
kuddai file_path /tmp/lux_ai/replays/9f7cbc30-c1f3-4e00-97e6-066dea51ee6d
kuddai file_path /tmp/lux_ai/replays/71abc974-2dbd-4f54-b899-068173445da3
kuddai file_path /tmp/lux_ai/replays/33d2e72a-3473-43d0-b25f-a8bb42e55c32
kuddai file_path /tmp/lux_ai/replays/2

KeyboardInterrupt: 

In [20]:
def games_data_job():
    clear_dir(const.GAMES_DATA_PATH)
    step = collections.defaultdict(int)
    
    while games_data_needed:
        for file_name in os.listdir(const.GAMES_DATA_PATH):
            file_path = os.path.join(const.GAMES_DATA_PATH, file_name)
            with open(file_path) as f:
                game_data = json.load(f)
                agent_id = game_data['agent_id']
                
                for k, v in game_data.items():
                    if k == 'agent_id':
                        continue
                    
                    if v is None:
                        continue
                        
                    with tf_writer_lock, tf_writer.as_default():
                        with tf.name_scope('game_data/total'):
                            tf.summary.scalar(k, v, step['total'])
                            
                        with tf.name_scope('game_data/%s' % agent_id):
                            tf.summary.scalar(k, v, step[agent_id])
                    
                    
                step['total'] += 1
                step[agent_id] += 1

            os.unlink(file_path)

        time.sleep(1)

In [21]:
def soft_update(source_variables, target_variables, tau):
    for v_s, v_t in zip(source_variables, target_variables):
        v_t.assign((1 - tau) * v_t + tau * v_s)



def save_model(name, epoch, model, link_path=None):
    if not os.path.exists(const.MODELS_PATH):
        os.makedirs(const.MODELS_PATH)
        
    file_path = os.path.join(const.MODELS_PATH, '%s_%s_%s' % (name, current_time, epoch))
    model.save_weights(file_path)
    
    if link_path is not None:
        tmp_link_path = '%s_tmp' % link_path
        os.symlink(file_path, tmp_link_path)
        os.rename(tmp_link_path, link_path)


def clear_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)
    
    for file_name in os.listdir(path):
        file_path = os.path.join(path, file_name)
        os.unlink(file_path)

In [22]:
tf.random.set_seed(42)

In [23]:
BATCH_SIZE = 64
input_shape = [(BATCH_SIZE,) + tuple(i.shape) for i in utils.STATE_SPEC]

network = model.Model()
network.build(input_shape)
network.summary()
#latest_path = tf.train.latest_checkpoint(const.MODELS_PATH)
#print(latest_path)
#network.load_weights(latest_path)

target_network = model.Model()
target_network.build(input_shape)
target_network.trainable = False
target_network.set_weights(network.get_weights())


optimizer = keras.optimizers.Adam(learning_rate=5e-4, clipnorm=40)
loss_fn = keras.losses.MeanSquaredError()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_16 (Sequential)   (64, 256)                 215176    
_________________________________________________________________
sequential_17 (Sequential)   (64, 64)                  11584     
_________________________________________________________________
sequential_18 (Sequential)   (2048, 256)               777456    
_________________________________________________________________
sequential_19 (Sequential)   (64, 32, 32)              1312      
_________________________________________________________________
sequential_20 (Sequential)   (64, 32, 256)             205568    
_________________________________________________________________
sequential_21 (Sequential)   (64, 32, 512)             558080    
_________________________________________________________________
sequential_22 (Sequential)   (64, 32, 7)               1331

In [24]:
@tf.function
def learn_from_batch(states, actions, rewards, is_not_done, *, gamma, weights, online_network, target_network):
    #print('Tracing lfb')
    # states is tuple of 5 state_element, state_element has (B, N, *) shape
    target = n_step_return(
        states, actions, rewards, is_not_done,
        gamma=gamma, n=const.N_STEPS, online_network=online_network, target_network=target_network
    )
        
    first_state = tuple([state[:, 0] for state in states])
    first_action = actions[:, 0]
    first_units_mask = tf.cast(states[4][:, 0], dtype=tf.float32)
    
    with tf.GradientTape() as tape:
        q_values = online_network(first_state, training=True)
        prediction = tf.gather(q_values, first_action, axis=-1, batch_dims=2)
        prediction = tf.reduce_sum(prediction * first_units_mask, axis=-1)
        
        td_error = target - prediction
        #loss = tf.reduce_mean(tf.math.pow(td_error, 2))
        loss = tf.reduce_mean(tf.math.pow(td_error, 2) * weights)
    
    gradients = tape.gradient(loss, online_network.trainable_weights)
    grad_min = tf.reduce_mean([tf.reduce_min(g) for g in gradients])
    grad_max = tf.reduce_max([tf.reduce_max(g) for g in gradients])
    grad_l2 = tf.norm([tf.norm(g) for g in gradients])
    optimizer.apply_gradients(zip(gradients, online_network.trainable_weights))
    
    return loss, tf.math.abs(td_error), grad_min, grad_max, grad_l2

In [25]:
# Setup tensorboard
TB_LOGS_DIR = '/tmp/tb_logs/'

if not os.path.exists(TB_LOGS_DIR):
    os.mkdir(TB_LOGS_DIR)

current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
current_logs_dir = os.path.join(TB_LOGS_DIR, current_time)
tf_writer = tf.summary.create_file_writer(current_logs_dir)
tf_writer_lock = threading.Lock()

steps = 0

save_model('online', 0, network, const.LATEST_MODEL_SYMLINK_PATH)
save_model('target', 0, target_network, const.TARGET_MODEL_SYMLINK_PATH)

In [None]:
def stop_threads():
    if 'replay_buffer_thread' in globals():
        print('Stopping replay thread')
        global replay_data_needed
        replay_data_needed = False
        if replay_buffer_thread.is_alive():
            replay_buffer_thread.join()
    
    if 'games_data_thread' in globals():
        print('Stopping games data thread')
        global games_data_needed
        games_data_needed = False
        if games_data_thread.is_alive():
            games_data_thread.join()


stop_threads()


replay_data_needed = True
replay_buffer_thread = threading.Thread(target=replay_buffer_job)
replay_buffer_thread.start()

games_data_needed = True
games_data_thread = threading.Thread(target=games_data_job)
games_data_thread.start()


EPOCHS = 10000000
TARGET_NETWORK_UPDATE_EPISODES = 1000
MODEL_SAVE_PERIOD = 100
INITIAL_DATA_SIZE = BATCH_SIZE * 10


while len(replay_buffer) < INITIAL_DATA_SIZE:
    print('Waiting for data: %s / %s' % (len(replay_buffer), INITIAL_DATA_SIZE))
    time.sleep(1)

        
for epoch in trange(EPOCHS):
    with tf.device('CPU:0'):
        start_time = time.time()
        ts, ind, (batch_states, batch_actions, batch_rewards, batch_is_not_done), td, weights = replay_buffer.get_prioritized(BATCH_SIZE)
        #(batch_states, batch_actions, batch_rewards, batch_is_not_done), td = replay_buffer.get_uniform(BATCH_SIZE)
        get_time = time.time() - start_time
    
    if steps % TARGET_NETWORK_UPDATE_EPISODES == 0:
        print(epoch, 'Updating network')
        target_network.set_weights(network.get_weights())
        save_model('target', epoch, target_network, const.TARGET_MODEL_SYMLINK_PATH)
        gc.collect()
        tf.keras.backend.clear_session()
        # soft_update(network.variables, target_network.variables, 0.01)
    
    if steps % MODEL_SAVE_PERIOD == 0:
        print(epoch, 'Saving model')
        save_model('online', epoch, network, const.LATEST_MODEL_SYMLINK_PATH)
    
    start_time = time.time()
    #loss, new_td_error, grad_min, grad_max, grad_l2 = learn_from_batch(
    #    batch_states, batch_actions, batch_rewards, batch_is_not_done,
    #    gamma=const.GAMMA, weights=None, online_network=network, target_network=target_network)
    loss, new_td_error, grad_min, grad_max, grad_l2 = learn_from_batch(
        batch_states, batch_actions, batch_rewards, batch_is_not_done,
        gamma=const.GAMMA, weights=weights, online_network=network, target_network=target_network)
    learn_time = time.time() - start_time
    
    start_time = time.time()
    replay_buffer.update(ts, ind, new_td_error)
    update_time = time.time() - start_time

    with tf_writer_lock, tf_writer.as_default():
        tf.summary.scalar('loss', loss, step=steps)
        #tf.summary.scalar('avg_version', avg_model_version, step=steps)
        with tf.name_scope('td_error'):
            tf.summary.scalar('before', tf.reduce_mean(td), step=steps)
            tf.summary.scalar('after', tf.reduce_mean(new_td_error), step=steps)
        with tf.name_scope('timing'):
            tf.summary.scalar('get', get_time, step=steps)
            tf.summary.scalar('update', update_time, step=steps)
            tf.summary.scalar('learn', learn_time, step=steps)
        with tf.name_scope('gradients'):
            tf.summary.scalar('min', grad_min, step=steps)
            tf.summary.scalar('max', grad_max, step=steps)
            tf.summary.scalar('l2 norm', grad_l2, step=steps)
            
    steps += 1

stop_threads()

Stopping replay thread
Stopping games data thread
Waiting for data: 0 / 640
Waiting for data: 0 / 640
Waiting for data: 0 / 640
Waiting for data: 0 / 640
Waiting for data: 0 / 640
kuddai file_path /tmp/lux_ai/replays/ad19815d-2b41-447b-8930-93d69eab604a
kuddai file_path /tmp/lux_ai/replays/0b7c35d4-c157-4619-b7d1-546d149fa9f3
Waiting for data: 0 / 640
Waiting for data: 232 / 640
Waiting for data: 232 / 640
Waiting for data: 232 / 640
kuddai file_path /tmp/lux_ai/replays/3a47b896-7dfe-4440-84ed-91bd03fa66c0
kuddai file_path /tmp/lux_ai/replays/a5a3333a-7892-4d6d-a907-abf85bd081ed
kuddai file_path /tmp/lux_ai/replays/779a4337-10b6-4e54-830e-2343c4cc271b
Waiting for data: 365 / 640
Waiting for data: 365 / 640
Waiting for data: 365 / 640
Waiting for data: 365 / 640
Waiting for data: 365 / 640
kuddai file_path /tmp/lux_ai/replays/da870b12-3f17-4239-9418-f0f0f17f30ea
kuddai file_path /tmp/lux_ai/replays/b4b8ba9e-fcdd-43d1-9383-a2201764d2c9
Waiting for data: 429 / 640
Waiting for data: 429 / 

  0%|          | 0/10000000 [00:00<?, ?it/s]

tracing get prio
0 Updating network
0 Saving model
kuddai file_path /tmp/lux_ai/replays/6ae46df0-31ec-4a02-8a23-f4ae50caa420
kuddai file_path /tmp/lux_ai/replays/5594e6e9-8076-4f5e-9996-f0d235b01789
tracing update
tracing was repl
kuddai file_path /tmp/lux_ai/replays/535808bd-b75d-45ab-a7b7-5281a968bf1d
kuddai file_path /tmp/lux_ai/replays/5536e95d-f087-4e3d-b79f-526ff9d178b3
kuddai file_path /tmp/lux_ai/replays/df6dc7e0-f611-4cdc-87d3-c1ed2282a853
kuddai file_path /tmp/lux_ai/replays/fdca10bc-d9bb-4c47-b7d8-11fe6eab11b0
kuddai file_path /tmp/lux_ai/replays/b160a743-b6ef-4d52-b653-b0676ccc9ac1
kuddai file_path /tmp/lux_ai/replays/9f713515-07c1-46c7-a14e-4a2a41ca3846
kuddai file_path /tmp/lux_ai/replays/2555156c-9482-46a5-b880-b2f8c465578a
kuddai file_path /tmp/lux_ai/replays/ab5d1710-0333-4298-b116-2c1850c9a551
kuddai file_path /tmp/lux_ai/replays/e14f6636-891d-4fa5-96ec-5d2821843ec9
100 Saving model
kuddai file_path /tmp/lux_ai/replays/05dd73d8-84cf-4f1c-8153-66cffd69e73d
kuddai file_

2700 Saving model
kuddai file_path /tmp/lux_ai/replays/415a2427-b225-4fe3-b4d8-729c87fceb82
kuddai file_path /tmp/lux_ai/replays/0d859f4b-7c8a-48cf-b8f7-5bcfd757eb1b
kuddai file_path /tmp/lux_ai/replays/f63a32b8-9ade-4dd4-a304-2df0dea16c44
kuddai file_path /tmp/lux_ai/replays/c5f2d28f-e457-4b38-8a5b-0b10064e4c2b
kuddai file_path /tmp/lux_ai/replays/2dbb6c71-961d-46b9-95b2-de16f4523e87
kuddai file_path /tmp/lux_ai/replays/2af24e25-2deb-4430-b4c1-fe2ac89be170
2800 Saving model
kuddai file_path /tmp/lux_ai/replays/d104f87b-a86d-4631-bfae-41c6a487d8e7
kuddai file_path /tmp/lux_ai/replays/23729093-8bc6-459e-b0d1-175c2aa8ee32
kuddai file_path /tmp/lux_ai/replays/67309f95-dabf-43d5-a37d-b877a85e3438
2900 Saving model
kuddai file_path /tmp/lux_ai/replays/a9672c18-184d-4c63-9667-3697047081f5
3000 Updating network
3000 Saving model
kuddai file_path /tmp/lux_ai/replays/3c60b0c3-46dd-4d7d-a99d-c3fe0ebfdebf
kuddai file_path /tmp/lux_ai/replays/d0559642-3ae7-4b4d-9528-1f6a62eb50f7
kuddai file_path /

kuddai file_path /tmp/lux_ai/replays/8ce24a84-99be-4833-8291-56f41b803703
kuddai file_path /tmp/lux_ai/replays/9c80c24c-499d-4eb3-8b08-6c1966d53225
kuddai file_path /tmp/lux_ai/replays/91028aee-69f3-406c-adac-0d885e787f61
kuddai file_path /tmp/lux_ai/replays/57a5af2f-76d6-4ae4-9f63-27c4bf3d6592
6800 Saving model
kuddai file_path /tmp/lux_ai/replays/91ca8463-e5b1-4796-91c9-63edb03d8278
kuddai file_path /tmp/lux_ai/replays/997deef2-b235-47d1-b071-a22d0f23e2da
6900 Saving model
kuddai file_path /tmp/lux_ai/replays/5dfdc75c-3a03-48a3-9a6a-4a54970ce994
kuddai file_path /tmp/lux_ai/replays/256bea1b-9f29-49a9-a247-42ed475edfd5
kuddai file_path /tmp/lux_ai/replays/c302505b-c283-4a65-8f6e-d1890437da1f
kuddai file_path /tmp/lux_ai/replays/d03af602-8acb-47c9-8fb2-ef0742983889
kuddai file_path /tmp/lux_ai/replays/9ba7752d-9834-4f28-b316-b51aeac783da
kuddai file_path /tmp/lux_ai/replays/3b851545-e5e8-4ad6-bb27-cd1a3c6af703
kuddai file_path /tmp/lux_ai/replays/1ed9ccb0-094d-4ebe-b36b-5db955a97518
70

kuddai file_path /tmp/lux_ai/replays/ee92ae13-4352-4e5a-b8e7-89bf8d27c3f4
9800 Saving model
kuddai file_path /tmp/lux_ai/replays/30871e41-0ad7-46d4-8037-628298259e3a
kuddai file_path /tmp/lux_ai/replays/e31aed23-1d42-4670-b87e-3d8a9f5b6559
9900 Saving model
kuddai file_path /tmp/lux_ai/replays/ba6d5707-ba8f-4f39-b010-f10cab756ad5
kuddai file_path /tmp/lux_ai/replays/b9e04017-8bd5-4b89-a26d-1ff675af7d81
kuddai file_path /tmp/lux_ai/replays/c9a24cd0-eeb2-41e7-b5d2-9fa6a9cd6ef9
10000 Updating network
10000 Saving model
10100 Saving model
kuddai file_path /tmp/lux_ai/replays/ddffaa5f-409f-40cc-ab3b-3622e8c39f9d
kuddai file_path /tmp/lux_ai/replays/eee8d39d-6f5a-4d33-a84e-bd81b0e6d728
kuddai file_path /tmp/lux_ai/replays/2014bf2b-f22f-4e20-a657-49081a1417af
kuddai file_path /tmp/lux_ai/replays/4f9d00e0-8a79-4b87-b8b3-ea7767cb3f1b
kuddai file_path /tmp/lux_ai/replays/b9d38eb8-355a-4101-86c5-fe137d4cae4f
10200 Saving model
10300 Saving model
kuddai file_path /tmp/lux_ai/replays/cef25552-60ad-

kuddai file_path /tmp/lux_ai/replays/a43f6b69-a6d3-4fc6-9d36-241c278c8394
kuddai file_path /tmp/lux_ai/replays/7a2de6e4-e909-4bd0-957e-df20c0d8ece9
13500 Saving model
kuddai file_path /tmp/lux_ai/replays/bfde5cec-0841-426e-af4b-4c06ead3f42b
kuddai file_path /tmp/lux_ai/replays/8f98be31-ab69-4721-b998-c0cd755d8cb1
kuddai file_path /tmp/lux_ai/replays/3bbaf79a-8f08-4305-9cbb-80ecd2b83f60
13600 Saving model
kuddai file_path /tmp/lux_ai/replays/5326cce5-c388-4b0d-b760-2a12cb242dc8
kuddai file_path /tmp/lux_ai/replays/7463c984-7d03-4c7e-a1a2-b7e3ff2fe39e
13700 Saving model
kuddai file_path /tmp/lux_ai/replays/0948e5eb-8446-440e-b057-d06f80877c0b
13800 Saving model
kuddai file_path /tmp/lux_ai/replays/7c37a4d1-c64c-4640-bc15-99244d60d731
kuddai file_path /tmp/lux_ai/replays/fe8ba4e4-7128-4eb7-b38e-1ae3954bb08c
13900 Saving model
kuddai file_path /tmp/lux_ai/replays/b6729fce-976b-46ad-b3ca-8d4a1fd0d24e
14000 Updating network
14000 Saving model
kuddai file_path /tmp/lux_ai/replays/42b80027-c26

In [None]:
stop_threads()