### define tha working model path

In [1]:
model_loading_saving_path = "snake_nn_2024_12_24"

### import all the necessary libraries

In [2]:
import time
import gymnasium as gym
from stable_baselines3 import DQN
from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from IPython.display import clear_output
import torch
from snake_nn import SnakeNN

from snake import Snake

snake = Snake()

### initialize a new DQN agent with the required parameters

In [21]:
policy_kwargs = dict(
    features_extractor_class=SnakeNN,
    features_extractor_kwargs=dict(features_dim=64)
)

model = DQN(
    "MlpPolicy",
    snake.env,
    verbose=1,
    learning_rate=1e-4,
    buffer_size=10000,
    policy_kwargs=policy_kwargs
)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


### OR load the pre-trained model

In [3]:
model = DQN.load(model_loading_saving_path, env = snake.env)

Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


### train the agent for a number of episodes and save the model

In [33]:
epochs = 500
save_for_every_epoch = 5
total_time = 0

for i in range(epochs):
    start_time = time.time()
    model = snake.training(model=model, epochs=10000)
    snake.run(model = model)
    
    if i % save_for_every_epoch == 0:
        model.save(model_loading_saving_path)
    
    end_time = time.time()
    iteration_time = end_time - start_time
    total_time += iteration_time
    
    completed_epochs = i + 1
    avg_time_per_epoch = total_time / completed_epochs
    remaining_epochs = epochs - completed_epochs
    estimated_time_left = avg_time_per_epoch * remaining_epochs
    
    estimated_minutes = int(estimated_time_left // 60)
    estimated_seconds = int(estimated_time_left % 60)
    
    progress_bar = "█" * int((i+1)/epochs*20)
    clear_output(wait=True)
    print(f"[{i+1}/{epochs}] |{progress_bar}{'-'*(20-len(progress_bar))}| -> {estimated_minutes}m {estimated_seconds}s")
    
model.save(model_loading_saving_path)

[10/500] |--------------------| -> 48m 47s


KeyboardInterrupt: 

### run the agent

In [8]:
snake.run(model = model)

. . . . . S . . . . 
. . . . . S . . . . 
. . . . . S . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . C . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 
. . . . . . . . . . 

SCORE: -1.2890997124915389
