In [1]:
import torch
import torch.nn as nn
import numpy as np

In [2]:
import copy

In [3]:
linear = nn.Linear(20, 50)

In [4]:
class mixout(nn.Module):
    def __init__(self, linear, p, norm_flag=True):
        super().__init__()
        self.layer = linear
        self.norm_flag = norm_flag
        self.p = p
        self.layer_frozen = copy.deepcopy(linear)
        for param in self.layer_frozen.parameters():
            param.requires_grad = False
    def forward(self, x):
        if isinstance(x, np.ndarray):
            x = torch.Tensor(x)
        if not self.training or self.p == 0:
            return self.layer(x)
        self.noise = torch.FloatTensor(x.shape[0], linear.out_features, linear.in_features).uniform_(0, 1)
        self.mask = (self.noise < self.p)
        self.mask = self.mask.type(torch.FloatTensor)
        # mask bs, input, output
        # layer frozen input, output
        self.frozen_masked = self.mask * torch.unsqueeze(self.layer_frozen.weight, 0)
        self.learned_masked = (1-self.mask) * torch.unsqueeze(self.layer.weight, 0)
        # bs, input, output
        self.masked_layer = (self.frozen_masked + self.learned_masked)
        if self.norm_flag:
            self.masked_layer = self.masked_layer * torch.norm(
                self.layer.weight) / torch.norm(self.masked_layer, dim = [1, 2]).unsqueeze(1).unsqueeze(2)
        #bs, output
        self.output = (x.unsqueeze(1) * self.masked_layer).sum(2) + self.layer.bias.unsqueeze(0)
        return self.output
        

In [5]:
m = mixout(linear, p=0.5)

In [6]:
x = np.random.rand(1, linear.in_features)

In [7]:
y = m(x)

In [8]:
loss = y.sum()

In [9]:
optimizer = torch.optim.Adam(params=m.parameters())

In [10]:
m.zero_grad()
# loss.backward()

In [11]:
m.layer.weight.grad

In [12]:
loss.backward()

In [14]:
#not zero because of the normalization
#if you set norm_flag=False in the mixout
#then the corresponding zeroed weights will have zero gradients
m.layer.weight.grad

tensor([[ 3.9967e-01,  5.3937e-01,  4.1424e-02,  2.0204e-01,  8.5268e-01,
          5.6803e-01,  8.9640e-01,  3.0182e-02,  9.4533e-01,  2.0347e-01,
          6.1209e-03,  8.1988e-01, -1.4640e-02, -4.2149e-02,  3.7392e-01,
          6.2557e-01,  9.2927e-01, -1.1693e-02, -1.2218e-02,  5.2663e-01],
        [ 1.8804e-02,  1.1339e-02,  7.8458e-01,  2.0204e-01,  4.0620e-02,
          5.6803e-01,  4.1135e-02, -2.6580e-02,  1.6119e-02,  2.0347e-01,
         -2.0497e-03, -4.2364e-02,  2.0113e-02, -2.0310e-02,  3.7392e-01,
         -2.8640e-02,  9.2927e-01,  2.9010e-02,  4.0579e-01,  1.6971e-02],
        [ 3.9967e-01, -2.8823e-02, -3.2549e-02, -2.2311e-02, -2.4722e-02,
          3.9495e-02,  8.9640e-01,  8.7633e-01,  9.4533e-01,  2.0347e-01,
          2.8092e-02,  8.1988e-01,  7.3709e-01,  4.2386e-02, -8.9548e-03,
          1.8949e-02,  9.2927e-01,  2.9010e-02,  4.0579e-01,  5.2663e-01],
        [ 3.9967e-01,  5.3937e-01,  7.8458e-01,  2.0204e-01, -5.5682e-03,
          5.6803e-01,  2.0592e-02, 

In [13]:
# TODO check no gradient update