In [1]:
%load_ext autoreload
%autoreload 2

# DQN PTAN

In [2]:
import common
import gym
import numpy as np
import time

import torch
import torch.optim as optim
import torch.nn.functional as F

import ptan

In [3]:
params = common.HYPERPARAMS['pong']
device = torch.device('cuda')

In [4]:
env = gym.make(params.env_name)
env = ptan.common.wrappers.wrap_dqn(env)
env.seed(common.SEED);

In [5]:
net = common.DQN(env.observation_space.shape, env.action_space.n).to(device)
tgt_net = ptan.agent.TargetNet(net)

selector = ptan.actions.EpsilonGreedyActionSelector(epsilon=params.epsilon_start)
epsilon_tracker = common.EpsilonTracker(selector, params)
agent = ptan.agent.DQNAgent(net, selector, device=device)

exp_source = ptan.experience.ExperienceSourceFirstLast(env, agent, params.gamma)
buffer = ptan.experience.ExperienceReplayBuffer(exp_source, params.replay_size)
opt = optim.Adam(net.parameters(), params.learning_rate)

In [6]:
def unpack_batch(batch):
    states, actions, rewards, last_states = list(zip(*batch))

    states = [np.array(state, copy=False) for state in states]

    last_states = torch.tensor([np.array(last_state, copy=False) if last_state is not None else states[0] for last_state in last_states]).to(device)

    states = torch.tensor(states).to(device)
    actions = torch.tensor(actions).to(device)
    rewards = torch.tensor(rewards, dtype=torch.float32).to(device)
    dones = torch.BoolTensor([last_state is None for last_state in last_states])
    
    return states, actions, rewards, last_states, dones

In [7]:
def unpack_batch_(batch):
    states, actions, rewards, dones, last_states = [],[],[],[],[]
    for exp in batch:
        state = np.array(exp.state)
        states.append(state)
        actions.append(exp.action)
        rewards.append(exp.reward)
        dones.append(exp.last_state is None)
        if exp.last_state is None:
            lstate = state  # the result will be masked anyway
        else:
            lstate = np.array(exp.last_state)
        last_states.append(lstate)
    return np.array(states, copy=False), np.array(actions), \
           np.array(rewards, dtype=np.float32), \
           np.array(dones, dtype=np.uint8), \
           np.array(last_states, copy=False)

def unpack_batch(batch):
    states, actions, rewards, dones, last_states = [],[],[],[],[]
    for exp in batch:
        state = np.array(exp.state)
        states.append(state)
        actions.append(exp.action)
        rewards.append(exp.reward)
        dones.append(exp.last_state is None)
        if exp.last_state is None:
            lstate = state  # the result will be masked anyway
        else:
            lstate = np.array(exp.last_state)
        last_states.append(lstate)
    
    states = np.array(states, copy=False)
    #actions = np.array(actions)
    #rewards = np.array(rewards, dtype=np.float32)
    #dones = np.array(dones, dtype=np.uint8)
    next_states = np.array(last_states, copy=False)
    
    states_v = torch.tensor(states).to(device)
    next_states_v = torch.tensor(next_states).to(device)
    actions_v = torch.tensor(actions).to(device)
    rewards_v = torch.tensor(rewards, dtype=torch.float32).to(device)
    done_mask = torch.BoolTensor(dones).to(device)
    
    return states_v, actions_v, rewards_v, next_states_v, done_mask

In [None]:
iteration = 0
episode = 0
time_begin = time.time()

for batch in common.batch_generator(buffer, params.replay_initial, params.batch_size):
    iteration += 1
    
    for reward,steps in exp_source.pop_rewards_steps():
        episode += 1
        elapsed = time.time() - time_begin
        print(f'Episode: {episode} | Reward: {reward} | Epsilon: {selector.epsilon} | Elapsed: {elapsed}')
        time_begin = time.time()
    
    opt.zero_grad()
    states, actions, rewards, last_states, dones = unpack_batch(batch)

    actual_qs = net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1)

    # target
    with torch.no_grad():
        target_qs = tgt_net.model(last_states).max(dim=1)[0]
        target_qs[dones] = 0.0
        target = rewards + params.gamma*target_qs.detach()

    loss = F.mse_loss(target, actual_qs)
    loss.backward()
    opt.step()
    
    epsilon_tracker.frame(iteration)
    
    if iteration%params.target_net_sync==0:
        print("Sync")
        tgt_net.sync()

