In [1]:
import torch
import numpy as np
import mazelab
import matplotlib.pyplot as plt 
import cv2

In [2]:
import seaborn as sns
sns.set_style("whitegrid")

In [3]:
from mazelab.solvers.dijkstra_solver import dijkstra_solver, dijkstra_solver_full

In [4]:
from rl_trickery.envs.maze import generate_random, MazelabEnv, Maze
from rl_trickery.envs.wrappers import ResizeImage, TransposeImage
from rl_trickery.data.maze_storage import generate_dataset

In [5]:
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import TensorDataset, Dataset

# Generate mazes

Generate maze, set goal, find solutions, store all.

In [53]:
BATCH_SIZE = 16
maze_size = 15
env = MazelabEnv(maze_size=maze_size, maze_kind="maze", goal_fixed=False, maze_fixed=False, goal_reward=False, wall_reward=False)
# env = ResizeImage(env, (64, 64), antialias=True)
env = TransposeImage(env)

dl_train = DataLoader(
    generate_dataset(env, 1e3, resize=False),
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
    drop_last=True
)
dl_test = DataLoader(
    generate_dataset(env, 1e3, resize=False),
    batch_size=BATCH_SIZE,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
    drop_last=True
)

# Init networks

Define models, init

In [54]:
from rl_trickery.models.tricky_policy_networks import *
from rl_trickery.models.conv_lstm import ConvLSTM

