In [1]:
import gymnasium as gym
from gymnasium.wrappers import GrayScaleObservation, TransformObservation
from gymnasium.wrappers import FrameStack, FlattenObservation, ResizeObservation

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch import nn

from collections import deque
from tqdm import tqdm

In [2]:
class ReplayBuffer:
    def __init__(self, length=10000, batch_size = 32, device=None):
        self.mem = deque(maxlen=length)
        self.length = length
        self.batch_size = [batch_size]
        self.device = device if device else torch.device("cpu")
    
    def store(self, s, a, r, s_, d):
        self.mem.append([s, a, r, s_, d])
    
    def sample(self):
        idxs = torch.randint(high=len(self), size=self.batch_size)
        batch = [self.mem[idx] for idx in idxs]
        return map(torch.Tensor, list(zip(*batch)))
        
    def __len__(self):
        return len(self.mem)

In [None]:
env = gym.make("ALE/Pong-v5")
env = GrayScaleObservation(env)
env = ResizeObservation(env, 64)
env = FrameStack(env, 3)
env = FlattenObservation(env)
env = TransformObservation(env, lambda obs: obs / 255.0)

In [3]:
env = gym.make("CartPole-v1")
# env = FrameStack(env, 3)
env = FlattenObservation(env)

In [None]:
env = gym.make("FrozenLake-v1")

In [4]:
input_dim = env.observation_space.shape[0]
output_dim = env.action_space.n

model = nn.Sequential(
    nn.Linear(input_dim, 64), nn.ReLU(),
    nn.Linear(64, 64), nn.ReLU(),
    nn.Linear(64, output_dim)
)

loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [5]:
batch_size = 256
n_steps = 1
epsilon = 0.1
gamma = 0.99
max_episodes = 10000
step = 0

rb = ReplayBuffer(length=1000000, batch_size=batch_size)
writer = SummaryWriter()

In [9]:
max_episodes = 100000
for episode in tqdm(range(max_episodes)):
    obs, info = env.reset()
    done = False
    
    ep_len = 0
    ep_loss = []
    ep_reward = 0.0
    
    while not done:
        step += 1
        best_act = torch.argmax(model.forward(torch.tensor(obs)))
        action = env.action_space.sample() if (torch.rand(1) < epsilon) else best_act.item()
        obs_, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        rb.store(obs, action, reward, obs_, done)
        obs = obs_
        
        if step % n_steps == 0:      
            s, a, r, s_, d = rb.sample()
            qs = model.forward(s)
            qs_ = model.forward(s_)
            qs_max = torch.max(qs_, dim=1)[0]
            target = r * d + gamma * qs_max * (1 - d)
            actual = qs[range(batch_size), a.tolist()]

            loss = loss_fn(actual, target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
        ep_len += 1
        ep_loss.append(loss.item())
        ep_reward += reward
        
    writer.add_scalar('Train/length', np.mean(ep_len), step)
    writer.add_scalar('Train/loss', np.mean(ep_loss), step)
    writer.add_scalar('Train/reward', ep_reward, step)

 49%|████▉     | 49160/100000 [10:29:48<10:51:19,  1.30it/s]    


KeyboardInterrupt: 

In [None]:
for episode in tqdm(range(max_episodes)):
    obs, info = env.reset()
    done = False
    
    while not done:
        step += 1
        with torch.no_grad():
            best_act = torch.argmax(model.forward(torch.tensor(obs)))
            action = env.action_space.sample() if (torch.rand(1) < epsilon) else best_act.item()
            obs_, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
        
        env.render()

In [None]:
# s, a, r, s_, d = list(rb.sample())
s, a, r, s_, d = map(torch.tensor, list(rb.sample()))

d.shape
type(s)
# torch.tensor(s_).shape

In [None]:
t = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([False, False, True])
b2 = torch.tensor([10.0, 11.0, 12.0])
t * b + b2

In [None]:
dd = [[1.0, 1.5], [2.0, 2.5]]
# list(zip(*dd))

ddd = [[1.0, 1.5], [2.0, 2.5]]
ss = list(map(torch.tensor, ddd[0]))
ss