In [1]:
import sys

# Inside your code
print("Debug statement")
sys.stdout.flush()
import logging

# Configure logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

# Inside your code
logging.debug("Debug statement")

import torch
import torch.nn as nn


import numpy as np

from model import Q_network, get_input
from game import GameEnvironment
from collections import deque
from replay_buffer import replay_buffer
import time

model = Q_network(10, 20, 5)
epsilon = 0.1
gridsize = 15 
GAMMA = 0.9

board = GameEnvironment(gridsize, nothing=0, dead=-1, apple=1)
memory = replay_buffer(1000)  
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-5)

Debug statement


2024-07-12 16:02:12,586 - DEBUG - Debug statement


In [2]:
def play_game(num_games):
    global epsilon
    if epsilon >= 0.1:
            epsilon -= 0.0005
    run = True
    games_played = 0
    total_reward = 0    
    episode_games = 0
    len_array = []
    
    while run:
        state = get_input(board.snake, board.apple)
        action_0 = model(state)
        rand = np.random.uniform(0,1)
        if rand > epsilon:
            action = torch.argmax(action_0)
        else:
            action = np.random.randint(0,5)

        ## update_boardstate the same snake till    
        reward, done, len_of_snake =board.update_boardstate(action)
        next_state = get_input(board.snake, board.apple)
        
        memory.push(state, action, reward, next_state, done)
        
        total_reward += reward
        
        episode_games += 1
        
        if board.game_over == True:
            games_played += 1
            len_array.append(len_of_snake)
            board.resetgame()
            
            if num_games == games_played:
                run = False
                
    avg_len_of_snake = np.mean(len_array)
    max_len_of_snake = np.max(len_array)
    return total_reward, avg_len_of_snake, max_len_of_snake

In [3]:
MSE = nn.MSELoss()
def learn(epoches_number, batch_Size):
    total_loss = 0.0
    for _ in range(epoches_number):
        optimizer.zero_grad()
        sample = memory.sample(batch_Size)
        
        # get the states, actoins, next states, and rewards from the sample in the tensor form to use it in the model and calculations
        states, actions, rewards, next_states, dones = sample
        states = torch.cat([x.unsqueeze(0) for x in states], dim=0) 
        actions = torch.LongTensor(actions) 
        rewards = torch.FloatTensor(rewards) 
        next_states = torch.cat([x.unsqueeze(0) for x in next_states]) 
        dones = torch.tensor(dones, dtype=torch.float32)  # Convert bool to float
        
        q_local = model.forward(states)
        q_next_local = model.forward(next_states)

        q_expected  = q_local.gather(1, actions.unsqueeze(0).transpose(0,1)).transpose(0,1).squeeze(0)  
        # actions is a row vector. so i add dimention in the first position using unsqueez(0) function then transpose to use as indices on axis one of the qvalue matrix
        # . so i using theses indices to select q_values of the actions. the matrix now is a column vector so i convert it to row vector and remove this extra dimention.
        q_targets_next  = torch.max(q_next_local, 1)[0]*(torch.ones(dones.size()) - dones)
        # i use [0] because max returns tuple of selected max and indices .here done handle the terminal cases.

        q_targets  = rewards + GAMMA * q_targets_next 
        
        loss = MSE(q_expected, q_targets)
        
        total_loss += loss
        loss.backward()
        optimizer.step()
        
    return total_loss



In [4]:
num_episodes = 30000 
num_updates = 500 
print_every = 20
games_in_episode = 30
batch_size = 20

def train():
    
    scores_deque = deque(maxlen=100)
    scores_array = []
    avg_scores_array = []    
    
    avg_len_array = []
    avg_max_len_array = []
    
    time_start = time.time()
    
    
    for i_episode in range(num_episodes+1):
        
        ## print('i_episode: ', i_episode)
        
        score, avg_len, max_len = play_game(games_in_episode)
        scores_deque.append(score)
        scores_array.append(score)
        avg_len_array.append(avg_len)
        avg_max_len_array.append(max_len)
        
        
        avg_score = np.mean(scores_deque)
        avg_scores_array.append(avg_score)
        
        total_loss = learn(num_updates, batch_size)
        
        dt = (int)(time.time() - time_start)
        
        if i_episode % print_every == 0 and i_episode > 0:
            print('Ep.: {:6}, Loss: {:.3f}, Avg.Score: {:.2f}, Avg.LenOfSnake: {:.2f}, Max.LenOfSnake:  {:.2f} Time: {:02}:{:02}:{:02} '.\
                  format(i_episode, total_loss, score, avg_len, max_len, dt//3600, dt%3600//60, dt%60))
            
        memory.truncate()
        
        if i_episode % 250 == 0 and i_episode > 0:   
            torch.save(model.state_dict(), './check_points_len/Snake_{}'.format(i_episode))
            
    return scores_array, avg_scores_array, avg_len_array, avg_max_len_array    

scores, avg_scores, avg_len_of_snake, max_len_of_snake = train()

Ep.:     20, Loss: 7.753, Avg.Score: -30.00, Avg.LenOfSnake: 3.00, Max.LenOfSnake:  3.00 Time: 00:00:28 
Ep.:     40, Loss: 6.344, Avg.Score: -29.00, Avg.LenOfSnake: 3.03, Max.LenOfSnake:  4.00 Time: 00:00:55 
Ep.:     60, Loss: 4.469, Avg.Score: -29.00, Avg.LenOfSnake: 3.03, Max.LenOfSnake:  4.00 Time: 00:01:23 
Ep.:     80, Loss: 4.766, Avg.Score: -27.00, Avg.LenOfSnake: 3.10, Max.LenOfSnake:  4.00 Time: 00:01:50 
Ep.:    100, Loss: 5.010, Avg.Score: -23.00, Avg.LenOfSnake: 3.23, Max.LenOfSnake:  5.00 Time: 00:02:18 
Ep.:    120, Loss: 3.665, Avg.Score: -29.00, Avg.LenOfSnake: 3.03, Max.LenOfSnake:  4.00 Time: 00:02:48 
Ep.:    140, Loss: 3.670, Avg.Score: -27.00, Avg.LenOfSnake: 3.10, Max.LenOfSnake:  4.00 Time: 00:03:18 
Ep.:    160, Loss: 4.664, Avg.Score: -26.00, Avg.LenOfSnake: 3.13, Max.LenOfSnake:  5.00 Time: 00:03:49 
Ep.:    180, Loss: 5.405, Avg.Score: -23.00, Avg.LenOfSnake: 3.23, Max.LenOfSnake:  4.00 Time: 00:04:21 
Ep.:    200, Loss: 5.253, Avg.Score: -21.00, Avg.LenOfS