In [1]:
# LIBRARIES USED
import json
import torch
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torch_pruning as tp
import torchvision.models as models
import caption
from torchsummary import summary
import torch.quantization as quantization

In [2]:
# MACROS
img = r"C:\Users\xiaomi\OneDrive\TUM\WS 2021-2022\Advanced Topics in Communication Electronics\image-captioning-on-pytorch-master\image-captioning-on-pytorch-master\img4.jpg"
word_map = r"C:\Users\xiaomi\OneDrive\TUM\WS 2021-2022\Advanced Topics in Communication Electronics\image-captioning-on-pytorch-master\image-captioning-on-pytorch-master\WORDMAP_coco_5_cap_per_img_5_min_word_freq.json"
model = r"C:\Users\xiaomi\OneDrive\TUM\WS 2021-2022\Advanced Topics in Communication Electronics\image-captioning-on-pytorch-master\image-captioning-on-pytorch-master\BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar"
beam_size = 5
device = "cpu"

In [3]:
# INFERENCE FUNCTION
def inference(encoder, decoder, img, word_map = word_map, beam_size = beam_size):
    seq, alphas = caption.caption_image_beam_search(encoder, decoder, img, word_map, beam_size)
    alphas = torch.FloatTensor(alphas)
    words = [rev_word_map[ind] for ind in seq]
    sentence = ""
    for word in words:
        sentence = sentence + " " + word
    return sentence

In [4]:
# PRUNE FUNCTION
def prune_my_encoder(encoder):
    for name, module in encoder.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(module, name="weight", amount=0.05, n=2, dim=0)
            prune.remove(module, 'weight')
        elif isinstance(module, torch.nn.Linear):
            prune.ln_structured(module, name="weight", amount=0.05, n=2, dim=0)
            prune.remove(module, 'weight')
    eps = 1e-3
    with torch.no_grad():
        for name, module in myencoder.named_modules():
            if isinstance(module, torch.nn.Conv2d):
                module.weight[abs(module.weight)<eps] = 0
            elif isinstance(module, torch.nn.Linear):
                module.weight[abs(module.weight)<eps] = 0            

In [5]:
# ENCODER CLASS
class Encoder(torch.nn.Module):
    """
    Encoder.
    """

    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size
        self.quant = quantization.QuantStub()
        self.dequant = quantization.DeQuantStub()
        #resnet = torchvision.models.resnet101(pretrained=True)  # pretrained ImageNet ResNet-101
        resnet = models.resnet101()
        #resnet.load_state_dict(torch.load("resnet101-2.pth"))

        # Remove linear and pool layers (since we're not doing classification)
        modules = list(resnet.children())[:-2]
        self.resnet = torch.nn.Sequential(*modules)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = torch.nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune()
        self.qconfig = quantization.get_default_qconfig('fbgemm')
        # set the qengine to control weight packing
        torch.backends.quantized.engine = 'fbgemm'

    def forward(self, images):
        """
        Forward propagation.

        :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
        :return: encoded images
        """
#         images = quantization.QuantStub(images)
        images = self.quant(images)
        out = self.resnet(images)  # (batch_size, 2048, image_size/32, image_size/32)
        out = self.adaptive_pool(out)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)  # (batch_size, encoded_image_size, encoded_image_size, 2048)
#         out = quantization.DeQuantStub(out)
        out = self.dequant(out)

        return out

    def fine_tune(self, fine_tune=True):
        """
        Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.

        :param fine_tune: Allow?
        """
        for p in self.resnet.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

In [6]:
checkpoint = torch.load(model, map_location=device)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()
# Load word map (word2ix)
with open(word_map, 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}  # ix2word



In [7]:
myencoder = Encoder()
# myencoder.eval()
myencoder = myencoder.to(device)
myencoder.eval()

