<a href="https://colab.research.google.com/github/d3sm0/torch_dqn/blob/master/torch_dqn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
!pip install -q torch gym

In [None]:
from collections import namedtuple,deque
import random

import torch
import torch.nn.functional
import torch.optim as optim
import numpy as np
import gym

In [0]:
class QNetwork(torch.nn.Module):
    def __init__(self, obs_shape, act_shape):
        super(QNetwork, self).__init__()
        self.fc_0 = torch.nn.Linear(obs_shape, 64)
        self.fc_1 = torch.nn.Linear(64, 64)
        self.out = torch.nn.Linear(64, act_shape)

    def forward(self, x):
        x = x.view((x.size(0), -1))
        x = torch.nn.functional.relu(self.fc_0(x))
        x = torch.nn.functional.relu(self.fc_1(x))
        x = self.out(x)
        return x

In [0]:
Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward', 'done'))

def one_hot(a, num_classes):
    return np.squeeze(np.eye(num_classes)[a.reshape(-1)])


class Memory(object):

    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, *args):
        """Saves a transition."""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = Transition(*args)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)


In [0]:
class LinearSchedule(object):
    def __init__(self, schedule_timesteps, final_p, initial_p=1.0):
        """Linear interpolation between initial_p and final_p over
        schedule_timesteps. After this many timesteps pass final_p is
        returned.

        Parameters
        ----------
        schedule_timesteps: int
            Number of timesteps for which to linearly anneal initial_p
            to final_p
        initial_p: float
            initial output value
        final_p: float
            final output value
        """
        self.schedule_timesteps = schedule_timesteps
        self.final_p = final_p
        self.initial_p = initial_p
        self.p = self.initial_p

    @property
    def value(self):
        return self.p

    def reset(self):
        self.p = self.initial_p

    def update(self, t):
        """See Schedule.value"""
        if self.p > self.final_p:
            fraction = min(float(t) / self.schedule_timesteps, 1.0)
            self.p = self.initial_p + fraction * (self.final_p - self.initial_p)


In [None]:
lr = 1e-3
max_steps = int(1e5)
env = 'CartPole-v1'
gamma = .99
train_every = 4
update_every int(1e4)
batch_siz = 64

In [22]:
env = gym.make('CartPole-v1')
# env = ImgObsWrapper(env)
#path = args.base_path + args.env
#os.makedirs(path, exist_ok=True)
obs_shape = env.observation_space.shape[0]
act_shape = env.action_space.n

q = QNetwork(obs_shape, act_shape)
q_target = QNetwork(obs_shape, act_shape)
opt = optim.Adam(lr=1e-3, params=q.parameters())
memory = Memory(capacity=int(1e4))
scheduler = LinearSchedule(schedule_timesteps=int(1e5 * 0.1), final_p=0.01)


avg_rw = deque(maxlen=40)
avg_len = deque(maxlen=40)

  result = entry_point.load(False)


In [0]:
def get_action(s, t):

    s = torch.Tensor(s[None, :])
    _q = q(s)
    if np.random.sample() > scheduler.value:
        best_action = np.argmax(_q.detach(), axis=-1).item()
    else:
        best_action = np.random.randint(0, act_shape)
        scheduler.update(t)
    return best_action

In [0]:
 def train(batch, gamma = .99):
        batch = Transition(*zip(*batch))
        s = torch.Tensor(batch.state)
        a = torch.Tensor(one_hot(np.array(batch.action), num_classes=act_shape))
        r = torch.Tensor(batch.reward)
        d = torch.Tensor(batch.done)
        s1 = torch.Tensor(batch.next_state)

        value = (q(s) * a).sum(dim=-1)
        next_value = r + gamma* (1. - d) * torch.max(q_target(s1), dim=-1)[0]
        loss = (.5 * (next_value - value) ** 2).mean()
        opt.zero_grad()
        loss.backward()
        opt.step()

In [32]:
state = env.reset()

q_target.load_state_dict(q.state_dict())

ep_rw = 0
ep_len = 0
ep = 0
for t in range(int(max_steps)):
    action = get_action(state, t)
    next_state, reward, done, _ = env.step(action)
    memory.push(state, action, next_state, reward, done)
    ep_rw += reward
    ep_len += 1

    state = next_state.copy()
    if done:
        ep += 1
        avg_rw.append(ep_rw)
        avg_len.append(ep_len)
        ep_rw = 0
        ep_len = 0
        state = env.reset()

    if t % train_every == 0 and len(memory) > batch_size:
        batch = memory.sample(batch_size=batch_size)
        train(batch)

    if t % update_every == 0:
        q_target.load_state_dict(q.state_dict())
        print(f't:{t}\tep:{ep}\tavg_rw:{np.mean(avg_rw)}\tavg_len:{np.mean(avg_len)}\teps:{scheduler.value}')


t:0	ep:0	avg_rw:12.666666666666666	avg_len:12.666666666666666	eps:1.0


KeyboardInterrupt: ignored