In [1]:
import torch
from matplotlib import pyplot as plt
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

In [2]:
def sample_gumbel(shape, eps=1e-20):
    U = torch.rand(shape)
    return -torch.log(-torch.log(U + eps) + eps)

def gumbel_softmax_sample(logits, temperature):
    y = logits + sample_gumbel(logits.size())
    return F.softmax(y / temperature, dim=-1)


def sum_neighbours(x,y,grid):
    """Given a matrix and a point x,y
     outputs the number of infected neighbours of x,y
    """
    N = 0
    for i in (-1, 0 ,1):
        for j in (-1,0,1):
            if i!=0 or j!=0:
                N+= grid[x+i,y+j]
    return N

def accept_infection(x,y,grid_old, grid_new,N,beta):
    """By construction of the S.I. model, this function, given an "healthy" pixel, and
    the numbers N of infected neighbours, updates the state of that healthy pixel drawing a
    bernoulli distributed random variable with parameter 1-(1-beta)^N.
    """
    if grid_old[x,y] == 0:
        param = 1-(1-beta)**N
        grid_new[x,y] = torch.distributions.bernoulli.Bernoulli(param).sample(torch.tensor([1]))
    return 0


def one_simulation(k, grid_new,beta):
    #Simulate for k time steps
    output = []
    for m in range(0,k):
        grid_old = grid_new.clone().detach()
        for i in range(1,n-1):
            for j in range(1,n-1):
                N = sum_neighbours(i,j,grid_old)
                accept_infection(i,j,grid_old,grid_new, N,beta)
        output.append(grid_new.clone())
        if m%10 == 0:
            print(m/10)
    return torch.stack(output, dim=0)



In [3]:
#TRYING TO TEST THE Conv2d function
# Define your binary tensor, for example, a 4x4 tensor
binary_tensor = torch.tensor([[0, 1, 0, 1],
                              [1, 0, 0, 1],
                              [1, 1, 0, 0],
                              [0, 0, 1, 0]], dtype=torch.float32)

# in order for Conv2d to work, the tensor needs 3 or 4 dimensions
binary_tensor = binary_tensor.unsqueeze(0)

# Define the 3x3 kernel with all values set to 1
kernel = torch.ones(1, 1, 3, 3, dtype=torch.float32)

kernel[0][0][1][1] = 0 #middle coefficient set to 0

# Perform convolution
neighbors_count = F.conv2d(binary_tensor, kernel, padding=1)  # Padding to keep output size same as input

# Count the number of neighbors with value greater than one
#neighbors_count = convolved - binary_tensor

print("Original binary tensor:")
print(binary_tensor.squeeze())
print("\nNumber of neighbors with value greater than one:")
print(neighbors_count.squeeze())

Original binary tensor:
tensor([[0., 1., 0., 1.],
        [1., 0., 0., 1.],
        [1., 1., 0., 0.],
        [0., 0., 1., 0.]])

Number of neighbors with value greater than one:
tensor([[2., 1., 3., 1.],
        [3., 4., 4., 1.],
        [2., 3., 3., 2.],
        [2., 3., 1., 1.]])


In [4]:
neighbors_count[binary_tensor==1] = 0
print(neighbors_count[0].size()
)

#trying to mimic the SI model
betaaa = torch.tensor([0.3], requires_grad=True)
xx = 1-(1-betaaa)**neighbors_count
fakeloss = xx.sum()-5
neighbors_count

torch.Size([4, 4])


tensor([[[2., 0., 3., 0.],
         [0., 4., 4., 0.],
         [0., 0., 3., 2.],
         [2., 3., 0., 1.]]])

In [5]:
fakeloss.backward()
betaaa.grad ## differentiable!!!

tensor([12.3540])

In [6]:
class STEFunction(torch.autograd.Function):
    @staticmethod
    def forward(self, x, sample_gumbel, temperature, param):
        """Gumbel-softmax trick that given a list of values of N (neigbours),
        outputs whether every item in the list is accepted or not
        """
        x = x.unsqueeze(0) # Add a dimension to the starting grid

        # Define the 3x3 kernel with all values set to 1
        kernel = torch.ones(1, 1, 3, 3, dtype=torch.float32)
        # and middle value to 0
        kernel[0][0][1][1] = 0

        # Perform convolution
        neighbors_count = F.conv2d(x, kernel, padding=1)

        #GUMBEL - SOFTMAX TRICK THAT BEHAVES LIKE ~Bern(1-(1-beta)^N)
        #here I always keep the dimensionality of n x n
        logits_input = 1-(1-param)**neighbors_count
        logits = torch.stack((torch.log(1-logits_input), torch.log(logits_input))).reshape((2,x.size()[1],x.size()[2]))
        softmax_input = logits + sample_gumbel
        z = F.softmax(softmax_input / temperature, dim=0)
        y_pred = z.argmax(dim=0)

        # making sure that if a pixel was already infected, stays infected
        y_pred[x[0]==1] = 1

        #making sure that an healthy pixel with no neighbours doesnt get infected
        y_pred[neighbors_count[0]==0] = 0

        return y_pred.type(torch.FloatTensor)

    @staticmethod
    def backward(ctx, grad_output):
        """"Straight through estimator
        """
        #return None, None, None, grad_output
        return None, None, None, torch.tan(grad_output)  # <- return the identity function

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.a = nn.Parameter(torch.tensor([0.5]), requires_grad=True)  #initializing param 'a'

    def forward(self, x, sample_gumbel, temperature):
        x = STEFunction.apply(x, sample_gumbel, temperature,self.a)
        return x

