In [8]:
import torch
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


In [9]:
def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)

In [10]:
class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(self, x, sample_gumbel, temperature, param):
        logits_input = 1-(1-param)**x
        #length = x.size()[1]
        logits = torch.stack((torch.log(1-logits_input), torch.log(logits_input))).reshape(2,x.size()[1])
        softmax_input = logits + sample_gumbel
        z = F.softmax(softmax_input / temperature, dim=0)
        y_pred = z.argmax(dim=0)
        return y_pred.type(torch.FloatTensor)

    @staticmethod
    def backward(ctx, grad_output):
        return None, None, None, F.hardtanh(grad_output)
        #return None, None, None, torch.tan(grad_output)  # <- return the identity function

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.a = nn.Parameter(torch.tensor([0.5]), requires_grad=True)  #initializing param 'a'

    def forward(self, x, sample_gumbel, temperature):
        x = STEFunction.apply(x, sample_gumbel, temperature,self.a)
        return x

In [11]:
#create synthetic data
x_train = torch.randint(1, 9, (1, 1000)).float()  #get some numbers between 0 and 20
a_true = torch.tensor(0.33333)      # True value of parameter 'a'

parameters = 1-(1-a_true)**x_train
y_train = torch.bernoulli(parameters)
y_train

model = MyModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
epochs = 1000

for epoch in range(epochs):
    #I think it's also fair to sample the stochastic process outside the loop, would surely be faster
    gumbel_sample = sample_gumbel((2,x_train.size()[1]))

    y_pred= model(x = x_train, sample_gumbel = gumbel_sample, temperature = 0.1)
    loss = (y_pred.mean()-y_train.mean())**2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')


print(y_train.mean(), y_pred.mean())
print("True parameter 'a':", a_true.item(),"Learned parameter 'a':", model.a.item())


Epoch [100/1000], Loss: 0.00025599912623874843
Epoch [200/1000], Loss: 6.400026177288964e-05
Epoch [300/1000], Loss: 1.600006544322241e-05
Epoch [400/1000], Loss: 1.600006544322241e-05
Epoch [500/1000], Loss: 1.600006544322241e-05
Epoch [600/1000], Loss: 0.0
Epoch [700/1000], Loss: 0.0
Epoch [800/1000], Loss: 0.0004840006586164236
Epoch [900/1000], Loss: 0.0002249995741294697
Epoch [1000/1000], Loss: 0.00012099950981792063
tensor(0.7640) tensor(0.7750, grad_fn=<MeanBackward0>)
True parameter 'a': 0.3333300054073334 Learned parameter 'a': 0.32972025871276855
