In [0]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

#
import numpy as np
import gym
from collections import deque

In [0]:
class Memory:
    def __init__(self, memory_size, state_shape, device="cpu",dtype = np.float32):
        if not isinstance(state_shape, list):
            state_shape = [state_shape]
        self.states = np.zeros((memory_size, *state_shape), dtype=dtype)
        self.actions = np.zeros((memory_size,1), dtype=dtype)
        self.rewards = np.zeros((memory_size), dtype=dtype)
        self.next_states = np.zeros((memory_size, *state_shape), dtype=dtype)
        self.dones = np.zeros((memory_size), dtype=dtype)

        self.dev = device
        self.indices = []
        self.ptr = 0
        self.max_len = memory_size

    def member(self, state,action,reward,next_state,done):
        self.states[self.ptr] = self._state_to_correct_shape(state)
        self.actions[self.ptr] = action
        self.rewards[self.ptr] = reward
        self.next_states[self.ptr] = self._state_to_correct_shape(next_state)
        self.dones[self.ptr] = done
    
        self.ptr = (self.ptr + 1) % self.max_len
        if len(self.indices) != self.max_len - 1:
          self.indices.append(self.ptr)
                  
    @staticmethod
    def _state_to_correct_shape(state):
        state = np.array(state)
        # state = state.reshape(1,-1)
        return state

    def sample(self, batch_size = 16):
        idx = np.random.choice(self.indices, batch_size)
        return (torch.from_numpy(self.states[idx]).to(self.dev), 
                torch.from_numpy(self.actions[idx]).to(self.dev), 
                torch.from_numpy(self.rewards[idx]).to(self.dev), 
                torch.from_numpy(self.next_states[idx]).to(self.dev), 
                torch.from_numpy(self.dones[idx]).to(self.dev))

    def __len__(self):
        return len(self.indices)

class MeanBuffer:
    def __init__(self, size):
        self.buffer = deque(maxlen=size)
        self.sum = 0
        self.size = size
        self.global_max = None
    def append(self, x):
        if len(self.buffer) == self.size:
            self.sum-=self.buffer[0]
        self.global_max = max(self.global_max,x) if self.global_max else x
        self.sum += x
        self.buffer.append(x)
    def mean(self):
        if not self.buffer:
            return 0.0
        return self.sum / len(self.buffer)
    def max(self):
      return self.global_max

In [0]:
class Flatten(nn.Module):
  def __init__(self):
    super().__init__()
  def forward(self,x):
    return x.view(x.size(0),-1)
  
class QNet(nn.Module):
  def __init__(self, n_action,in_channels=3):
    super().__init__()
#     self.net = nn.Sequential(
#         nn.Conv2d(in_channels, 32, kernel_size=8, stride=4),
#         nn.BatchNorm2d(32),
#         nn.Conv2d(32,64, kernel_size=4, stride=3),
#         nn.BatchNorm2d(64),
#         nn.Conv2d(64,64, kernel_size=3, stride=1),
#         nn.BatchNorm2d(64),
#         Flatten(),
#         nn.Linear(8960, 512),
#         nn.Linear(512,n_action)
#     )
    self.net = nn.Sequential(
        nn.Linear(4, 128),
        nn.Linear(128,n_action)
    )
    
  
  def forward(self,x):
    return self.net(x)

In [0]:
class Agent:
  def __init__(self,env,device="cuda"):
    self.memory = Memory(10_000, list(np.roll(env.reset().shape,1)), device = device)
    self.target_net = QNet(env.action_space.n).to(device)
    self.policy_net = QNet(env.action_space.n).to(device)
    self.target_net.load_state_dict(self.policy_net.state_dict())
    self.env = env
    self.bs = 64
    self.reward_tracker = MeanBuffer(20)
    self.lr = 5e-2
    self.gamma = 0.95
    
    self.epsilon = 1.
    self.epsilon_min = 0.05
    self.epsilon_decay = 0.999
    
    #
    self.optim = optim.Adam(self.policy_net.parameters(), lr=self.lr)
    self.start_frame = 1_000
    self.update_frame = 1_000
    self.dev = device
    
    
  def act(self,state):
    if np.random.rand() < self.epsilon:
      return self.env.action_space.sample()
    else:
      if not isinstance(state,torch.Tensor):
        state = torch.FloatTensor(state).unsqueeze(0).to(self.dev)
      return torch.argmax(self.policy_net(state)).item()
    
  def replay(self):
    states,actions,rewards,next_states,dones = self.memory.sample(self.bs)
    next_states_qs = self.target_net(next_states)
    states_qs = self.policy_net(states)
    target = rewards + self.gamma * next_states_qs.max(1)[0] * dones
    loss = F.smooth_l1_loss(states_qs.gather(1, actions.long()), target.unsqueeze(1))
    self.optim.zero_grad()
    loss.backward()
    for param in self.policy_net.parameters():
      param.grad.data.clamp_(-1, 1)
    self.optim.step()
    
    return loss
    
  def member(self, state,action,reward,next_state,done):
    self.memory.member(state,action,reward,next_state,done)
  def decay(self, steps_done):
    self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
