In [1]:
import os
import pandas as pd
import numpy as np
import matplotlib as plt
import tensorflow as tf
from tqdm import tqdm
import utils

# Import data 
game_state_20, game_state_40, game_state_60, game_state_80, game_state_100 = utils.get_inputs()

2023-11-22 21:15:39.346967: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-22 21:15:39.385167: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-22 21:15:39.385197: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-22 21:15:39.386561: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-11-22 21:15:39.392977: I external/local_tsl/tsl/cuda/cudart_stub.cc:31] Could not find cuda drivers on your machine, GPU will not be used.
2023-11-22 21:15:39.393305: I tensorflow/core/platform/cpu_feature_guard.cc:1

In [2]:
# Prepare training, validation, and test datasets for 20% game time

train_frac = 0.8
valid_frac = 0.1

train_set = game_state_20.sample(frac = 0.8)
remaining_set = game_state_20.drop(train_set.index)
valid_set = remaining_set.sample(frac = valid_frac / (1 - train_frac))
test_set = remaining_set.drop(valid_set.index)

# Separate wins
train_wins = train_set.pop('blueWin')
valid_wins = valid_set.pop('blueWin')
test_wins = test_set.pop('blueWin')

print(game_state_20.shape, train_set.shape, valid_set.shape, test_set.shape)


(64556, 36) (51645, 35) (6456, 35) (6455, 35)


In [4]:
### Try simple DNN classification

# Convert dataframes to tensors
train_data = tf.convert_to_tensor(train_set)
train_labels = tf.convert_to_tensor(train_wins)

valid_data = tf.convert_to_tensor(valid_set)
valid_labels = tf.convert_to_tensor(valid_wins)

#train_data = train_data[..., tf.newaxis]
#valid_data = valid_data[..., tf.newaxis]
# Define batches

def get_batch(data, labels, batch_size):
    return tf.data.Dataset.from_tensor_slices((data, labels)).batch(batch_size)

train_batched = get_batch(train_data, train_labels, 64)
valid_batched = get_batch(valid_data, valid_labels, 64)

print(tf.shape(train_data), tf.shape(train_labels))


tf.Tensor([51645    35], shape=(2,), dtype=int32) tf.Tensor([51645], shape=(1,), dtype=int32)


In [None]:
# Define model
class DNNModel(tf.keras.Model):

    def __init__(self):
        super().__init__()
        # Single dense layer
        self.layer1 = tf.keras.layers.Dense(35, activation='relu')
        self.layer2 = tf.keras.layers.Dense(1)
        
    def call(self, input):
        input = self.layer1(input)
        return self.layer2(input)

model = DNNModel()

# Choose optimizer and loss function
optimizer = tf.keras.optimizers.Adam()
loss_func = tf.keras.losses.CategoricalCrossentropy()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')

valid_loss = tf.keras.metrics.Mean(name='validation_loss')
valid_accuracy = tf.keras.metrics.CategoricalAccuracy(name='validation_accuracy')


