In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# export
import torch
from torch import nn
import torch.nn.functional as F

In [35]:
# export
class MinMaxMap(torch.autograd.Function):
    
    @staticmethod
    def forward(ctx, x):
        """
        In the forward pass we receive a Tensor containing the input and return
        a Tensor containing the output. ctx is a context object that IS used
        to stash information for backward computation.
        """
        ctx.save_for_backward(x)
        MAP_CHANNEL = 0 # if changed need to fix indexing further in the class
        
        # assume NCHW so channel dim number is 1 
        CHANNEL_DIM =  1
        INFO_CHANNELS =  x.shape[CHANNEL_DIM] - 1 # substract importance map
        
        c = nn.Parameter(torch.arange(start=0,end=INFO_CHANNELS,dtype=torch.float32,requires_grad=False))
        print(c.shape,INFO_CHANNELS)
        c = torch.reshape(c,(INFO_CHANNELS, 1, 1))                      
        
        # choose the first channel as the importance map
        importance_map = x[:,MAP_CHANNEL,...] # NHW
        importance_map = torch.sigmoid(importance_map) * INFO_CHANNELS
        importance_map.unsqueeze_(CHANNEL_DIM) # N1HW
        
        z = x[:,MAP_CHANNEL:,...]
        
        
        importance_map = torch.max(
            torch.min(importance_map - c, torch.tensor(1.0, dtype=torch.float32, requires_grad=False)),
                      torch.tensor(0.0, dtype=torch.float32, requires_grad = False))  # NCHW
        return torch.mul(importance_map, z)
                         
    @staticmethod
    def backward(ctx, grad_output):
        """
        In the backward pass we receive a Tensor containing the gradient of the loss
        with respect to the output, and we need to compute the gradient of the loss
        with respect to the input.
        """
#         x, = ctx.saved_tensors
        grad_input = grad_output.clone()
        return grad_input # identity grad

In [36]:
# export
class ImportanceMapMult(nn.Module):
    def __init__(self,use_map=True):
        super().__init__()
        self.use_map = use_map
        
    
    def forward(self, x):
        """
        forward prop.
        Parameters:
            x : tensor including z-hat with importance map channel.
        """
        if not self.use_map:
            return x

        return MinMaxMap.apply(x)
        
        

In [37]:
x =torch.round(torch.rand([2,2,2,2])*10)
model = ImportanceMapMult()
y= model(x)
# print(x)
# print(y)
# print(x.shape)


torch.Size([1]) 1