In [55]:
class SeqSolverRNN(nn.Module):
    def __init__(
        self,
        obs_space,
        state_channels=32,
    ):
        super(SeqSolverRNN, self).__init__()
        
        n_channels = obs_space.shape[0]
        im_size = obs_space.shape[1]

        init_relu = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0),
                                   nn.init.calculate_gain('relu'))
        
        self.encoder = nn.Sequential(
            init_relu(nn.Conv2d(n_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            Flatten(),
            init_relu(nn.Linear(state_channels*im_size*im_size, 256)),
            nn.ReLU(),
        )
        
        self.recurse = nn.LSTMCell(256, 256)
        
        self.out_layers = nn.Sequential(
            init_relu(nn.Linear(256, 256)),
            nn.ReLU(),
            init_relu(nn.Linear(256, 1)),
        )

    def forward(self, obs, recurse_depth=5):
        h = self.encoder(obs)
        
        outputs = []
        rnn_out = None
        for i in range(recurse_depth):
            rnn_out = self.recurse(h, rnn_out)
            h_cur, c_cur = rnn_out
            out = self.out_layers(h_cur)
            outputs.append(out)
            
        outputs = torch.stack(outputs)
        return outputs


In [56]:
class SeqSolverFF12(nn.Module):
    def __init__(
        self,
        obs_space,
        state_channels=32,
    ):
        super(SeqSolverFF12, self).__init__()
        
        n_channels = obs_space.shape[0]
        im_size = obs_space.shape[1]

        init_relu = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0),
                                   nn.init.calculate_gain('relu'))
        
        self.encoder = nn.Sequential(
            init_relu(nn.Conv2d(n_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
        )
        
        self.recurse = nn.Sequential(
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            
        )
        
        self.out_layers = nn.Sequential(
#             init_relu(nn.Conv2d(state_channels, 3, kernel_size=(3, 3), padding=1)),
            Flatten(),
            init_relu(nn.Linear(state_channels*im_size*im_size, 256)),
            nn.ReLU(),
            init_relu(nn.Linear(256, 1)),
        )

    def forward(self, obs, recurse_depth=5):
        h = self.encoder(obs)
        h = self.recurse(h)
        out = self.out_layers(h)
        outputs = torch.stack([out])
            
        return outputs


In [57]:
class SeqSolverFF5(nn.Module):
    def __init__(
        self,
        obs_space,
        state_channels=32,
    ):
        super(SeqSolverFF5, self).__init__()
        
        n_channels = obs_space.shape[0]
        im_size = obs_space.shape[1]

        init_relu = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0),
                                   nn.init.calculate_gain('relu'))
        
        self.encoder = nn.Sequential(
            init_relu(nn.Conv2d(n_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
        )
        
        self.recurse = nn.Sequential(
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
        )
        
        self.out_layers = nn.Sequential(
#             init_relu(nn.Conv2d(state_channels, 3, kernel_size=(3, 3), padding=1)),
            Flatten(),
            init_relu(nn.Linear(state_channels*im_size*im_size, 256)),
            nn.ReLU(),
            init_relu(nn.Linear(256, 1)),
        )

    def forward(self, obs, recurse_depth=5):
        h = self.encoder(obs)
        h = self.recurse(h)
        out = self.out_layers(h)
        outputs = torch.stack([out])
            
        return outputs


In [58]:
class SeqSolverFF1(nn.Module):
    def __init__(
        self,
        obs_space,
        state_channels=32,
    ):
        super(SeqSolverFF1, self).__init__()
        
        n_channels = obs_space.shape[0]
        im_size = obs_space.shape[1]

        init_relu = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0),
                                   nn.init.calculate_gain('relu'))
        
        self.encoder = nn.Sequential(
            init_relu(nn.Conv2d(n_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
        )
        
        self.recurse = nn.Sequential(
            init_relu(nn.Conv2d(state_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
        )
        
        self.out_layers = nn.Sequential(
#             init_relu(nn.Conv2d(state_channels, 3, kernel_size=(3, 3), padding=1)),
#             nn.MaxPool2d(2),
            Flatten(),
            init_relu(nn.Linear(state_channels*im_size*im_size, 256)),
            nn.ReLU(),
            init_relu(nn.Linear(256, 1)),
        )

    def forward(self, obs, recurse_depth=5):
        h = self.encoder(obs)
        h = self.recurse(h)
        out = self.out_layers(h)
        outputs = torch.stack([out])
            
        return outputs


In [59]:
class SeqSolverCRNN(nn.Module):
    def __init__(
        self,
        obs_space,
        state_channels=32,
    ):
        super(SeqSolverCRNN, self).__init__()
        
        n_channels = obs_space.shape[0]
        im_size = obs_space.shape[1]

        init_relu = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0),
                                   nn.init.calculate_gain('relu'))
        
        self.encoder = nn.Sequential(
            init_relu(nn.Conv2d(n_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
        )
        
        self.recurse = ConvLSTMCell(state_channels, state_channels, kernel_size=(3, 3), bias=True)
        
        self.out_layers = nn.Sequential(
#             init_relu(nn.Conv2d(state_channels, 3, kernel_size=(3, 3), padding=1)),
            Flatten(),
            init_relu(nn.Linear(state_channels*im_size*im_size, 256)),
            nn.ReLU(),
            init_relu(nn.Linear(256, 1)),
        )

    def forward(self, obs, recurse_depth=5):
        h = self.encoder(obs)
        
        outputs = []
        rnn_out = None
        for i in range(recurse_depth):
            rnn_out = self.recurse(h, rnn_out)
            h_cur, c_cur = rnn_out
            out = self.out_layers(h_cur)
            outputs.append(out)
            
        outputs = torch.stack(outputs)
        return outputs


In [60]:
class SeqSolverMuZero(nn.Module):
    def __init__(
        self,
        obs_space,
        state_channels=32,
    ):
        super(SeqSolverMuZero, self).__init__()
        
        n_channels = obs_space.shape[0]
        im_size = obs_space.shape[1]

        init_relu = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0),
                                   nn.init.calculate_gain('relu'))
        
        self.encoder = nn.Sequential(
            init_relu(nn.Conv2d(n_channels, state_channels, kernel_size=(3, 3), padding=1)),
            nn.ReLU(),
        )
        
        self.recurse = DynamicsNetwork(obs_space)
        
        self.out_layers = nn.Sequential(
#             init_relu(nn.Conv2d(state_channels, 3, kernel_size=(3, 3), padding=1)),
            Flatten(),
            init_relu(nn.Linear(state_channels*im_size*im_size, 256)),
            nn.ReLU(),
            init_relu(nn.Linear(256, 1)),
        )

    def forward(self, obs, recurse_depth=5):
        h = self.encoder(obs)
        h = self.recurse(h)
        out = self.out_layers(h)
        outputs = torch.stack([out])
            
        return outputs



In [63]:
net = SeqSolverMuZero(
    env.observation_space,
).cuda()

In [62]:
net = SeqSolverCRNN(
    env.observation_space,
).cuda()

In [14]:
net = SeqSolverFF1(
    env.observation_space,
).cuda()

In [15]:
net = SeqSolverFF12(
    env.observation_space,
).cuda()

In [16]:
net = SeqSolverFF5(
    env.observation_space,
).cuda()

In [17]:
net = SeqSolverRNN(
    env.observation_space,
).cuda()

In [64]:
EPOCHS = 500
opt = torch.optim.Adam(net.parameters(), lr=0.001)

for i_epoch in range(EPOCHS):
    total_loss = 0
    for i_batch, data in enumerate(dl_train):
        net.zero_grad()
        
        x, y = data
        y = y.cuda()

        y_hat = net(x.cuda(), 7)
        y = y.repeat(y_hat.size(0), 1, 1)
        loss = F.smooth_l1_loss(y_hat[-5:], y[-5:])
        loss.backward()
        opt.step()
        total_loss += loss
    total_loss /= i_batch+1
        
            
    with torch.no_grad():
        test_loss = 0
        test_loss_1 = 0
        for i_batch, data in enumerate(dl_test):
            x, y = data
            y = y.cuda()
            

            y_hat = net(x.cuda(), 7)
            y = y.repeat(y_hat.size(0), 1, 1)
            test_loss_1 += F.smooth_l1_loss(y_hat, y, reduction="none").mean(dim=(1,2))
            test_loss += F.smooth_l1_loss(y_hat[-5:], y[-5:])

        test_loss_1 /= i_batch+1
        test_loss /= i_batch+1
    
    

    print("epoch", i_epoch)
    print("train", total_loss.item(), "test", test_loss.item(), "var", (y_hat[-1].var()/y[-1].var()).item())
    print("test_long", test_loss_1)


epoch 0
train 10.465069770812988 test 5.174776077270508 var 0.32366320490837097
test_long tensor([5.1748], device='cuda:0')
epoch 1
train 5.495233058929443 test 4.460792064666748 var 0.5373578071594238
test_long tensor([4.4608], device='cuda:0')
epoch 2
train 4.533514499664307 test 4.32859992980957 var 0.596479594707489
test_long tensor([4.3286], device='cuda:0')
epoch 3
train 4.8481292724609375 test 4.922457695007324 var 0.4477035701274872
test_long tensor([4.9225], device='cuda:0')
epoch 4
train 4.044459819793701 test 4.030269145965576 var 0.5511994361877441
test_long tensor([4.0303], device='cuda:0')
epoch 5
train 4.135699272155762 test 4.060540199279785 var 0.4620460271835327
test_long tensor([4.0605], device='cuda:0')
epoch 6
train 4.022524356842041 test 3.9714102745056152 var 0.4983888864517212
test_long tensor([3.9714], device='cuda:0')
epoch 7
train 4.236393928527832 test 4.638983726501465 var 0.25521183013916016
test_long tensor([4.6390], device='cuda:0')
epoch 8
train 4.24130

epoch 67
train 2.2828071117401123 test 2.3756608963012695 var 0.737372636795044
test_long tensor([2.3757], device='cuda:0')
epoch 68
train 2.311039686203003 test 2.147235870361328 var 0.8799191117286682
test_long tensor([2.1472], device='cuda:0')
epoch 69
train 2.064668655395508 test 2.2403993606567383 var 0.6387246251106262
test_long tensor([2.2404], device='cuda:0')
epoch 70
train 2.113986015319824 test 2.549696207046509 var 0.9789568781852722
test_long tensor([2.5497], device='cuda:0')
epoch 71
train 2.3055672645568848 test 2.2138140201568604 var 2.0319530963897705
test_long tensor([2.2138], device='cuda:0')
epoch 72
train 2.3583247661590576 test 2.066009044647217 var 0.591119110584259
test_long tensor([2.0660], device='cuda:0')
epoch 73
train 2.2268176078796387 test 2.1302669048309326 var 0.8484196066856384
test_long tensor([2.1303], device='cuda:0')
epoch 74
train 2.210050582885742 test 2.306448459625244 var 0.8822386264801025
test_long tensor([2.3064], device='cuda:0')
epoch 75
t

epoch 133
train 1.068302035331726 test 1.3300771713256836 var 0.8825905323028564
test_long tensor([1.3301], device='cuda:0')
epoch 134
train 1.250789999961853 test 1.2392843961715698 var 0.8087695837020874
test_long tensor([1.2393], device='cuda:0')
epoch 135
train 1.1472887992858887 test 1.3014475107192993 var 1.0288670063018799
test_long tensor([1.3014], device='cuda:0')
epoch 136
train 1.4232127666473389 test 1.656859040260315 var 0.9950733780860901
test_long tensor([1.6569], device='cuda:0')
epoch 137
train 1.2008938789367676 test 1.148435354232788 var 0.7749805450439453
test_long tensor([1.1484], device='cuda:0')
epoch 138
train 1.181136965751648 test 1.161110758781433 var 0.8320150375366211
test_long tensor([1.1611], device='cuda:0')
epoch 139
train 1.2258965969085693 test 1.112381935119629 var 0.812635064125061
test_long tensor([1.1124], device='cuda:0')
epoch 140
train 1.1053354740142822 test 1.2684930562973022 var 0.866024374961853
test_long tensor([1.2685], device='cuda:0')
e

epoch 199
train 0.7977225184440613 test 0.824255108833313 var 1.0721559524536133
test_long tensor([0.8243], device='cuda:0')
epoch 200
train 0.8252397179603577 test 1.0204304456710815 var 0.717557430267334
test_long tensor([1.0204], device='cuda:0')
epoch 201
train 0.6972968578338623 test 0.7092879414558411 var 0.764660120010376
test_long tensor([0.7093], device='cuda:0')
epoch 202
train 0.8459187746047974 test 0.8639279007911682 var 0.9492583274841309
test_long tensor([0.8639], device='cuda:0')
epoch 203
train 0.7552265524864197 test 0.8426892757415771 var 0.940020740032196
test_long tensor([0.8427], device='cuda:0')
epoch 204
train 0.7246654629707336 test 0.840030312538147 var 1.5579884052276611
test_long tensor([0.8400], device='cuda:0')
epoch 205
train 0.7020742893218994 test 0.7967822551727295 var 0.888885498046875
test_long tensor([0.7968], device='cuda:0')
epoch 206
train 0.7732316851615906 test 1.11326003074646 var 0.8664326071739197
test_long tensor([1.1133], device='cuda:0')


epoch 265
train 0.5523657202720642 test 0.8830793499946594 var 0.8979906439781189
test_long tensor([0.8831], device='cuda:0')
epoch 266
train 0.6112418174743652 test 0.6466649174690247 var 0.978419303894043
test_long tensor([0.6467], device='cuda:0')
epoch 267
train 0.5149427056312561 test 0.7461268305778503 var 0.9230627417564392
test_long tensor([0.7461], device='cuda:0')
epoch 268
train 0.5123536586761475 test 0.8095065951347351 var 1.0664905309677124
test_long tensor([0.8095], device='cuda:0')
epoch 269
train 0.7316500544548035 test 1.046293020248413 var 0.9144373536109924
test_long tensor([1.0463], device='cuda:0')
epoch 270
train 0.6218879818916321 test 0.7101009488105774 var 0.9895703196525574
test_long tensor([0.7101], device='cuda:0')
epoch 271
train 0.5718905329704285 test 0.6811654567718506 var 1.0771957635879517
test_long tensor([0.6812], device='cuda:0')
epoch 272
train 0.550055742263794 test 0.7829280495643616 var 0.8649644255638123
test_long tensor([0.7829], device='cuda

epoch 331
train 0.5033696889877319 test 0.5503646731376648 var 0.87681645154953
test_long tensor([0.5504], device='cuda:0')
epoch 332
train 0.44636037945747375 test 0.631424605846405 var 1.0580041408538818
test_long tensor([0.6314], device='cuda:0')
epoch 333
train 0.4595438838005066 test 0.6308294534683228 var 0.9521017670631409
test_long tensor([0.6308], device='cuda:0')
epoch 334
train 0.48241594433784485 test 0.5910577774047852 var 1.1725714206695557
test_long tensor([0.5911], device='cuda:0')
epoch 335
train 0.4540734589099884 test 0.5364121198654175 var 1.0032424926757812
test_long tensor([0.5364], device='cuda:0')
epoch 336
train 0.4255520701408386 test 0.5846245884895325 var 1.017091989517212
test_long tensor([0.5846], device='cuda:0')
epoch 337
train 0.44366535544395447 test 0.5906643271446228 var 0.8908866047859192
test_long tensor([0.5907], device='cuda:0')
epoch 338
train 0.44575610756874084 test 0.5576006174087524 var 1.110438585281372
test_long tensor([0.5576], device='cu

epoch 396
train 0.37869367003440857 test 0.5246864557266235 var 0.972496509552002
test_long tensor([0.5247], device='cuda:0')
epoch 397
train 0.3734147548675537 test 0.4772006571292877 var 0.8466070890426636
test_long tensor([0.4772], device='cuda:0')
epoch 398
train 0.3675929605960846 test 0.5839464664459229 var 0.7593094110488892
test_long tensor([0.5839], device='cuda:0')
epoch 399
train 0.37445488572120667 test 0.4516356289386749 var 0.8946686387062073
test_long tensor([0.4516], device='cuda:0')
epoch 400
train 0.40352845191955566 test 0.5568632483482361 var 1.1289278268814087
test_long tensor([0.5569], device='cuda:0')
epoch 401
train 0.4556216299533844 test 0.539543628692627 var 0.9518262147903442
test_long tensor([0.5395], device='cuda:0')
epoch 402
train 0.3928239345550537 test 0.5073076486587524 var 1.0253578424453735
test_long tensor([0.5073], device='cuda:0')
epoch 403
train 0.39652758836746216 test 0.474418580532074 var 0.7940428256988525
test_long tensor([0.4744], device='

epoch 461
train 0.36249345541000366 test 0.44064566493034363 var 1.0781279802322388
test_long tensor([0.4406], device='cuda:0')
epoch 462
train 0.3795504868030548 test 0.40944868326187134 var 0.8489493727684021
test_long tensor([0.4094], device='cuda:0')
epoch 463
train 0.31134262681007385 test 0.5064818859100342 var 1.0257666110992432
test_long tensor([0.5065], device='cuda:0')
epoch 464
train 0.36415398120880127 test 0.5887435078620911 var 1.2289592027664185
test_long tensor([0.5887], device='cuda:0')
epoch 465
train 0.36333340406417847 test 0.49785318970680237 var 0.9867194294929504
test_long tensor([0.4979], device='cuda:0')
epoch 466
train 0.29469895362854004 test 0.45501115918159485 var 1.036777138710022
test_long tensor([0.4550], device='cuda:0')
epoch 467
train 0.3380342423915863 test 0.43189728260040283 var 0.8428900837898254
test_long tensor([0.4319], device='cuda:0')
epoch 468
train 0.33611607551574707 test 0.5298059582710266 var 0.8160582780838013
test_long tensor([0.5298],

In [None]:
# ff12
print("epoch", i_epoch)
print("train", total_loss.item(), "test", test_loss.item(), "var", (y_hat[-1].var()/y[-1].var()).item())
print("test_long", test_loss_1)

In [None]:
# ff1
print("epoch", i_epoch)
print("train", total_loss.item(), "test", test_loss.item(), "var", (y_hat[-1].var()/y[-1].var()).item())
print("test_long", test_loss_1)

In [None]:
# rnn
print("epoch", i_epoch)
print("train", total_loss.item(), "test", test_loss.item(), "var", (y_hat[-1].var()/y[-1].var()).item())
print("test_long", test_loss_1)

In [None]:
# crnn
print("epoch", i_epoch)
print("train", total_loss.item(), "test", test_loss.item(), "var", (y_hat[-1].var()/y[-1].var()).item())
print("test_long", test_loss_1)

In [None]:
with torch.no_grad():
    test_loss = 0
    test_loss_1 = 0
    for i_batch, data in enumerate(dl_test):
        x, y = data
        y = y.cuda()


        y_hat = net(x.cuda(), 7)
        y = y.repeat(y_hat.size(0), 1, 1)
        test_loss_1 += F.smooth_l1_loss(y_hat, y, reduction="none").mean(dim=(1,2))
        test_loss += F.smooth_l1_loss(y_hat[-5:], y[-5:])

    test_loss_1 /= i_batch+1
    test_loss /= i_batch+1

In [None]:
i = 0

In [None]:
idx = ((y[-1] - y_hat[-1])**2).argmax()
im = x.cpu().numpy()[idx].transpose([1, 2, 0])
plt.imshow(im)
print(y[-1, idx])
print(y_hat[-1, idx])


In [None]:
i += 1
im = x.cpu().numpy()[i].transpose([1, 2, 0])
plt.imshow(im)
print(y_hat[-15:, i], y[-1, i])

In [None]:
torch.cat([y_hat, y, (y - y_hat)], axis=1).detach().cpu().numpy()

In [None]:
net(x.cuda())

In [48]:
import torch

def conv3x3(in_channels, out_channels, stride=1):
    return torch.nn.Conv2d(
        in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
    )


# Residual block
class ResidualBlock(torch.nn.Module):
    def __init__(self, num_channels, stride=1):
        super().__init__()
        self.conv1 = conv3x3(num_channels, num_channels, stride)
        self.bn1 = torch.nn.BatchNorm2d(num_channels)
        self.relu = torch.nn.ReLU()
        self.conv2 = conv3x3(num_channels, num_channels)
        self.bn2 = torch.nn.BatchNorm2d(num_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += x
        out = self.relu(out)
        return out


class FullyConnectedNetwork(torch.nn.Module):
    def __init__(self, input_size, layer_sizes, output_size, activation=None):
        super().__init__()
        size_list = [input_size] + layer_sizes
        layers = []
        if 1 < len(size_list):
            for i in range(len(size_list) - 1):
                layers.extend(
                    [
                        torch.nn.Linear(size_list[i], size_list[i + 1]),
                        torch.nn.LeakyReLU(),
                    ]
                )
        layers.append(torch.nn.Linear(size_list[-1], output_size))
        if activation:
            layers.append(activation)
        self.layers = torch.nn.ModuleList(layers)

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class DynamicsNetwork(torch.nn.Module):
    def __init__(
        self,
        obs_space,
        num_blocks=16,
        num_channels=33,
        reduced_channels=32,
    ):
        super().__init__()
        observation_shape = obs_space.shape
        self.observation_shape = observation_shape
        self.conv = conv3x3(num_channels -1, num_channels - 1)
        self.bn = torch.nn.BatchNorm2d(num_channels - 1)
        self.relu = torch.nn.ReLU()
        self.resblocks = torch.nn.ModuleList(
            [ResidualBlock(num_channels - 1) for _ in range(num_blocks)]
        )

        self.conv1x1 = torch.n
        n.Conv2d(num_channels - 1, reduced_channels, 1)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        for block in self.resblocks:
            out = block(out)
        state = out
        return state

In [38]:
nett = DynamicsNetwork(
    env.observation_space,
    num_blocks=5,
)

In [39]:
nett(x).size()

torch.Size([16, 32, 9, 9])