In [6]:
file_path = '/home/zrc3hc/Chess/2. Models/combined_saved_games.csv'

from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator

spark = SparkSession.builder.appName("benchmarkmodel").getOrCreate()

df = spark.read.csv(file_path, header = True, inferSchema = True)



Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/12/01 09:32:32 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/12/01 09:32:33 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [11]:
df.show(2)

+----+-------+---------+------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
|Move|game_id|next_move|result| a1| b1| c1| d1| e1| f1| g1| h1| a2| b2| c2| d2| e2| f2| g2| h2| a3| b3| c3| d3| e3| f3| g3| h3| a4| b4| c4| d4| e4| f4| g4| h4| a5| b5| c5| d5| e5| f5| g5| h5| a6| b6| c6| d6| e6| f6| g6| h6| a7| b7| c7| d7| e7| f7| g7| h7| a8| b8| c8| d8| e8| f8| g8| h8|
+----+-------+---------+------+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+
|  83|  15380|     h7g8|     1|  0|  0|  0|  0|  0|  0|  0|  5| -5|  0|  0|  0|  0|  1|  0| 10|  0|  0|  0|  0|  0|  0|  1|  1|  0|  0| 

In [10]:
df.count()

10018



# **Chess Piece Values**

| **Chess Piece**     | **Value** |
|----------------------|-----------|
| White Rook 1         | `5`       |
| White Rook 2         | `5`       |
| White Knight 1       | `3`       |
| White Knight 2       | `3`       |
| White Bishop 1       | `3`       |
| White Bishop 2       | `3`       |
| White Queen          | `9`       |
| White King           | `10`      |
| White Pawn 1â€“8       | `1`       |

**Note:** Black pieces have the same values as white pieces but are negative. 

**Note:** If the result is labeled 1, that means white won. If black won, result is labeled 0.


In [21]:
## Creating training/validation split

training_data, validation_data = df.randomSplit([0.9, 0.1], seed=42)

#specifying feature space

board_spots = [col for col in df.columns if col not in ['Move', 'game_id', 'next_move', 'result']]

vector_assembler = VectorAssembler(inputCols=board_spots, outputCol="features")
training_data = vector_assembler.transform(training_data)
validation_data = vector_assembler.transform(validation_data)

In [22]:
# Basic Logistic Model

logistic_regression = LogisticRegression(featuresCol="features", labelCol="result", maxIter=1000)
lr_model = logistic_regression.fit(training_data)


In [23]:
# Predictions on Validatin Set

predictions = lr_model.transform(validation_data)
predictions.select("features", "result", "prediction", "probability").show(5)


+--------------------+------+----------+--------------------+
|            features|result|prediction|         probability|
+--------------------+------+----------+--------------------+
|(64,[0,4,5,6,7,9,...|     0|       1.0|[0.40431716519803...|
|(64,[0,3,5,6,7,8,...|     1|       1.0|[0.42697934966022...|
|(64,[0,2,6,7,8,9,...|     0|       0.0|[0.82664399980251...|
|(64,[0,2,5,7,8,9,...|     1|       1.0|[0.47374769140373...|
|(64,[0,3,6,7,8,9,...|     0|       1.0|[0.38246962742278...|
+--------------------+------+----------+--------------------+
only showing top 5 rows



In [24]:
evaluator = BinaryClassificationEvaluator(labelCol="result", rawPredictionCol="prediction", metricName="areaUnderROC")
auc = evaluator.evaluate(predictions)
print(f"Area Under ROC Curve (AUC): {auc}")

Area Under ROC Curve (AUC): 0.7084113980627702
