In [3]:
%load_ext autoreload
%autoreload 2

# Imports

In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from MultiheadAttention import MultiheadAttention

# Model

## Conv

In [4]:
class ConvInputModel(nn.Module):
    def __init__(self):
        super(ConvInputModel, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 24, 3, stride=2, padding=1)
        self.batchNorm1 = nn.BatchNorm2d(24)
        self.conv2 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
        self.batchNorm2 = nn.BatchNorm2d(24)
        self.conv3 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
        self.batchNorm3 = nn.BatchNorm2d(24)
        self.conv4 = nn.Conv2d(24, 24, 3, stride=2, padding=1)
        self.batchNorm4 = nn.BatchNorm2d(24)
        
    def forward(self, img):
        """convolution"""
        x = self.conv1(img)
        x = self.batchNorm1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.batchNorm2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = self.batchNorm3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = self.batchNorm4(x)
        x = F.relu(x)
        return x

## Qembed

In [5]:
class QuestionEmbedModel(nn.Module):
    def __init__(self, in_size, embed=32, hidden=128):
        super(QuestionEmbedModel, self).__init__()
        
        self.wembedding = nn.Embedding(in_size + 1, embed)  #word embeddings have size 32
        self.lstm = nn.LSTM(embed, hidden, batch_first=True)  # Input dim is 32, output dim is the question embedding
        self.hidden = hidden
        
    def forward(self, question):
        #calculate question embeddings
        wembed = self.wembedding(question)
        # wembed = wembed.permute(1,0,2) # in lstm minibatches are in the 2-nd dimension
        self.lstm.flatten_parameters()
        _, hidden = self.lstm(wembed) # initial state is set to zeros by default
        qst_emb = hidden[0] # hidden state of the lstm. qst = (B x 128)
        #qst_emb = qst_emb.permute(1,0,2).contiguous()
        #qst_emb = qst_emb.view(-1, self.hidden*2)
        qst_emb = qst_emb[0]
        
        return qst_emb

## RLBase

In [6]:
class RelationalLayerBase(nn.Module):
    def __init__(self, in_size, out_size, qst_size, hyp):
        super().__init__()

        # f_fc1
        self.f_fc1 = nn.Linear(hyp["g_layers"][-1], hyp["f_fc1"])
        self.mha_fc1 = MultiheadAttention(hyp["g_layers"][-1], MULTIHEADATTENTION_HEADS)
        self.identity_fc1 = nn.Identity()
        # f_fc2
        self.f_fc2 = nn.Linear(hyp["f_fc1"], hyp["f_fc2"])
        self.mha_fc2 = MultiheadAttention(hyp["f_fc1"], MULTIHEADATTENTION_HEADS)
        self.identity_fc2 = nn.Identity()
        # f_fc3
        self.f_fc3 = nn.Linear(hyp["f_fc2"], out_size)
        self.mha_fc3 = MultiheadAttention(hyp["f_fc2"], MULTIHEADATTENTION_HEADS)
        self.identity_fc3 = nn.Identity()
    
        self.dropout = nn.Dropout(p=hyp["dropout"])
        
        self.on_gpu = False
        self.hyp = hyp
        self.qst_size = qst_size
        self.in_size = in_size
        self.out_size = out_size

    def cuda(self, device=None):
        self.on_gpu = True
        super().cuda(device)

## RL