Done populating
Episode: 1 | Reward: -20.0 | Epsilon: 1.0 | Elapsed: 28.728986263275146
Episode: 2 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 1.430511474609375e-06
Episode: 3 | Reward: -20.0 | Epsilon: 1.0 | Elapsed: 9.5367431640625e-07
Episode: 4 | Reward: -20.0 | Epsilon: 1.0 | Elapsed: 1.1920928955078125e-06
Episode: 5 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 9.5367431640625e-07
Episode: 6 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 1.1920928955078125e-06
Episode: 7 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 1.1920928955078125e-06
Episode: 8 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 7.152557373046875e-07
Episode: 9 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 7.152557373046875e-07
Episode: 10 | Reward: -19.0 | Epsilon: 1.0 | Elapsed: 7.152557373046875e-07
Episode: 11 | Reward: -21.0 | Epsilon: 1.0 | Elapsed: 7.152557373046875e-07
Episode: 12 | Reward: -20.0 | Epsilon: 0.99202 | Elapsed: 12.264180898666382
Sync
Episode: 13 | Reward: -21.0 | Epsilon: 0.98379 | Elapsed: 13.927971124649048
Sy

Sync
Episode: 98 | Reward: -20.0 | Epsilon: 0.10297 | Elapsed: 20.961503744125366
Sync
Episode: 99 | Reward: -19.0 | Epsilon: 0.09058999999999995 | Elapsed: 19.509050130844116
Sync
Sync
Episode: 100 | Reward: -19.0 | Epsilon: 0.07394 | Elapsed: 26.57668900489807
Sync
Sync
Episode: 101 | Reward: -18.0 | Epsilon: 0.056540000000000035 | Elapsed: 27.741881132125854
Sync
Episode: 102 | Reward: -20.0 | Epsilon: 0.04274 | Elapsed: 21.665473222732544
Sync
Sync
Episode: 103 | Reward: -21.0 | Epsilon: 0.029610000000000025 | Elapsed: 19.849308490753174
Sync
Episode: 104 | Reward: -19.0 | Epsilon: 0.02 | Elapsed: 21.092462301254272
Sync
Episode: 105 | Reward: -19.0 | Epsilon: 0.02 | Elapsed: 23.86977744102478
Sync
Sync
Episode: 106 | Reward: -21.0 | Epsilon: 0.02 | Elapsed: 19.050861835479736
Sync
Episode: 107 | Reward: -17.0 | Epsilon: 0.02 | Elapsed: 23.99392318725586
Sync
Episode: 108 | Reward: -19.0 | Epsilon: 0.02 | Elapsed: 18.472620964050293
Sync
Sync
Episode: 109 | Reward: -17.0 | Epsilon:

Sync
Sync
Sync
Episode: 192 | Reward: -8.0 | Epsilon: 0.02 | Elapsed: 65.51734590530396
Sync
Sync
Sync
Sync
Sync
Episode: 193 | Reward: -3.0 | Epsilon: 0.02 | Elapsed: 78.6270318031311
Sync
Sync
Sync
Sync
Episode: 194 | Reward: -8.0 | Epsilon: 0.02 | Elapsed: 65.26312923431396
Sync
Sync
Sync
Episode: 195 | Reward: -8.0 | Epsilon: 0.02 | Elapsed: 57.98408532142639
Sync
Sync
Sync
Sync
Episode: 196 | Reward: -5.0 | Epsilon: 0.02 | Elapsed: 67.9667272567749
Sync
Sync
Sync
Sync
Episode: 197 | Reward: 3.0 | Epsilon: 0.02 | Elapsed: 69.95203995704651
Sync
Sync
Sync


KeyboardInterrupt: 