In [1]:
from SnakeBoard import SnakeBoard
from SnakeGame import SnakeGame
from NeuralNetwork import NeuralNetwork
import numpy as np
import time
import matplotlib.pyplot as plt

# ---------- User defined parameters ----------

# Miscellaneous parameters
restore_weights_prev_training = 1
manual_play = 0 # Get user input (keyboard) instead of neural network auto-play
t_between_gen = 0 # Time (secs) between generations
n_gens_2_save_weights = 25 # Num of generations elapsed to save weights in a file

# Training parameters
n_of_gens = 5000 # Number of training generations
n_games_per_gen = 1000 # Number of parallel games per generation
selected_games_per_gen = 20 # Selected baselines per generation to be used as references for mutations

# Mutation parameters
mrate_bias = 0.2
mrate_weights = 0.2
msize_bias = .2
msize_weights = .2

# ---------- Machine Learning main logic ---------- 

# Restore weights from previous training if required
if restore_weights_prev_training == 1:
    tmp = np.load("./training_history.npy", allow_pickle=True)

# Create game + ANN instances
record_score = 0
s_board = SnakeBoard(n_games_per_gen)
s_games, s_ann = [] , []
for i in range(n_games_per_gen):
    s_games.append(SnakeGame(s_board))
    s_ann.append(NeuralNetwork())

    if restore_weights_prev_training == 1:
        s_ann[-1].set_weights_biases(tmp[0][-1].weights, tmp[0][-1].biases)

s_board.init_board()

# Run number of generations
best_score_history = list()
best_ann_weights_history = list()
for idx_gen in range(n_of_gens):

    # Step all games in current generation (until all games are over)
    while True:     
        game_status = list() 
        for idx_game, game in enumerate(s_games):
            
            # Get current game state and decide the next move
            state = game.get_game_state()
            if manual_play == 1:
                next_move = game.get_key()
            else:
                next_move = s_ann[idx_game].calculate(state) 
                if next_move == 0: next_move = "IDLE"
                elif next_move == 1: next_move = "T_LEFT"
                elif next_move == 2: next_move = "T_RIGHT"

            # Step game instance based on ANN calc. next move
            [game_over, score] = game.step_game(next_move)

            # Save game data in game status array
            game_status.append([game_over, score, idx_game])       

        # Update graphics of all games (visual feedback)
        #s_board.clear_board()
        #s_board.update_board_elements(s_games)

        # If all game instances are done, finish current generation
        game_over_list = [g[0] for g in game_status]
        if np.min(game_over_list)==True:
            break
    
    # Get best score + ANN in prev. generation
    game_status.sort(key=lambda x:x[1],reverse =True)
    max_res_gen = game_status[0]
    print("GEN ", idx_gen, " ----- BEST SCORE: ", (max_res_gen[1]), " ----- RECORD: ", record_score)
    best_score_history.append(max_res_gen[1])
    best_ann_weights_history.append(s_ann[max_res_gen[2]])

    # Save weights in an external file
    if np.mod(idx_gen,n_gens_2_save_weights)==0:
        np.save("./training_history.npy",[best_ann_weights_history, best_score_history])
    
    # If best score in curr. generation is a record, show it
    if max_res_gen[1] > record_score:
        record_score = max_res_gen[1]
    
    # Get the best "selected_games_per_gen" games in the current generation
    # and place them in the first positions
    for i in range(selected_games_per_gen):
        s_ann[i] = s_ann[game_status[i][2]].copy()
    
    # Mutate the best ones in the subsequent positions
    for i in range(selected_games_per_gen, n_games_per_gen):
        s_ann[i] = s_ann[np.mod(i,selected_games_per_gen)].copy()
        s_ann[i].mutate(mrate_weights,msize_weights, mrate_bias, msize_bias) # random mutations

    #Reset all games once they're finished
    for idx_game, game in enumerate(s_games):
        game.reset_game()
    
    time.sleep(t_between_gen)

s_board.quit_board()

pygame 2.5.2 (SDL 2.28.3, Python 3.11.6)
Hello from the pygame community. https://www.pygame.org/contribute.html
SnakeBoard instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance created.
SnakeGame instance

In [None]:
# To restore the weights
import numpy as np
import matplotlib.pyplot as plt
tmp = np.load("./training_history_01.npy", allow_pickle=True)
plt.plot(tmp[1])

[array([[-13.11056359,  -0.25268534,  10.63140892,  -2.51101773,
          -4.38729764,   0.49238926,  10.08253139,  -1.8115979 ,
          -5.26079732, -10.22663155,   2.05121869,  -0.9736432 ],
        [  0.62366126,  -1.32187548,  -1.7337265 ,  -1.12277197,
           4.96779296,  16.7468175 ,   1.04716625,  -1.76288842,
          -9.43265309,   3.46919337,   6.95666197,  -1.85914232],
        [ -1.49557817,   0.23098939,  -2.88300797,  -4.83061398,
           2.26725947,  -0.9829557 ,  -0.31291183,  10.97775166,
          -0.84789104,  -3.03564521,  -1.77123536,   2.26063339],
        [ -3.11371044,   7.69652653,  -0.18476883,  -0.06164074,
          -6.33443975,   0.27225258,   1.34272009,  -3.75883318,
          -5.33138687,   1.42486785,  -7.20128375,  -6.02052107],
        [ -5.26165254,   0.48771698,  -0.07182728,   0.78127872,
           1.45349488,  -0.33487072,   1.97595849,  -5.93715708,
          -1.76632394,  -0.85849187,   1.61752565,  -0.01684721]]),
 array([[ 5.345442

In [None]:
a[-1]

4