In [7]:
class RelationalLayer(RelationalLayerBase):
    def __init__(self, in_size, out_size, qst_size, hyp, extraction=False):
        super().__init__(in_size, out_size, qst_size, hyp)

        self.quest_inject_position = hyp["question_injection_position"]
        self.in_size = in_size

	    #create all g layers
        self.g_layers = []
        self.g_layers_size = hyp["g_layers"]

        #create all multiheadattention layers
        self.mha_layers = []
        self.identity_layers = []

        for idx,g_layer_size in enumerate(hyp["g_layers"]):
            in_s = in_size if idx==0 else hyp["g_layers"][idx-1]
            out_s = g_layer_size
            if idx==self.quest_inject_position:
                #create the h layer. Now, for better code organization, it is part of the g layers pool. 
                l = nn.Linear(in_s+qst_size, out_s)
                mha = MultiheadAttention(in_s+qst_size, MULTIHEADATTENTION_HEADS)
            else:
                #create a standard g layer.
                l = nn.Linear(in_s, out_s)
                mha = MultiheadAttention(in_s, MULTIHEADATTENTION_HEADS)
            self.g_layers.append(l)
            self.mha_layers.append(mha)
            self.identity_layers.append(nn.Identity())


        self.g_layers = nn.ModuleList(self.g_layers)
        self.mha_layers = nn.ModuleList(self.mha_layers)
        self.identity_layers = nn.ModuleList(self.identity_layers)
        self.extraction = extraction
    
    def forward(self, x, qst):
        # x = (B x 8*8 x 24)
        # qst = (B x 128)
        """g"""
        b, d, k = x.size()
        qst_size = qst.size()[1]
        l1_reg = 0
        
        # add question everywhere
        qst = torch.unsqueeze(qst, 1)                      # (B x 1 x 128)
        query = qst.clone().transpose(1, 0)
        qst = qst.repeat(1, d, 1)                       # (B x 64 x 128)
        qst = torch.unsqueeze(qst, 2)                      # (B x 64 x 1 x 128)
        
        # cast all pairs against each other
        x_i = torch.unsqueeze(x, 1)                   # (B x 1 x 64 x 26)
        x_i = x_i.repeat(1, d, 1, 1)                    # (B x 64 x 64 x 26)
        x_j = torch.unsqueeze(x, 2)                   # (B x 64 x 1 x 26)
        #x_j = torch.cat([x_j, qst], 3)
        x_j = x_j.repeat(1, 1, d, 1)                    # (B x 64 x 64 x 26)
        
        # concatenate all together
        x_full = torch.cat([x_i, x_j], 3)                  # (B x 64 x 64 x 2*26)
        
        # reshape for passing through network
        x_ = x_full.view(b * d**2, self.in_size)

        #create g and inject the question at the position pointed by quest_inject_position.
        for idx, (g_layer, mha_layer, g_layer_size, identity) in enumerate(zip(self.g_layers, self.mha_layers, self.g_layers_size, self.identity_layers)):
            if idx==self.quest_inject_position:
                in_size = self.in_size if idx==0 else self.g_layers_size[idx-1]

                # questions inserted
                x_img = x_.view(b,d,d,in_size)
                qst = qst.repeat(1,1,d,1)
                x_concat = torch.cat([x_img,qst],3) #(B x 64 x 64 x 128 + 2 * 26)

                # h layer
                x_ = x_concat.view(b*(d**2),in_size+self.qst_size)
                x_ = g_layer(x_)
                x_ = F.relu(x_)
            else:
                x_ = g_layer(x_)
                x_ = F.relu(x_)
                # Pass through multiheadattention layer
                weights = torch.unsqueeze(g_layer.weight, 0).repeat(b, 1, 1).transpose(1, 0)
                _, attn_output_weights = mha_layer(query, weights, weights)
                l1_reg += (attn_output_weights.abs().sum() / (attn_output_weights.size(0) * attn_output_weights.size(2)))
                attn_output_weights = attn_output_weights.repeat(1, d**2, 1)
                # Apply attn_output_weights to x_
                x_ = x_.view(b, d**2, g_layer_size) * attn_output_weights
                x_ = x_.view(b * (d ** 2), g_layer_size)
            x_ = identity(x_)

        if self.extraction:
            return None
        
        # reshape again and sum
        x_g = x_.view(b, d**2, self.g_layers_size[-1])
        x_g = x_g.sum(1).squeeze(1)
        
        """f"""
        # f_fc1
        x_f = self.f_fc1(x_g)
        x_f = F.relu(x_f)
        weights = torch.unsqueeze(self.f_fc1.weight, 0).repeat(b, 1, 1).transpose(1, 0)
        _, attn_output_weights = self.mha_fc1(query, weights, weights)
        l1_reg += (attn_output_weights.abs().sum() / (attn_output_weights.size(0) * attn_output_weights.size(2)))
        x_f = x_f * attn_output_weights.squeeze(1)
        x_f = self.identity_fc1(x_f)
        # f_fc2
        x_f = self.f_fc2(x_f)
        x_f = self.dropout(x_f)
        x_f = F.relu(x_f)
        weights = torch.unsqueeze(self.f_fc2.weight, 0).repeat(b, 1, 1).transpose(1, 0)
        _, attn_output_weights = self.mha_fc2(query, weights, weights)
        l1_reg += (attn_output_weights.abs().sum() / (attn_output_weights.size(0) * attn_output_weights.size(2)))
        x_f = x_f * attn_output_weights.squeeze(1)
        x_f = self.identity_fc2(x_f)
        # f_fc3
        x_f = self.f_fc3(x_f)
        weights = torch.unsqueeze(self.f_fc3.weight, 0).repeat(b, 1, 1).transpose(1, 0)
        _, attn_output_weights = self.mha_fc3(query, weights, weights)
        l1_reg += (attn_output_weights.abs().sum() / (attn_output_weights.size(0) * attn_output_weights.size(2)))
        x_f = x_f * attn_output_weights.squeeze(1)
        x_f = self.identity_fc3(x_f)
        return F.log_softmax(x_f, dim=1), l1_reg 

