In [None]:
import copy
from typing import NamedTuple, Sequence, Optional

import gym
import numpy as np
import pandas as pd
import plotnine as gg
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical, Normal
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler

torch.manual_seed(1)


class Transition(NamedTuple):
  states: torch.Tensor
  actions: torch.Tensor
  log_probs: torch.Tensor
  rewards: torch.Tensor
  states_: torch.Tensor



class Replay:

  data: Optional[Sequence[np.ndarray]]
  capacity: int
  num_added: int

  def __init__(self, capacity: int):
    self.data = None
    self.capacity = capacity
    self.num_added = 0

  def add(self, items: Sequence[Any]):
    if self.data is None:
      self.preallocate(items)

    for slot, item in zip(self.data, items):
      slot[self.num_added % self.capacity] = item

    self.num_added += 1

  def sample(self, size: int) -> Sequence[np.ndarray]:
    indices = np.random.randint(self.size, size=size)
    return [slot[indices] for slot in self.data]

  def reset(self,):
    self.data = None

  @property
  def size(self) -> int:
    return min(self.capacity, self.num_added)

  @property
  def fraction_filled(self) -> float:
    return self.size / self.capacity

  def preallocate(self, items: Sequence[Any]):
    as_array = []
    for item in items:
      as_array.append(np.asarray(item))

    self.data = [np.zeros(dtype=x.dtype, shape=(self.capacity,) + x.shape)
                  for x in as_array]



class QNetwork(nn.Module):

  def __init__(self, obs_shape, hidden_size, num_actions):
    self.linear = nn.Linear(np.prod(obs_shape), hidden_size)
    self.q_head = nn.Linear(hidden_size, num_actions)
    self.num_actions = num_actions

  def forward(self, x):
    x = torch.flatten(x, start_dim=1)
    x = F.relu(self.linear(x))
    return self.q_head(x)


class DQNAgent:

    clip_param = 0.2
    max_grad_norm = 0.5
    buffer_capacity = 1_000
    batch_size = 32
    discount = 0.99
    target_network_update_interval = 100
    epsilon = 0.01

    def __init__(self, q_network, action_bound=None):
        self.training_step = 0

        self.q_network = q_network
        self.target_network = copy.deepcopy(q_network)
        self.replay = Replay(10_000)
        self.counter = 0

        self.optimizer = optim.Adam(self.actor.parameters(), lr=1e-3)

    def select_action(self, state):
        # Epsilon greedy.
        if np.random.rand() < self.epsilon:
          return torch.randint(self.q_network.num_actions)
        state = torch.from_numpy(state).float().unsqueeze(0)
        with torch.no_grad():
            q_values = self.q_network(state)

        return torch.argmax(q_values, axis=-1)

    def update(self, transition):

        self.replay.add(transition)
        self.counter += 1
        if self.counter % self.buffer_capacity != 0:
          return  # Only do SGD step when we have a full batch.

        self.training_step += 1

        batch = self.replay.sample(self.batch_size)

        s = torch.tensor([t.states for t in batch], dtype=torch.float)
        a = torch.tensor([t.actions for t in batch], dtype=torch.float).view(-1, 1)
        r = torch.tensor([t.rewards for t in batch], dtype=torch.float).view(-1, 1)
        s_ = torch.tensor([t.states_ for t in batch], dtype=torch.float)

        if self.training_step % self.target_network_update_interval == 0:
          self.target_network = copy.deepcopy(self.q_network)

        with torch.no_grad():
            target_q_values = self.q_network(s_)
            target = r + self.discount * self.q_network(s_)

            # TODO: Finalise DQN loss here.


        target 


        del self.buffer[:]


def main(num_episodes: int):

    discrete = True  # TODO Fix
    env = gym.make('CartPole-v1', new_step_api=True)
    # env = gym.make('Pendulum-v1', new_step_api=True)

    obs_shape = env.observation_space.shape

    # actor = ContinuousActor(obs_shape, 100).float()
    actor = DiscreteActor(obs_shape, 100, num_actions=env.action_space.n)
    critic = Critic(obs_shape, 100).float()
    agent = PPOAgent(actor, critic, action_bound=2)

    results = []
    state = env.reset()
    for episode in range(num_episodes):
        episode_return = 0
        state = env.reset()

        while True:
            action, log_prob = agent.select_action(state)
            env_action = action if discrete else [action]  # TODO: clean this up.
            state_, reward, terminated, truncated, _ = env.step(env_action)

            if terminated or truncated:
              break

            # TODO fix reward hack...
            transition = Transition(state, action, log_prob, reward, state_)
            agent.update(transition)
            episode_return += reward
            state = state_

        result = {
            'episode': episode,
            'return': episode_return,  # TODO fix this.
        }
        results.append(result)

        if episode % 5 == 0:
            print(result)

    return results

results = main(1000)