In [4]:
import torch
import torch.nn as nn
from torch.autograd import Function

In [5]:
class StepFunction(Function):
    @staticmethod
    def forward(ctx, input):
        output = torch.where(input>=0, torch.tensor(1.0), torch.tensor(0.0))
        ctx.save_for_backward(input)
        return output
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = torch.zeros_like(input)
        return grad_input
    
class StepActivation(nn.Module):
    def forward(self, input):
        return StepFunction.apply(input)

In [18]:
model = nn.Sequential(
    nn.Linear(2, 3),
    StepActivation(),
    nn.Linear(3,1)
)

a = model.state_dict()
for x in a.items():
    print(x[1])

tensor([[-0.2831, -0.2200],
        [ 0.3280,  0.1197],
        [ 0.5111, -0.5031]])
tensor([-0.6133,  0.1545, -0.2820])
tensor([[-0.0715, -0.1847,  0.1690]])
tensor([-0.0634])
