In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
import matplotlib
import matplotlib.pyplot as plt
from collections import deque, namedtuple
import random
from deep_q.commons import DeepQConfig

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

# if gpu is to be used
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [ ]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
Run = namedtuple("Run", ('states', 'actions', 'rewards'))


class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        self.memory.append(Run(*args))

    def sample(self, batch_size):
        return random.sample(list(self.memory), batch_size)

    def __len__(self):
        return len(self.memory)

In [ ]:
class DRQN(nn.Module):
    def __init__(self, input_size, output_size, hidden_size=128, hidden_layers=1):
        super(DRQN, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, hidden_layers, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
      
    def forward(self, x, i=None):
        if i is not None:
            x, i = self.lstm(x, i)
        else:
            x, i = self.lstm(x)
        return self.out(x), i


In [ ]:
class DRQN_Agent:
    def __init__(self, net:nn.Module, config:DeepQConfig, path=None):
        self.config = config

        self.target_net = net(self.config.n_inputs, self.config.n_outputs, **self.config.net_kwargs).to(self.config.device)

        if path:
            self.q_net = torch.load(path).to(self.config.device)
        else:
            self.q_net = net(self.config.n_inputs, self.config.n_outputs, **self.config.net_kwargs)

        self.target_net.load_state_dict(self.q_net.state_dict())
        self.target_net.eval()

        self.optimizer = optim.Adam(self.q_net.parameters(), lr=self.config.lr)
        self.memory = ReplayMemory(self.config.rm_size)

        self.criterion = nn.MSELoss()

    def update_target_net(self):
        self.target_net.load_state_dict(self.q_net.state_dict())
        
    def soft_update_target_net(self):
        target_net_state_dict = self.target_net.state_dict()
        policy_net_state_dict = self.q_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key] * self.config.tau + target_net_state_dict[key] * (
                    1 - self.config.tau)
        self.target_net.load_state_dict(target_net_state_dict)

    def add_to_memory(self, *args):
        self.memory.push(*args)

    def select_action(self, state, i=None, epsilon=0):
        with torch.no_grad():
            x, i = self.q_net(state.view(-1, self.config.n_inputs), i)

        if random.random() < epsilon:
            # Explore: take a random action
            return torch.tensor([random.randrange(self.config.n_outputs)], device=self.config.device, dtype=torch.long), i
        else:
            # Exploit: select the highest Q value
            return x.max(1)[1][-1].view(1), i

    def Q(self, state, action):
        state = pad_sequence(state, batch_first=True)
        action = pad_sequence(action, batch_first=True)

        x, _= self.q_net(state)
        return x.gather(2, action.view(self.config.batch_size, -1, 1))

    def target(self, state, reward):
        state = pad_sequence(state, batch_first=True)
        reward = pad_sequence(reward, batch_first=True)

        x, _ = self.target_net(state)
        Q_target = x.max(2)[0].detach()

        
        return (Q_target * self.config.gamma) + reward

    def optimize(self):
        if len(self.memory) < self.config.batch_size:
            return

        batch = Run(*zip(*self.memory.sample(self.config.batch_size)))

        next_state_batch = (i[3:] for i in batch.states)
        state_batch = (i[:-3] for i in batch.states)
        action_batch = batch.actions
        reward_batch = batch.rewards

        self.optimizer.zero_grad()

        # state_batch.shape = (N x L x states)
        # action_batch.shape = (N x L x 1)
        y = self.Q(state_batch, action_batch).view(self.config.batch_size, -1)

        # next_state_batch.shape = (N x L x states)
        # reward_batch.shape = (N x L x 1)
        yl = self.target(next_state_batch, reward_batch).view(self.config.batch_size, -1)

        loss = self.criterion(y, yl)

        loss.backward()
        self.optimizer.step()

        return loss.cpu().data.item()

    def save(self, name):
        torch.save(self.q_net, name)
