In [None]:
import torch
import torch.nn as nn

# Let's create an autoencoder model...

# Create an encoder model instance
encoder = nn.Sequential(
  nn.Conv2d(in_channels=1,out_channels=32,kernel_size=(5,5),stride=(1,1)),
  nn.ReLU(inplace=True),
  nn.Conv2d(32,32,kernel_size=5,stride=1),
  nn.ReLU(inplace=True),
  nn.Conv2d(32,32,kernel_size=4,stride=2),
  nn.ReLU(inplace=True),
  nn.Conv2d(32,32,kernel_size=3,stride=2),
  nn.ReLU(inplace=True),
  nn.Conv2d(32,8,kernel_size=4,stride=1)
)

print(encoder)

x = torch.randn((16,1,64,64))
print('x.shape:',x.shape)
y=encoder(x)
print('y.shape:',y.shape)

# Create a decoder class
class decoder(nn.Module):
  def __init__(self):
    super().__init__()
    self.convT1 = nn.ConvTranspose2d(8,32,4,1)
    self.convT2 = nn.ConvTranspose2d(32,32,3,2)
    self.convT3 = nn.ConvTranspose2d(32,32,4,2)
    self.convT4 = nn.ConvTranspose2d(32,32,5,1)
    self.convT5 = nn.ConvTranspose2d(32,1,5,1)

    self.relu = nn.ReLU()

  def forward(self,x):
    out = self.convT1(x)
    out = self.relu(out)
    out = self.convT2(out)
    out = self.relu(out)
    out = self.convT3(out)
    out = self.relu(out)
    out = self.convT4(out)
    out = self.relu(out)
    out = self.convT5(out)

    return out

# Create an instance of the decoder class
dec = decoder()
print(dec)

z = dec(y)
print('z.shape:',z.shape)


Sequential(
  (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (3): ReLU(inplace=True)
  (4): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2))
  (5): ReLU(inplace=True)
  (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
  (7): ReLU(inplace=True)
  (8): Conv2d(32, 8, kernel_size=(4, 4), stride=(1, 1))
)
x.shape: torch.Size([16, 1, 64, 64])
y.shape: torch.Size([16, 8, 10, 10])
decoder(
  (convT1): ConvTranspose2d(8, 32, kernel_size=(4, 4), stride=(1, 1))
  (convT2): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
  (convT3): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2))
  (convT4): ConvTranspose2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
  (convT5): ConvTranspose2d(32, 1, kernel_size=(5, 5), stride=(1, 1))
  (relu): ReLU()
)
z.shape: torch.Size([16, 1, 64, 64])


In [None]:
# Subblocks can be combined with nn.Sequential()
AutoEncoder = nn.Sequential(
    encoder,
    dec
)

print(AutoEncoder)

z_pred = AutoEncoder(x)
print(z_pred.shape)

Sequential(
  (0): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2))
    (5): ReLU(inplace=True)
    (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
    (7): ReLU(inplace=True)
    (8): Conv2d(32, 8, kernel_size=(4, 4), stride=(1, 1))
  )
  (1): decoder(
    (convT1): ConvTranspose2d(8, 32, kernel_size=(4, 4), stride=(1, 1))
    (convT2): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
    (convT3): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2))
    (convT4): ConvTranspose2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
    (convT5): ConvTranspose2d(32, 1, kernel_size=(5, 5), stride=(1, 1))
    (relu): ReLU()
  )
)
torch.Size([16, 1, 64, 64])


In [None]:
!pip install torchinfo
from torchinfo import summary
summary(AutoEncoder, (1, 1, 64, 64))

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


