In [64]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import random
import collections

In [17]:
Observation = collections.namedtuple('Observation', ('state', 'action', 'reward', 'next_state', 'done'))

In [4]:
env = gym.make('CartPole-v1')
env.reset()

(array([ 0.01155154, -0.01353062,  0.01343306,  0.02006867], dtype=float32),
 {})

In [66]:
env = gym.make("CartPole-v1")
observation_space = env.observation_space.shape[0]
action_space = env.action_space.n
dqn_solver = DQNSolver(observation_space, action_space)
epsilon = 0.5
while True:
    state = env.reset()
    state = torch.tensor(np.reshape(state[0], [1, observation_space]))
    while True:
        env.render()
        action = dqn_solver.choose_action(state,epsilon)
        state_next, reward, terminal, truncated, info = env.step(action)
        reward = reward if not terminal else 0
        state_next = torch.tensor(np.reshape(state_next, [1, observation_space]))
        dqn_solver.remember(Observation(state, action, reward, state_next, terminal))
        dqn_solver.replay()
        state = state_next
        if terminal:
            break

NameError: name 'replay_memory' is not defined

In [21]:
class DQN(nn.Module):
    def __init__(self, n_inputs, n_outputs):
        super().__init__()
        self.fc1 = nn.Linear(n_inputs, 24)
        self.fc2 = nn.Linear(24, 48)
        self.fc3 = nn.Linear(48, n_outputs)
    def forward(self, x):        
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

In [65]:
class DQNSolver:
    GAMMA = 0.99
    BATCH_SIZE = 64
    REPLAY_MEMORY = 10000
    MIN_REPLAY_MEMORY = 1000
    def __init__(self, n_inputs, n_outputs):
        self.dqn = DQN(n_inputs, n_outputs)
        self.criterion  = torch.nn.MSELoss()
        self.num_actions = n_outputs
        self.opt = torch.optim.Adam(self.dqn.parameters(), lr = 0.01)
        self.replay_memory = collections.deque([], maxlen = 100000)
    
    def choose_action(self, state, epsilon):
        if (np.random.random() <= epsilon):
            return np.random.randint(self.num_actions)
        else:
            with torch.no_grad():
                return torch.argmax(self.dqn(state)).numpy()
    
    def remember(self, observation):
        #add past actions to deque memory
        self.replay_memory.append(observation)
        
    def replay(self):
        #take a random minibatch from memory - largeset minibatch size is given by batch_size
        #for each minibatch, y is the dqn action from state
        
        #reward is +1 for surviving another tick, +0 for not living
        if len(self.replay_memory) < self.MIN_REPLAY_MEMORY:
            return 0
        y_batch, y_target_batch = [], []
        minibatch = random.sample(self.replay_memory, self.BATCH_SIZE)
        for state, action, reward, next_state, done in minibatch:
            print(state, action, reward, next_state, done)
            y = self.dqn(state)
            print(y.shape)
            y_target = y.clone().detach()
            with torch.no_grad():
                y_target[0][action] = reward if done else reward + self.gamma * torch.max(self.dqn(next_state)[0])
            y_batch.append(y[0])
            y_target_batch.append(y_target[0])
        
        y_batch = torch.cat(y_batch)
        y_target_batch = torch.cat(y_target_batch)
        
        self.opt.zero_grad()
        loss = self.criterion(y_batch, y_target_batch)
        loss.backward()
        self.opt.step()        
        
        return loss.item()
    
        

In [9]:
class DQNCartPoleSolver:
    def __init__(self, n_episodes=2000, n_win_ticks=195, max_env_steps=None, gamma=1.0, epsilon=1.0, epsilon_min=0.01, alpha=0.01, alpha_decay=0.01, batch_size=64, monitor=False, quiet=False):
        self.memory = deque(maxlen=100000)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.alpha = alpha
        self.alpha_decay = alpha_decay
        self.n_episodes = n_episodes
        self.batch_size = batch_size
        if max_env_steps is not None: 
            self.env._max_episode_steps = max_env_steps
            
        # Init model
        self.dqn = DQN()
        self.criterion = torch.nn.MSELoss()
        self.opt = torch.optim.Adam(self.dqn.parameters(), lr=0.01)
    
    def choose_action(self, state, epsilon):
        #if random drawn number is less than epsilon, act randomly
        #else, choose greedy from dqn output
        if (np.random.random() <= epsilon):
            return self.env.action_space.sample() 
        else:
            with torch.no_grad():
                return torch.argmax(self.dqn(state)).numpy()
    
    def remember(self, state, action, reward, next_state, done):
        #add past actions to deque memory
        reward = torch.tensor(reward)
        self.memory.append((state, action, reward, next_state, done))
    
    def replay(self, batch_size):
        #take a random minibatch from memory - largeset minibatch size is given by batch_size
        #for each minibatch, y is the dqn action from state
        
        #reward is +1 for surviving another tick, +0 for not living
        y_batch, y_target_batch = [], []
        minibatch = random.sample(self.memory, min(len(self.memory), batch_size))
        for state, action, reward, next_state, done in minibatch:
            y = self.dqn(state)
            print(y.shape)
            y_target = y.clone().detach()
            with torch.no_grad():
                y_target[0][action] = reward if done else reward + self.gamma * torch.max(self.dqn(next_state)[0])
            y_batch.append(y[0])
            y_target_batch.append(y_target[0])
        
        y_batch = torch.cat(y_batch)
        y_target_batch = torch.cat(y_target_batch)
        
        self.opt.zero_grad()
        loss = self.criterion(y_batch, y_target_batch)
        loss.backward()
        self.opt.step()        
        
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
            
    def run(self):
        scores = deque(maxlen=100)
        
        for e in range(self.n_episodes):
            state = self.preprocess_state(self.env.reset())
            done = False
            i = 0
            while not done:
                if e % 100 == 0 and not self.quiet:
                    self.env.render()
                action = self.choose_action(state, self.get_epsilon(e))
                next_state, reward, done, _ = self.env.step(action)
                next_state = self.preprocess_state(next_state)
                self.remember(state, action, reward, next_state, done)
                state = next_state
                i += 1
            
            scores.append(i)
            mean_score = np.mean(scores)
            if mean_score >= self.n_win_ticks and e >= 100:
                if not self.quiet: print('Ran {} episodes. Solved after {} trials ✔'.format(e, e - 100))
                return e - 100
            if e % 100 == 0 and not self.quiet:
                print('[Episode {}] - Mean survival time over last 100 episodes was {} ticks.'.format(e, mean_score))
            self.replay(self.batch_size)
        
        if not self.quiet: print('Did not solve after {} episodes'.format(e))
        return e