In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
import copy

In [3]:
linear = nn.Linear(20, 50)

In [4]:
t = torch.tensor(np.ones((5, 3, 4, 6, 12)))

In [5]:
torch.flatten(t, end_dim=-2).view(t.shape).shape

torch.Size([5, 3, 4, 6, 12])

In [6]:
t.view(-1, t.shape[-1]).shape

torch.Size([360, 12])

In [20]:
??t.detach

In [7]:

class mixout_layer(nn.Module):
    def __init__(self, linear, p, norm_flag=True):
        super().__init__()
        self.layer = linear
        self.norm_flag = norm_flag
        self.p = p
        self.layer_frozen = copy.deepcopy(linear)
        for param in self.layer_frozen.parameters():
            param.requires_grad = False

    def forward(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        if not self.training or self.p == 0:
            return self.layer(x)

        x_shape = x.shape
        x = torch.flatten(x, end_dim=-2)
        learned_layer_output = self.layer(x)
        frozen_layer_output = self.layer_frozen(x)
        self.noise = torch.FloatTensor(
            x.shape[0], self.layer.out_features).uniform_(0, 1)
        self.mask = (self.noise < self.p)
        self.mask = self.mask.type(torch.FloatTensor)
        self.masked_learned = learned_layer_output * (1-self.mask)
        self.masked_frozen = frozen_layer_output * self.mask
        self.raw_output = self.masked_learned + self.masked_frozen
        self.num_scale = self.normalize(
                learned_layer_output, frozen_layer_output)
        self.denom_scale = self.normalize(
            self.raw_output, frozen_layer_output, keepdim = True, dim=[1])
        self.output = self.raw_output * self.normalize(
                learned_layer_output, frozen_layer_output) / self.normalize(
            self.raw_output, frozen_layer_output, keepdim = True, dim=[1])
#         torch.norm(self.raw_output - learned_layer_output, dim=[1], keepdim=True,
#                                                            p = lambda x: sum(abs(x)))
        self.output = self.output.view(*x_shape[:-1], -1)
        return self.output
    
    def normalize(self, x, x_frozen, dim = None, keepdim = False):
        return torch.norm(x - x_frozen, dim=dim, keepdim=keepdim, p = 1).detach() + 1e-10
#         self.noise = torch.FloatTensor(
#             x.shape[0], self.layer.out_features, self.layer.in_features).uniform_(0, 1)
#         self.mask = (self.noise < self.p)
#         self.mask = self.mask.type(torch.FloatTensor)
#         # mask bs, input, output
#         # layer frozen input, output
#         self.frozen_masked = self.mask * \
#             torch.unsqueeze(self.layer_frozen.weight, 0)
#         self.learned_masked = (1 - self.mask) * \
#             torch.unsqueeze(self.layer.weight, 0)
#         # bs, input, output
#         self.masked_layer = (self.frozen_masked + self.learned_masked)
#         if self.norm_flag:
#             self.masked_layer = self.masked_layer * torch.norm(
#                 self.layer.weight) / torch.norm(self.masked_layer, dim=[1, 2]).unsqueeze(1).unsqueeze(2)
#         # bs, output
#         self.output = (x.unsqueeze(1) * self.masked_layer).sum(2) + \
#             self.layer.bias.unsqueeze(0)
#         self.output = self.output.view(*x_shape[:-1], -1)
#         return self.output


In [8]:
m = mixout_layer(linear, p=0.5)

In [9]:
x = np.random.rand(1, linear.in_features)

In [10]:
y = m(x)

In [11]:
loss = y.sum()

In [12]:
optimizer = torch.optim.Adam(params=m.parameters())

In [13]:
m.zero_grad()
# loss.backward()

In [14]:
m.layer.weight.grad

In [15]:
loss.backward()

In [16]:
#not zero because of the normalization
#if you set norm_flag=False in the mixout
#then the corresponding zeroed weights will have zero gradients
m.layer.weight.grad

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.0576, 0.8538, 0.8362, 0.7750, 0.2693, 0.1928, 0.2513, 0.1599, 0.4966,
         0.4106, 0.8997, 0.6378, 0.8760, 0.0515, 0.8907, 0.3837, 0.8708, 0.3842,
         0.5066, 0.3579],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0

In [17]:
# TODO check no gradient update

In [18]:
m.num_scale

tensor(1.0000e-10)

In [19]:
m.denom_scale

tensor([[1.0000e-10]])