In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
import copy

In [3]:
linear = nn.Linear(20, 50)

In [4]:
t = torch.tensor(np.ones((5, 3, 4, 6, 12)))

In [5]:
torch.flatten(t, end_dim=-2).view(t.shape).shape

torch.Size([5, 3, 4, 6, 12])

In [6]:
t.view(-1, t.shape[-1]).shape

torch.Size([360, 12])

In [8]:

class mixout_layer(nn.Module):
    def __init__(self, linear, p, norm_flag=True):
        super().__init__()
        self.layer = linear
        self.norm_flag = norm_flag
        self.p = p
        self.layer_frozen = copy.deepcopy(linear)
        for param in self.layer_frozen.parameters():
            param.requires_grad = False

    def forward(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        if not self.training or self.p == 0:
            return self.layer(x)

        x_shape = x.shape
        x = torch.flatten(x, end_dim=-2)
        learned_layer_output = self.layer(x)
        frozen_layer_output = self.layer_frozen(x)
        self.noise = torch.FloatTensor(
            x.shape[0], self.layer.out_features).uniform_(0, 1)
        self.mask = (self.noise < self.p)
        self.mask = self.mask.type(torch.FloatTensor)
        self.masked_learned = learned_layer_output * (1-self.mask)
        self.masked_frozen = frozen_layer_output * self.mask
        self.raw_output = self.masked_learned + self.masked_frozen
        self.output = self.raw_output * torch.norm(
                learned_layer_output) / torch.norm(self.raw_output, dim=[1], keepdim=True)
        self.output = self.output.view(*x_shape[:-1], -1)
        return self.output
    
#         self.noise = torch.FloatTensor(
#             x.shape[0], self.layer.out_features, self.layer.in_features).uniform_(0, 1)
#         self.mask = (self.noise < self.p)
#         self.mask = self.mask.type(torch.FloatTensor)
#         # mask bs, input, output
#         # layer frozen input, output
#         self.frozen_masked = self.mask * \
#             torch.unsqueeze(self.layer_frozen.weight, 0)
#         self.learned_masked = (1 - self.mask) * \
#             torch.unsqueeze(self.layer.weight, 0)
#         # bs, input, output
#         self.masked_layer = (self.frozen_masked + self.learned_masked)
#         if self.norm_flag:
#             self.masked_layer = self.masked_layer * torch.norm(
#                 self.layer.weight) / torch.norm(self.masked_layer, dim=[1, 2]).unsqueeze(1).unsqueeze(2)
#         # bs, output
#         self.output = (x.unsqueeze(1) * self.masked_layer).sum(2) + \
#             self.layer.bias.unsqueeze(0)
#         self.output = self.output.view(*x_shape[:-1], -1)
#         return self.output


In [9]:
m = mixout_layer(linear, p=0.5)

In [10]:
x = np.random.rand(1, linear.in_features)

In [11]:
y = m(x)

In [12]:
loss = y.sum()

In [13]:
optimizer = torch.optim.Adam(params=m.parameters())

In [14]:
m.zero_grad()
# loss.backward()

In [15]:
m.layer.weight.grad

In [16]:
loss.backward()

In [17]:
#not zero because of the normalization
#if you set norm_flag=False in the mixout
#then the corresponding zeroed weights will have zero gradients
m.layer.weight.grad

tensor([[ 4.3478e-01,  1.2173e-01,  3.0279e-01,  8.4129e-01,  1.2487e-01,
          9.4919e-01,  9.4362e-02,  7.2605e-02,  7.5299e-01,  2.6924e-01,
          7.7921e-01,  4.4160e-02,  3.4909e-01,  6.3390e-01,  6.7537e-01,
          8.5538e-01,  4.4209e-02,  2.1641e-01,  3.6781e-01,  3.7342e-02],
        [-1.5703e-02, -4.3966e-03, -1.0936e-02, -3.0385e-02, -4.5100e-03,
         -3.4282e-02, -3.4081e-03, -2.6223e-03, -2.7196e-02, -9.7241e-03,
         -2.8143e-02, -1.5949e-03, -1.2608e-02, -2.2895e-02, -2.4392e-02,
         -3.0894e-02, -1.5967e-03, -7.8162e-03, -1.3284e-02, -1.3487e-03],
        [-1.2175e-03, -3.4087e-04, -8.4788e-04, -2.3558e-03, -3.4967e-04,
         -2.6579e-03, -2.6423e-04, -2.0331e-04, -2.1085e-03, -7.5393e-04,
         -2.1820e-03, -1.2366e-04, -9.7752e-04, -1.7751e-03, -1.8912e-03,
         -2.3953e-03, -1.2380e-04, -6.0600e-04, -1.0299e-03, -1.0457e-04],
        [-7.1108e-02, -1.9909e-02, -4.9520e-02, -1.3759e-01, -2.0422e-02,
         -1.5524e-01, -1.5433e-02, 

In [18]:
# TODO check no gradient update