In [1]:
import random
import copy
from collections import deque
import itertools

import gym
from gym.spaces.box import Box
from gym import wrappers
from gym.wrappers import TransformObservation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.utils.data import DataLoader, IterableDataset
import numpy as np

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint

from gym.wrappers import RecordVideo, RecordEpisodeStatistics, TimeLimit, AtariPreprocessing


device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_gpus = torch.cuda.device_count()

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class DRQN(nn.Module):
    def __init__(self, state_size , n_actions):
        super(DRQN, self).__init__()
        
        self.state_size = state_size
        self.conv = nn.Sequential(
                        nn.Conv2d(state_size[0], 32, 3, stride=2, padding=1),
                        nn.ELU(),
                        nn.Conv2d(32, 32, 3, stride=2, padding=1),
                        nn.ELU(),
                        nn.Conv2d(32, 32, 3, stride=2, padding=1),
                        nn.ELU(),
                        nn.Conv2d(32, 32, 3, stride=2, padding=1),
                        nn.ELU()
                    )
        conv_out_size = self._get_conv_out(state_size)
        self.fc1 = nn.Linear(conv_out_size, 256)
        self.fc_adv = nn.Linear(256, n_actions) 
        self.fc_value = nn.Linear(256, 1)
        
    def _get_conv_out(self, shape):
        conv_out = self.conv(torch.zeros(1, *shape))
        return int(np.prod(conv_out.size()))
    
    def forward(self, x):        
        o = self.conv(x.float()).view(x.shape[0], -1)
        o = F.relu(self.fc1(o))
        
        adv = self.fc_adv(o)
        value = self.fc_value(o)  
        
        return value + adv - torch.mean(adv, dim=1, keepdim=True)

In [3]:
def epsilon_greedy(state, env, net, epsilon=0.0):
    if np.random.random() < epsilon:
        action = env.action_space.sample()
    else:
        state = state.to(device)
        q_values = net(state)
        _, action = torch.max(q_values, dim=1)
        action = int(action.item())
    return action

In [4]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
        
    def __len__(self):
        return len(self.buffer)
    
    def append(self, experience):
        self.buffer.append(experience)
        
    def sample(self, batch_size):
        return random.sample(self.buffer, batch_size)

In [5]:
class RLDataset(IterableDataset):
    def __init__(self, buffer, sample_size=400):
        self.buffer = buffer
        self.sample_size = sample_size
    
    def __iter__(self):
        for experience in self.buffer.sample(self.sample_size):
            yield experience

In [6]:
def create_environment(name):
    env = gym.make(name, render_mode="rgb_array")
    env.unwrapped._frameskip = 1
    env = TimeLimit(env, max_episode_steps=400)
    env = RecordVideo(env, video_folder='./videos/drqn-pong', episode_trigger=lambda x: x % 50 == 0)
    env = RecordEpisodeStatistics(env)
    env = gym.wrappers.AtariPreprocessing(env, frame_skip=8, noop_max=28, screen_size=64, terminal_on_life_loss=False, grayscale_obs=True, grayscale_newaxis=False, scale_obs=True)
    env.observation_space = Box(0.0, 1.0, [1, 64, 64])
    return env



