<a href="https://colab.research.google.com/github/newmantic/DQN/blob/main/DQN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque

In [2]:
class QNetwork(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=64):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, action_size)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

In [3]:
class DQNAgent:
    def __init__(self, state_size, action_size, hidden_size=64, gamma=0.99, learning_rate=0.001, batch_size=64, memory_size=10000, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995):
        self.state_size = state_size
        self.action_size = action_size
        self.hidden_size = hidden_size
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.batch_size = batch_size
        self.memory = deque(maxlen=memory_size)
        self.epsilon = epsilon_start
        self.epsilon_min = epsilon_end
        self.epsilon_decay = epsilon_decay

        # Q-Network
        self.q_network = QNetwork(state_size, action_size, hidden_size)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def choose_action(self, state):
        if np.random.rand() <= self.epsilon:
            return random.choice(range(self.action_size))  # Explore
        else:
            state = torch.FloatTensor(state).unsqueeze(0)
            with torch.no_grad():
                q_values = self.q_network(state)
            return np.argmax(q_values.cpu().numpy())  # Exploit

    def replay(self):
        if len(self.memory) < self.batch_size:
            return

        minibatch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*minibatch)

        states = torch.FloatTensor(states)
        actions = torch.LongTensor(actions).unsqueeze(1)
        rewards = torch.FloatTensor(rewards).unsqueeze(1)
        next_states = torch.FloatTensor(next_states)
        dones = torch.FloatTensor(dones).unsqueeze(1)

        current_q_values = self.q_network(states).gather(1, actions)
        next_q_values = self.q_network(next_states).max(1)[0].unsqueeze(1)
        target_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        loss = nn.MSELoss()(current_q_values, target_q_values)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, path):
        self.q_network.load_state_dict(torch.load(path))

    def save(self, path):
        torch.save(self.q_network.state_dict(), path)

In [4]:
class Simple1DEnv:
    def __init__(self, length=10, start=0, goal=9):
        self.length = length
        self.start = start
        self.goal = goal
        self.state = start

    def reset(self):
        self.state = self.start
        return self.state

    def step(self, action):
        if action == 0:  # move left
            self.state = max(0, self.state - 1)
        elif action == 1:  # move right
            self.state = min(self.length - 1, self.state + 1)

        if self.state == self.goal:
            return self.state, 1, True  # reached goal, reward 1
        else:
            return self.state, -1, False  # not at goal, penalty -1

In [5]:
def train_dqn():
    env = Simple1DEnv()
    state_size = 1  # since the state is just the position in the 1D space
    action_size = 2  # two possible actions: move left or right

    agent = DQNAgent(state_size, action_size, hidden_size=24, gamma=0.99, learning_rate=0.001, batch_size=32, memory_size=2000, epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.995)

    n_episodes = 500
    max_steps = 100

    for e in range(n_episodes):
        state = env.reset()
        state = np.reshape(state, [1])
        for time in range(max_steps):
            action = agent.choose_action(state)
            next_state, reward, done = env.step(action)
            next_state = np.reshape(next_state, [1])
            agent.remember(state, action, reward, next_state, done)
            state = next_state
            if done:
                print(f"Episode: {e+1}/{n_episodes}, score: {time+1}, epsilon: {agent.epsilon:.2}")
                break
            agent.replay()

    return agent

# Train the DQN agent
dqn_agent = train_dqn()

# Save the trained model
dqn_agent.save("dqn_model.pth")

Episode: 1/500, score: 29, epsilon: 1.0


  states = torch.FloatTensor(states)


Episode: 26/500, score: 80, epsilon: 0.01
Episode: 27/500, score: 9, epsilon: 0.01
Episode: 28/500, score: 9, epsilon: 0.01
Episode: 29/500, score: 9, epsilon: 0.01
Episode: 30/500, score: 9, epsilon: 0.01
Episode: 31/500, score: 9, epsilon: 0.01
Episode: 32/500, score: 9, epsilon: 0.01
Episode: 33/500, score: 9, epsilon: 0.01
Episode: 34/500, score: 9, epsilon: 0.01
Episode: 35/500, score: 9, epsilon: 0.01
Episode: 36/500, score: 9, epsilon: 0.01
Episode: 37/500, score: 9, epsilon: 0.01
Episode: 38/500, score: 9, epsilon: 0.01
Episode: 39/500, score: 9, epsilon: 0.01
Episode: 40/500, score: 9, epsilon: 0.01
Episode: 41/500, score: 9, epsilon: 0.01
Episode: 42/500, score: 9, epsilon: 0.01
Episode: 43/500, score: 9, epsilon: 0.01
Episode: 44/500, score: 9, epsilon: 0.01
Episode: 45/500, score: 9, epsilon: 0.01
Episode: 46/500, score: 9, epsilon: 0.01
Episode: 47/500, score: 9, epsilon: 0.01
Episode: 48/500, score: 9, epsilon: 0.01
Episode: 49/500, score: 9, epsilon: 0.01
Episode: 50/500