In [35]:
import torch as nn
import torch.nn.functional as F
from models_extended import *
from torchsummary import summary

### DECODER Summary
```
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
├─Embedding: 1-1                         [-1, 52, 512]             4,858,880
├─Linear: 1-2                            [-1, 512]                 1,049,088
├─Linear: 1-3                            [-1, 512]                 1,049,088
├─Attention: 1-4                         [-1, 2048]                --
|    └─Linear: 2-1                       [-1, 196, 512]            1,049,088
|    └─Linear: 2-2                       [-1, 512]                 262,656
|    └─ReLU: 2-3                         [-1, 196, 512]            --
|    └─Linear: 2-4                       [-1, 196, 1]              513
|    └─Softmax: 2-5                      [-1, 196]                 --
├─Linear: 1-5                            [-1, 2048]                1,050,624
├─Sigmoid: 1-6                           [-1, 2048]                --
├─LSTMCell: 1-7                          [-1, 512]                 6,295,552
├─Dropout: 1-8                           [-1, 512]                 --
├─Linear: 1-9                            [-1, 9490]                4,868,370
├─Attention: 1-10                        [-1, 2048]                (recursive)
|    └─Linear: 2-6                       [-1, 196, 512]            (recursive)
|    └─Linear: 2-7                       [-1, 512]                 (recursive)
|    └─ReLU: 2-8                         [-1, 196, 512]            --
|    └─Linear: 2-9                       [-1, 196, 1]              (recursive)
|    └─Softmax: 2-10                     [-1, 196]                 --
├─Linear: 1-11                           [-1, 2048]                (recursive)
├─Sigmoid: 1-12                          [-1, 2048]                --
├─LSTMCell: 1-13                         [-1, 512]                 (recursive)
├─Dropout: 1-14                          [-1, 512]                 --
├─Linear: 1-15                           [-1, 9490]                (recursive)
├─Attention: 1-16                        [-1, 2048]                (recursive)
|    └─Linear: 2-11                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-12                      [-1, 512]                 (recursive)
|    └─ReLU: 2-13                        [-1, 196, 512]            --
|    └─Linear: 2-14                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-15                     [-1, 196]                 --
├─Linear: 1-17                           [-1, 2048]                (recursive)
├─Sigmoid: 1-18                          [-1, 2048]                --
├─LSTMCell: 1-19                         [-1, 512]                 (recursive)
├─Dropout: 1-20                          [-1, 512]                 --
├─Linear: 1-21                           [-1, 9490]                (recursive)
├─Attention: 1-22                        [-1, 2048]                (recursive)
|    └─Linear: 2-16                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-17                      [-1, 512]                 (recursive)
|    └─ReLU: 2-18                        [-1, 196, 512]            --
|    └─Linear: 2-19                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-20                     [-1, 196]                 --
├─Linear: 1-23                           [-1, 2048]                (recursive)
├─Sigmoid: 1-24                          [-1, 2048]                --
├─LSTMCell: 1-25                         [-1, 512]                 (recursive)
├─Dropout: 1-26                          [-1, 512]                 --
├─Linear: 1-27                           [-1, 9490]                (recursive)
├─Attention: 1-28                        [-1, 2048]                (recursive)
|    └─Linear: 2-21                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-22                      [-1, 512]                 (recursive)
|    └─ReLU: 2-23                        [-1, 196, 512]            --
|    └─Linear: 2-24                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-25                     [-1, 196]                 --
├─Linear: 1-29                           [-1, 2048]                (recursive)
├─Sigmoid: 1-30                          [-1, 2048]                --
├─LSTMCell: 1-31                         [-1, 512]                 (recursive)
├─Dropout: 1-32                          [-1, 512]                 --
├─Linear: 1-33                           [-1, 9490]                (recursive)
├─Attention: 1-34                        [-1, 2048]                (recursive)
|    └─Linear: 2-26                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-27                      [-1, 512]                 (recursive)
|    └─ReLU: 2-28                        [-1, 196, 512]            --
|    └─Linear: 2-29                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-30                     [-1, 196]                 --
├─Linear: 1-35                           [-1, 2048]                (recursive)
├─Sigmoid: 1-36                          [-1, 2048]                --
├─LSTMCell: 1-37                         [-1, 512]                 (recursive)
├─Dropout: 1-38                          [-1, 512]                 --
├─Linear: 1-39                           [-1, 9490]                (recursive)
├─Attention: 1-40                        [-1, 2048]                (recursive)
|    └─Linear: 2-31                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-32                      [-1, 512]                 (recursive)
|    └─ReLU: 2-33                        [-1, 196, 512]            --
|    └─Linear: 2-34                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-35                     [-1, 196]                 --
├─Linear: 1-41                           [-1, 2048]                (recursive)
├─Sigmoid: 1-42                          [-1, 2048]                --
├─LSTMCell: 1-43                         [-1, 512]                 (recursive)
├─Dropout: 1-44                          [-1, 512]                 --
├─Linear: 1-45                           [-1, 9490]                (recursive)
├─Attention: 1-46                        [-1, 2048]                (recursive)
|    └─Linear: 2-36                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-37                      [-1, 512]                 (recursive)
|    └─ReLU: 2-38                        [-1, 196, 512]            --
|    └─Linear: 2-39                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-40                     [-1, 196]                 --
├─Linear: 1-47                           [-1, 2048]                (recursive)
├─Sigmoid: 1-48                          [-1, 2048]                --
├─LSTMCell: 1-49                         [-1, 512]                 (recursive)
├─Dropout: 1-50                          [-1, 512]                 --
├─Linear: 1-51                           [-1, 9490]                (recursive)
├─Attention: 1-52                        [-1, 2048]                (recursive)
|    └─Linear: 2-41                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-42                      [-1, 512]                 (recursive)
|    └─ReLU: 2-43                        [-1, 196, 512]            --
|    └─Linear: 2-44                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-45                     [-1, 196]                 --
├─Linear: 1-53                           [-1, 2048]                (recursive)
├─Sigmoid: 1-54                          [-1, 2048]                --
├─LSTMCell: 1-55                         [-1, 512]                 (recursive)
├─Dropout: 1-56                          [-1, 512]                 --
├─Linear: 1-57                           [-1, 9490]                (recursive)
├─Attention: 1-58                        [-1, 2048]                (recursive)
|    └─Linear: 2-46                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-47                      [-1, 512]                 (recursive)
|    └─ReLU: 2-48                        [-1, 196, 512]            --
|    └─Linear: 2-49                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-50                     [-1, 196]                 --
├─Linear: 1-59                           [-1, 2048]                (recursive)
├─Sigmoid: 1-60                          [-1, 2048]                --
├─LSTMCell: 1-61                         [-1, 512]                 (recursive)
├─Dropout: 1-62                          [-1, 512]                 --
├─Linear: 1-63                           [-1, 9490]                (recursive)
├─Attention: 1-64                        [-1, 2048]                (recursive)
|    └─Linear: 2-51                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-52                      [-1, 512]                 (recursive)
|    └─ReLU: 2-53                        [-1, 196, 512]            --
|    └─Linear: 2-54                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-55                     [-1, 196]                 --
├─Linear: 1-65                           [-1, 2048]                (recursive)
├─Sigmoid: 1-66                          [-1, 2048]                --
├─LSTMCell: 1-67                         [-1, 512]                 (recursive)
├─Dropout: 1-68                          [-1, 512]                 --
├─Linear: 1-69                           [-1, 9490]                (recursive)
├─Attention: 1-70                        [-1, 2048]                (recursive)
|    └─Linear: 2-56                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-57                      [-1, 512]                 (recursive)
|    └─ReLU: 2-58                        [-1, 196, 512]            --
|    └─Linear: 2-59                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-60                     [-1, 196]                 --
├─Linear: 1-71                           [-1, 2048]                (recursive)
├─Sigmoid: 1-72                          [-1, 2048]                --
├─LSTMCell: 1-73                         [-1, 512]                 (recursive)
├─Dropout: 1-74                          [-1, 512]                 --
├─Linear: 1-75                           [-1, 9490]                (recursive)
├─Attention: 1-76                        [-1, 2048]                (recursive)
|    └─Linear: 2-61                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-62                      [-1, 512]                 (recursive)
|    └─ReLU: 2-63                        [-1, 196, 512]            --
|    └─Linear: 2-64                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-65                     [-1, 196]                 --
├─Linear: 1-77                           [-1, 2048]                (recursive)
├─Sigmoid: 1-78                          [-1, 2048]                --
├─LSTMCell: 1-79                         [-1, 512]                 (recursive)
├─Dropout: 1-80                          [-1, 512]                 --
├─Linear: 1-81                           [-1, 9490]                (recursive)
├─Attention: 1-82                        [-1, 2048]                (recursive)
|    └─Linear: 2-66                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-67                      [-1, 512]                 (recursive)
|    └─ReLU: 2-68                        [-1, 196, 512]            --
|    └─Linear: 2-69                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-70                     [-1, 196]                 --
├─Linear: 1-83                           [-1, 2048]                (recursive)
├─Sigmoid: 1-84                          [-1, 2048]                --
├─LSTMCell: 1-85                         [-1, 512]                 (recursive)
├─Dropout: 1-86                          [-1, 512]                 --
├─Linear: 1-87                           [-1, 9490]                (recursive)
├─Attention: 1-88                        [-1, 2048]                (recursive)
|    └─Linear: 2-71                      [-1, 196, 512]            (recursive)
|    └─Linear: 2-72                      [-1, 512]                 (recursive)
|    └─ReLU: 2-73                        [-1, 196, 512]            --
|    └─Linear: 2-74                      [-1, 196, 1]              (recursive)
|    └─Softmax: 2-75                     [-1, 196]                 --
├─Linear: 1-89                           [-1, 2048]                (recursive)
├─Sigmoid: 1-90                          [-1, 2048]                --
├─LSTMCell: 1-91                         [-1, 512]                 (recursive)
├─Dropout: 1-92                          [-1, 512]                 --
├─Linear: 1-93                           [-1, 9490]                (recursive)
==========================================================================================
Total params: 20,483,859
Trainable params: 20,483,859
Non-trainable params: 0
Total mult-adds (M): 229.28
==========================================================================================
Input size (MB): 1.53
Forward/backward pass size (MB): 1.07
Params size (MB): 78.14
Estimated Total Size (MB): 80.75
==========================================================================================
```

