In [1]:
# LIBRARIES USED
#!pip install torch-summary
import json
import torch
import torch.nn.utils.prune as prune
import torch.nn.functional as F
#import torch_pruning as tp
import torchvision.models as models
import caption
import ResnetQuant as RQuant
from torchsummary import summary
import torch.quantization as quantization
import torch.nn as nn

In [2]:
# MACROS
img ="img2.jpg"
word_map = "WORDMAP_coco_5_cap_per_img_5_min_word_freq.json"
model = "BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar"
beam_size = 5
device = "cpu"

In [3]:
# INFERENCE FUNCTION
def inference(encoder, decoder, img, word_map = word_map, beam_size = beam_size):
    seq, alphas = caption.caption_image_beam_search(encoder, decoder, img, word_map, beam_size)
    alphas = torch.FloatTensor(alphas)
    words = [rev_word_map[ind] for ind in seq]
    sentence = ""
    for word in words:
        sentence = sentence + " " + word
    return sentence

In [4]:
# PRUNE FUNCTION
def prune_my_encoder(encoder):
    for name, module in encoder.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(module, name="weight", amount=0.05, n=2, dim=0)
            prune.remove(module, 'weight')
        elif isinstance(module, torch.nn.Linear):
            prune.ln_structured(module, name="weight", amount=0.05, n=2, dim=0)
            prune.remove(module, 'weight')
    eps = 1e-3
    with torch.no_grad():
        for name, module in myencoder.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                module.weight[abs(module.weight)<eps] = 0
            elif isinstance(module, torch.nn.Linear):
                module.weight[abs(module.weight)<eps] = 0            

In [5]:
# ENCODER CLASS
class Encoder(torch.nn.Module):
    """
    Encoder.
    """

    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size

        #resnet = torchvision.models.resnet101(pretrained=True)  # pretrained ImageNet ResNet-101
        resnet = models.resnet101()
        #resnet = RQuant.resnet101()
        #resnet.load_state_dict(torch.load("resnet101-2.pth"))

        # Remove linear and pool layers (since we're not doing classification)
        modules = list(resnet.children())[:-2]
        self.resnet = torch.nn.Sequential(*modules)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = torch.nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune()
        #self.qconfig = quantization.get_default_qconfig('qnnpack')
        # set the qengine to control weight packing
        #torch.backends.quantized.engine = 'qnnpack'
        #self.quant = quantization.QuantStub()
        #self.dequant = quantization.DeQuantStub()

    def forward(self, images):
        """
        Forward propagation.

        :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
        :return: encoded images
        """
#        images = self.quant(images)
        out = self.resnet(images)  # (batch_size, 2048, image_size/32, image_size/32)
#        out = self.dequant(out)
        out = self.adaptive_pool(out)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)  # (batch_size, encoded_image_size, encoded_image_size, 2048)
#         out = quantization.DeQuantStub(out)
        return out

    def fine_tune(self, fine_tune=True):
        """
        Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.

        :param fine_tune: Allow?
        """
        for p in self.resnet.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

In [7]:
checkpoint = torch.load(model, map_location=device)
decoder = checkpoint['decoder']
#torch.save(decoder,'DecoderOrig.pth')
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
#torch.save(encoder,'EncoderOrig.pth')
encoder = encoder.to(device)
encoder.eval()
# Load word map (word2ix)
with open(word_map, 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}  # ix2word



In [8]:
myencoder = Encoder()
# myencoder.eval()
myencoder = myencoder.to(device)
myencoder.eval()

