# Cart Pole solved with DQN using PyTorch

In [None]:
import utils 
from replaybuffer import ReplayBuffer
import torch 
import gym
import numpy as np 
import copy 
import matplotlib.pyplot as plt
%matplotlib inline

## Hyperparameters And constants

In [None]:
max_episodes= 5000
max_steps_in_episode=1000
learning_rate= 1e-3
gamma=0.995
max_replays=200000
batch_size=64
train_interval=3
tau= 1e-3 # used for soft update of target network after training. takes a fraction of the model's trained weights

layers=[64,64]
MSE = torch.nn.MSELoss()
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
memories=ReplayBuffer(max_replays, batch_size, device)

# Creating the environment

In [None]:
env = gym.make("CartPole-v1",new_step_api=True)

state_size = env.observation_space.shape
num_actions = env.action_space.n
Actions=np.arange(num_actions)
print(f'num_actions={num_actions}\n state_size={state_size} ')

### instantiating the deep Q networks

In [None]:
model=utils.DQN(state_size[0],  num_actions, layers)
target=copy.deepcopy(model) # target is a clone of the Q network 
opt=torch.optim.Adam(model.parameters(), lr=learning_rate)
target.eval() # target is not trained. Uses soft update to update its weights instead
model.train()
target.to(device)
model.to(device)

print(target)

## Exploration vs exploitation
### When training starts, exploration is favoured over exploitation. $\epsilon$ decays over training session to favour exploitation over exploration as model learns to take the right steps for each state.

In [None]:
epsilon=1.0 # starting epsilon
decay=0.995 # decay factor per episode

In [None]:
def getAction(dqnModel, epsilon, curState):
    model.eval()
    action=-1    
    if np.random.uniform(0,1) < max(0.05, epsilon) :
        action=np.random.choice(Actions) # randomly pick an action       
    else:
        action=utils.getQAction(dqnModel, state,device)
    return action

## Train the Q network

$$
\begin{equation}
    y_j =
    \begin{cases}
      R_j & \text{if episode terminates at step  } j+1\\
      R_j + \gamma \max_{a'}\hat{Q}(s_{j+1},a') & \text{otherwise}\\
    \end{cases}       
\end{equation}
$$

In [None]:
def trainDQN():
    model.train()
    states, actions, rewards, next_states, dones=memories.sample()

    with torch.no_grad():
        max_qsa=target(next_states)

    max_qsa, _=torch.max(max_qsa, dim=1)
    y_targets=rewards + gamma* max_qsa * (1. - dones)

    qsa=model(states)    
    vals=qsa.gather(1, actions.unsqueeze(1)).squeeze(1)
    opt.zero_grad()
    loss=MSE(y_targets, vals)
    loss.backward()
    opt.step()
    # soft update the target network 
    utils.softupdate(target, model, tau)
    

## Start collecting experiences and train the DQN

In [None]:
AverageRewards=[]
Rewards=0.
showRes=100
totalRewards=0. 
cnt=0

for ep in range (1, (max_episodes+1),1):
    state=env.reset()
    Rewards=0.
    epsilon*=decay
    for step in range(1, (max_steps_in_episode+1), 1):        
        cnt+=1
        action=getAction(model, epsilon, state)
        next, reward, done, _,_ = env.step(action)        
        memories.add(state, action, reward,next, done)
        Rewards+=reward
        if len(memories)> batch_size and cnt % train_interval==0:            
            trainDQN()
        state=next 
        if done or step ==max_steps_in_episode:
            break        
    print(f"\r episode: {ep} \t reward: {Rewards}", end="")
    totalRewards+=Rewards
    if ep % showRes==0:                  
        AverageRewards.append(totalRewards/showRes)     
        print(f'\r Episode:{ep} \t Average rewards= {AverageRewards[-1] } ') 
        totalRewards=0.  
        if AverageRewards[-1] > 500:
            print('\n\n training ends early ')
            break

In [None]:
plt.plot(AverageRewards)
plt.show()

In [None]:
utils.runDQNAgent(model, env, device, fps=50)
assert(False)

In [None]:
env.close()

## Save the trained model

In [None]:
utils.saveTrainedModel(model, 'cartpole')

## load trained model to play

In [None]:
trainedModel= utils.DQN(state_size[0],  num_actions, layers)
utils.loadModel(trainedModel, "weights/cartpole")
print(trainedModel)
trainedModel.to(device)

###  run the trained agent 

In [None]:
utils.runDQNAgent(trainedModel, env, device, fps=50)

In [None]:
env.close()