# Define training step
@tf.function
def train_step(match_states, outcomes):
    with tf.GradientTape() as tape:  
        predictions = tf.squeeze(model(match_states))
        loss = loss_func(outcomes, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(outcomes, predictions)

# Define testing step
@tf.function
def test_step(match_states, outcomes):
    predictions = tf.squeeze(model(match_states))
    loss = loss_func(outcomes, predictions)

    valid_loss(loss)
    valid_accuracy(outcomes, predictions)

# Set output
output_dir = './output'
output_prefix = os.path.join(output_dir, "DNN_2l_20")


# Run
EPOCHS = 100

for epoch in tqdm(range(EPOCHS)):
    train_loss.reset_states()
    train_accuracy.reset_states()
    valid_loss.reset_states()
    valid_accuracy.reset_states()

    for match_states, outcomes in train_batched:
        train_step(match_states, outcomes)

    for match_states, outcomes in valid_batched:
        test_step(match_states, outcomes)

    if epoch % 10 == 0:
        model.save_weights(output_prefix)

    print(
    f'Epoch {epoch + 1}, '
    f'Loss: {train_loss.result()}, '
    f'Accuracy: {train_accuracy.result()}, '
    f'Validation Loss: {valid_loss.result()}, '
    f'Validation Accuracy: {valid_accuracy.result()}'
  )

model.save_weights(output_prefix)

  1%|▍                                          | 1/100 [00:02<03:38,  2.21s/it]

Epoch 1, Loss: 135.4866485595703, Accuracy: 0.012391573749482632, Validation Loss: 133.3809051513672, Validation Accuracy: 0.0


  2%|▊                                          | 2/100 [00:02<02:13,  1.37s/it]

Epoch 2, Loss: 135.02342224121094, Accuracy: 0.01610904559493065, Validation Loss: 133.60118103027344, Validation Accuracy: 0.0


  3%|█▎                                         | 3/100 [00:03<01:46,  1.10s/it]

Epoch 3, Loss: 134.68026733398438, Accuracy: 0.01610904559493065, Validation Loss: 132.9066619873047, Validation Accuracy: 0.009900989942252636


  4%|█▋                                         | 4/100 [00:04<01:33,  1.03it/s]

Epoch 4, Loss: 134.41465759277344, Accuracy: 0.014869888313114643, Validation Loss: 132.76626586914062, Validation Accuracy: 0.009900989942252636


  5%|██▏                                        | 5/100 [00:05<01:25,  1.11it/s]

Epoch 5, Loss: 134.217041015625, Accuracy: 0.013630731031298637, Validation Loss: 132.36080932617188, Validation Accuracy: 0.009900989942252636


  6%|██▌                                        | 6/100 [00:06<01:20,  1.16it/s]

Epoch 6, Loss: 134.0821075439453, Accuracy: 0.012391573749482632, Validation Loss: 132.57261657714844, Validation Accuracy: 0.009900989942252636


  7%|███                                        | 7/100 [00:06<01:17,  1.20it/s]

Epoch 7, Loss: 133.9376220703125, Accuracy: 0.012391573749482632, Validation Loss: 132.12342834472656, Validation Accuracy: 0.009900989942252636


  8%|███▍                                       | 8/100 [00:07<01:15,  1.23it/s]

Epoch 8, Loss: 133.88865661621094, Accuracy: 0.012391573749482632, Validation Loss: 132.1063690185547, Validation Accuracy: 0.009900989942252636


  9%|███▊                                       | 9/100 [00:08<01:13,  1.24it/s]

Epoch 9, Loss: 133.94210815429688, Accuracy: 0.011152416467666626, Validation Loss: 132.360107421875, Validation Accuracy: 0.009900989942252636


 10%|████▏                                     | 10/100 [00:09<01:11,  1.26it/s]

Epoch 10, Loss: 134.04368591308594, Accuracy: 0.013630731031298637, Validation Loss: 132.0482940673828, Validation Accuracy: 0.009900989942252636


 11%|████▌                                     | 11/100 [00:09<01:10,  1.26it/s]

Epoch 11, Loss: 133.79043579101562, Accuracy: 0.013630731031298637, Validation Loss: 132.0484161376953, Validation Accuracy: 0.009900989942252636


 12%|█████                                     | 12/100 [00:10<01:09,  1.26it/s]

Epoch 12, Loss: 133.70339965820312, Accuracy: 0.013630731031298637, Validation Loss: 132.07205200195312, Validation Accuracy: 0.009900989942252636


 13%|█████▍                                    | 13/100 [00:11<01:08,  1.27it/s]

Epoch 13, Loss: 133.68862915039062, Accuracy: 0.012391573749482632, Validation Loss: 131.8603515625, Validation Accuracy: 0.009900989942252636


 14%|█████▉                                    | 14/100 [00:12<01:07,  1.27it/s]

Epoch 14, Loss: 133.67669677734375, Accuracy: 0.012391573749482632, Validation Loss: 131.9541473388672, Validation Accuracy: 0.009900989942252636


 15%|██████▎                                   | 15/100 [00:13<01:06,  1.28it/s]

Epoch 15, Loss: 133.60824584960938, Accuracy: 0.00991325918585062, Validation Loss: 131.8675537109375, Validation Accuracy: 0.009900989942252636


 16%|██████▋                                   | 16/100 [00:13<01:05,  1.28it/s]

Epoch 16, Loss: 133.6483917236328, Accuracy: 0.006195786874741316, Validation Loss: 131.9732208251953, Validation Accuracy: 0.009900989942252636


 17%|███████▏                                  | 17/100 [00:14<01:04,  1.28it/s]

Epoch 17, Loss: 133.59808349609375, Accuracy: 0.008674101904034615, Validation Loss: 131.93165588378906, Validation Accuracy: 0.009900989942252636


 18%|███████▌                                  | 18/100 [00:15<01:03,  1.29it/s]

Epoch 18, Loss: 133.6345977783203, Accuracy: 0.0074349441565573215, Validation Loss: 131.82086181640625, Validation Accuracy: 0.009900989942252636


 19%|███████▉                                  | 19/100 [00:16<01:02,  1.29it/s]

Epoch 19, Loss: 133.6363067626953, Accuracy: 0.0074349441565573215, Validation Loss: 131.9536895751953, Validation Accuracy: 0.009900989942252636


 20%|████████▍                                 | 20/100 [00:16<01:01,  1.29it/s]

Epoch 20, Loss: 135.1437530517578, Accuracy: 0.011152416467666626, Validation Loss: 133.34471130371094, Validation Accuracy: 0.009900989942252636


 21%|████████▊                                 | 21/100 [00:17<01:01,  1.29it/s]

Epoch 21, Loss: 134.42642211914062, Accuracy: 0.00991325918585062, Validation Loss: 132.1162872314453, Validation Accuracy: 0.009900989942252636


 22%|█████████▏                                | 22/100 [00:18<01:00,  1.29it/s]

Epoch 22, Loss: 133.66168212890625, Accuracy: 0.013630731031298637, Validation Loss: 131.9220733642578, Validation Accuracy: 0.009900989942252636


 23%|█████████▋                                | 23/100 [00:19<00:59,  1.29it/s]

Epoch 23, Loss: 133.56326293945312, Accuracy: 0.013630731031298637, Validation Loss: 131.82293701171875, Validation Accuracy: 0.009900989942252636


 24%|██████████                                | 24/100 [00:20<00:58,  1.29it/s]

Epoch 24, Loss: 133.52639770507812, Accuracy: 0.00991325918585062, Validation Loss: 131.79876708984375, Validation Accuracy: 0.009900989942252636


 25%|██████████▌                               | 25/100 [00:20<00:58,  1.29it/s]

Epoch 25, Loss: 133.5132598876953, Accuracy: 0.00991325918585062, Validation Loss: 131.76644897460938, Validation Accuracy: 0.009900989942252636


 26%|██████████▉                               | 26/100 [00:21<00:57,  1.29it/s]

Epoch 26, Loss: 133.5107879638672, Accuracy: 0.00991325918585062, Validation Loss: 131.7425994873047, Validation Accuracy: 0.009900989942252636


 27%|███████████▎                              | 27/100 [00:22<00:56,  1.29it/s]

Epoch 27, Loss: 133.5374298095703, Accuracy: 0.00991325918585062, Validation Loss: 131.74136352539062, Validation Accuracy: 0.009900989942252636


 28%|███████████▊                              | 28/100 [00:23<00:56,  1.28it/s]

Epoch 28, Loss: 133.51329040527344, Accuracy: 0.013630731031298637, Validation Loss: 131.78582763671875, Validation Accuracy: 0.009900989942252636


 29%|████████████▏                             | 29/100 [00:24<00:55,  1.27it/s]

Epoch 29, Loss: 133.49925231933594, Accuracy: 0.012391573749482632, Validation Loss: 131.71824645996094, Validation Accuracy: 0.009900989942252636


 30%|████████████▌                             | 30/100 [00:24<00:55,  1.27it/s]

Epoch 30, Loss: 133.49363708496094, Accuracy: 0.011152416467666626, Validation Loss: 131.7110137939453, Validation Accuracy: 0.009900989942252636


 31%|█████████████                             | 31/100 [00:25<00:55,  1.25it/s]

Epoch 31, Loss: 133.4871063232422, Accuracy: 0.012391573749482632, Validation Loss: 131.7118682861328, Validation Accuracy: 0.009900989942252636


 32%|█████████████▍                            | 32/100 [00:26<00:54,  1.24it/s]

Epoch 32, Loss: 133.5148468017578, Accuracy: 0.013630731031298637, Validation Loss: 131.70144653320312, Validation Accuracy: 0.009900989942252636


 33%|█████████████▊                            | 33/100 [00:27<00:53,  1.24it/s]

Epoch 33, Loss: 133.5001220703125, Accuracy: 0.013630731031298637, Validation Loss: 131.69593811035156, Validation Accuracy: 0.009900989942252636


 34%|██████████████▎                           | 34/100 [00:28<00:53,  1.24it/s]

Epoch 34, Loss: 133.4703369140625, Accuracy: 0.013630731031298637, Validation Loss: 131.7056121826172, Validation Accuracy: 0.009900989942252636


 35%|██████████████▋                           | 35/100 [00:28<00:52,  1.25it/s]

Epoch 35, Loss: 133.46229553222656, Accuracy: 0.014869888313114643, Validation Loss: 131.6865692138672, Validation Accuracy: 0.009900989942252636


 36%|███████████████                           | 36/100 [00:29<00:51,  1.24it/s]

Epoch 36, Loss: 133.45742797851562, Accuracy: 0.013630731031298637, Validation Loss: 131.693603515625, Validation Accuracy: 0.009900989942252636


 37%|███████████████▌                          | 37/100 [00:31<01:00,  1.03it/s]

Epoch 37, Loss: 133.46646118164062, Accuracy: 0.014869888313114643, Validation Loss: 131.67564392089844, Validation Accuracy: 0.009900989942252636
