In [1]:
!pip3 install gym
!pip3 install seaborn



In [2]:
import torch
from torch import nn
import copy
from collections import deque
from tqdm import tqdm
import random
import gym
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
torch.manual_seed(1234)

<torch._C.Generator at 0x7f14ac47a510>

In [4]:
sync_freq = 5
network_sync_counter = 0
network_sync_freq = 5
exp_replay_size = 256

In [5]:
q_net = nn.Sequential(
    nn.Linear(4, 64),
    nn.Tanh(),
    nn.Linear(64, 2)
    )
q_net.cuda()

target_net = nn.Sequential(
    nn.Linear(4, 64),
    nn.Tanh(),
    nn.Linear(64, 2)
)        
        
target_net.cuda()
        
loss_fn = torch.nn.MSELoss()
optimizer = torch.optim.Adam(q_net.parameters(), lr=1e-3)

network_sync_counter = 0

discount = torch.tensor(0.95).float().cuda()

experience_replay = deque(maxlen=exp_replay_size)


def get_action(state, action_space_len, epsilon):
    # We do not require gradient at this point, because this function will be used either
    # during experience collection or during inference
    with torch.no_grad():
        Qp = q_net(torch.from_numpy(state).float().cuda())
    Q, A = torch.max(Qp, axis=0)
    A = A if torch.rand(1, ).item() > epsilon else torch.randint(0, action_space_len, (1,))
    return A

def get_q_next(state):
    with torch.no_grad():
        qp = target_net(state)
    q, _ = torch.max(qp, axis=1)
    return q

def collect_experience(experience):
    experience_replay.append(experience)
    return

def sample_from_experience(sample_size):
    if len(experience_replay) < sample_size:
        sample_size = len(experience_replay)
    sample = random.sample(experience_replay, sample_size)
    s = torch.tensor([exp[0] for exp in sample]).float()
    a = torch.tensor([exp[1] for exp in sample]).float()
    rn = torch.tensor([exp[2] for exp in sample]).float()
    sn = torch.tensor([exp[3] for exp in sample]).float()
    return s, a, rn, sn

def train(batch_size):
    global network_sync_counter
    global network_sync_freq
    
    s, a, rn, sn = sample_from_experience(sample_size=batch_size)
    if network_sync_counter == network_sync_freq:
        target_net.load_state_dict(q_net.state_dict())
        network_sync_counter = 0

    # predict expected return of current state using main network
    qp = q_net(s.cuda())
    pred_return, _ = torch.max(qp, axis=1)

    # get target return using target network
    q_next = get_q_next(sn.cuda())
    target_return = rn.cuda() + discount * q_next

    loss = loss_fn(pred_return, target_return)
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()

    network_sync_counter += 1
    return loss.item()

In [6]:
env = gym.make('CartPole-v0')

# Our observation space has 4 inputs
input_dim = 4

# Our action space has 2 outputs
output_dim = 2

exp_replay_size = 256

# Main training loop
losses_list, reward_list, episode_len_list, epsilon_list = [], [], [], []
episodes = 10000
epsilon = 1

# initiliaze experiance replay
index = 0
for i in range(exp_replay_size):
    obs = env.reset()
    done = False
    while not done:
        A = get_action(obs, env.action_space.n, epsilon=1)
        obs_next, reward, done, _ = env.step(A.item())
        collect_experience([obs, A.item(), reward, obs_next])
        obs = obs_next
        index += 1
        if index > exp_replay_size:
            break

index = 128
for i in tqdm(range(episodes)):
    obs, done, losses, ep_len, rew = env.reset(), False, 0, 0, 0
    while not done:
        ep_len += 1
        A = get_action(obs, env.action_space.n, epsilon)
        obs_next, reward, done, _ = env.step(A.item())
        collect_experience([obs, A.item(), reward, obs_next])

        obs = obs_next
        rew += reward
        index += 1

        if index > 128:
            index = 0
            for j in range(4):
                loss = train(batch_size=16)
                losses += loss
    if epsilon > 0.05:
        epsilon -= (1 / 5000)

    losses_list.append(losses / ep_len), reward_list.append(rew)
    episode_len_list.append(ep_len), epsilon_list.append(epsilon)


 73%|█████████████████████████████████████████████████████████████████████████████████▉                              | 7321/10000 [00:37<00:13, 193.62it/s]


KeyboardInterrupt: 

In [None]:
sns.lineplot(y=episode_len_list, x=[x for x in range(len(episode_len_list))])
plt.show()