In [47]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

Again a problem of parameter estimation, where now: y = x*Z, where Z~Bern(a) and x are random integers.

Problem here relies on both the stochasticity and discreteness of Z. To solve this I used the Gumbel
Softmax Trick.



In [48]:
def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)

##CREATING MY OWN AUTOGRAD

torch.manual_seed(0)

class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(self, x, sample_gumbel, temperature, param):

        length = x.size()[1]
        logits = torch.tensor([torch.log(1-param), torch.log(param)])
        logits = torch.ones(2, length) * logits.view(-1, 1) #create logits for every x
        softmax_input = logits + sample_gumbel
        z = F.softmax(softmax_input / temperature, dim=0)
        y_pred = z.argmax(dim=0)*x
        return y_pred

    @staticmethod
    def backward(ctx, grad_output):
        #return None, None, None, F.hardtanh(grad_output)
        return None, None, None, grad_output  # <- return the identity function

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.a = nn.Parameter(torch.tensor([0.5]))  #initializing param 'a'

    def forward(self, x, sample_gumbel, temperature):
        x = STEFunction.apply(x, sample_gumbel, temperature,self.a)
        return x

#create synthetic data
x_train = torch.randint(0, 21, (1, 1000)).float()  #get some numbers between 0 and 20
a_true = torch.tensor(1/3)      # True value of parameter 'a'
Z_train = torch.distributions.bernoulli.Bernoulli(a_true).sample(x_train.shape)# Generate Z from Bern distribution
y_train = x_train * Z_train

model = MyModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
epochs = 1000

for epoch in range(epochs):
    #I think it's also fair to sample the stochastic process outside the loop, would surely be faster
    gumbel_sample = sample_gumbel((2,x_train.size()[1]))

    y_pred= model(x = x_train, sample_gumbel = gumbel_sample, temperature = 0.1)

    loss = (y_pred.mean()-y_train.mean())**2

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if (epoch+1) % 100 == 0:
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()}')

print(y_train.mean(), y_pred.mean())
print("True parameter 'a':", a_true.item(),"Learned parameter 'a':", model.a.item())

Epoch [100/1000], Loss: 0.016640987247228622
Epoch [200/1000], Loss: 0.005041013937443495
Epoch [300/1000], Loss: 0.009408977814018726
Epoch [400/1000], Loss: 2.4998760636663064e-05
Epoch [500/1000], Loss: 0.007744010537862778
Epoch [600/1000], Loss: 0.019881010055541992
Epoch [700/1000], Loss: 0.034595951437950134
Epoch [800/1000], Loss: 0.001155994599685073
Epoch [900/1000], Loss: 0.00672400277107954
Epoch [1000/1000], Loss: 0.11902502179145813
tensor(3.5530) tensor(3.8980, grad_fn=<MeanBackward0>)
True parameter 'a': 0.3333333432674408 Learned parameter 'a': 0.3426192104816437


In [49]:
## simple test to check whether the gumblel softmax trick would work
## indeed, it simulates a Bern(a) distribution
tot = 0
for i in range(10000):
    logits = torch.tensor([torch.log(1-a_true), torch.log(a_true)])
    z = gumbel_softmax_sample(logits, 1)
    tot+=z.argmax() #the output of argmax is a position in z, hence either 0 or 1 <=> filp a coin
print(tot/10000, a_true)



tensor(0.3363) tensor(0.3333)
