In [3]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import math
from collections import deque 
import random

In [4]:
env = gym.make('CartPole-v0')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [5]:
class DQN(nn.Module):
    def __init__(self, obs_space, act_space):
        super(DQN, self).__init__()
        self.sequential = nn.Sequential(
            nn.Linear(obs_space,64),
            nn.ReLU(),
            nn.Linear(64,128),
            nn.ReLU(),
            nn.Linear(128,act_space)
        )
        
    def forward(self, x):
        x = self.sequential(x)
        return x

In [6]:
def normalize_rewards(rewards):
    return ((rewards-torch.mean(rewards))/torch.std(rewards)).squeeze()

In [72]:
class DQNTrainer():
    def __init__(self, dqnNet, targetNet ,gamma=0.99, lr=1e-3):
        self.gamma = gamma
        self.dqnNet = dqnNet
        self.targetNet = targetNet
        self.loss = nn.MSELoss()
        self.optimizer = torch.optim.Adam(dqnNet.parameters(), lr=lr)
        
    def train_model(self, obs_tensor, act_tensor, rew_tensor, next_obs_tensor, dones):
        #rescale rewards
        
        
        #run model
        pred = self.dqnNet(obs_tensor)
        target = pred.clone()
        
        for idx in range(len(dones)):
            Q_new = rew_tensor[idx]
            if not dones[idx]:
                Q_new = rew_tensor[idx]  + self.gamma*torch.max(self.targetNet(next_obs_tensor[idx])) #BELLMAN'S EQUATION
                
            target[idx][act_tensor[idx].item()] = Q_new
        
        
        loss = self.loss(pred, target)
        
        
        self.optimizer.zero_grad()
        
        loss.backward()
        
        self.optimizer.step()
        return loss.item()

In [78]:
BATCH_SIZE = 100

class CartPoleAgent():
    def __init__(self, obs_space, act_space, device = device, gamma=1e-3, lr=1e-3):
        self.gamma = gamma
        self.device = device
        self.dqnModel = DQN(obs_space, act_space).to(device)
        self.targetModel = DQN(obs_space, act_space).to(device).eval()
        self.trainer = DQNTrainer(self.dqnModel, self.targetModel, gamma,lr)
        
        
        #Memory
        self.memory = deque()
        self.permanent = deque(maxlen=2000)
        
    def getAction(self, observation, episode):
        if episode < 50:
            action = torch.randint(0,2,(1,)).item()
        else:
            obs_tensor = torch.tensor(observation, dtype=torch.float64).view(1,-1).to(device)
            #Run Model
            with torch.no_grad():
                pred = self.dqnModel(obs_tensor.float())

            action = torch.argmax(pred).item()
        return action
        
        
    def saveToMemory(self, obs, act, rew, next_obs, done):
        self.memory.append((obs, act, rew, next_obs, done))
        self.permanent.appendleft((obs, act, rew, next_obs, done))
        
    def clearMemory(self):
        self.memory = deque()
    
    def train_step(self, obs, act, rew, next_obs, done):
        obs_tensor = torch.tensor(obs, device = self.device).view(1,-1)
        act_tensor = torch.tensor(act, device = self.device).view(1,-1)
        rew_tensor = torch.tensor(rew, device = self.device).view(1,-1)
        next_obs_tensor = torch.tensor(next_obs, device = self.device).view(1,-1)
        done_tensor = torch.tensor(done, device = self.device).view(1,-1)
        return self.trainer.train_model(obs_tensor, act_tensor, rew_tensor, next_obs_tensor, done_tensor)
    
    def train_episode(self):
        if len(self.permanent) > BATCH_SIZE:
            mini_sample = random.sample(self.permanent, BATCH_SIZE) #List of tuples
        else:
            mini_sample = self.permanent
        
        obs, act, rew, next_obs, dones = zip(*mini_sample)
        
        obs_tensor = torch.tensor(obs, device = self.device)
        act_tensor = torch.tensor(act, device = self.device)
        rew_tensor = torch.tensor(rew, device = self.device)
        next_obs_tensor = torch.tensor(next_obs, device = self.device)
        dones_tensor = torch.tensor(dones, device = self.device)
        return self.trainer.train_model(obs_tensor, act_tensor, rew_tensor, next_obs_tensor,dones_tensor)
    
    def getEpisodeRewards(self):
        obs, act, rew, next_obs, done = zip(*self.memory)
        rewards = np.array(rew)
        return np.sum(rewards)
    
    def updateTargetNetwork(self):
        self.targetModel.load_state_dict(self.dqnModel.state_dict())
        
    

In [79]:
#https://www.youtube.com/watch?v=Ql8QPcp8818
import matplotlib.pyplot as plt
from IPython import display

plt.ion()
plt.style.use('seaborn')

def plot(rewards, loss, fig, axs):
    display.clear_output(wait=True)
    axs[0].clear()
    axs[1].clear()

    
    axs[0].plot(rewards, 'tab:orange')
    axs[0].title.set_text('Episode Rewards')

    axs[1].plot(loss, 'tab:blue')
    axs[1].title.set_text('Loss')
    
    plt.xlabel('Episode')

    axs[0].text(len(rewards)-1, rewards[-1], str(rewards[-1]))
    axs[1].text(len(loss)-1, loss[-1], str(loss[-1]))

    plt.show(block=False)
    plt.pause(.1)
    


In [80]:
obs_space = env.observation_space.shape[0]
act_space = env.action_space.n

In [81]:
env.reset()

array([-0.03804986, -0.04472082, -0.04827625,  0.04838444], dtype=float32)

In [None]:
%matplotlib qt
fig, axs = plt.subplots(2)

maxIter = 200
cart_pole_agent = CartPoleAgent(obs_space, act_space)

rewards = []
losses = []
e=0
updateTargetNet = 40


while True:
    cart_pole_agent.clearMemory()
    observation = env.reset()
    for i in range(maxIter):
        #env.render()
        action = cart_pole_agent.getAction(observation, e)
        #print(action)
        next_observation, reward, done, info = env.step(action)
        cart_pole_agent.saveToMemory(observation, action, reward,next_observation, done)
        #print(reward)
        if done:
            print(f"Episode finished at time step {i}")
            break;
        lossStep = cart_pole_agent.train_episode()
        #print(loss)
        observation = next_observation
    #print(cart_pole_agent.train_episode())
    #lossEpisode = cart_pole_agent.train_episode()

    rewards.append(cart_pole_agent.getEpisodeRewards())
    losses.append(lossEpisode)
    plot(rewards, losses, fig, axs)
    
    if e%updateTargetNet==0:
        cart_pole_agent.updateTargetNetwork()
        
    e+=1

Episode finished at time step 29


In [58]:
d = deque()

In [59]:
d.append(2)
d

deque([2])

In [60]:
d.append(1)
d

deque([2, 1])

In [61]:
d.append(3)
d

deque([2, 1, 3])