In [1]:
%load_ext autoreload
%autoreload 2

In [24]:
# export
import torch
from torch import nn
import torch.nn.functional as F

In [25]:
# export
class CustomRound(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that IS used
        to stash information for backward computation.
        """
        ctx.save_for_backward(x)
        return torch.ceil(x)
                         
    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
#         x, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input # identity grad

In [3]:
# export
class ImportanceMapMult(nn.Module):
   
    @staticmethod
    def forward(self, x: torch.tensor):
        """
        forward prop.
        Parameters:
            x : tensor including z-hat with importance map channel.
        """
        if not self.use_map:
            return x

        MAP_CHANNEL = 0 # if changed need to fix indexing further in the class
        
        # assume NCHW so channel dim number is 1 
        CHANNEL_DIM =  1
        INFO_CHANNELS =  torch.tensor(x.shape[CHANNEL_DIM] - 1,dtype=torch.float32) # substract importance map
        
        c = nn.Parameter(torch.range(start=0,end=30,dtype=torch.float32,requires_grad=False))
        c = torch.reshape(c,(INFO_CHANNELS, 1, 1))                      
        
        
        
        # choose the first channel as the importance map
        importance_map = x[:,MAP_CHANNEL,...] # NHW
        importance_map = F.sigmoid(importance_map) * INFO_CHANNELS
        importance_map.unsqueeze_(CHANNEL_DIM) # N1HW
        
        z = x[:,MAP_CHANNEL:,...]
        
        # if importance_map[x, y] == C, then importance_map[x, y, c] == 1 
        # \forall c \in {0, ..., C-1}
        ### !!!!!!! MAYBE THIS NEE TO BE Custrom round?
        importance_map = torch.max(torch.min(importance_map - c, 1), 0)  # NCHW
        return torch.mul(importance_map * z)
        
        

In [13]:
x =torch.round(torch.rand([2,2,2,2])*10)
print(x)
print(x.shape)

tensor([[[[0., 5.],
          [1., 2.]],

         [[4., 9.],
          [7., 8.]]],


        [[[1., 5.],
          [0., 6.]],

         [[7., 3.],
          [9., 4.]]]])
torch.Size([2, 2, 2, 2])


In [16]:
mymap = x[:,-1,...]
mymap

tensor([[[4., 9.],
         [7., 8.]],

        [[7., 3.],
         [9., 4.]]])

# continue reading  autoencoder_imgcomp._get_heatmap3D