In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math
import gym
import random
import numpy as np
import matplotlib.pyplot as plt
from env import MPSPEnv

env = MPSPEnv(3, 3, 3)
observation_space = env.observation_space[0].shape[0]**2 + \
    env.observation_space[1].shape[0]**2

# env = gym.make('CartPole-v1')
# observation_space = env.observation_space.shape[0]

action_space = env.action_space.n

EPISODES = 100
LEARNING_RATE = 0.0001
MEM_SIZE = 10000
BATCH_SIZE = 64
GAMMA = 0.95
EXPLORATION_MAX = 1.0
EXPLORATION_DECAY = 0.999
EXPLORATION_MIN = 0.001

FC1_DIMS = 1024
FC2_DIMS = 512
DEVICE = torch.device("cpu")

best_reward = 0
average_reward = 0
episode_number = []
average_reward_number = []


class Network(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.input_shape = observation_space
        self.action_space = action_space

        self.fc1 = nn.Linear(self.input_shape, FC1_DIMS)
        self.fc2 = nn.Linear(FC1_DIMS, FC2_DIMS)
        self.fc3 = nn.Linear(FC2_DIMS, self.action_space)

        self.optimizer = optim.Adam(self.parameters(), lr=LEARNING_RATE)
        self.loss = nn.MSELoss()
        self.to(DEVICE)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x


class ReplayBuffer:
    def __init__(self):
        self.mem_count = 0

        self.states = np.zeros((MEM_SIZE, observation_space), dtype=np.int64)
        self.actions = np.zeros(MEM_SIZE, dtype=np.int64)
        self.rewards = np.zeros(MEM_SIZE, dtype=np.int64)
        self.states_ = np.zeros((MEM_SIZE, observation_space), dtype=np.int64)
        self.dones = np.zeros(MEM_SIZE, dtype=bool)

    def add(self, state, action, reward, state_, done):
        mem_index = self.mem_count % MEM_SIZE

        self.states[mem_index] = state
        self.actions[mem_index] = action
        self.rewards[mem_index] = reward
        self.states_[mem_index] = state_
        self.dones[mem_index] = 1 - done

        self.mem_count += 1

    def sample(self):
        MEM_MAX = min(self.mem_count, MEM_SIZE)
        batch_indices = np.random.choice(MEM_MAX, BATCH_SIZE, replace=True)

        states = self.states[batch_indices]
        actions = self.actions[batch_indices]
        rewards = self.rewards[batch_indices]
        states_ = self.states_[batch_indices]
        dones = self.dones[batch_indices]

        return states, actions, rewards, states_, dones


class DQN_Solver:
    def __init__(self):
        self.memory = ReplayBuffer()
        self.exploration_rate = EXPLORATION_MAX
        self.network = Network()

    def choose_action(self, observation, mask):
        if random.random() < self.exploration_rate:
            return env.action_space.sample(mask)

        state = torch.tensor(observation).float().detach()
        state = state.to(DEVICE)
        state = state.unsqueeze(0)
        q_values = self.network(state).detach()
        q_max = q_values.abs().max()
        masked_argmax = (
            q_values - 2 * q_max * (1 - mask)
        ).argmax()
        return masked_argmax.item()

    def learn(self):
        if self.memory.mem_count < BATCH_SIZE:
            return

        states, actions, rewards, states_, dones = self.memory.sample()
        states = torch.tensor(states, dtype=torch.float32).to(DEVICE)
        actions = torch.tensor(actions, dtype=torch.long).to(DEVICE)
        rewards = torch.tensor(rewards, dtype=torch.float32).to(DEVICE)
        states_ = torch.tensor(states_, dtype=torch.float32).to(DEVICE)
        dones = torch.tensor(dones, dtype=torch.bool).to(DEVICE)
        batch_indices = np.arange(BATCH_SIZE, dtype=np.int64)

        q_values = self.network(states)
        next_q_values = self.network(states_)

        predicted_value_of_now = q_values[batch_indices, actions]
        predicted_value_of_future = torch.max(next_q_values, dim=1)[0]

        q_target = rewards + GAMMA * predicted_value_of_future * dones

        loss = self.network.loss(q_target, predicted_value_of_now)
        self.network.optimizer.zero_grad()
        loss.backward()
        self.network.optimizer.step()

        self.exploration_rate *= EXPLORATION_DECAY
        self.exploration_rate = max(EXPLORATION_MIN, self.exploration_rate)

    def returning_epsilon(self):
        return self.exploration_rate


In [None]:
agent = DQN_Solver()

In [3]:
for i in range(1, EPISODES):
    state, info = env.reset()
    state = np.concatenate((state[0].flatten(), state[1].flatten()))
    score = 0

    while True:
        action = agent.choose_action(state, info['mask'])
        state_, reward, done, info = env.step(action)
        state_ = np.concatenate((state_[0].flatten(), state_[1].flatten()))
        agent.memory.add(state, action, reward, state_, done)
        agent.learn()
        state = state_
        score += reward

        if done:
            if score > best_reward:
                best_reward = score
            average_reward += score
            print("Episode {} Average Reward {} Best Reward {} Last Reward {} Epsilon {}".format(
                i, average_reward/i, best_reward, score, agent.returning_epsilon()))
            break

        episode_number.append(i)
        average_reward_number.append(average_reward/i)

plt.plot(episode_number, average_reward_number)
plt.show()


Episode 1 Average Reward -432.0 Best Reward 0 Last Reward -13 Epsilon 1.0
Episode 2 Average Reward -236.5 Best Reward 0 Last Reward -41 Epsilon 0.9360999518731578
Episode 3 Average Reward -169.0 Best Reward 0 Last Reward -34 Epsilon 0.862367254825433
Episode 4 Average Reward -131.25 Best Reward 0 Last Reward -18 Epsilon 0.8235779600286178
Episode 5 Average Reward -105.0 Best Reward 0 Last Reward 0 Epsilon 0.818648829478636
Episode 6 Average Reward -88.66666666666667 Best Reward 0 Last Reward -7 Epsilon 0.8024304668606914
Episode 7 Average Reward -79.85714285714286 Best Reward 0 Last Reward -27 Epsilon 0.7534131012276413
Episode 8 Average Reward -70.0 Best Reward 0 Last Reward -1 Epsilon 0.7474068498462175
Episode 9 Average Reward -62.22222222222222 Best Reward 0 Last Reward 0 Epsilon 0.7444217038990517
Episode 10 Average Reward -56.2 Best Reward 0 Last Reward -2 Epsilon 0.7340672721936974
Episode 11 Average Reward -52.90909090909091 Best Reward 0 Last Reward -20 Epsilon 0.6989478686545

KeyboardInterrupt: 