In [9]:
### IF YOU HAVE THE FILE, UPLOAD IT BY RUNNING THIS CELL,
### ONLY OTHERWISE RUN THE NEXT CELL
res = torch.load('Beta is 03333 and Nsim is 20.pt')
n = 100
grid_size = (n,n)
tau = 50
NSimumlation = 10
betaTrue = 0.3333333

In [15]:
betaTrue = 0.3333333

#THIS TAKES A WHILE

#generate  N simulations where population is n x n and goes up until time step tau
#always start with an only infected pixel in the middle
n = 100
grid_size = (n,n)
grid = torch.zeros(grid_size, dtype=torch.int8)
grid[int(n/2),int(n/2)] = 1
tau = 50
NSimumlation = 10
#res = one_simulation(tau,grid,betaT)
res = one_simulation(tau,grid,betaTrue)
for sim in range(NSimumlation):  #10:07
    grid = torch.zeros(grid_size, dtype=torch.int8)
    grid[int(n/2),int(n/2)] = 1

    res1 = one_simulation(tau,grid,betaTrue)
    res = torch.vstack((res,res1))
    print(sim)


0.0
1.0
2.0
3.0
4.0
0
0.0
1.0
2.0
3.0
4.0
1
0.0
1.0
2.0
3.0
4.0
2
0.0
1.0
2.0
3.0
4.0
3
0.0
1.0
2.0
3.0
4.0
4
0.0
1.0
2.0
3.0
4.0
5
0.0
1.0
2.0
3.0
4.0
6
0.0
1.0
2.0
3.0
4.0
7
0.0
1.0
2.0
3.0
4.0
8
0.0
1.0
2.0
3.0
4.0
9


In [10]:
model = MyModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
epochs = 15
for epoch in range(epochs):
    for i in range(res.size()[0]):
        #skip when you reach the final time step of the simulation
        if i %tau == 49:
            continue
        #I think it's also fair to sample the stochastic process outside the loop, would surely be faster
        gumbel_sample = sample_gumbel((2,n, n))
        input_data = res[i].type(torch.FloatTensor)
        y_pred= model(x = input_data, sample_gumbel = gumbel_sample, temperature = 0.1)

        loss =  torch.mean((y_pred - res[i+1])**2)

        #loss =  (y_pred.mean() - res[i+1].type(torch.FloatTensor).mean())**2

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item()},  parameter: {model.a.item()}')
    #print(model.a.item(), model.a.grad, loss)
print(f'True parameter: {betaTrue}, Learned param: {model.a.item()}')

Epoch [1/15], Loss: 0.017999999225139618,  parameter: 0.43882182240486145
Epoch [2/15], Loss: 0.01679999940097332,  parameter: 0.3983713686466217
Epoch [3/15], Loss: 0.019200000911951065,  parameter: 0.37266868352890015
Epoch [4/15], Loss: 0.01860000006854534,  parameter: 0.3571324348449707
Epoch [5/15], Loss: 0.01759999990463257,  parameter: 0.34787771105766296
Epoch [6/15], Loss: 0.019999999552965164,  parameter: 0.34176549315452576
Epoch [7/15], Loss: 0.017899999395012856,  parameter: 0.33822357654571533
Epoch [8/15], Loss: 0.020099999383091927,  parameter: 0.33620136976242065
Epoch [9/15], Loss: 0.021299999207258224,  parameter: 0.334945410490036
Epoch [10/15], Loss: 0.01860000006854534,  parameter: 0.3345251977443695
Epoch [11/15], Loss: 0.019300000742077827,  parameter: 0.33389705419540405
Epoch [12/15], Loss: 0.018200000748038292,  parameter: 0.33329835534095764
Epoch [13/15], Loss: 0.01759999990463257,  parameter: 0.3334141969680786
Epoch [14/15], Loss: 0.017899999395012856,  p

In [33]:
torch.save(res, "Beta is 03333 and Nsim is {}.pt".format(int(res.size()[0]/tau)))