`python -m ipykernel install --user --name Connect-Four --display-name "Conda Connect-Four"`

In [47]:
import os
import glob
import time
import numpy as np
import pandas as pd
from game import game

In [106]:
def swap_tokens(x):
    """
    Maps 0 --> 0
         1 --> 2
         2 --> 1
    """
    return ((x * 2) % 3) * np.sign(x)

In [99]:
# Create out_dir if it doesn't already exist
out_dir = "generated_games"
if not os.path.isdir(out_dir):
    os.mkdir(out_dir)
    print(f"out_dir created at: {out_dir}")
else:
    print(f"out_dir already_exists at: {out_dir}")

out_dir already_exists at: generated_games


In [115]:
# View any existing data

if os.path.exists(out_dir):
    csvs = glob.glob(os.path.join(out_dir, "*.csv"))
    if csvs:
        for csv in csvs:
            with open(csv) as f_obj:
                num_lines = sum(1 for line in f_obj)
                mb = round(os.path.getsize(csv) / 1e6, 2)
                print(f"{num_lines} lines found in {csv} ({mb} MB)")
    else:
        print("No csv data found.")
else:
    print("No csv data found.")

246125 lines found in generated_games\random_choice_games.csv (46.07 MB)


In [114]:
# Run games and append to existing data

# Configuration
GAME_BATCH_SIZE = 10000

# Play games
start = time.time()
games = game.play_n_games(GAME_BATCH_SIZE)
accumulator = []

# Create DataFrame
n_rows = games[0]["states"][0].shape[0]
n_cols = games[0]["states"][0].shape[1]
columns = [f"position_{i}" for i in range(n_cols * n_rows)]
for g in games:
    winner_token = g["winner"]
    
    if winner_token == 1:
        records = [state.flatten() for state in g["states"]]
    elif winner_token == 2:
        records = [swap_tokens(state.flatten()) for state in g["states"]]
    states_data = pd.DataFrame.from_records(records, columns=columns)
    
    num_turns = states_data.shape[0]
    reward = np.logspace(-2, 0, num=num_turns)
    states_data["reward"] = reward
    
    accumulator.append(states_data)

data = pd.concat(accumulator)


# Append to an existing file, if such a file exists
fname = os.path.join(out_dir, "random_choice_games.csv")
if os.path.isfile(fname):
    data.to_csv(fname, mode='a', header=False, index=False)
else:
    data.to_csv(fname, index=False)
print(f"Wrote {data.shape[0]} lines to {fname}")
    
end = time.time()
print(f"Data generation took {round(end - start, 3)} seconds")

Wrote 223249 lines to generated_games\random_choice_games.csv
Data generation took 151.644 seconds


In [103]:
# Clear all existing data
TURN_ON = False

if TURN_ON and os.path.exists(out_dir):
    csvs = glob.glob(os.path.join(out_dir, "*.csv"))
    for csv in csvs:
        os.remove(csv)
        print(f"{csv} deleted")
if not TURN_ON:
    print("Data not deleted")