In [1]:
from torch import nn
import torch
import gym
from collections import deque
import itertools
import numpy as np
import random
import matplotlib.pyplot as plt

In [2]:

GAMMA=0.99
BATCH_SIZE=32
BUFFER_SIZE=50000
MIN_REPLAY_SIZE=1800
EPSILON_START=1.0
EPSILON_END=0.02
EPSILON_DECAY=10000
TARGET_UPDATE_FREQ=1000


In [3]:
class Network(nn.Module):
    def __init__(self,env):
        super().__init__()
        in_features=int(np.prod(env.observation_space.shape))
        self.net=nn.Sequential(
            nn.Linear(in_features,64),
            nn.Tanh(),
            nn.Linear(64,env.action_space.n)
        )
    def forward(self,x):
        return self.net(x)
    def act(self,obs):
        obs_t=torch.as_tensor(obs,dtype=torch.float32).to(device)
        q_values=self(obs_t.unsqueeze(0))
        max_q_index=torch.argmax(q_values,dim=1)[0]
        action=max_q_index.detach().item()
        return action
        

In [4]:
env=gym.make('CartPole-v0')
replay_buffer=deque(maxlen=BUFFER_SIZE)
rew_buffer=[]
epsilon_degradation=[]
episode_reward=0.0

  logger.warn(


In [5]:
device='cuda'
online_net=Network(env).to(device)
target_net=Network(env).to(device)
target_net.load_state_dict(online_net.state_dict())

<All keys matched successfully>

In [6]:
#Playing random games to fill the reward buffer
obs=env.reset()[0]
for _ in range(MIN_REPLAY_SIZE):
    action=env.action_space.sample()
    new_obs,rew,done,info,_=env.step(action)
    transistion=(obs,action,rew,done,new_obs)
    replay_buffer.append(transistion)
    obs=new_obs
    if done:
        obs=env.reset()[0]

In [7]:
def train_model(transistions):
    for i in range(0,len(transistions)-BATCH_SIZE,BATCH_SIZE):
        transistions0=transistions[i:i+BATCH_SIZE]
        obses=np.asarray([t[0] for t in transistions0])
        actions=np.asarray([t[1] for t in transistions0])
        rews=np.asarray([t[2] for t in transistions0])
        dones=np.asarray([t[3] for t in transistions0])
        new_obses=np.asarray([t[4] for t in transistions0])
        obses_t=torch.as_tensor(obses,dtype=torch.float32)
        actions_t=torch.as_tensor(actions,dtype=torch.int64).unsqueeze(-1)
        rews_t=torch.as_tensor(rews,dtype=torch.float32).unsqueeze(-1)
        dones_t=torch.as_tensor(dones,dtype=torch.float32).unsqueeze(-1)
        new_obses_t=torch.as_tensor(new_obses,dtype=torch.float32)
        target_q_values=target_net(new_obses_t.to(device)).to('cpu')
        max_target_q_values=target_q_values.max(dim=1,keepdim=True)[0]
        targets=(rews_t+GAMMA*(1-dones_t)*max_target_q_values).to(device)

        q_values=(online_net(obses_t.to(device))).to('cpu')
        action_q_values=(torch.gather(input=q_values,dim=1,index=actions_t)).to(device)
        loss=nn.functional.smooth_l1_loss(action_q_values, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()

In [None]:
optimizer=torch.optim.Adam(online_net.parameters(), lr=5e-4)
obs=env.reset()[0]
for step in itertools.count():
    epsilon=np.interp(step,[0,EPSILON_DECAY],[EPSILON_START,EPSILON_END])
    rnd_sample=random.random()
    if rnd_sample<=epsilon:
        action=env.action_space.sample()
    else:
        action=online_net.act(obs)
    new_obs,rew,done,info,_=env.step(action)
    transistion=(obs,action,rew,done,new_obs)
    replay_buffer.append(transistion)
    obs=new_obs
    episode_reward+=rew
    if done:
        obs=env.reset()[0]
        rew_buffer.append(episode_reward)
        epsilon_degradation.append(epsilon)
        episode_reward=0.0

    if step%100==0:
        transistions = list(replay_buffer)
        train_model(transistions)
    if step % TARGET_UPDATE_FREQ==0:
        target_net.load_state_dict(online_net.state_dict())
    if step % 100==0:
        print()
        print('step',step)
        print('avg reward',np.mean(rew_buffer))




In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.scatter(x=range(len(rew_buffer)),y=rew_buffer)
ax2.scatter(x=range(len(epsilon_degradation)),y=epsilon_degradation)
plt.show()