In [159]:
from data.dataset import VisualGroundingDataset
from collections import OrderedDict
import os
import torch
    import pytorch_lightning as pl
    import torchvision.models as models
import torch.nn.functional as F

In [220]:
from config.config import cfg
dataset = VisualGroundingDataset('train', cfg)
img, text,length = dataset[1]
text = text.unsqueeze(0).expand(3,-1,-1)
length = [length,length,length]
img = img.unsqueeze(0).expand(3,-1,-1,-1)

elmo embedder initialized on gpu?: True


In [222]:
class VisualFeatures(pl.LightningModule):

# layer ids correspond to layers 4.1,4.3,5.1,5.3 in the vgg16 architecture
    FEATURE_LAYERS_ID = ['18','22','25','29']
    M = 18
    D = 1024
    NUM_CONV_LAYERS = 3
    alpha = 0.25
    def __init__(self):
        super().__init__()
        image_model = models.vgg16(pretrained=False)
        checkpoint = torch.load('/models/cache/hub/checkpoints/vgg16-397923af.pth')
        image_model.load_state_dict(checkpoint)
        self.pretrained_model = list(image_model.children())[0]
        
        for parameter in self.pretrained_model.parameters():
            parameter.requires_grad = False
        
        self.raw_visual_features = OrderedDict()
        self.visual_features = OrderedDict()
        
        #no explict need to reference these hooks ,but reference them for potential future use
        self.forward_hooks = []
        
        
        for l in list(self.pretrained_model._modules.keys()):
            
            if l in self.FEATURE_LAYERS_ID:
                self.forward_hooks.append(getattr(self.pretrained_model,l).register_forward_hook(self._forward_hook(l)) )
        
        
        self.resize = torch.nn.Upsample(size=(self.M,self.M), scale_factor=None, mode='bilinear', align_corners=False)
        
        for l in self.FEATURE_LAYERS_ID:
            for i in range(self.NUM_CONV_LAYERS):
                layer_name = l+'_conv_' + str(i)
                if i == 0:
                    setattr(self,layer_name,torch.nn.Conv2d(512, self.D,  kernel_size = (1,1), stride=1))
                else:
                    setattr(self,layer_name,torch.nn.Conv2d(self.D, self.D,  kernel_size = (1,1), stride=1))
        self.leaky_relu = torch.nn.LeakyReLU(self.alpha, inplace=True)
    def _forward_hook(self,layer_id):
        def hook(module,input,output):
            self.raw_visual_features[layer_id] = output
        return hook

    def forward(self,x):
        out = self.pretrained_model(x)
        for l in self.FEATURE_LAYERS_ID:
            x = self.resize(self.raw_visual_features[l])
            for i in range(self.NUM_CONV_LAYERS):
                layer_name = l+'_conv_' + str(i)
                x = getattr(self,layer_name)(x)
                x = self.leaky_relu(x)
            self.visual_features[l] = x
        x = torch.stack([self.visual_features[l] for l in self.visual_features.keys()],1)
        x = torch.nn.functional.normalize(x ,p=2,dim=2).permute((0,3,4,1,2))
        return x
        

In [223]:

class TextualFeatures(pl.LightningModule):
    """
    Decoder Object

    Note that the forward method iterates through all time steps
    It expects an input of [batch]
    """

    def __init__(self, cfg):
        
        
        """
        """
        super().__init__()
        
        
        self.input_dim = cfg['text']['lstm']['input_dims']
        self.hidden_dim = cfg['feature_hidden_dimension']
        assert self.hidden_dim%2 == 0
        self.lstm_hidden_dim = self.hidden_dim//2
        self.num_layers = cfg['text']['lstm']['num_layers']
        self.dropout = cfg['text']['lstm']['dropout']
        self.relu_alpha = cfg['leaky_relu_alpha']
        self.lstm_1 = torch.nn.LSTM(self.input_dim, self.lstm_hidden_dim,
                                num_layers=1, dropout=self.dropout, bidirectional=True,batch_first=True,)
        self.lstm_2 = torch.nn.LSTM(self.hidden_dim, self.lstm_hidden_dim,
                                num_layers=1, dropout=self.dropout, bidirectional=True,batch_first=True,)
                                   
        self.leaky_relu = torch.nn.LeakyReLU(self.relu_alpha, inplace=True)
        self.sentence_fc = torch.nn.Sequential(torch.nn.Linear(self.hidden_dim,self.hidden_dim),self.leaky_relu,torch.nn.Linear(self.hidden_dim,self.hidden_dim),self.leaky_relu)
        self.word_fc = torch.nn.Sequential(torch.nn.Linear(self.hidden_dim,self.hidden_dim),self.leaky_relu,torch.nn.Linear(self.hidden_dim,self.hidden_dim),self.leaky_relu)
        
        self.word_linear_comb = torch.nn.Linear(3,1)
        self.sentence_linear_comb = torch.nn.Linear(2,1)
    def forward(self, x,seq_len):
        """
       
        """

        packed_x = torch.nn.utils.rnn.pack_padded_sequence(x, seq_len,batch_first=True)
        packed_output_1, (hidden_1, cell_1) = self.lstm_1(packed_x)
        packed_output_2, (hidden_2, cell_2) = self.lstm_2(packed_output_1)
        
        output_1, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output_1,batch_first=True)
        output_2, _ = torch.nn.utils.rnn.pad_packed_sequence(packed_output_2,batch_first=True)
        
        word_feature = self.word_linear_comb(torch.stack([x,output_1,output_2,],-1)).squeeze(-1)
        word_feature = self.word_fc(word_feature)
        
        #start embedding taken from the backward lstm