In [36]:
model = ResNet101Encoder()
summary(model, (3, 256, 256)) # Accept color image
# The output of the model is (batch_size, encoded_image_size (14), encoded_image_size( 14), 2048)

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 8, 8]          --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        (9,408)
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        (128)
|    └─ReLU: 2-3                         [-1, 64, 128, 128]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-5                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         (75,008)
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         (70,400)
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         (70,400)
|    └─Sequential: 2-6                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 8, 8]          --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        (9,408)
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        (128)
|    └─ReLU: 2-3                         [-1, 64, 128, 128]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-5                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         (75,008)
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         (70,400)
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         (70,400)
|    └─Sequential: 2-6                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512

In [11]:
model = ResNet152Encoder()
summary(model, (3, 256, 256))

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to /home/ps4534/.cache/torch/hub/checkpoints/resnet152-394f9c45.pth


  0%|          | 0.00/230M [00:00<?, ?B/s]

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 8, 8]          --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        (9,408)
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        (128)
|    └─ReLU: 2-3                         [-1, 64, 128, 128]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-5                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         (75,008)
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         (70,400)
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         (70,400)
|    └─Sequential: 2-6                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 8, 8]          --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        (9,408)
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        (128)
|    └─ReLU: 2-3                         [-1, 64, 128, 128]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-5                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         (75,008)
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         (70,400)
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         (70,400)
|    └─Sequential: 2-6                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512