## RN

In [8]:
class RN(nn.Module):
    def __init__(self, args, hyp, extraction=False):
        super(RN, self).__init__()
        self.coord_tensor = None
        self.on_gpu = False
        
        # CNN
        self.conv = ConvInputModel()
        self.state_desc = hyp['state_description']            
            
        # LSTM
        hidden_size = hyp["lstm_hidden"]
        self.text = QuestionEmbedModel(args.qdict_size, embed=hyp["lstm_word_emb"], hidden=hidden_size)
        
        # RELATIONAL LAYER
        self.rl_in_size = hyp["rl_in_size"]
        self.rl_out_size = args.adict_size
        self.rl = RelationalLayer(self.rl_in_size, self.rl_out_size, hidden_size, hyp, extraction) 
        if hyp["question_injection_position"] != 0:          
            print('Supposing IR model')
        else:     
            print('Supposing original DeepMind model')

    def forward(self, img, qst_idxs):
        if self.state_desc:
            x = img # (B x 12 x 8)
        else:
            x = self.conv(img)  # (B x 24 x 8 x 8)
            b, k, d, _ = x.size()
            x = x.view(b,k,d*d) # (B x 24 x 8*8)
            
            # add coordinates
            if self.coord_tensor is None or torch.cuda.device_count() == 1:
                self.build_coord_tensor(b, d)                  # (B x 2 x 8 x 8)
                self.coord_tensor = self.coord_tensor.view(b,2,d*d) # (B x 2 x 8*8)
            
            x = torch.cat([x, self.coord_tensor], 1)    # (B x 24+2 x 8*8)
            x = x.permute(0, 2, 1)    # (B x 64 x 24+2)
        
        qst = self.text(qst_idxs)
        y = self.rl(x, qst)
        return y
       
    # prepare coord tensor
    def build_coord_tensor(self, b, d):
        coords = torch.linspace(-d/2., d/2., d)
        x = coords.unsqueeze(0).repeat(d, 1)
        y = coords.unsqueeze(1).repeat(1, d)
        ct = torch.stack((x,y))
        # broadcast to all batches
        # TODO: upgrade pytorch and use broadcasting
        ct = ct.unsqueeze(0).repeat(b, 1, 1, 1)
        self.coord_tensor = Variable(ct, requires_grad=False)
        if self.on_gpu:
            self.coord_tensor = self.coord_tensor.cuda()
    
    def cuda(self, device=None):
        self.on_gpu = True
        self.rl.cuda(device)
        super(RN, self).cuda(device)