#     self.epsilon = self.epsilon_min + (1 - self.epsilon_min) * np.exp(-1. * steps_done / EPS_DECAY) 
  def update_target(self):
    self.target_net.load_state_dict(self.policy_net.state_dict())
  def play(self, max_frame):
    current_frame = 0
    while max_frame > current_frame:
      state = self.env.reset()
      episode_reward = 0
      while True:
        current_frame += 1
        action = self.act(state)
        next_state, reward, done, _ = self.env.step(action)
        next_state = next_state
        episode_reward += reward
        self.member(state,action,episode_reward,next_state,0 if done else 1)
        if self.start_frame < current_frame:
          self.replay()
        if (current_frame % self.update_frame) == 0:
          self.update_target()
          torch.save(self.policy_net.state_dict,f"breakout_policy_net_{current_frame}")

        if done:
          self.reward_tracker.append(episode_reward)
          print(f"Frame : {current_frame} Reward : {episode_reward}, mean : {self.reward_tracker.mean()} epsilon : {self.epsilon}")
          episode_reward = 0
          self.decay(current_frame)
          break
          
        state = next_state  
    

In [0]:
env = gym.make("Breakout-v0")

In [0]:
env = gym.make("CartPole-v0")

In [0]:
agent = Agent(env)

In [8]:
agent.play(50_000)

Frame : 15 Reward : 15.0, mean : 15.0 epsilon : 1.0
Frame : 26 Reward : 11.0, mean : 13.0 epsilon : 0.999
Frame : 41 Reward : 15.0, mean : 13.666666666666666 epsilon : 0.998001
Frame : 52 Reward : 11.0, mean : 13.0 epsilon : 0.997002999
Frame : 66 Reward : 14.0, mean : 13.2 epsilon : 0.996005996001
Frame : 77 Reward : 11.0, mean : 12.833333333333334 epsilon : 0.995009990004999
Frame : 95 Reward : 18.0, mean : 13.571428571428571 epsilon : 0.994014980014994
Frame : 118 Reward : 23.0, mean : 14.75 epsilon : 0.993020965034979
Frame : 134 Reward : 16.0, mean : 14.88888888888889 epsilon : 0.9920279440699441
Frame : 144 Reward : 10.0, mean : 14.4 epsilon : 0.9910359161258742
Frame : 165 Reward : 21.0, mean : 15.0 epsilon : 0.9900448802097482
Frame : 185 Reward : 20.0, mean : 15.416666666666666 epsilon : 0.9890548353295385
Frame : 199 Reward : 14.0, mean : 15.307692307692308 epsilon : 0.988065780494209
Frame : 216 Reward : 17.0, mean : 15.428571428571429 epsilon : 0.9870777147137147
Frame : 22

  "type " + obj.__name__ + ". It won't be checked "


Frame : 1081 Reward : 20.0, mean : 20.85 epsilon : 0.9493041681462656
Frame : 1105 Reward : 24.0, mean : 21.15 epsilon : 0.9483548639781193
Frame : 1126 Reward : 21.0, mean : 21.5 epsilon : 0.9474065091141411
Frame : 1136 Reward : 10.0, mean : 21.15 epsilon : 0.946459102605027
Frame : 1157 Reward : 21.0, mean : 21.15 epsilon : 0.9455126435024219
Frame : 1199 Reward : 42.0, mean : 21.45 epsilon : 0.9445671308589195
Frame : 1244 Reward : 45.0, mean : 22.6 epsilon : 0.9436225637280606
Frame : 1260 Reward : 16.0, mean : 22.2 epsilon : 0.9426789411643326
Frame : 1287 Reward : 27.0, mean : 22.35 epsilon : 0.9417362622231683
Frame : 1303 Reward : 16.0, mean : 22.25 epsilon : 0.9407945259609451
Frame : 1320 Reward : 17.0, mean : 21.95 epsilon : 0.9398537314349842
Frame : 1340 Reward : 20.0, mean : 22.1 epsilon : 0.9389138777035492
Frame : 1367 Reward : 27.0, mean : 21.9 epsilon : 0.9379749638258457
Frame : 1388 Reward : 21.0, mean : 21.85 epsilon : 0.9370369888620198
Frame : 1408 Reward : 20.0