In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In this example I am going to apply Straight Through Estimator technique in order to
be able to compute the gradient of a discrete function.

The way I am going to apply this is through parameter estimation.

Starting from a given function y= a * sign(x) [sign(x) = x/|x|], I create some synthetic data
from a given 'a' of my chice and then create a model that tries to estimate the correct parameter

In [19]:
torch.manual_seed(0)

#customized function for automatic differentiation
class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, param):
        return torch.sign(input)*param

    @staticmethod
    def backward(ctx, grad_output):
        #return None, F.hardtanh(grad_output)
        return None, grad_output             # <- return identity function
        #return None, torch.tanh(grad_output)


class StraightThroughEstimator(nn.Module):
    def __init__(self):
        super(StraightThroughEstimator, self).__init__()
        self.a = nn.Parameter(torch.tensor([1.0]))
    def forward(self, x):
        x = STEFunction.apply(x,self.a)
        return x

# Generate input data
x = torch.randint(-100, 101, (1, 1000)).float()  # Generate 1000 random integers between -100 and 100
a_true = torch.tensor([5.0])  # True value of parameter 'a'

y_true = a_true * torch.sign(x)

model = StraightThroughEstimator()

# Define optimizer
optimizer = optim.SGD(model.parameters(), lr=0.1)
epochs = 1000
for epoch in range(epochs):

    y_pred = model(x)
    loss = nn.functional.mse_loss(y_pred, y_true)
    optimizer.zero_grad()
    loss.backward()

    optimizer.step()

    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')
# Get the learned value of parameter 'a'
a_learned = model.a.item()
print("True parameter 'a':", a_true.item(),"Learned parameter 'a':", model.a.item())

Epoch [100/1000], Loss: 2.878908157348633
Epoch [200/1000], Loss: 0.5116955041885376
Epoch [300/1000], Loss: 0.09094812721014023
Epoch [400/1000], Loss: 0.01616493985056877
Epoch [500/1000], Loss: 0.0028730907943099737
Epoch [600/1000], Loss: 0.000510769197717309
Epoch [700/1000], Loss: 9.079359006136656e-05
Epoch [800/1000], Loss: 1.6134763427544385e-05
Epoch [900/1000], Loss: 2.864013595171855e-06
Epoch [1000/1000], Loss: 5.090328727419546e-07
True parameter 'a': 5.0 Learned parameter 'a': 4.999290943145752