#         sentence_features_backwards = [output_1[:,0,self.hidden_dim:],output_1[:,0,self.hidden_dim:]],-1)
        
        #end of sentence embedding taken from the forward lstm
        sentence_end_1 = []
        sentence_end_2 = []
        for i,end in enumerate(seq_len):
            sentence_end_1.append(output_1[i,end-1,:self.lstm_hidden_dim])
            sentence_end_2.append(output_2[i,end-1,:self.lstm_hidden_dim])
#         sentence_features_forward = torch.cat([torch.stack(sentence_end_1,0),torch.stack(sentence_end_1,0)],-1)
        
        first_layer_sentence_feature = torch.cat([torch.stack(sentence_end_1,0),output_1[:,0,self.lstm_hidden_dim:]],-1)
        second_layer_sentence_feature = torch.cat([torch.stack(sentence_end_2,0),output_2[:,0,self.lstm_hidden_dim:]],-1)
        sentence_feature = self.sentence_linear_comb(torch.stack([first_layer_sentence_feature,second_layer_sentence_feature],-1)).squeeze(-1)
        sentence_feature = self.sentence_fc(sentence_feature)
        
        
    
        return  word_feature, sentence_feature

    def get_len_mask(self, batch_size, max_len, seq_lens):
        """Generates an upper-triangular matrix of -inf, with zeros on diag."""
        block = torch.ones(batch_size, max_len, max_len)
        for i in range(batch_size):
            seq_len = seq_lens[i]
            block[i, :seq_len, :seq_len] = torch.zeros(seq_len, seq_len)
        return block.bool()


In [224]:
text_model = TextualFeatures(cfg['model']).cuda()
output,output2= text_model(text,length)
viz_model = VisualFeatures()
img_output = viz_model(img)
# print(cell.shape)

In [302]:

class MultimodalAttention(pl.LightningModule):
    
    def __init__(self,cfg):
        super().__init__()
        self.L = 4
        self.M = 18
        self.gamma_1 = 5
        self.gamma_2 = 10
 
       
        
    def forward(self,word_feature, sentence_feature, image_feature,seq_lens):
        
        '''
        word_feature dims (B,T,D)
        image_feature dims (B,M,M,L,D)
        '''
        #reshape to (B,M,M,T,L,D)
        batch_size = word_feature.shape[0]
        word_image_max_score = self.get_pertinence_scores(word_feature,image_feature,batch_size)
        sentence_image_score = self.get_pertinence_scores(sentence_feature.unsqueeze(1),image_feature,batch_size).squeeze(2)
        
        
        aggregated_sentence_image_score = torch.exp(word_image_max_score * self.gamma_1) *  self.get_len_mask(batch_size , max_word_len,seq_lens)
        aggregated_sentence_image_score = torch.log(torch.sum(aggregated_sentence_image_score,2) ** (1/self.gamma_1))
        
        return aggregated_sentence_image_score, sentence_image_score 

    def get_pertinence_scores(self,word_feature,image_feature,batch_size):
        # #reshape to (B,B'M,M,T,L,D) repeated along dim 0
        reshaped_word_feature = word_feature.unsqueeze(1).unsqueeze(1).unsqueeze(4).unsqueeze(0).expand(batch_size,-1,self.M,self.M,-1,self.L,-1)
        max_word_len = reshaped_word_feature.shape[4]
        #reshape to (B,B'M,M,T,L,D) repeated along dim 1
        reshaped_image_feature = image_feature.unsqueeze(3).unsqueeze(1).expand(-1,batch_size,-1,-1,max_word_len,-1,-1)
        #heatmap dims (B,B',M,M,T,L)
        similarity_heatmap = F.relu(F.cosine_similarity(reshaped_word_feature,reshaped_image_feature,dim=6))
        
        # collapse image width and heigh dimensions into single dim for weighted summing via matrix mult
        # reshape so that dimension to sum across is at the end  
        # (B,B' T, L, 1, MXM)
        similarity_heatmap_flat = torch.flatten(similarity_heatmap, start_dim=2, end_dim=3).permute(0,1,3,4,2).unsqueeze(4)
        # (B,B',T,L,MXM,D)
        # collapse width and height dims for image_feature for weighted summing via matrix mult
        image_feature_flat = torch.flatten(image_feature, start_dim=1, end_dim=2).permute(0,2,1,3).unsqueeze(1).unsqueeze(1).expand(-1,batch_size,max_word_len,-1,-1,-1)

        visual_word_attention = torch.matmul(similarity_heatmap_flat,image_feature_flat).squeeze(4)
        
        #(B,B',T,L,D)
        visual_word_attention = torch.nn.functional.normalize(visual_word_attention ,p=2,dim=4)
        
        #(B,B',T,L)
        word_image_pertinence_score = F.cosine_similarity(word_feature.unsqueeze(2).unsqueeze(1).expand(-1,batch_size,-1,self.L,-1),visual_word_attention,dim=4)
        
        #(B,B',T,)
        word_image_max_pertinence_score,_ = torch.max(word_image_pertinence_score,dim=3)
        return word_image_max_pertinence_score
    
    def get_len_mask(self, batch_size, max_len, seq_lens):
        """Generates an 'upper-triangular matrix' with 1 in places without mask"""
        block = torch.zeros(batch_size, max_len)
        for i,l in enumerate(seq_lens):
            block[i, :l] = torch.ones(1, l)
        block = block.unsqueeze(0).expand(batch_size,-1,-1)
        return block.cuda()
        