# SE+ATTN

In [41]:
MULTIHEADATTENTION_HEADS = 1

In [78]:
def squeeze(weights):
    return weights.mean(dim=1)

class SEAttend(nn.Module):
    def __init__(self, in_dim=256, out_dim=256, squeeze_dim=16):
        super().__init__()
        
        self.in_dim = in_dim # Cantidad de neuronas capa anterior
        self.out_dim = out_dim # Cantidad de neuronas capa siguiente (mascara sobre estas)
        
        self.excite = nn.Sequential(
            nn.Linear(out_dim, squeeze_dim),
            nn.ReLU(),
            nn.Linear(squeeze_dim, out_dim),
            nn.Sigmoid(),
        )
        
        self.attend = MultiheadAttention(
            in_dim,
            MULTIHEADATTENTION_HEADS,
            dropout=0.1,
        )
    
    def forward(self, qst, weights):
        bsz = qst.size(0)
        
        scale = self.squeeze(weights) # scale: [out_dim (256)]
        scale = self.excite(scale.unsqueeze(0)) # scale: [1, out_dim (256)]
        weights = weights * scale.t() # weights: [out_dim, in_dim]
        weights = weights.unsqueeze(1).expand(self.out_dim, bsz, self.in_dim) # weights: [out_dim, bsz, in_dim]
        
        _, attn_output_weights = self.attend(qst.unsqueeze(0), weights, weights)
        
        # Retorno None para mantener el formato
        return None, attn_output_weights
    
    @staticmethod
    def squeeze(weights):
        return squeeze(weights)

In [75]:
layer = nn.Linear(255, 256)
sea = SEAttend(out_dim=256, in_dim=255)
qst = torch.randn(2, 255)
_, ret = sea(qst, layer.weight)
print(ret.size())
print(ret)

torch.Size([2, 1, 256])
tensor([[[0.3954, 0.0000, 0.3832, 0.7434, 0.9056, 0.7858, 0.3215, 0.0000,
          0.6960, 0.3804, 0.4661, 0.3918, 0.8876, 0.9478, 0.7458, 0.7132,
          0.6596, 0.8342, 0.2089, 0.4842, 0.0000, 0.8956, 0.5129, 0.3723,
          0.4412, 0.9768, 0.4684, 0.1167, 0.0000, 0.4218, 0.2370, 0.2961,
          0.6493, 0.5177, 0.2612, 0.0000, 0.4412, 0.6539, 0.1700, 0.7018,
          0.7742, 0.6343, 0.3036, 0.6198, 0.0000, 0.6323, 0.8880, 0.3315,
          0.4574, 0.3043, 0.2708, 0.5403, 0.4379, 0.5808, 0.0964, 0.6952,
          0.4946, 0.8681, 0.2941, 0.3637, 0.3320, 0.1906, 0.2941, 0.1529,
          0.6513, 0.6320, 0.9429, 0.8656, 0.7018, 0.9351, 0.1732, 0.7580,
          0.0000, 0.7699, 0.7505, 0.6371, 0.7836, 0.4200, 0.7134, 0.2379,
          0.5477, 0.7258, 0.3581, 0.0000, 0.4896, 0.6800, 0.8775, 0.4228,
          0.2313, 0.7232, 0.9392, 0.6202, 0.5477, 0.6598, 0.4646, 0.9394,
          0.0000, 0.5675, 0.3592, 0.0000, 0.6514, 0.5870, 0.6984, 0.5621,
          0.44

In [23]:
# Cada fila representa todos los pesos que llegan a una
# neurona de la capa siguiente
layer.weight.size()

torch.Size([256, 255])

In [29]:
means = layer.weight.mean(dim=1)
means.size()

torch.Size([256])

In [None]:
q = torch.randn()