Encoder(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (resnet): Sequential(
    (0): QuantStub()
    (1): DeQuantStub()
    (2): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (6): Sequential(
      (0): Bottleneck(
        (quant): QuantStub()
        (dequant): DeQuantStub()
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=

In [8]:
sd = encoder.state_dict()
sd2 = myencoder.state_dict()

In [9]:
for key in sd.keys():
    sd2[key] = sd[key].type(torch.float16)
    print(sd2[key].dtype)

torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.float16
torch.

In [10]:
for name, param in myencoder.named_parameters():
    print(name,param.dtype)

resnet.2.weight torch.float32
resnet.3.weight torch.float32
resnet.3.bias torch.float32
resnet.6.0.conv1.weight torch.float32
resnet.6.0.bn1.weight torch.float32
resnet.6.0.bn1.bias torch.float32
resnet.6.0.conv2.weight torch.float32
resnet.6.0.bn2.weight torch.float32
resnet.6.0.bn2.bias torch.float32
resnet.6.0.conv3.weight torch.float32
resnet.6.0.bn3.weight torch.float32
resnet.6.0.bn3.bias torch.float32
resnet.6.0.downsample.0.weight torch.float32
resnet.6.0.downsample.1.weight torch.float32
resnet.6.0.downsample.1.bias torch.float32
resnet.6.1.conv1.weight torch.float32
resnet.6.1.bn1.weight torch.float32
resnet.6.1.bn1.bias torch.float32
resnet.6.1.conv2.weight torch.float32
resnet.6.1.bn2.weight torch.float32
resnet.6.1.bn2.bias torch.float32
resnet.6.1.conv3.weight torch.float32
resnet.6.1.bn3.weight torch.float32
resnet.6.1.bn3.bias torch.float32
resnet.6.2.conv1.weight torch.float32
resnet.6.2.bn1.weight torch.float32
resnet.6.2.bn1.bias torch.float32
resnet.6.2.conv2.weight

In [11]:
myencoder.load_state_dict(sd2)

RuntimeError: Error(s) in loading state_dict for Encoder:
	Unexpected key(s) in state_dict: "resnet.0.weight", "resnet.1.weight", "resnet.1.bias", "resnet.1.running_mean", "resnet.1.running_var", "resnet.4.0.conv1.weight", "resnet.4.0.bn1.weight", "resnet.4.0.bn1.bias", "resnet.4.0.bn1.running_mean", "resnet.4.0.bn1.running_var", "resnet.4.0.conv2.weight", "resnet.4.0.bn2.weight", "resnet.4.0.bn2.bias", "resnet.4.0.bn2.running_mean", "resnet.4.0.bn2.running_var", "resnet.4.0.conv3.weight", "resnet.4.0.bn3.weight", "resnet.4.0.bn3.bias", "resnet.4.0.bn3.running_mean", "resnet.4.0.bn3.running_var", "resnet.4.0.downsample.0.weight", "resnet.4.0.downsample.1.weight", "resnet.4.0.downsample.1.bias", "resnet.4.0.downsample.1.running_mean", "resnet.4.0.downsample.1.running_var", "resnet.4.1.conv1.weight", "resnet.4.1.bn1.weight", "resnet.4.1.bn1.bias", "resnet.4.1.bn1.running_mean", "resnet.4.1.bn1.running_var", "resnet.4.1.conv2.weight", "resnet.4.1.bn2.weight", "resnet.4.1.bn2.bias", "resnet.4.1.bn2.running_mean", "resnet.4.1.bn2.running_var", "resnet.4.1.conv3.weight", "resnet.4.1.bn3.weight", "resnet.4.1.bn3.bias", "resnet.4.1.bn3.running_mean", "resnet.4.1.bn3.running_var", "resnet.4.2.conv1.weight", "resnet.4.2.bn1.weight", "resnet.4.2.bn1.bias", "resnet.4.2.bn1.running_mean", "resnet.4.2.bn1.running_var", "resnet.4.2.conv2.weight", "resnet.4.2.bn2.weight", "resnet.4.2.bn2.bias", "resnet.4.2.bn2.running_mean", "resnet.4.2.bn2.running_var", "resnet.4.2.conv3.weight", "resnet.4.2.bn3.weight", "resnet.4.2.bn3.bias", "resnet.4.2.bn3.running_mean", "resnet.4.2.bn3.running_var", "resnet.5.0.conv1.weight", "resnet.5.0.bn1.weight", "resnet.5.0.bn1.bias", "resnet.5.0.bn1.running_mean", "resnet.5.0.bn1.running_var", "resnet.5.0.conv2.weight", "resnet.5.0.bn2.weight", "resnet.5.0.bn2.bias", "resnet.5.0.bn2.running_mean", "resnet.5.0.bn2.running_var", "resnet.5.0.conv3.weight", "resnet.5.0.bn3.weight", "resnet.5.0.bn3.bias", "resnet.5.0.bn3.running_mean", "resnet.5.0.bn3.running_var", "resnet.5.0.downsample.0.weight", "resnet.5.0.downsample.1.weight", "resnet.5.0.downsample.1.bias", "resnet.5.0.downsample.1.running_mean", "resnet.5.0.downsample.1.running_var", "resnet.5.1.conv1.weight", "resnet.5.1.bn1.weight", "resnet.5.1.bn1.bias", "resnet.5.1.bn1.running_mean", "resnet.5.1.bn1.running_var", "resnet.5.1.conv2.weight", "resnet.5.1.bn2.weight", "resnet.5.1.bn2.bias", "resnet.5.1.bn2.running_mean", "resnet.5.1.bn2.running_var", "resnet.5.1.conv3.weight", "resnet.5.1.bn3.weight", "resnet.5.1.bn3.bias", "resnet.5.1.bn3.running_mean", "resnet.5.1.bn3.running_var", "resnet.5.2.conv1.weight", "resnet.5.2.bn1.weight", "resnet.5.2.bn1.bias", "resnet.5.2.bn1.running_mean", "resnet.5.2.bn1.running_var", "resnet.5.2.conv2.weight", "resnet.5.2.bn2.weight", "resnet.5.2.bn2.bias", "resnet.5.2.bn2.running_mean", "resnet.5.2.bn2.running_var", "resnet.5.2.conv3.weight", "resnet.5.2.bn3.weight", "resnet.5.2.bn3.bias", "resnet.5.2.bn3.running_mean", "resnet.5.2.bn3.running_var", "resnet.5.3.conv1.weight", "resnet.5.3.bn1.weight", "resnet.5.3.bn1.bias", "resnet.5.3.bn1.running_mean", "resnet.5.3.bn1.running_var", "resnet.5.3.conv2.weight", "resnet.5.3.bn2.weight", "resnet.5.3.bn2.bias", "resnet.5.3.bn2.running_mean", "resnet.5.3.bn2.running_var", "resnet.5.3.conv3.weight", "resnet.5.3.bn3.weight", "resnet.5.3.bn3.bias", "resnet.5.3.bn3.running_mean", "resnet.5.3.bn3.running_var", "resnet.6.3.conv1.weight", "resnet.6.3.bn1.weight", "resnet.6.3.bn1.bias", "resnet.6.3.bn1.running_mean", "resnet.6.3.bn1.running_var", "resnet.6.3.conv2.weight", "resnet.6.3.bn2.weight", "resnet.6.3.bn2.bias", "resnet.6.3.bn2.running_mean", "resnet.6.3.bn2.running_var", "resnet.6.3.conv3.weight", "resnet.6.3.bn3.weight", "resnet.6.3.bn3.bias", "resnet.6.3.bn3.running_mean", "resnet.6.3.bn3.running_var", "resnet.6.4.conv1.weight", "resnet.6.4.bn1.weight", "resnet.6.4.bn1.bias", "resnet.6.4.bn1.running_mean", "resnet.6.4.bn1.running_var", "resnet.6.4.conv2.weight", "resnet.6.4.bn2.weight", "resnet.6.4.bn2.bias", "resnet.6.4.bn2.running_mean", "resnet.6.4.bn2.running_var", "resnet.6.4.conv3.weight", "resnet.6.4.bn3.weight", "resnet.6.4.bn3.bias", "resnet.6.4.bn3.running_mean", "resnet.6.4.bn3.running_var", "resnet.6.5.conv1.weight", "resnet.6.5.bn1.weight", "resnet.6.5.bn1.bias", "resnet.6.5.bn1.running_mean", "resnet.6.5.bn1.running_var", "resnet.6.5.conv2.weight", "resnet.6.5.bn2.weight", "resnet.6.5.bn2.bias", "resnet.6.5.bn2.running_mean", "resnet.6.5.bn2.running_var", "resnet.6.5.conv3.weight", "resnet.6.5.bn3.weight", "resnet.6.5.bn3.bias", "resnet.6.5.bn3.running_mean", "resnet.6.5.bn3.running_var", "resnet.6.6.conv1.weight", "resnet.6.6.bn1.weight", "resnet.6.6.bn1.bias", "resnet.6.6.bn1.running_mean", "resnet.6.6.bn1.running_var", "resnet.6.6.conv2.weight", "resnet.6.6.bn2.weight", "resnet.6.6.bn2.bias", "resnet.6.6.bn2.running_mean", "resnet.6.6.bn2.running_var", "resnet.6.6.conv3.weight", "resnet.6.6.bn3.weight", "resnet.6.6.bn3.bias", "resnet.6.6.bn3.running_mean", "resnet.6.6.bn3.running_var", "resnet.6.7.conv1.weight", "resnet.6.7.bn1.weight", "resnet.6.7.bn1.bias", "resnet.6.7.bn1.running_mean", "resnet.6.7.bn1.running_var", "resnet.6.7.conv2.weight", "resnet.6.7.bn2.weight", "resnet.6.7.bn2.bias", "resnet.6.7.bn2.running_mean", "resnet.6.7.bn2.running_var", "resnet.6.7.conv3.weight", "resnet.6.7.bn3.weight", "resnet.6.7.bn3.bias", "resnet.6.7.bn3.running_mean", "resnet.6.7.bn3.running_var", "resnet.6.8.conv1.weight", "resnet.6.8.bn1.weight", "resnet.6.8.bn1.bias", "resnet.6.8.bn1.running_mean", "resnet.6.8.bn1.running_var", "resnet.6.8.conv2.weight", "resnet.6.8.bn2.weight", "resnet.6.8.bn2.bias", "resnet.6.8.bn2.running_mean", "resnet.6.8.bn2.running_var", "resnet.6.8.conv3.weight", "resnet.6.8.bn3.weight", "resnet.6.8.bn3.bias", "resnet.6.8.bn3.running_mean", "resnet.6.8.bn3.running_var", "resnet.6.9.conv1.weight", "resnet.6.9.bn1.weight", "resnet.6.9.bn1.bias", "resnet.6.9.bn1.running_mean", "resnet.6.9.bn1.running_var", "resnet.6.9.conv2.weight", "resnet.6.9.bn2.weight", "resnet.6.9.bn2.bias", "resnet.6.9.bn2.running_mean", "resnet.6.9.bn2.running_var", "resnet.6.9.conv3.weight", "resnet.6.9.bn3.weight", "resnet.6.9.bn3.bias", "resnet.6.9.bn3.running_mean", "resnet.6.9.bn3.running_var", "resnet.6.10.conv1.weight", "resnet.6.10.bn1.weight", "resnet.6.10.bn1.bias", "resnet.6.10.bn1.running_mean", "resnet.6.10.bn1.running_var", "resnet.6.10.conv2.weight", "resnet.6.10.bn2.weight", "resnet.6.10.bn2.bias", "resnet.6.10.bn2.running_mean", "resnet.6.10.bn2.running_var", "resnet.6.10.conv3.weight", "resnet.6.10.bn3.weight", "resnet.6.10.bn3.bias", "resnet.6.10.bn3.running_mean", "resnet.6.10.bn3.running_var", "resnet.6.11.conv1.weight", "resnet.6.11.bn1.weight", "resnet.6.11.bn1.bias", "resnet.6.11.bn1.running_mean", "resnet.6.11.bn1.running_var", "resnet.6.11.conv2.weight", "resnet.6.11.bn2.weight", "resnet.6.11.bn2.bias", "resnet.6.11.bn2.running_mean", "resnet.6.11.bn2.running_var", "resnet.6.11.conv3.weight", "resnet.6.11.bn3.weight", "resnet.6.11.bn3.bias", "resnet.6.11.bn3.running_mean", "resnet.6.11.bn3.running_var", "resnet.6.12.conv1.weight", "resnet.6.12.bn1.weight", "resnet.6.12.bn1.bias", "resnet.6.12.bn1.running_mean", "resnet.6.12.bn1.running_var", "resnet.6.12.conv2.weight", "resnet.6.12.bn2.weight", "resnet.6.12.bn2.bias", "resnet.6.12.bn2.running_mean", "resnet.6.12.bn2.running_var", "resnet.6.12.conv3.weight", "resnet.6.12.bn3.weight", "resnet.6.12.bn3.bias", "resnet.6.12.bn3.running_mean", "resnet.6.12.bn3.running_var", "resnet.6.13.conv1.weight", "resnet.6.13.bn1.weight", "resnet.6.13.bn1.bias", "resnet.6.13.bn1.running_mean", "resnet.6.13.bn1.running_var", "resnet.6.13.conv2.weight", "resnet.6.13.bn2.weight", "resnet.6.13.bn2.bias", "resnet.6.13.bn2.running_mean", "resnet.6.13.bn2.running_var", "resnet.6.13.conv3.weight", "resnet.6.13.bn3.weight", "resnet.6.13.bn3.bias", "resnet.6.13.bn3.running_mean", "resnet.6.13.bn3.running_var", "resnet.6.14.conv1.weight", "resnet.6.14.bn1.weight", "resnet.6.14.bn1.bias", "resnet.6.14.bn1.running_mean", "resnet.6.14.bn1.running_var", "resnet.6.14.conv2.weight", "resnet.6.14.bn2.weight", "resnet.6.14.bn2.bias", "resnet.6.14.bn2.running_mean", "resnet.6.14.bn2.running_var", "resnet.6.14.conv3.weight", "resnet.6.14.bn3.weight", "resnet.6.14.bn3.bias", "resnet.6.14.bn3.running_mean", "resnet.6.14.bn3.running_var", "resnet.6.15.conv1.weight", "resnet.6.15.bn1.weight", "resnet.6.15.bn1.bias", "resnet.6.15.bn1.running_mean", "resnet.6.15.bn1.running_var", "resnet.6.15.conv2.weight", "resnet.6.15.bn2.weight", "resnet.6.15.bn2.bias", "resnet.6.15.bn2.running_mean", "resnet.6.15.bn2.running_var", "resnet.6.15.conv3.weight", "resnet.6.15.bn3.weight", "resnet.6.15.bn3.bias", "resnet.6.15.bn3.running_mean", "resnet.6.15.bn3.running_var", "resnet.6.16.conv1.weight", "resnet.6.16.bn1.weight", "resnet.6.16.bn1.bias", "resnet.6.16.bn1.running_mean", "resnet.6.16.bn1.running_var", "resnet.6.16.conv2.weight", "resnet.6.16.bn2.weight", "resnet.6.16.bn2.bias", "resnet.6.16.bn2.running_mean", "resnet.6.16.bn2.running_var", "resnet.6.16.conv3.weight", "resnet.6.16.bn3.weight", "resnet.6.16.bn3.bias", "resnet.6.16.bn3.running_mean", "resnet.6.16.bn3.running_var", "resnet.6.17.conv1.weight", "resnet.6.17.bn1.weight", "resnet.6.17.bn1.bias", "resnet.6.17.bn1.running_mean", "resnet.6.17.bn1.running_var", "resnet.6.17.conv2.weight", "resnet.6.17.bn2.weight", "resnet.6.17.bn2.bias", "resnet.6.17.bn2.running_mean", "resnet.6.17.bn2.running_var", "resnet.6.17.conv3.weight", "resnet.6.17.bn3.weight", "resnet.6.17.bn3.bias", "resnet.6.17.bn3.running_mean", "resnet.6.17.bn3.running_var", "resnet.6.18.conv1.weight", "resnet.6.18.bn1.weight", "resnet.6.18.bn1.bias", "resnet.6.18.bn1.running_mean", "resnet.6.18.bn1.running_var", "resnet.6.18.conv2.weight", "resnet.6.18.bn2.weight", "resnet.6.18.bn2.bias", "resnet.6.18.bn2.running_mean", "resnet.6.18.bn2.running_var", "resnet.6.18.conv3.weight", "resnet.6.18.bn3.weight", "resnet.6.18.bn3.bias", "resnet.6.18.bn3.running_mean", "resnet.6.18.bn3.running_var", "resnet.6.19.conv1.weight", "resnet.6.19.bn1.weight", "resnet.6.19.bn1.bias", "resnet.6.19.bn1.running_mean", "resnet.6.19.bn1.running_var", "resnet.6.19.conv2.weight", "resnet.6.19.bn2.weight", "resnet.6.19.bn2.bias", "resnet.6.19.bn2.running_mean", "resnet.6.19.bn2.running_var", "resnet.6.19.conv3.weight", "resnet.6.19.bn3.weight", "resnet.6.19.bn3.bias", "resnet.6.19.bn3.running_mean", "resnet.6.19.bn3.running_var", "resnet.6.20.conv1.weight", "resnet.6.20.bn1.weight", "resnet.6.20.bn1.bias", "resnet.6.20.bn1.running_mean", "resnet.6.20.bn1.running_var", "resnet.6.20.conv2.weight", "resnet.6.20.bn2.weight", "resnet.6.20.bn2.bias", "resnet.6.20.bn2.running_mean", "resnet.6.20.bn2.running_var", "resnet.6.20.conv3.weight", "resnet.6.20.bn3.weight", "resnet.6.20.bn3.bias", "resnet.6.20.bn3.running_mean", "resnet.6.20.bn3.running_var", "resnet.6.21.conv1.weight", "resnet.6.21.bn1.weight", "resnet.6.21.bn1.bias", "resnet.6.21.bn1.running_mean", "resnet.6.21.bn1.running_var", "resnet.6.21.conv2.weight", "resnet.6.21.bn2.weight", "resnet.6.21.bn2.bias", "resnet.6.21.bn2.running_mean", "resnet.6.21.bn2.running_var", "resnet.6.21.conv3.weight", "resnet.6.21.bn3.weight", "resnet.6.21.bn3.bias", "resnet.6.21.bn3.running_mean", "resnet.6.21.bn3.running_var", "resnet.6.22.conv1.weight", "resnet.6.22.bn1.weight", "resnet.6.22.bn1.bias", "resnet.6.22.bn1.running_mean", "resnet.6.22.bn1.running_var", "resnet.6.22.conv2.weight", "resnet.6.22.bn2.weight", "resnet.6.22.bn2.bias", "resnet.6.22.bn2.running_mean", "resnet.6.22.bn2.running_var", "resnet.6.22.conv3.weight", "resnet.6.22.bn3.weight", "resnet.6.22.bn3.bias", "resnet.6.22.bn3.running_mean", "resnet.6.22.bn3.running_var". 
	size mismatch for resnet.6.0.conv1.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 64, 1, 1]).
	size mismatch for resnet.6.0.bn1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.0.bn1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.0.bn1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.0.bn1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.0.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for resnet.6.0.bn2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.0.bn2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.0.bn2.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.0.bn2.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.0.conv3.weight: copying a param with shape torch.Size([1024, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 64, 1, 1]).
	size mismatch for resnet.6.0.bn3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.0.bn3.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.0.bn3.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.0.bn3.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.0.downsample.0.weight: copying a param with shape torch.Size([1024, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 64, 1, 1]).
	size mismatch for resnet.6.0.downsample.1.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.0.downsample.1.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.0.downsample.1.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.0.downsample.1.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.1.conv1.weight: copying a param with shape torch.Size([256, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 256, 1, 1]).
	size mismatch for resnet.6.1.bn1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.1.bn1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.1.bn1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.1.bn1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.1.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for resnet.6.1.bn2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.1.bn2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.1.bn2.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.1.bn2.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.1.conv3.weight: copying a param with shape torch.Size([1024, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 64, 1, 1]).
	size mismatch for resnet.6.1.bn3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.1.bn3.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.1.bn3.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.1.bn3.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.2.conv1.weight: copying a param with shape torch.Size([256, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 256, 1, 1]).
	size mismatch for resnet.6.2.bn1.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.2.bn1.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.2.bn1.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.2.bn1.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.2.conv2.weight: copying a param with shape torch.Size([256, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([64, 64, 3, 3]).
	size mismatch for resnet.6.2.bn2.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.2.bn2.bias: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.2.bn2.running_mean: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.2.bn2.running_var: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([64]).
	size mismatch for resnet.6.2.conv3.weight: copying a param with shape torch.Size([1024, 256, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 64, 1, 1]).
	size mismatch for resnet.6.2.bn3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.2.bn3.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.2.bn3.running_mean: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.6.2.bn3.running_var: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([256]).
	size mismatch for resnet.7.0.conv1.weight: copying a param with shape torch.Size([512, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 256, 1, 1]).
	size mismatch for resnet.7.0.bn1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.0.bn1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.0.bn1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.0.bn1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.0.conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for resnet.7.0.bn2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.0.bn2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.0.bn2.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.0.bn2.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.0.conv3.weight: copying a param with shape torch.Size([2048, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 128, 1, 1]).
	size mismatch for resnet.7.0.bn3.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.0.bn3.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.0.bn3.running_mean: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.0.bn3.running_var: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.0.downsample.0.weight: copying a param with shape torch.Size([2048, 1024, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 256, 1, 1]).
	size mismatch for resnet.7.0.downsample.1.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.0.downsample.1.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.0.downsample.1.running_mean: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.0.downsample.1.running_var: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.1.conv1.weight: copying a param with shape torch.Size([512, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 512, 1, 1]).
	size mismatch for resnet.7.1.bn1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.1.bn1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.1.bn1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.1.bn1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.1.conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for resnet.7.1.bn2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.1.bn2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.1.bn2.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.1.bn2.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.1.conv3.weight: copying a param with shape torch.Size([2048, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 128, 1, 1]).
	size mismatch for resnet.7.1.bn3.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.1.bn3.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.1.bn3.running_mean: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.1.bn3.running_var: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.2.conv1.weight: copying a param with shape torch.Size([512, 2048, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 512, 1, 1]).
	size mismatch for resnet.7.2.bn1.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.2.bn1.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.2.bn1.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.2.bn1.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.2.conv2.weight: copying a param with shape torch.Size([512, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([128, 128, 3, 3]).
	size mismatch for resnet.7.2.bn2.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.2.bn2.bias: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.2.bn2.running_mean: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.2.bn2.running_var: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([128]).
	size mismatch for resnet.7.2.conv3.weight: copying a param with shape torch.Size([2048, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([512, 128, 1, 1]).
	size mismatch for resnet.7.2.bn3.weight: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.2.bn3.bias: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.2.bn3.running_mean: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for resnet.7.2.bn3.running_var: copying a param with shape torch.Size([2048]) from checkpoint, the shape in current model is torch.Size([512]).

In [12]:
inference(encoder, decoder, img, word_map, beam_size)

AttributeError: 'Bottleneck' object has no attribute 'dequant'

In [None]:
inference(myencoder, decoder, img, word_map, beam_size)

In [None]:
prune_my_encoder(encoder)

In [None]:
inference(myencoder, decoder, img, word_map, beam_size)

In [None]:
total_encoder_sparsity = 0
den = 0
num = 0
for name, module in myencoder.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        print("Sparsity in {} (shape = {}): {:.2f}%".format
        (name, module.weight.shape,
        100. * float(torch.sum(module.weight == 0))/ float(module.weight.nelement())
        ))
        num += float(torch.sum(module.weight == 0))
        den += float(module.weight.nelement())
    elif isinstance(module, torch.nn.Linear):
        print("Sparsity in {} (shape = {}): {:.2f}%".format
        (name, module.weight.shape,
        100. * float(torch.sum(module.weight == 0))/ float(module.weight.nelement())
        ))
        num += float(torch.sum(module.weight == 0))
        den += float(module.weight.nelement())
total_encoder_sparsity = 100.*num/den        
print("Total sparsity in the pruned encoder: {:.2f}%".format(total_encoder_sparsity))        

In [None]:
for name, module in myencoder.named_modules():
    if isinstance(module, torch.nn.Linear) or isinstance(module, torch.nn.LSTM):
        quantized_model = quantization.quantize_dynamic(module, {torch.nn.LSTM, torch.nn.Linear}, dtype=torch.qint8)
        for n, p in quantized_model.named_parameters():
            print(n, 'has type:', p.dtype)

In [None]:
# QUANTIZATION FUNCTION
def quantize_my_model(model):
    # set the qconfig for PTQ
    model.qconfig = quantization.get_default_qconfig('fbgemm')
    # set the qengine to control weight packing
    torch.backends.quantized.engine = 'fbgemm'
    # put model in eval mode
    model.eval()
    quantization.prepare(model, inplace=True)
    quantization.convert(myencoder, inplace=True)

In [None]:
quantize_my_model(myencoder)
print(myencoder)

In [None]:
inference(myencoder, decoder, img, word_map, beam_size)
# model.to('QuantizedCPU')
# summary(myencoder, (3,256,256))

In [None]:
# # Specify quantization configuration
# # Start with simple min/max range estimation and per-tensor
# # quantization of weights
# myencoder.qconfig = quantization.default_qconfig
# # Convert to quantized model
# quantization.convert(myencoder, inplace=True)
# # Calibrate with the training set
# # evaluate(myModel, criterion, data_loader, neval_batches=num_calibration_batches)
# for name, module in myencoder.named_modules(): 
#     for n, p in module.named_parameters():
#             print(n, 'has type:', p.dtype)
# # prunedencoder = Encoder()
# # myencoder_state_dict = myencoder.state_dict()
# # prunedencoder_state_dict = prunedencoder.state_dict()
# # for key in myencoder_state_dict.keys():
# #     print(key)
#     #     if key in prunedencoder_state_dict.keys():
# #         prunedencoder_state_dict[key] = myencoder_state_dict[key]
# # prunedencoder.load_state_dict(myencoder_state_dict)

In [None]:

# summary(myencoder, (3,256,256))
for name, param in myencoder.named_parameters():
    print(param.dtype)

In [None]:
summary(encoder, (3,256,256))
