In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from game import Environment
from dqn import DQN, ReplayMemory
from agent import Agent

pygame 2.6.1 (SDL 2.28.4, Python 3.12.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
BATCH_SIZE = 128
GAMMA = 0.99
TAU = 0.005
LR = 1e-4
EPS_DECAY = 0.995

In [3]:
env = Environment()
state = env.reset()
n_observations = len(state)
n_actions = 2

agent = Agent(n_observations, n_actions, batch_size=BATCH_SIZE, epsilon_decay=EPS_DECAY, gamma=GAMMA, tau=TAU, lr=LR)

In [4]:
num_episodes = 1000

for episode in range(num_episodes):
    state = env.reset()
    done = False
    total_reward = 0
    steps = 0

    while not done:
        action = agent.nextAction(state)
        next_state, reward, done = env.step(action)
        
        agent.storeTransition(state, action, reward, next_state, done)

        state = next_state
        total_reward += reward

        agent.replay()
        agent.updateTargetNetwork()

        steps += 1
        if steps > 1000:
            break

    agent.decayEpsilon()

    if episode % 100 == 0:
        print(f"Episode {episode}, Steps: {steps}, Total Reward: {total_reward}, Epsilon: {agent.epsilon:.2f}")

torch.save(agent.target_net.state_dict(), "model2.pth")
print("Model saved as model2.pth")

Episode 0, Steps: 39, Total Reward: 28, Epsilon: 0.99
Episode 100, Steps: 40, Total Reward: 29, Epsilon: 0.60
Episode 200, Steps: 63, Total Reward: 52, Epsilon: 0.37
Episode 300, Steps: 84, Total Reward: 73, Epsilon: 0.22
Episode 400, Steps: 84, Total Reward: 73, Epsilon: 0.13
Episode 500, Steps: 61, Total Reward: 50, Epsilon: 0.08
Episode 600, Steps: 59, Total Reward: 48, Epsilon: 0.05
Episode 700, Steps: 74, Total Reward: 63, Epsilon: 0.03
Episode 800, Steps: 84, Total Reward: 73, Epsilon: 0.02
Episode 900, Steps: 84, Total Reward: 73, Epsilon: 0.01
Model saved as model2.pth


In [None]:
model = DQN(n_observations, n_actions)
model.load_state_dict(torch.load("model1.pth"))

test_env = Environment(renderGame=True)
state = test_env.reset()

done = False

while not done:
    action = model.action(state)
    next_state, reward, done = test_env.step(action)
    state = next_state

    if not done:
        test_env.render()

test_env.quit()