In [1]:
import torch
import torchvision.models as models
import os
import pytorch_lightning as pl
from collections import OrderedDict 



In [88]:
os.environ['TORCH_HOME'] = '/models/cache' #setting the environment variable
resnet = models.vgg16(pretrained=True)

# model = torch.hub.load('pytorch/vision:v0.10.0', 'vgg11', pretrained=True)

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /models/cache/hub/checkpoints/vgg16-397923af.pth


  0%|          | 0.00/528M [00:00<?, ?B/s]

In [4]:
class VisualFeatures(pl.LightningModule):

# layer ids correspond to layers 4.1,4.3,5.1,5.3 in the vgg16 architecture
    FEATURE_LAYERS_ID = ['18','22','25','29']
    M = 18
    D = 1024
    NUM_CONV_LAYERS = 3
    alpha = 0.25
    def __init__(self):
        super().__init__()
        image_model = models.vgg16(pretrained=False)
        checkpoint = torch.load('/models/cache/hub/checkpoints/vgg16-397923af.pth')
        image_model.load_state_dict(checkpoint)
        self.pretrained_model = list(image_model.children())[0]
        
        for parameter in self.pretrained_model.parameters():
            parameter.requires_grad = False
        
        self.raw_visual_features = OrderedDict()
        self.visual_features = OrderedDict()
        
        #no explict need to reference these hooks ,but reference them for potential future use
        self.forward_hooks = []
        
        
        for l in list(self.pretrained_model._modules.keys()):
            
            if l in self.FEATURE_LAYERS_ID:
                self.forward_hooks.append(getattr(self.pretrained_model,l).register_forward_hook(self._forward_hook(l)) )
        
        
        self.resize = torch.nn.Upsample(size=(self.M,self.M), scale_factor=None, mode='bilinear', align_corners=False)
        
        for l in self.FEATURE_LAYERS_ID:
            for i in range(self.NUM_CONV_LAYERS):
                layer_name = l+'_conv_' + str(i)
                if i == 0:
                    setattr(self,layer_name,torch.nn.Conv2d(512, self.D,  kernel_size = (1,1), stride=1))
                else:
                    setattr(self,layer_name,torch.nn.Conv2d(self.D, self.D,  kernel_size = (1,1), stride=1))
        self.leaky_relu = torch.nn.LeakyReLU(self.alpha, inplace=True)
    def _forward_hook(self,layer_id):
        def hook(module,input,output):
            self.raw_visual_features[layer_id] = output
        return hook

    def forward(self,x):
        out = self.pretrained_model(x)
        for l in self.FEATURE_LAYERS_ID:
            x = self.resize(self.raw_visual_features[l])
            for i in range(self.NUM_CONV_LAYERS):
                layer_name = l+'_conv_' + str(i)
                x = getattr(self,layer_name)(x)
                x = self.leaky_relu(x)
            self.visual_features[l] = x
        x = torch.stack([self.visual_features[l] for l in self.visual_features.keys()],1)
        x = torch.nn.functional.normalize(x ,p=2,dim=2).permute((0,3,4,1,2))
        return x
        

In [5]:
sample = torch.rand(1,3,224,224)
viz = VisualFeatures()
result = viz(sample)
print(result.shape)
# for key,value in result.items():
#     print(value.shape)
# viz.eval()
# script = viz.to_torchscript()
# torch.jit.save(script, os.path.join('/models',"model.pt"))


torch.Size([1, 18, 18, 4, 1024])


In [112]:
torch.norm(result,dim=2,p=2).shape

torch.Size([1, 1024, 18, 18])

In [76]:
result

