In [1]:
sc

In [36]:
import chess
from IPython.display import clear_output
import time
from texttable import Texttable


import pyspark.sql.functions as F
from pyspark.sql.types import ArrayType, StringType
import matplotlib.pyplot as plt

%matplotlib notebook

In [3]:
def moves_to_arr(moves_string):
    moves = []
    for i in moves_string.split(" "):
        if "." not in i:
            moves.append(i)
    return moves

moves_to_arr_udf = F.udf(moves_to_arr, ArrayType(StringType()))

In [4]:
def generate_base_df(file="chess_dataframe_parquet"):
    df = spark.read.parquet("chess_dataframe_parquet").repartition(200)
    df = df.where(df.Moves.isNotNull())
    df = df.where(df.Result != "*")
    base_df = df.withColumn("MovesArray", moves_to_arr_udf(df["Moves"])).cache()
    
    return base_df

In [5]:
base_df = generate_base_df()
base_df.select("MovesArray").show()

+--------------------+
|          MovesArray|
+--------------------+
|[e4, e6, d4, d5, ...|
|[e4, e5, d4, exd4...|
|[Nf3, d5, d4, Nf6...|
|[d4, d5, c4, c6, ...|
|[d4, Nf6, Bg5, e6...|
|[e4, c5, Nf3, Nf6...|
|[d4, d5, c4, c6, ...|
|[e4, e6, Qe2, d5,...|
|[d4, Nf6, Nf3, g6...|
|[d4, g6, e4, Bg7,...|
|[d4, Nf6, c4, c5,...|
|[e4, e5, Nf3, Nc6...|
|[d4, d5, Nf3, Nf6...|
|[c4, c6, g3, d5, ...|
|[d4, Nf6, Nf3, d5...|
|[d4, d5, c4, e5, ...|
|[e4, c5, g3, h5, ...|
|[e4, c5, Nf3, e6,...|
|[e4, e6, d4, d5, ...|
|[e4, c5, d4, cxd4...|
+--------------------+
only showing top 20 rows



In [61]:
def parse_results(rows):
    results = {"White":0, "Black":0, "Draw": 0}
    for row in rows:
        key = "Draw"
        if row.Result == "1-0":
            key = "White"
        elif row.Result == "0-1":
            key = "Black"
        results[key] = row.winRate
    return results

def visualize_top_moves(top_moves, turn):
    all_results = []
    for i in range(0, len(top_moves), 3):
        move = top_moves[i].move
        results = parse_results(top_moves[i: i+3])
        total = top_moves[i].total
        all_results.append([turn//2 + 1, move, total, results["White"], results["Draw"], results["Black"]])
    
    table = Texttable()
    table.add_rows([['Turn', 'Move', 'Game', 'White', 'Draw', 'Black']]+all_results)
    print(table.draw())

In [None]:
def game(board = chess.Board()):
    board.reset()
    game_df = generate_base_df()
    i = 0
    while not board.is_game_over():
        current_df = game_df.groupBy(game_df.MovesArray.getItem(i).alias("move"), game_df.Result) \
        .count().filter(F.col("move").isNotNull())
        
        current_totals = current_df.groupBy(current_df.move).agg(F.sum("count").alias("total"))
        current_df = current_df.join(current_totals, on=["move"], how="inner")
        top_moves = current_df.withColumn("winRate", F.col("count")/F.col("total")) \
        .orderBy(["total", "winRate", "Result"], ascending=False).take(36)

        #clear_output(wait=False)
        visualize_top_moves(top_moves, i)
        display(board)
        move = input("Move(SAN): ")
        print("\n\n")
        if move == "q":
            break
        while True:
            try:
                uci_move = board.parse_san(move)
                board.push(uci_move)
                break
            except ValueError:
                print("Illegal Move! Try Again.")
                move = input("Move(SAN): ")
            time.sleep(0.5)
        game_df = game_df.where(game_df.MovesArray.getItem(i) == move)
        i += 1
    #clear_output(wait=True)
    print("\n\n")
    print("Game finished.")
    display(board)

In [None]:
game()