In [1]:
import json
import torch
import torch.nn.utils.prune as prune
import torch.nn.functional as F
import torch_pruning as tp
import torchvision.models as models
import caption

In [2]:
class Encoder(torch.nn.Module):
    """
    Encoder.
    """

    def __init__(self, encoded_image_size=14):
        super(Encoder, self).__init__()
        self.enc_image_size = encoded_image_size

        #resnet = torchvision.models.resnet101(pretrained=True)  # pretrained ImageNet ResNet-101
        resnet = models.resnet101()
        #resnet.load_state_dict(torch.load("resnet101-2.pth"))

        # Remove linear and pool layers (since we're not doing classification)
        modules = list(resnet.children())[:-2]
        self.resnet = torch.nn.Sequential(*modules)

        # Resize image to fixed size to allow input images of variable size
        self.adaptive_pool = torch.nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune()

    def forward(self, images):
        """
        Forward propagation.

        :param images: images, a tensor of dimensions (batch_size, 3, image_size, image_size)
        :return: encoded images
        """
        out = self.resnet(images)  # (batch_size, 2048, image_size/32, image_size/32)
        out = self.adaptive_pool(out)  # (batch_size, 2048, encoded_image_size, encoded_image_size)
        out = out.permute(0, 2, 3, 1)  # (batch_size, encoded_image_size, encoded_image_size, 2048)
        return out

    def fine_tune(self, fine_tune=True):
        """
        Allow or prevent the computation of gradients for convolutional blocks 2 through 4 of the encoder.

        :param fine_tune: Allow?
        """
        for p in self.resnet.parameters():
            p.requires_grad = False
        # If fine-tuning, only fine-tune convolutional blocks 2 through 4
        for c in list(self.resnet.children())[5:]:
            for p in c.parameters():
                p.requires_grad = fine_tune

In [3]:
word_map = "WORDMAP_coco_5_cap_per_img_5_min_word_freq.json"
model = "BEST_checkpoint_coco_5_cap_per_img_5_min_word_freq.pth.tar"
beam_size = 5
device = 'cpu'

checkpoint = torch.load(model, map_location=device)
decoder = checkpoint['decoder']
decoder = decoder.to(device)
decoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
#encoder.eval()

# Load word map (word2ix)
with open(word_map, 'r') as j:
    word_map = json.load(j)
rev_word_map = {v: k for k, v in word_map.items()}  # ix2word

In [4]:
myencoder = Encoder()
myencoder.eval()
encoder = checkpoint['encoder']
encoder = encoder.to(device)
encoder.eval()

Encoder(
  (resnet): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2d(64, 2

In [5]:
sd = encoder.state_dict()
sd2 = myencoder.state_dict()

In [None]:
eps = 5e-3
param = dict()
for key in sd.keys():
    param[key] = torch.Tensor(sd[key].shape)

for key in sd.keys():
#     sd[key][abs(sd[key])<eps] = False
#     param = sd[key][abs(sd[key])>=eps].type(torch.float16)   
    if len(sd[key].shape) == 4:
        for i in range(sd[key].shape[0]):
            for j in range(sd[key].shape[1]):
                for k in range(sd[key].shape[2]):
                    for l in range(sd[key].shape[3]):
                        if sd[key][i,j,k,l] == 0:
                            param[key][i,j,k,l] = False
                        else:
                            param[key][i,j,k,l] = sd[key][i,j,k,l].type(torch.float16)
    else:
        param[key] = sd[key].type(torch.float16)
    
    sd2[key] = param
#     print(param.dtype)
    print(sd2[key])

In [19]:
t= sd['resnet.0.weight'].type(torch.bool)
print(t.dtype)

torch.bool


In [None]:
myencoder.load_state_dict(sd2)

In [13]:
# print(sd)
# print(sd.keys())
# print("\nNEW ENCODER\n")
# print(sd2.keys())
for key in sd.keys():
    if key not in sd2.keys():
        print('They are not the same')
print('END')

END


In [None]:
img = r"C:\Users\xiaomi\OneDrive\TUM\WS 2021-2022\Advanced Topics in Communication Electronics\image-captioning-on-pytorch-master\image-captioning-on-pytorch-master\img5.jpg"
seq, alphas = caption.caption_image_beam_search(myencoder, decoder, img, word_map, beam_size)

In [None]:
alphas = torch.FloatTensor(alphas)
words = [rev_word_map[ind] for ind in seq]
sentence = ""
for word in words:
    sentence = sentence + " " + word

In [None]:
sentence

In [None]:
for name, module in myencoder.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        prune.random_unstructured(module, name='weight', amount=0.3)
        print(list(module.named_buffers()))
        print("salam")
    elif isinstance(module, torch.nn.Linear):
        print("linear")

In [None]:
for name, module in myencoder.named_modules():
    print('Name:',name, 'Module:\n',module)

In [None]:
for name, param in myencoder.named_parameters():
    print('Name:',name, 'Parameter:\n', param)

In [None]:
print(dict(myencoder.named_parameters())['resnet.0.weight_orig'])
print(dict(myencoder.named_parameters())['resnet.0.weight'])