In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
import copy

In [3]:
linear = nn.Linear(20, 50)

In [4]:
class mixout(nn.Module):
    def __init__(self, linear, p, norm_flag=True):
        super().__init__()
        self.layer = linear
        self.norm_flag = norm_flag
        self.p = p
        self.layer_frozen = copy.deepcopy(linear)
        for param in self.layer_frozen.parameters():
            param.requires_grad = False
    def forward(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        if not self.training or self.p == 0:
            return self.layer(x)
        self.noise = torch.FloatTensor(x.shape[0], linear.out_features, linear.in_features).uniform_(0, 1)
        self.mask = (self.noise < self.p)
        self.mask = self.mask.type(torch.FloatTensor)
        # mask bs, output, input
        # layer froze, output, input
        self.frozen_masked = self.mask * torch.unsqueeze(self.layer_frozen.weight, 0)
        self.learned_masked = (1-self.mask) * torch.unsqueeze(self.layer.weight, 0)
        # bs, output, input
        self.masked_layer = (self.frozen_masked + self.learned_masked)
        if self.norm_flag:
            self.masked_layer = self.masked_layer * torch.norm(
                self.layer.weight) / torch.norm(self.masked_layer, dim = [1, 2]).unsqueeze(1).unsqueeze(2)
        # x: bs, input
        # bs, output
        self.output = (x.unsqueeze(1) * self.masked_layer).sum(2) + self.layer.bias.unsqueeze(0)
        return self.output
        

In [5]:
m = mixout(linear, p=0.5)

In [6]:
x = np.random.rand(1, linear.in_features)

In [7]:
y = m(x)

In [8]:
loss = y.sum()

In [9]:
optimizer = torch.optim.Adam(params=m.parameters())

In [10]:
m.zero_grad()
# loss.backward()

In [11]:
m.layer.weight.grad

In [12]:
loss.backward()

In [13]:
#not zero because of the normalization
#if you set norm_flag=False in the mixout
#then the corresponding zeroed weights will have zero gradients
m.layer.weight.grad

tensor([[ 4.0455e-02,  3.4325e-01,  9.6155e-05,  3.5436e-02,  1.4308e-01,
          1.5740e-02,  4.8726e-01,  1.7637e-01,  5.8140e-04,  7.0515e-02,
         -3.3551e-02,  4.0705e-01, -1.1545e-02,  1.1571e-02,  2.6168e-02,
          8.0852e-01,  4.7836e-01,  5.3721e-01,  1.0545e-01,  1.5865e-02],
        [ 1.1746e-01,  4.2322e-03,  3.0313e-02,  5.6444e-01,  1.4308e-01,
          9.5551e-01,  4.8726e-01,  1.7637e-01,  5.8138e-04, -5.7993e-03,
          8.4597e-01,  4.5085e-03,  2.4175e-01,  3.0503e-01,  3.3928e-02,
          8.0852e-01, -2.3421e-03, -3.2000e-02,  1.0545e-01,  3.2199e-01],
        [-5.3046e-04,  1.0329e-02,  8.8153e-01,  5.6444e-01,  1.4308e-01,
          9.5551e-01,  4.8726e-01,  8.0877e-03,  3.3358e-02,  7.0515e-02,
          8.4597e-01,  4.0705e-01,  2.4175e-01,  2.2099e-02,  1.2476e-02,
          3.7256e-02,  4.7836e-01,  8.4730e-03, -3.7546e-02, -6.2013e-03],
        [ 1.1746e-01,  3.4325e-01, -3.8205e-02,  4.8901e-03,  1.4308e-01,
          9.5551e-01, -2.7136e-02, 

In [14]:
# TODO check no gradient update