# Creating Custom layers

In [20]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

# Layer without parameters
Creating a layer which performs centering and feature scaling over the input

In [8]:
class CustomLayerWoParameters(nn.Module):

  def __init__(self):
    super().__init__()

  def forward(self, X):
    X = X - X.mean()
    return X/(X.max() - X.min())

In [9]:
center_layer = CustomLayerWoParameters()
center_layer(torch.FloatTensor([1,2,3,4,5,6]))

tensor([-0.5000, -0.3000, -0.1000,  0.1000,  0.3000,  0.5000])

Now using this layer in a model

In [11]:
model = nn.Sequential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), CustomLayerWoParameters(), nn.Linear(128, 16))

In [12]:
y = model(torch.randn(784))

In [13]:
print(y)

tensor([ 0.1062,  0.0806, -0.1107,  0.0729,  0.1580,  0.1286,  0.2094,  0.2878,
        -0.0973, -0.0501,  0.0990,  0.0805, -0.0517,  0.0681,  0.1358,  0.1860],
       grad_fn=<AddBackward0>)


# Layer with parameter
Gives us ability of end to end training, Here I'm trying to create a thresholding layer. Idea is to keep threshold as learnable parameter

In [26]:
class AdaThreshold(nn.Module):

  def __init__(self, out_features):
    super().__init__()
    self.threshold = Variable(torch.ones(out_features))

  def forward(self, X):
    thresholded_vec = (X<self.threshold)
    return torch.tensor(thresholded_vec, dtype=torch.float32)

In [27]:
threshold_layer = AdaThreshold(out_features=10)
threshold_layer(torch.randn(10))

  if __name__ == '__main__':


tensor([1., 1., 1., 1., 1., 0., 0., 1., 1., 0.])

putting this layer into model

In [30]:
model = nn.Sequential(nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 128), nn.ReLU(), AdaThreshold(128), nn.Linear(128, 16))
y = model(torch.randn(784))
print(y)

tensor([ 0.1891, -0.4544,  0.9712, -0.3627,  0.1832, -0.4668, -0.0545, -0.7840,
         0.9422, -0.5710, -0.5363, -1.0384, -0.1755,  0.6081, -0.3865, -0.3153],
       grad_fn=<AddBackward0>)


  if __name__ == '__main__':