Layer (type:depth-idx)                   Output Shape              Param #
Sequential                               [1, 1, 64, 64]            --
├─Sequential: 1-1                        [1, 8, 10, 10]            --
│    └─Conv2d: 2-1                       [1, 32, 60, 60]           832
│    └─ReLU: 2-2                         [1, 32, 60, 60]           --
│    └─Conv2d: 2-3                       [1, 32, 56, 56]           25,632
│    └─ReLU: 2-4                         [1, 32, 56, 56]           --
│    └─Conv2d: 2-5                       [1, 32, 27, 27]           16,416
│    └─ReLU: 2-6                         [1, 32, 27, 27]           --
│    └─Conv2d: 2-7                       [1, 32, 13, 13]           9,248
│    └─ReLU: 2-8                         [1, 32, 13, 13]           --
│    └─Conv2d: 2-9                       [1, 8, 10, 10]            4,104
├─decoder: 1-2                           [1, 1, 64, 64]            --
│    └─ConvTranspose2d: 2-10             [1, 32, 13, 13]           4,1

In [None]:
# Subblocks can be combined under nn.Module
class AutoEncoder(nn.Module):
  def __init__(self,encoder,decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder
  
  def forward(self,x):
    out = self.encoder(x)
    out = self.decoder(out)
    return out

AutoEnc = AutoEncoder(encoder,dec)
print(AutoEnc)


AutoEncoder(
  (encoder): Sequential(
    (0): Conv2d(1, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
    (3): ReLU(inplace=True)
    (4): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2))
    (5): ReLU(inplace=True)
    (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
    (7): ReLU(inplace=True)
    (8): Conv2d(32, 8, kernel_size=(4, 4), stride=(1, 1))
  )
  (decoder): decoder(
    (convT1): ConvTranspose2d(8, 32, kernel_size=(4, 4), stride=(1, 1))
    (convT2): ConvTranspose2d(32, 32, kernel_size=(3, 3), stride=(2, 2))
    (convT3): ConvTranspose2d(32, 32, kernel_size=(4, 4), stride=(2, 2))
    (convT4): ConvTranspose2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
    (convT5): ConvTranspose2d(32, 1, kernel_size=(5, 5), stride=(1, 1))
    (relu): ReLU()
  )
)


In [None]:
from torchsummary import summary
summary(AutoEnc, (1, 64, 64))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 60, 60]             832
              ReLU-2           [-1, 32, 60, 60]               0
            Conv2d-3           [-1, 32, 56, 56]          25,632
              ReLU-4           [-1, 32, 56, 56]               0
            Conv2d-5           [-1, 32, 27, 27]          16,416
              ReLU-6           [-1, 32, 27, 27]               0
            Conv2d-7           [-1, 32, 13, 13]           9,248
              ReLU-8           [-1, 32, 13, 13]               0
            Conv2d-9            [-1, 8, 10, 10]           4,104
  ConvTranspose2d-10           [-1, 32, 13, 13]           4,128
             ReLU-11           [-1, 32, 13, 13]               0
  ConvTranspose2d-12           [-1, 32, 27, 27]           9,248
             ReLU-13           [-1, 32, 27, 27]               0
  ConvTranspose2d-14           [-1, 32,

In [None]:
!pip install torchinfo
from torchinfo import summary
summary(AutoEnc, (1, 1, 64, 64))

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchinfo
  Downloading torchinfo-1.7.1-py3-none-any.whl (22 kB)
Installing collected packages: torchinfo
Successfully installed torchinfo-1.7.1


Layer (type:depth-idx)                   Output Shape              Param #
AutoEncoder                              [1, 1, 64, 64]            --
├─Sequential: 1-1                        [1, 8, 10, 10]            --
│    └─Conv2d: 2-1                       [1, 32, 60, 60]           832
│    └─ReLU: 2-2                         [1, 32, 60, 60]           --
│    └─Conv2d: 2-3                       [1, 32, 56, 56]           25,632
│    └─ReLU: 2-4                         [1, 32, 56, 56]           --
│    └─Conv2d: 2-5                       [1, 32, 27, 27]           16,416
│    └─ReLU: 2-6                         [1, 32, 27, 27]           --
│    └─Conv2d: 2-7                       [1, 32, 13, 13]           9,248
│    └─ReLU: 2-8                         [1, 32, 13, 13]           --
│    └─Conv2d: 2-9                       [1, 8, 10, 10]            4,104
├─decoder: 1-2                           [1, 1, 64, 64]            --
│    └─ConvTranspose2d: 2-10             [1, 32, 13, 13]           4,1