Encoder(
  (resnet): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 2

In [9]:
sd = encoder.state_dict()
sd2 = myencoder.state_dict()

In [10]:
for key in sd.keys():
    sd2[key] = sd[key]

In [11]:
myencoder.load_state_dict(sd2)

<All keys matched successfully>

In [12]:
inference(encoder, decoder, img, word_map, beam_size)

' <start> a large jetliner flying through a blue sky <end>'

In [13]:
inference(myencoder, decoder, img, word_map, beam_size)

' <start> a large jetliner flying through a blue sky <end>'

In [14]:
prune_my_encoder(myencoder)
#torch.save(myencoder,'EncoderPrunedForEval.pth')
#torch.save(decoder,'DecoderPruned.pth')
#prune_my_encoder(decoder)

In [15]:
inference(myencoder, decoder, img, word_map, beam_size)

' <start> a large jetliner flying through a blue sky <end>'

In [16]:
total_encoder_sparsity = 0
den = 0
num = 0
for name, module in myencoder.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        print("Sparsity in {} (shape = {}): {:.2f}%".format
        (name, module.weight.shape,
        100. * float(torch.sum(module.weight == 0))/ float(module.weight.nelement())
        ))
        num += float(torch.sum(module.weight == 0))
        den += float(module.weight.nelement())
    elif isinstance(module, torch.nn.Linear):
        print("Sparsity in {} (shape = {}): {:.2f}%".format
        (name, module.weight.shape,
        100. * float(torch.sum(module.weight == 0))/ float(module.weight.nelement())
        ))
        num += float(torch.sum(module.weight == 0))
        den += float(module.weight.nelement())
total_encoder_sparsity = 100.*num/den        
print("Total sparsity in the pruned encoder: {:.2f}%".format(total_encoder_sparsity))        

Sparsity in resnet.0 (shape = torch.Size([64, 3, 7, 7])): 24.29%
Sparsity in resnet.4.0.conv1 (shape = torch.Size([64, 64, 1, 1])): 40.16%
Sparsity in resnet.4.0.conv2 (shape = torch.Size([64, 64, 3, 3])): 27.43%
Sparsity in resnet.4.0.conv3 (shape = torch.Size([256, 64, 1, 1])): 18.80%
Sparsity in resnet.4.0.downsample.0 (shape = torch.Size([256, 64, 1, 1])): 33.92%
Sparsity in resnet.4.1.conv1 (shape = torch.Size([64, 256, 1, 1])): 52.14%
Sparsity in resnet.4.1.conv2 (shape = torch.Size([64, 64, 3, 3])): 65.67%
Sparsity in resnet.4.1.conv3 (shape = torch.Size([256, 64, 1, 1])): 52.01%
Sparsity in resnet.4.2.conv1 (shape = torch.Size([64, 256, 1, 1])): 17.96%
Sparsity in resnet.4.2.conv2 (shape = torch.Size([64, 64, 3, 3])): 16.64%
Sparsity in resnet.4.2.conv3 (shape = torch.Size([256, 64, 1, 1])): 25.15%
Sparsity in resnet.5.0.conv1 (shape = torch.Size([128, 256, 1, 1])): 6.99%
Sparsity in resnet.5.0.conv2 (shape = torch.Size([128, 128, 3, 3])): 6.88%
Sparsity in resnet.5.0.conv3 (sh

In [18]:
# Quantizing the Decoder
decoder.to("cpu")
quantized_decoder = quantization.quantize_dynamic(decoder, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
print(quantized_decoder)

DecoderWithAttention(
  (attention): Attention(
    (encoder_att): DynamicQuantizedLinear(in_features=2048, out_features=512, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (decoder_att): DynamicQuantizedLinear(in_features=512, out_features=512, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (full_att): DynamicQuantizedLinear(in_features=512, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (relu): ReLU()
    (softmax): Softmax(dim=1)
  )
  (embedding): Embedding(9490, 512)
  (dropout): Dropout(p=0.5, inplace=False)
  (decode_step): LSTMCell(2560, 512, bias=1)
  (init_h): DynamicQuantizedLinear(in_features=2048, out_features=512, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (init_c): DynamicQuantizedLinear(in_features=2048, out_features=512, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (f_beta): DynamicQuantizedLinear(in_features=512, out_features=2048, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (sigmoid): Sigmoid()


In [19]:
# QUANTIZATION FUNCTION
def quantize_my_model(model):
    # set the qconfig for PTQ
    model.qconfig = quantization.get_default_qconfig('fbgemm')
    # set the qengine to control weight packing
    torch.backends.quantized.engine = 'fbgemm'
    # put model in eval mode
    model.eval()
    quantization.prepare(model, inplace=True)
    quantization.convert(model, inplace=True)

In [None]:
quantize_my_model(myencoder)
#quantize_my_model(decoder)
print(myencoder)

In [20]:
inference(myencoder, quantized_decoder, img, word_map, beam_size)
# model.to('QuantizedCPU')
# summary(myencoder, (3,256,256))

' <start> a large jetliner flying through a blue sky <end>'

In [None]:
# # Specify quantization configuration
# # Start with simple min/max range estimation and per-tensor
# # quantization of weights
# myencoder.qconfig = quantization.default_qconfig
# # Convert to quantized model
# quantization.convert(myencoder, inplace=True)
# # Calibrate with the training set
# # evaluate(myModel, criterion, data_loader, neval_batches=num_calibration_batches)
# for name, module in myencoder.named_modules(): 
#     for n, p in module.named_parameters():
#             print(n, 'has type:', p.dtype)
# # prunedencoder = Encoder()
# # myencoder_state_dict = myencoder.state_dict()
# # prunedencoder_state_dict = prunedencoder.state_dict()
# # for key in myencoder_state_dict.keys():
# #     print(key)
#     #     if key in prunedencoder_state_dict.keys():
# #         prunedencoder_state_dict[key] = myencoder_state_dict[key]
# # prunedencoder.load_state_dict(myencoder_state_dict)

In [None]:

# summary(myencoder, (3,256,256))
for name, param in myencoder.named_parameters():
    print(param.dtype)

In [None]:
#Checkpoint2 = {'encoder': myencoder}
#torch.save(quantized_decoder,'QuantizedDecoderForEval.pth') - didnt work
torch.save(quantized_decoder.state_dict(), "QuantizedDecoderStateDict.pth")
decoder.load_state_dict(torch.load("QuantizedDecoderStateDict.pth"))
#torch.save(myencoder.state_dict(), "QuantEncwStateDict2.pth")

In [None]:
#checkpoint = torch.load('QuantizedModel.pth', map_location=device)
#theencoder = checkpoint['encoder']
#torch.save(theencoder,'EncoderQuantized.pth')
#theencoder = Encoder()
#theencoder = theencoder.to(device)
#theencoder.eval()
#theencoder.load_state_dict(torch.load("QuantEncwStateDict2.pth"))


In [None]:
inference(theencoder, decoder, img, word_map, beam_size)

In [21]:
encodertry = torch.load("EncoderPrunedForEval.pth")
#decodertry = torch.load("QuantizedDecoderForEval.pth")
encodertry.to(device)
#decodertry.to(device)
encodertry.eval()
#decodertry.eval()

Encoder(
  (resnet): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 2

In [22]:
inference(encodertry, decoder, img, word_map, beam_size)

' <start> a large jetliner flying through a blue sky <end>'

In [23]:
inference(encodertry, quantized_decoder, img, word_map, beam_size)

' <start> a large jetliner flying through a blue sky <end>'

In [24]:
summary(encodertry,(3,256,256),device=device)

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 8, 8]          --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        (9,408)
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        (128)
|    └─ReLU: 2-3                         [-1, 64, 128, 128]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-5                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         (75,008)
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         (70,400)
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         (70,400)
|    └─Sequential: 2-6                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 8, 8]          --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        (9,408)
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        (128)
|    └─ReLU: 2-3                         [-1, 64, 128, 128]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-5                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         (75,008)
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         (70,400)
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         (70,400)
|    └─Sequential: 2-6                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512

In [None]:
summary(quantized_decoder,input_size = None)

In [25]:
print(quantized_decoder)

DecoderWithAttention(
  (attention): Attention(
    (encoder_att): DynamicQuantizedLinear(in_features=2048, out_features=512, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (decoder_att): DynamicQuantizedLinear(in_features=512, out_features=512, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (full_att): DynamicQuantizedLinear(in_features=512, out_features=1, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (relu): ReLU()
    (softmax): Softmax(dim=1)
  )
  (embedding): Embedding(9490, 512)
  (dropout): Dropout(p=0.5, inplace=False)
  (decode_step): LSTMCell(2560, 512, bias=1)
  (init_h): DynamicQuantizedLinear(in_features=2048, out_features=512, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (init_c): DynamicQuantizedLinear(in_features=2048, out_features=512, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (f_beta): DynamicQuantizedLinear(in_features=512, out_features=2048, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (sigmoid): Sigmoid()


In [None]:
summary(quantized_decoder.to(device),(3,400,600))

In [None]:
myencoder.to(device)
inference(myencoder, decoder, img, word_map, beam_size)

In [26]:
summary(myencoder,(3,256,256))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 8, 8]          --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        (9,408)
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        (128)
|    └─ReLU: 2-3                         [-1, 64, 128, 128]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-5                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         (75,008)
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         (70,400)
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         (70,400)
|    └─Sequential: 2-6                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 2048, 8, 8]          --
|    └─Conv2d: 2-1                       [-1, 64, 128, 128]        (9,408)
|    └─BatchNorm2d: 2-2                  [-1, 64, 128, 128]        (128)
|    └─ReLU: 2-3                         [-1, 64, 128, 128]        --
|    └─MaxPool2d: 2-4                    [-1, 64, 64, 64]          --
|    └─Sequential: 2-5                   [-1, 256, 64, 64]         --
|    |    └─Bottleneck: 3-1              [-1, 256, 64, 64]         (75,008)
|    |    └─Bottleneck: 3-2              [-1, 256, 64, 64]         (70,400)
|    |    └─Bottleneck: 3-3              [-1, 256, 64, 64]         (70,400)
|    └─Sequential: 2-6                   [-1, 512, 32, 32]         --
|    |    └─Bottleneck: 3-4              [-1, 512, 32, 32]         379,392
|    |    └─Bottleneck: 3-5              [-1, 512, 32, 32]         280,064
|    |    └─Bottleneck: 3-6              [-1, 512

In [None]:
summary(encoder,(3,256,256))

In [27]:
caption_lengths = torch.tensor([[11],
        [12],
        [10],
        [14],
        [14],
        [10],
        [12],
        [12],
        [10],
        [11],
        [11],
        [12],
        [12],
        [11],
        [13],
        [10]], device='cpu')

In [28]:
encoded_captions = torch.tensor([[9488,    1,  375,  976,  157,   61,   23,  156, 1717, 1458, 9489,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488, 1819,    1, 1725,   28,  349,  823,  217,   32,    1,  764, 9489,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,    1,    2,   11,    1,  287,    6,    1, 1490, 9489,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,    1,    4,  430,   28,   46, 1752,   28,  156,  419, 3329,   28,
          580, 9489,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,    1,    2,   71,    6,   14,   16,  743,   17,    1,  184,   91,
          336, 9489,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,    1,   45,  336,    3,    1,  696,  690,  577, 9489,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488, 2078, 2257,    6,    1,  102,   28,   90,   35, 2342,   86, 9489,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,    1,   92,   17,   93,  609,    6,  205,   17,   14, 2826, 9489,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,  747,  887,  119,  564,    3, 2653,   28, 4110, 9489,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,  167, 1220,   32,  288,    3,  349,  603,   28,  419, 9489,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,    1, 1881,   98,   79,   14, 1533,   17,    1, 2039, 9489,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,   14,   21,   35,   43,  245,    1, 2343,    3,    1,   53, 9489,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,   49,  179,  164, 1287,   17,    1,    2, 2077,    1, 2015, 9489,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,    1,  230,  996,  134,  164, 1640,  371,    1,  940, 9489,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,    1, 1428, 1051,   28,   41,  369,  179,   66, 4450,   23, 4450,
         9489,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0],
        [9488,    1,    2,   71,    6,    1,  720,   11, 1348, 9489,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0]], device='cpu')

In [None]:
print(caption_lengths)

In [None]:
encoder_out = torch.tensor([[[1.5874, 0.0000, 3.9018,  ..., 1.1873, 0.0000, 0.0000],
         [0.7937, 0.0000, 1.9509,  ..., 0.5936, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [1.4390, 0.0000, 1.3281,  ..., 0.2475, 0.0000, 0.0000],
         [1.8537, 0.9871, 1.0184,  ..., 0.2423, 0.1009, 0.0000],
         [2.2683, 1.9742, 0.7086,  ..., 0.2370, 0.2018, 0.0000]],

        [[1.5874, 0.0000, 3.9018,  ..., 1.1873, 0.0000, 0.0000],
         [0.7937, 0.0000, 1.9509,  ..., 0.5936, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [1.4390, 0.0000, 1.3281,  ..., 0.2475, 0.0000, 0.0000],
         [1.8537, 0.9871, 1.0184,  ..., 0.2423, 0.1009, 0.0000],
         [2.2683, 1.9742, 0.7086,  ..., 0.2370, 0.2018, 0.0000]],

        [[1.5874, 0.0000, 3.9018,  ..., 1.1873, 0.0000, 0.0000],
         [0.7937, 0.0000, 1.9509,  ..., 0.5936, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [1.4390, 0.0000, 1.3281,  ..., 0.2475, 0.0000, 0.0000],
         [1.8537, 0.9871, 1.0184,  ..., 0.2423, 0.1009, 0.0000],
         [2.2683, 1.9742, 0.7086,  ..., 0.2370, 0.2018, 0.0000]],

        ...,

        [[1.5764, 0.0000, 3.8673,  ..., 1.1372, 0.0000, 0.0000],
         [0.7882, 0.0000, 1.9337,  ..., 0.5686, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.2420, 0.0000, 0.0000,  ..., 0.0000, 0.4471, 0.0000],
         [0.4089, 0.0457, 0.0000,  ..., 0.0000, 0.5250, 0.0000],
         [0.5757, 0.0915, 0.0000,  ..., 0.0000, 0.6030, 0.0000]],

        [[1.5874, 0.0000, 3.9018,  ..., 1.1873, 0.0000, 0.0000],
         [0.7937, 0.0000, 1.9509,  ..., 0.5936, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [1.4390, 0.0000, 1.3281,  ..., 0.2475, 0.0000, 0.0000],
         [1.8537, 0.9871, 1.0184,  ..., 0.2423, 0.1009, 0.0000],
         [2.2683, 1.9742, 0.7086,  ..., 0.2370, 0.2018, 0.0000]],

        [[1.5874, 0.0000, 3.9018,  ..., 1.1873, 0.0000, 0.0000],
         [0.7937, 0.0000, 1.9509,  ..., 0.5936, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [1.4390, 0.0000, 1.3281,  ..., 0.2475, 0.0000, 0.0000],
         [1.8537, 0.9871, 1.0184,  ..., 0.2423, 0.1009, 0.0000],
         [2.2683, 1.9742, 0.7086,  ..., 0.2370, 0.2018, 0.0000]]],
       device='cuda:0')


In [29]:
encoder_out = 2*torch.rand(16,196,2048,device = "cpu")

In [30]:
summary(model=decoder,input_data=encoder_out,encoded_captions=encoded_captions,caption_lengths=caption_lengths)

Layer (type:depth-idx)                   Output Shape              Param #
├─Embedding: 1-1                         [-1, 52, 512]             4,858,880
├─Linear: 1-2                            [-1, 512]                 1,049,088
├─Linear: 1-3                            [-1, 512]                 1,049,088
├─Attention: 1-4                         [-1, 2048]                --
|    └─Linear: 2-1                       [-1, 196, 512]            1,049,088
|    └─Linear: 2-2                       [-1, 512]                 262,656
|    └─ReLU: 2-3                         [-1, 196, 512]            --
|    └─Linear: 2-4                       [-1, 196, 1]              513
|    └─Softmax: 2-5                      [-1, 196]                 --
├─Linear: 1-5                            [-1, 2048]                1,050,624
├─Sigmoid: 1-6                           [-1, 2048]                --
├─LSTMCell: 1-7                          [-1, 512]                 6,295,552
├─Dropout: 1-8                       

Layer (type:depth-idx)                   Output Shape              Param #
├─Embedding: 1-1                         [-1, 52, 512]             4,858,880
├─Linear: 1-2                            [-1, 512]                 1,049,088
├─Linear: 1-3                            [-1, 512]                 1,049,088
├─Attention: 1-4                         [-1, 2048]                --
|    └─Linear: 2-1                       [-1, 196, 512]            1,049,088
|    └─Linear: 2-2                       [-1, 512]                 262,656
|    └─ReLU: 2-3                         [-1, 196, 512]            --
|    └─Linear: 2-4                       [-1, 196, 1]              513
|    └─Softmax: 2-5                      [-1, 196]                 --
├─Linear: 1-5                            [-1, 2048]                1,050,624
├─Sigmoid: 1-6                           [-1, 2048]                --
├─LSTMCell: 1-7                          [-1, 512]                 6,295,552
├─Dropout: 1-8                       

In [32]:
#encoder_out = torch.quantization.QuantStub(encoder_out,qconfig)
#encoded_captions = torch.quantization.QuantStub(encoded_captions)
#caption_lengths = torch.quantization.QuantStub(caption_lengths)
#quantized_decoder.to("cpu")
summary(model=quantized_decoder,input_data=encoder_out,encoded_captions=encoded_captions,caption_lengths=caption_lengths)

RuntimeError: Failed to run torchsummary. See above stack traces for more details. Executed layers up to: [Embedding: 1-1]