In [51]:
import models
import torch
import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np

%load_ext autoreload
%autoreload 2

device=torch.device('cpu')

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [52]:
# env = gym.make('ALE/Breakout-v5', render_mode="rgb_array")
# env = gym.wrappers.ResizeObservation(env, shape=(64,64))
# # env = gym.wrappers.NormalizeObservation(env)
# obs, _ = env.reset()
# obs = obs.T

In [53]:
net = models.ConvEncoder().to(device)
print(net)

ConvEncoder(
  (conv_layer): Sequential(
    (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=valid)
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=valid)
    (3): ReLU()
    (4): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=valid)
    (5): ReLU()
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=valid)
    (7): ReLU()
  )
)


In [54]:
obs = np.random.random((3,64,64))
print(obs.shape)
out = net(torch.from_numpy(obs).to(device, dtype=torch.float))

(3, 64, 64)


In [55]:
decoder = models.ConvDecoder(600, 600).to(device)
print(decoder)

ConvDecoder(
  (net): Sequential(
    (0): Linear(in_features=1200, out_features=1024, bias=True)
    (1): Unflatten(dim=1, unflattened_size=(1024, 1))
    (2): Unflatten(dim=2, unflattened_size=(1, 1))
    (3): ConvTranspose2d(1024, 128, kernel_size=(5, 5), stride=(2, 2))
    (4): ReLU()
    (5): ConvTranspose2d(128, 64, kernel_size=(5, 5), stride=(2, 2))
    (6): ReLU()
    (7): ConvTranspose2d(64, 32, kernel_size=(6, 6), stride=(2, 2))
    (8): ReLU()
    (9): ConvTranspose2d(32, 3, kernel_size=(6, 6), stride=(2, 2))
  )
)


In [56]:
post = torch.rand((50,49, 600)).to(device)
det = torch.rand((50,49, 600)).to(device)
decoded = decoder(post, det)
print(decoded)

Independent(Normal(loc: torch.Size([50, 49, 3, 64, 64]), scale: torch.Size([50, 49, 3, 64, 64])), 3)


In [57]:
mps_device = torch.device('mps')
rssm_net = models.RSSM(600, 1024, 600, 200, 18).to(mps_device)
print(rssm_net)

RSSM(
  (recurrent_linear): Sequential(
    (0): Linear(in_features=618, out_features=200, bias=True)
    (1): ELU(alpha=1.0)
  )
  (gru_cell): GRUCell(200, 600)
  (representatio_model): Sequential(
    (0): Linear(in_features=1624, out_features=200, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=200, out_features=1200, bias=True)
  )
  (transition_model): Sequential(
    (0): Linear(in_features=600, out_features=200, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=200, out_features=1200, bias=True)
  )
)


In [58]:
test_act = torch.rand((50, 18)).to(mps_device)
test_post = torch.rand((50, 600)).to(mps_device)
test_deter = torch.rand((50, 600)).to(mps_device)
out = rssm_net.recurrent(test_post, test_act, test_deter)

In [59]:
out.shape

torch.Size([50, 600])

In [60]:
test_embed_obs = torch.rand((50,1024)).to(mps_device)
dist, out = rssm_net.representation(test_embed_obs, test_deter)

In [61]:
print(dist)
print(out.shape)

Normal(loc: torch.Size([50, 600]), scale: torch.Size([50, 600]))
torch.Size([50, 600])


In [62]:
dist, out = rssm_net.transition(test_deter)
print(dist)
print(out.shape)

Normal(loc: torch.Size([50, 600]), scale: torch.Size([50, 600]))
torch.Size([50, 600])


In [64]:
actor_net = models.Actor(1200, 400, 18, False)
print(actor_net)

Actor(
  (net): Sequential(
    (0): Linear(in_features=1200, out_features=400, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=400, out_features=36, bias=True)
  )
)


In [66]:
post = torch.rand((50,49,600))
deter = torch.rand((50,49,600))

out = actor_net(post, deter)
print(out.shape)

torch.Size([50, 49, 18])


In [70]:
critic_net = models.Critic(1200, 400)
print(critic_net)

Critic(
  (net): Sequential(
    (0): Linear(in_features=1200, out_features=400, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=400, out_features=400, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=400, out_features=1, bias=True)
  )
)


In [79]:
post = torch.rand((50,49,600))
deter = torch.rand((50,49,600))

out = critic_net(post, deter)
print(out)

Independent(Normal(loc: torch.Size([50, 49, 1]), scale: torch.Size([50, 49, 1])), 1)