In [303]:
class VisualGroundingModel(pl.LightningModule):
    self.gamma_2 = 10
    def get_multimodal_loss(self,sentence_image_score):
        
        score = torch.exp(sentence_image_score * self.gamma_2)
        fixed_image_score_ = score / torch.sum(score,dim=1,keepdim=True)
        fixed_sentence_score = score / torch.sum(score,dim=0,keepdim=True)
        
        loss = -torch.sum(torch.log(torch.diagonal(fixed_image_score, 0)) + torch.log(torch.diagonal(fixed_image_score, 0)))
        return loss
    
    def get_full_multimodal_loss(self,aggregated_sentence_image_score,sentence_image_score):
        return self.get_multimodal_loss(aggregated_sentence_image_score) + self.get_multimodal_loss(sentence_image_score)
    

SyntaxError: duplicate argument 'sentence_image_score' in function definition (3968648132.py, line 3)

In [304]:
multimodal_attn = MultimodalAttention(cfg)
m_output1,m_output2 = multimodal_attn(output.cuda(),output2.cuda(),img_output.cuda(),length)

In [319]:
score = torch.exp(m_output1 * gamma_2)

In [328]:
loss = -torch.sum(torch.log(torch.diagonal(fixed_image_score, 0)) + torch.log(torch.diagonal(fixed_image_score, 0)))

In [329]:
loss

tensor(6.5917, device='cuda:0', grad_fn=<NegBackward>)

In [324]:
fixed_image_score = score / torch.sum(score,dim=1,keepdim=True)
fixed_sentence_score = score / torch.sum(score,dim=0,keepdim=True)

In [315]:
torch.sum(aggregated_score,dim=1,keepdim=True).shape

torch.Size([3, 1])

In [226]:
word_feature = output
image_feature = img_output

In [256]:
reshaped_word_feature = word_feature.unsqueeze(1).unsqueeze(1).unsqueeze(4).unsqueeze(0).expand(batch_size,-1,M,M,-1,L,-1).cuda()

In [257]:
reshaped_image_feature = image_feature.unsqueeze(3).unsqueeze(1).expand(-1,batch_size,-1,-1,max_word_len,-1,-1).cuda()

In [258]:
similarity_heatmap = F.relu(F.cosine_similarity(reshaped_word_feature,reshaped_image_feature,dim=6))

In [254]:
reshaped_image_feature.shape

torch.Size([3, 3, 18, 18, 26, 4, 1024])

In [229]:
batch_size =3

In [282]:
image_feature_flat.shape

torch.Size([3, 3, 26, 4, 324, 1024])

In [299]:
gamma_1 =5

In [None]:
aggregated_sentence_image_score = torch.exp(word_image_max_score * gamma_1) *  self.get_len_mask(batch_size , max_word_len,seq_lens)

In [267]:
similarity_heatmap_flat = torch.flatten(similarity_heatmap, start_dim=2, end_dim=3).permute(0,1,3,4,2).unsqueeze(4)

In [284]:
image_feature_flat = torch.flatten(image_feature, start_dim=1, end_dim=2).permute(0,2,1,3).unsqueeze(1).unsqueeze(1).expand(-1,batch_size,max_word_len,-1,-1,-1).cuda()


In [287]:
visual_word_attention = torch.matmul(similarity_heatmap_flat,image_feature_flat).squeeze(4)

In [290]:
visual_word_attention.shape

torch.Size([3, 3, 26, 4, 1024])

In [297]:
word_image_pertinence_score = F.cosine_similarity(word_feature.unsqueeze(2).unsqueeze(1).expand(-1,batch_size,-1,L,-1),visual_word_attention,dim=4)

In [298]:
word_image_pertinence_score.shape

torch.Size([3, 3, 26, 4])