In [7]:
class DeepQLearning(LightningModule):
    def __init__(self, env_name, policy=epsilon_greedy, capacity=100_000, 
               batch_size=256, lr=1e-3, hidden_size=128, gamma=0.99, 
               loss_fn=nn.MSELoss(), optim=AdamW, eps_start=1.0, eps_end=0.15, 
               eps_last_episode=400, samples_per_epoch=1024, sync_rate=10,
               sequence_length = 8):
    
        super().__init__()
        self.env = create_environment(env_name)

        obs_size = self.env.observation_space.shape
        n_actions = self.env.action_space.n

        self.q_net = DRQN(obs_size, n_actions)

        self.target_q_net = copy.deepcopy(self.q_net)

        self.policy = policy
        self.buffer = ReplayBuffer(capacity=capacity)
        self.save_hyperparameters()

        while len(self.buffer) < self.hparams.samples_per_epoch:
            print(f"{len(self.buffer)} samples in experience buffer. Filling...")
            self.play_episode(epsilon=self.hparams.eps_start)
            
    @torch.no_grad()
    def play_episode(self, policy=None, epsilon=0.):
        state  = self.env.reset()
        state  = torch.from_numpy(state[0]).unsqueeze(dim=0)
        done = False
        
        while not done:
            if policy:
                action = policy(state.unsqueeze(dim=0), self.env, self.q_net, epsilon=epsilon)
            else:
                action = self.env.action_space.sample()
            next_state, reward, done, tru , _ = self.env.step(action)
            if tru:
                done = tru
            
            next_state = torch.from_numpy(next_state).unsqueeze(dim=0) 
            exp = (state, action, reward, done, next_state)
            
            self.buffer.append(exp)
            state = next_state
            
        self.env.close()
        
        
    def forward(self, x):
        return self.q_net(x)

    
    def configure_optimizers(self):
        q_net_optimizer = self.hparams.optim(self.q_net.parameters(), lr=self.hparams.lr)
        return [q_net_optimizer]

     # Create dataloader.
    def train_dataloader(self):
        dataset = RLDataset(self.buffer, self.hparams.samples_per_epoch)
       
        dataloader = DataLoader(
            dataset=dataset,
            batch_size=self.hparams.batch_size
        )
        return dataloader
    
    def training_step(self, batch, batch_idx):
        states, actions, rewards, dones, next_states = batch
        actions = actions.unsqueeze(1)
        rewards = rewards.unsqueeze(1)
        dones = dones.unsqueeze(1)
        
        state_action_values = self.q_net(states).gather(1, actions)

        next_action_values, _ = self.target_q_net(next_states).max(dim=1, keepdim=True)
        next_action_values[dones] = 0.0

        expected_state_action_values = rewards + self.hparams.gamma * next_action_values

        loss = self.hparams.loss_fn(state_action_values.float(), expected_state_action_values.float())
        self.log('episode/Q-Error', loss)
        return loss
    
    # Training epoch end.
    def training_epoch_end(self, training_step_outputs):
        epsilon = max(
            self.hparams.eps_end,
            self.hparams.eps_start - self.current_epoch / self.hparams.eps_last_episode
        )

        self.play_episode(policy=self.policy, epsilon=epsilon)
        self.log('episode/Return', self.env.return_queue[-1])

        if self.current_epoch % self.hparams.sync_rate == 0:
            self.target_q_net.load_state_dict(self.q_net.state_dict())
            
            
    def save_model(self):
        torch.save(self.q_net.state_dict(), "./model")
        
    def load_model(self):
        self.q_net.load_state_dict(torch.load( "./model"))


In [8]:
algo = DeepQLearning('ALE/Pong-v5')

checkpoint_callback = ModelCheckpoint(dirpath="./checkpoints/drqb-pong", save_top_k=1,mode="max", monitor="episode/Return")

trainer = Trainer(
     accelerator='gpu',
     devices=num_gpus,
     max_epochs=20_000,
     callbacks=[checkpoint_callback], # EarlyStopping(monitor='episode/Return', mode='max', patience=1000)
)

