In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import common
from nb_export.core import build_all

In [3]:
import gym
import numpy as np
import time

import torch
import torch.optim as optim
import torch.nn.functional as F

import ptan

In [4]:
params = common.HYPERPARAMS['pong']
device = torch.device('cuda')

env = gym.make(params.env_name)
env = ptan.common.wrappers.wrap_dqn(env)
env.seed(common.SEED)

net = common.DQN(env.observation_space.shape, env.action_space.n).to(device)
tgt_net = ptan.agent.TargetNet(net)

selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=params.epsilon_start)
epsilon_tracker = common.EpsilonTracker(selector, params)
agent = ptan.agent.DQNAgent(net, selector, device=device)

exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, params.gamma, steps_count=params.n)
buffer = ptan.experience.ExperienceReplayBuffer(exp_source, params.replay_size)
opt = optim.Adam(net.parameters(), params.learning_rate)

In [5]:
iteration = 0
episode = 0
time_begin = time.time()

for batch in common.batch_generator(buffer, params.replay_initial, params.batch_size):
    iteration += 1
    
    for reward,steps in exp_source.pop_rewards_steps():
        episode += 1
        elapsed = time.time() - time_begin
        print(f'Episode: {episode} | Reward: {reward} | Epsilon: {selector.epsilon} | Elapsed: {elapsed}')
        time_begin = time.time()
    
    opt.zero_grad()
    states, actions, rewards, last_states, dones = common.unpack_batch(batch)

    actual_qs = net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

    # target
    with torch.no_grad():
        target_qs = tgt_net.model(last_states).max(dim=1)[0]
        target_qs[dones] = 0.0
        target = rewards + (params.gamma**params.n)*target_qs.detach()

    loss = F.mse_loss(target, actual_qs)
    loss.backward()
    opt.step()
    
    epsilon_tracker.frame(iteration)
    
    if iteration%params.target_net_sync==0:
        print("Sync")
        tgt_net.sync()

Done populating
Episode: 1 | Reward: -20.0 | Epsilon: 1.0 | Elapsed: 31.85182285308838
Episode: 2 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 1.9073486328125e-06
Episode: 3 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 1.1920928955078125e-06
Episode: 4 | Reward: -20.0 | Epsilon: 1.0 | Elapsed: 9.5367431640625e-07
Episode: 5 | Reward: -20.0 | Epsilon: 1.0 | Elapsed: 9.5367431640625e-07
Episode: 6 | Reward: -20.0 | Epsilon: 1.0 | Elapsed: 1.1920928955078125e-06
Episode: 7 | Reward: -18.0 | Epsilon: 1.0 | Elapsed: 9.5367431640625e-07
Episode: 8 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 9.5367431640625e-07
Episode: 9 | Reward: -20.0 | Epsilon: 1.0 | Elapsed: 9.5367431640625e-07
Episode: 10 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 9.5367431640625e-07
Episode: 11 | Reward: -19.0 | Epsilon: 0.99617 | Elapsed: 6.208643674850464
Sync
Episode: 12 | Reward: -21.0 | Epsilon: 0.98831 | Elapsed: 13.85500431060791
Sync
Episode: 13 | Reward: -20.0 | Epsilon: 0.97915 | Elapsed: 15.892884969711304
Episode: 1

Sync
Sync
Sync
Episode: 98 | Reward: 4.0 | Epsilon: 0.02 | Elapsed: 52.0904541015625
Sync
Sync
Sync
Episode: 99 | Reward: 4.0 | Epsilon: 0.02 | Elapsed: 51.346985816955566
Sync
Sync
Sync
Episode: 100 | Reward: 4.0 | Epsilon: 0.02 | Elapsed: 54.168370723724365
Sync
Sync
Sync
Episode: 101 | Reward: -1.0 | Epsilon: 0.02 | Elapsed: 57.62851047515869
Sync
Sync
Sync
Episode: 102 | Reward: 3.0 | Epsilon: 0.02 | Elapsed: 52.81425213813782
Sync
Sync
Sync
Sync
Episode: 103 | Reward: 4.0 | Epsilon: 0.02 | Elapsed: 63.71525287628174
Sync
Sync
Sync
Episode: 104 | Reward: 11.0 | Epsilon: 0.02 | Elapsed: 47.79267144203186
Sync
Sync
Episode: 105 | Reward: 12.0 | Epsilon: 0.02 | Elapsed: 48.96186542510986
Sync
Sync
Sync
Episode: 106 | Reward: 14.0 | Epsilon: 0.02 | Elapsed: 42.41900992393494
Sync
Sync
Sync
Episode: 107 | Reward: 9.0 | Epsilon: 0.02 | Elapsed: 45.96444082260132
Sync
Sync
Episode: 108 | Reward: 16.0 | Epsilon: 0.02 | Elapsed: 39.21913433074951
Sync
Sync
Episode: 109 | Reward: 14.0 | Epsi

KeyboardInterrupt: 