In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import gym
import numpy as np
import matplotlib.pyplot as plt
import sys, os
from IPython.display import clear_output

In [None]:
batch_size = 32
gamma = 0.99

In [None]:
replay = []
env = gym.make('CartPole-v0')

for _ in range(100):
    obs = env.reset()
    done = False

    while (not done):
        prev_obs = obs
        action = env.action_space.sample()
        obs, reward, done, info = env.step(action)
        replay.append((prev_obs, action, reward, done, obs))

env.close()

In [None]:
len(replay)

In [None]:
class dqn(nn.Module):
    def __init__(self):
        super(dqn, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(self.ob_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, self.action_dim)
        )

    def forward(self, x):
        return self.layers(x)

In [None]:
dqn_agent = dqn()
dqn_target = dqn()

In [None]:
def target_update(dqn_main, dqn_target):
    dqn_target.load_state_dict(dqn_main.state_dict())

In [None]:
iteration = 5000
optimizer = optim.SGD(dqn_target.parameters(), lr = 1e-4)
optimizer.zero_grad()
q_values = []
losses = []

for iter in range(iteration):
    sample_index = np.random.randint(low=0, high=len(replay), size=batch_size)
    train_batch = [replay[idx] for idx in sample_index]

    s = torch.FloatTensor(np.array([i[0] for i in train_batch]))
    a = torch.LongTensor(np.array([i[1] for i in train_batch])).view([-1, 1])
    r = torch.FloatTensor(np.array([i[2] for i in train_batch])).view([-1, 1])
    d = torch.BoolTensor(np.array([i[3] for i in train_batch])).view([-1, 1])
    s_2 = torch.FloatTensor(np.array([i[4] for i in train_batch]))

    Q = torch.gather(dqn_agent.forward(s), 1, a)
 
    with torch.no_grad():
        y = r + gamma * torch.max(dqn_target.forward(s_2), dim=1, keepdim=True)[0]

    loss = F.mse_loss(Q, y)
    losses.append(loss)
    loss.backward()
    optimizer.step()

    if iter % 100 == 0:
        target_update(dqn_agent, dqn_target)

    if iter % 100 == 0:
        clear_output(True)
        plt.figure(figsize=(20, 5))
        plt.subplot(121)
        plt.title('loss')
        plt.plot(losses)
        plt.subplot(122)
        plt.title('q-value')
        plt.plot(Q.detach().numpy())
        plt.show()
        