trainer.fit(algo)

  logger.warn(
  rank_zero_warn(
  logger.warn(
  logger.warn(


0 samples in experience buffer. Filling...
48 samples in experience buffer. Filling...
95 samples in experience buffer. Filling...
144 samples in experience buffer. Filling...
191 samples in experience buffer. Filling...
241 samples in experience buffer. Filling...
288 samples in experience buffer. Filling...
336 samples in experience buffer. Filling...
383 samples in experience buffer. Filling...
431 samples in experience buffer. Filling...
480 samples in experience buffer. Filling...
527 samples in experience buffer. Filling...
576 samples in experience buffer. Filling...
625 samples in experience buffer. Filling...
675 samples in experience buffer. Filling...
723 samples in experience buffer. Filling...
771 samples in experience buffer. Filling...
820 samples in experience buffer. Filling...
869 samples in experience buffer. Filling...
917 samples in experience buffer. Filling...


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


965 samples in experience buffer. Filling...
1014 samples in experience buffer. Filling...


  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type | Params
--------------------------------------
0 | q_net        | DRQN | 161 K 
1 | target_q_net | DRQN | 161 K 
--------------------------------------
322 K     Trainable params
0         Non-trainable params
322 K     Total params
1.290     Total estimated model params size (MB)
  rank_zero_warn(


Epoch 5659: : 1it [00:00, 23.95it/s, loss=0.0129, v_num=5] 

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [97]:

frames = []
env = create_environment("ALE/Pong-v5")
for episode in range(1):
  done = False
  obs, _  = env.reset()
  while not done:
    frames.append(obs)
    action = env.action_space.sample()
    obs, reward, done, t ,_ = env.step(action)
    #print(reward)
    if t:
        done = True
    env.close()
    
print(len(frames))



50


In [94]:
env = algo.env
policy = algo.policy
q_net = algo.q_net.cuda()
frames = []

for episode in range(10):
    state = env.reset()
    state = state[0]
    state = torch.from_numpy(state).unsqueeze(dim=0)
    done = False
    hidden = None
    while not done:
        action, hidden = policy(state, env, q_net, epsilon=0, hidden=hidden)
        next_state, reward, done, tru , _ = env.step(action)
        next_state = torch.from_numpy(next_state).unsqueeze(dim=0)
        if tru:
            done = tru
        state = next_state
        frame = state.squeeze(dim=0)
        frame = frame.numpy()
        frames.append(frame)
        


In [98]:
import matplotlib.pyplot as plt
from IPython.display import clear_output


print(frames[0].shape)
for frame in frames:
    plt.imshow(frame)
    plt.show()
    clear_output(wait=True)
print(frames[0])   


[[0.23529412 0.23529412 0.23529412 ... 0.34117648 0.34117648 0.34117648]
 [0.34117648 0.34117648 0.34117648 ... 0.34117648 0.34117648 0.34117648]
 [0.34117648 0.34117648 0.34117648 ... 0.34117648 0.34117648 0.34117648]
 ...
 [0.9254902  0.9254902  0.9254902  ... 0.9254902  0.9254902  0.9254902 ]
 [0.9254902  0.9254902  0.9254902  ... 0.9254902  0.9254902  0.9254902 ]
 [0.9254902  0.9254902  0.9254902  ... 0.9254902  0.9254902  0.9254902 ]]


In [None]:
loader = algo.train_dataloader()

ite = iter(loader)

x = ite.next()
x[0]

In [None]:
x = Variable(torch.tensor([[1,2,3,4],[1,2,3,4]]))
x = x.squeeze(dim=1)
x.shape

In [None]:
x = torch.tensor([[[1,2,3,4]]])
x.device

In [None]:
x = torch.tensor([[[[[1,1]]],[[[2,2]]],[[[3,3]]],[[[4,4]]],[[[5,5]]],[[[6,6]]]],[[[[7,7]]],[[[8,8]]],[[[1,2]]],[[[1,2]]],[[[1,2]]],[[[1,2]]]]])
x.shape

In [None]:
x[:,3,:,:,:]

In [None]:
x = torch.tensor([1,1])
y = torch.tensor([2,2])
stack = torch.stack([x,y],dim=0)
stack

In [None]:
new = torch.tensor([3,3]).unsqueeze(dim=0)


In [None]:
stack

In [None]:
x = torch.tensor([1,1,1,1])


In [None]:
l = [1,2,3,4,5,6]

for i in l[::2]:
    print(i)

In [None]:
x = torch.tensor([[1,1,1,1,1],[2,2,2,2,2],[3,3,3,3,3],[4,4,4,4,4]])
x.shape

In [None]:
for i in range(2):
    for j in range(2):
        print(x[i+j:2:])
    

In [10]:
x = ReplayBuffer(10)
x.append((1,1,0))
x.append((1,1,0))
x.append((1,1,1))
x.append((1,1,0))
x.append((1,1,0))
x.append((1,1,-1))
x.append((1,1,0))
x.append((1,1,0))
x.append((1,1,0))
x.append((1,1,1))
x.append((1,1,0))
x.append((1,1,-1))

In [12]:
x.sample(5,3)

[(1, 1, 0),
 (1, 1, 0),
 (1, 1, -1),
 (1, 1, 0),
 (1, 1, 0),
 (1, 1, -1),
 (1, 1, 0),
 (1, 1, 0),
 (1, 1, -1)]

In [88]:
newx

[(1, 1, 0),
 (1, 1, 0),
 (1, 1, 1),
 (1, 1, 0),
 (1, 1, 0),
 (1, 1, -1),
 (1, 1, 0),
 (1, 1, 0),
 (1, 1, 1),
 (1, 1, 1),
 (1, 1, 0),
 (1, 1, -1)]