OrderedDict([('10',
              tensor([[[[ 5.7290,  3.7643,  4.2850,  ...,  6.4441,  6.9683, 10.6676],
                        [ 5.2510,  3.3421,  4.2545,  ...,  3.5091,  2.2189,  5.6018],
                        [ 3.9606,  4.4606,  4.3926,  ...,  4.4884,  2.5008,  4.4092],
                        ...,
                        [ 4.0569,  1.0795,  2.2917,  ...,  0.4592,  1.7108,  3.5214],
                        [ 5.0973,  2.7780,  1.9630,  ...,  3.1584,  1.2424,  4.1274],
                        [ 9.1784,  8.0466,  6.0983,  ...,  7.0364,  6.0461,  9.3533]],
              
                       [[ 1.2641,  0.0000,  1.6852,  ...,  0.6854,  0.1268,  3.2701],
                        [ 0.6663,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                        [ 0.0000,  0.0000,  1.0350,  ...,  0.0000,  0.0000,  2.4748],
                        ...,
                        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  4.9356],
                        [ 3.3929,  0.0000,  0.

In [46]:
sample.shape

torch.Size([3, 224, 224])

In [34]:
image_model = models.vgg11(pretrained=False)
checkpoint = torch.load('/models/cache/hub/checkpoints/vgg11-8a719046.pth')
image_model.load_state_dict(checkpoint)

newmodel = list(model.children())[0]
print(newmodel)

for parameter in newmodel.parameters():
    parameter.requires_grad = False

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): ReLU(inplace=True)
  (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (14): ReLU(inplace=True)
  (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (16): Conv2d(512, 512, kernel_size=(3, 3), stride=

In [27]:
list(newmodel._modules)

['0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 '10',
 '11',
 '12',
 '13',
 '14',
 '15',
 '16',
 '17',
 '18',
 '19',
 '20']

In [24]:
for parameter in newmodel.parameters():
    parameter.requires_grad = False

In [22]:
list(model.children())[0]

[19,]

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): ReLU(inplace=True)
  (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (12): ReLU(inplace=True)
  (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (14): ReLU(inplace=True)
  (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (16): Conv2d(512, 512, kernel_size=(3, 3), stride=

In [3]:
scripted_gate = torch.jit.script(resnet)

In [5]:
scripted_gate.save('/models/vgg/wrapped_rnn.pt')

In [11]:
resnet


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU(inplace=True)
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU(inplace=True)
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
 

In [89]:

from torchsummary import summary
image_model = models.vgg16(pretrained=False)
checkpoint = torch.load('/models/cache/hub/checkpoints/vgg16-397923af.pth')
image_model.load_state_dict(checkpoint)
image_model.cuda()
summary(image_model, (3, 224, 224))

# [19,18,15,13,10]

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256,

In [97]:
image_model.features._modules['22']

ReLU(inplace=True)

In [None]:

class MultimodalAttentionOld(pl.LightningModule):
    
    def __init__(self,cfg):
        super().__init__()
        self.L = 4
        self.M = 18
        self.gamma_1 = 5
        self.gamma_2 = 10
 
       
        
    def forward(self,word_feature, sentence_feature, image_feature,seq_lens):
        
        '''
        word_feature dims (B,T,D)
        image_feature dims (B,M,M,L,D)
        '''
        #reshape to (B,M,M,T,L,D)
        batch_size = word_feature.shape[0]
        word_image_max_score = self.get_pertinence_scores(word_feature,image_feature)
        sentence_image_score = self.get_pertinence_scores(sentence_feature.unsqueeze(1),image_feature).squeeze(1)
        
        
        aggregated_sentence_image_score = torch.exp(word_image_max_score * self.gamma_1) *  self.get_len_mask(batch_size , max_word_len,seq_lens)
        aggregated_sentence_image_score = torch.log(torch.sum(aggregated_sentence_image_score,1) ** (1/self.gamma_1))
        
        return aggregated_sentence_image_score, sentence_image_score 

    def get_pertinence_scores(self,word_feature,image_feature):
        reshaped_word_feature = word_feature.unsqueeze(1).unsqueeze(1).unsqueeze(4).expand(-1,self.M,self.M,-1,self.L,-1)
        max_word_len = reshaped_word_feature.shape[3]
        #reshape to (B,M,M,T,L,D)
        reshaped_image_feature = image_feature.unsqueeze(3).expand(-1,-1,-1,max_word_len,-1,-1)
        
        #heatmap dims (B,M,M,T,L)
        similarity_heatmap = F.relu(F.cosine_similarity(reshaped_word_feature,reshaped_image_feature,dim=5))
        
        # collapse image width and heigh dimensions into single dim for weighted summing via matrix mult
        # reshape so that dimension to sum across is at the end  
        # (B, T, L, 1, MXM)
        similarity_heatmap_flat = torch.flatten(similarity_heatmap, start_dim=1, end_dim=2).permute(0,2,3,1).unsqueeze(3)
        
        # (B,T,L,MXM,D)
        # collapse width and height dims for image_feature for weighted summing via matrix mult
        image_feature_flat = torch.flatten(image_feature, start_dim=1, end_dim=2).permute(0,2,1,3).unsqueeze(1).expand(-1,max_word_len,-1,-1,-1)

        visual_word_attention = torch.matmul(similarity_heatmap_flat,image_feature_flat).squeeze(3)
        
        #(B,T,L,D)
        visual_word_attention = torch.nn.functional.normalize(visual_word_attention ,p=2,dim=3)
        
        #(B,T,L)
        word_image_pertinence_score = F.cosine_similarity(word_feature.unsqueeze(2).expand(-1,-1,self.L,-1),visual_word_attention,dim=3)
        
        word_image_max_pertinence_score,_ = torch.max(word_image_pertinence_score,dim=2)
        return word_image_max_pertinence_score
    
    def get_len_mask(self, batch_size, max_len, seq_lens):
        """Generates an 'upper-triangular matrix' with 1 in places without mask"""
        block = torch.zeros(batch_size, max_len)
        for i,l in enumerate(seq_lens):
            block[i, :l] = torch.ones(1, l)
        
        return block.cuda()