In [296]:
import torch

In [297]:
import numpy as np 

In [320]:
#Assume that's 3SAT
def moment_bmm(mu, sigma, signal):
    #mu, sigma: torch tensor
    #signal: numpy array
    num_literal = len(mu)
    
    t = torch.zeros(num_literal, requires_grad=True)
    pdf = torch.exp(torch.matmul(mu, t) + 0.5 * torch.matmul(torch.matmul(t, sigma), t))
    final = torch.empty(3)
    
    d1 = torch.autograd.grad(pdf, t, create_graph=True)
    dx = d1[0][abs(signal[0])-1]
    dy = d1[0][abs(signal[1])-1]
    dz = d1[0][abs(signal[2])-1]
    
    d2_x = torch.autograd.grad(dx, t, create_graph=True)
    d2_y = torch.autograd.grad(dy, t, create_graph=True)
    dxz = d2_x[0][abs(signal[2])-1]
    dxy = d2_x[0][abs(signal[1])-1]
    dyz = d2_y[0][abs(signal[2])-1]
    
    
    d3 = torch.autograd.grad(dxz, t, create_graph=True)
    dxyz = d3[0][abs(signal[1]) -1]
    #print(dxyz,torch.autograd.grad(dyz, t, create_graph=True)[0][abs(signal[0]) -1])
    
    
    if signal[0] > 0 and signal[1] > 0 and signal[2] > 0:
        final = dxyz
    elif signal[0] < 0 and signal[1] > 0 and signal[2] > 0:
        final = dyz - dxyz
    elif signal[0] > 0 and signal[1] < 0 and signal[2] > 0:
        final = dxz - dxyz
    elif signal[0] > 0 and signal[1] > 0 and signal[2] < 0:
        final = dxy - dxyz        
    elif signal[0] < 0 and signal[1] < 0 and signal[2] > 0:
        final = dz - dyz - dxz + dxyz
    elif signal[0] < 0 and signal[1] > 0 and signal[2] < 0:
        final = dy - dyz - dxy + dxyz
    elif signal[0] > 0 and signal[1] < 0 and signal[2] < 0:
        final = dx - dxz - dxy + dxyz
    else:
        final = 1 - dy - dx - dz + dxy + dyz + dxz - dxyz
        
    denominator = final.item()
        
    second_derivative = torch.autograd.grad(final, t, create_graph=True)
    
    M1_change = second_derivative[0].detach().numpy()
    
    M2_change = np.zeros((num_literal, num_literal))
    
    for i in range(num_literal):
        for j in range(num_literal):
            if j>=i:
                derivative_i = second_derivative[0][i]
                second_derivative_i = torch.autograd.grad(derivative_i, t, create_graph=True)
                M2_change[i][j] = second_derivative_i[0][j].item()
            else:
                M2_change[i][j] = M2_change[j][i]
                  
    return denominator, M1_change, M2_change     

In [321]:
mu = torch.tensor([0.8, 0.9, 0.7, 0.2])
sigma = torch.tensor([[0.2, 0, 0, 0], [0, 0.35, 0, 0],  [0, 0, 0.1, 0], [0, 0, 0, 0.2]])

In [322]:
signal = np.asarray([2, 3, -4])

In [323]:
0.9 * 0.7 * 0.8

0.504

In [324]:
moment_bmm(mu, sigma, signal)

(0.5040000081062317,
 array([ 0.4032    ,  0.64959997,  0.4248    , -0.0252    ], dtype=float32),
 array([[ 0.42336002,  0.51968002,  0.33983999, -0.02016   ],
        [ 0.51968002,  0.93743998,  0.54751998, -0.03248   ],
        [ 0.33983999,  0.54751998,  0.39816001, -0.02123999],
        [-0.02016   , -0.03248   , -0.02123999,  0.07056   ]]))

In [325]:
def sigma_M2(sigma, mu):
    length = len(sigma)
    M2 = np.eye(length)
    for i in range(length):
        for j in range(length):
            M2[i][j] = sigma[i][j] + mu[i] * mu[j]
    return M2

In [326]:
def M2_sigma(M2, mu):
    length = len(M2)
    sigma = np.eye(length)
    for i in range(length):
        for j in range(length):
            if j >= i:
                sigma[i][j] = M2[i][j] - mu[i] * mu[j]
            else:
                sigma[i][j] = sigma[j][i]
    return sigma

In [404]:
def bmm_update(clauses, num_literal, num_clauses, num_epochs):
    mu = np.asarray([0.5] * num_literal, dtype=np.float32)
    sigma = np.eye(num_literal, dtype=np.float32) * 0.01
    M2 = sigma_M2(sigma, mu)
    
    t = torch.zeros(len(mu), requires_grad=True)

    
    for i in range(num_epochs):
        for clause in clauses:
            #print(M2)
            mu = mu.astype(dtype=np.float32)
            sigma = sigma.astype(dtype=np.float32)
            mu_torch = torch.from_numpy(mu)
            sigma_torch = torch.from_numpy(sigma)
            denominator, M1_change, M2_change = moment_bmm(mu_torch, sigma_torch, -1 * clause)
            mu -= M1_change
            M2 -= M2_change
            mu = mu / (1 - denominator)
            M2 = M2 / (1 - denominator)
            sigma = M2_sigma(M2, mu)
                
    return mu

In [405]:
clauses = np.asarray([[-1, -2, 3], [1, 2, 4], [2, -3, 4]])

In [406]:
clauses[0]

array([-1, -2,  3])

In [407]:
num_literal = 4
num_clauses = 3
num_epochs = 10

In [408]:
bmm_update(clauses, num_literal, num_clauses, num_epochs)

[[0.25999999 0.25       0.25       0.25      ]
 [0.25       0.25999999 0.25       0.25      ]
 [0.25       0.25       0.25999999 0.25      ]
 [0.25       0.25       0.25       0.25999999]]
[[0.25714284 0.24708571 0.25005714 0.24857143]
 [0.24708571 0.25714285 0.25005714 0.24857143]
 [0.25005714 0.25005714 0.26285713 0.25142857]
 [0.24857143 0.24857143 0.25142857 0.25999999]]
[[0.55433268 0.53270538 0.53774965 0.53591826]
 [0.53270538 0.5543327  0.53774967 0.53591827]
 [0.53774965 0.53774967 0.56378452 0.54070922]
 [0.53591826 0.53591827 0.54070922 0.5605128 ]]
[[-1.23787075 -1.24359849 -1.26096507 -1.2508191 ]
 [-1.24359849 -1.20607142 -1.24494573 -1.23511479]
 [-1.26096507 -1.24494573 -1.24141781 -1.25246041]
 [-1.2508191  -1.23511479 -1.25246041 -1.2206392 ]]
[[39.91621046 39.91383921 40.2855862  40.22868965]
 [39.91383921 39.95385731 40.30500586 40.2470108 ]
 [40.2855862  40.30500586 40.69934949 40.62177547]
 [40.22868965 40.2470108  40.62177547 40.58604474]]
[[-3277.08281321 -3248.

array([nan, nan, nan, nan], dtype=float32)