In [1]:
import torch

In [2]:
import numpy as np 

In [3]:
#Assume that's 3SAT
def moment_bmm(mu, sigma, signal):
    #mu, sigma: torch tensor
    #signal: numpy array
    num_literal = len(mu)
    
    t = torch.zeros(num_literal, requires_grad=True)
    pdf = torch.exp(torch.matmul(mu, t) + 0.5 * torch.matmul(torch.matmul(t, sigma), t))
    final = torch.empty(3)
    
    d1 = torch.autograd.grad(pdf, t, create_graph=True)
    dx = d1[0][abs(signal[0])-1]
    dy = d1[0][abs(signal[1])-1]
    dz = d1[0][abs(signal[2])-1]
    
    d2_x = torch.autograd.grad(dx, t, create_graph=True)
    d2_y = torch.autograd.grad(dy, t, create_graph=True)
    dxz = d2_x[0][abs(signal[2])-1]
    dxy = d2_x[0][abs(signal[1])-1]
    dyz = d2_y[0][abs(signal[2])-1]
    
    
    d3 = torch.autograd.grad(dxz, t, create_graph=True)
    dxyz = d3[0][abs(signal[1]) -1]
    #print(dxyz,torch.autograd.grad(dyz, t, create_graph=True)[0][abs(signal[0]) -1])
        
    
    if signal[0] > 0 and signal[1] > 0 and signal[2] > 0:
        final = dxyz
    elif signal[0] < 0 and signal[1] > 0 and signal[2] > 0:
        final = dyz - dxyz
    elif signal[0] > 0 and signal[1] < 0 and signal[2] > 0:
        final = dxz - dxyz
    elif signal[0] > 0 and signal[1] > 0 and signal[2] < 0:
        final = dxy - dxyz        
    elif signal[0] < 0 and signal[1] < 0 and signal[2] > 0:
        final = dz - dyz - dxz + dxyz
    elif signal[0] < 0 and signal[1] > 0 and signal[2] < 0:
        final = dy - dyz - dxy + dxyz
    elif signal[0] > 0 and signal[1] < 0 and signal[2] < 0:
        final = dx - dxz - dxy + dxyz
    else:
        final = 1 - dy - dx - dz + dxy + dyz + dxz - dxyz
        
    denominator = final.item()
        
    second_derivative = torch.autograd.grad(final, t, create_graph=True)
    
    M1_change = second_derivative[0].detach().numpy()
    
    M2_change = np.zeros((num_literal, num_literal))
    
    for i in range(num_literal):
        for j in range(num_literal):
            if j>=i:
                derivative_i = second_derivative[0][i]
                second_derivative_i = torch.autograd.grad(derivative_i, t, create_graph=True)
                M2_change[i][j] = second_derivative_i[0][j].item()
            else:
                M2_change[i][j] = M2_change[j][i]
                  
    return denominator, M1_change, M2_change     

In [4]:
mu = torch.tensor([0.8, 0.9, 0.7, 0.2])
sigma = torch.tensor([[0.1, 0, 0, 0], [0, 0.02, 0, 0],  [0, 0, 0.05, 0], [0, 0, 0, 0.1]])

In [5]:
0.2 - 0.2 **2

0.16

In [6]:
signal = np.asarray([-1, -2, 3])

In [7]:
0.8 * 0.9 * 0.7 * (0.2 * 0.2 + 0.1)

0.07056000000000001

In [8]:
moment_bmm(mu, sigma, signal)

(0.013999998569488525,
 array([0.0042, 0.0098, 0.0108, 0.0028], dtype=float32),
 array([[-0.00084001,  0.00294   ,  0.00324   ,  0.00084   ],
        [ 0.00294   ,  0.00658   ,  0.00756   ,  0.00196   ],
        [ 0.00324   ,  0.00756   ,  0.00896   ,  0.00216   ],
        [ 0.00084   ,  0.00196   ,  0.00216   ,  0.00196   ]]))

In [9]:
def sigma_M2(sigma, mu):
    length = len(sigma)
    M2 = np.eye(length)
    for i in range(length):
        for j in range(length):
            M2[i][j] = sigma[i][j] + mu[i] * mu[j]
    return M2

In [10]:
def M2_sigma(M2, mu):
    length = len(M2)
    sigma = np.eye(length)
    for i in range(length):
        for j in range(length):
            if j >= i:
                sigma[i][j] = M2[i][j] - mu[i] * mu[j]
            else:
                sigma[i][j] = sigma[j][i]
    return sigma

In [11]:
def modify_M1(mu, epsilon):
    for i in range(len(mu)):
        if mu[i] > 1:
            mu[i] = 1 - epsilon
        if mu[i] < 0:
            mu[i] = epsilon
    return mu

In [12]:
def modify_M2(mu, M2):
    length = len(M2)
    for i in range(length):
        for j in range(length):
            if i == j:
                if M2[i][i] > mu[i]:
                    M2[i][i] = mu[i]*0.8 + (mu[i]**2) * 0.2
                if M2[i][i] < mu[i]**2:
                    M2[i][i] = (mu[i]**2) * 0.8 + mu[i] * 0.8
            else:
                if M2[i][j] < 0:
                    M2[i][j] = 0.1 * min(mu[i], mu[j])
                if M2[i][j] > mu[i] or M2[i][j] > mu[j]:
                    M2[i][j] = 0.8 * min(mu[i], mu[j])
    return M2
                    
                

In [34]:
def bmm_update(clauses, num_literal, num_clauses, num_epochs):
    mu = np.asarray([0.5] * num_literal, dtype=np.float32)
    sigma = np.eye(num_literal, dtype=np.float32) * 0.1
    M2 = sigma_M2(sigma, mu)
    
    t = torch.zeros(len(mu), requires_grad=True)

    
    for i in range(num_epochs):
        print("------Epoch " + str(i) + " ------" )
        print(mu)
        print(sigma)
        print(" ")
        for clause in clauses:
            mu = mu.astype(dtype=np.float32)
            sigma = sigma.astype(dtype=np.float32)
            mu_torch = torch.from_numpy(mu)
            sigma_torch = torch.from_numpy(sigma)
            denominator, M1_change, M2_change = moment_bmm(mu_torch, sigma_torch, -1 * clause)
            mu -= M1_change
            M2 -= M2_change
            mu = mu / (1 - denominator)
            M2 = M2 / (1 - denominator)
            sigma = M2_sigma(M2, mu)
                
    return mu

In [35]:
num_literal = 4
num_clauses = 2
num_epochs = 5

In [36]:
clauses = np.asarray([[1, 2, 3], [-1, -2, 4]])

In [37]:
bmm_update(clauses, num_literal, num_clauses, num_epochs)

------Epoch 0 ------
[0.5 0.5 0.5 0.5]
[[0.1 0.  0.  0. ]
 [0.  0.1 0.  0. ]
 [0.  0.  0.1 0. ]
 [0.  0.  0.  0.1]]
 
------Epoch 1 ------
[ 0.38717636  0.387176    0.03436452 -0.2964895 ]
[[-5.42283922 -5.65291891 -6.36771339 -5.66206524]
 [-5.65291891 -5.42283955 -6.36771467 -5.662066  ]
 [-6.36771339 -6.36771467 -6.99706565 -6.45429356]
 [-5.66206524 -5.662066   -6.45429356 -5.64298415]]
 
------Epoch 2 ------
[3.7949994 3.7949998 3.869064  3.148047 ]
[[-251.69708779 -251.94775211 -284.76647365 -255.26365213]
 [-251.94775211 -251.6970759  -284.76647555 -255.26365308]
 [-284.76647365 -284.76647555 -321.68409297 -288.61862587]
 [-255.26365213 -255.26365308 -288.61862587 -258.63030259]]
 
------Epoch 3 ------
[-66.08878  -66.08878  -75.12932  -67.687355]
[[-34873.04352235 -34873.29347096 -39421.92144278 -35347.94785075]
 [-34873.29347096 -34873.04588036 -39421.92144278 -35347.94785075]
 [-39421.92144278 -39421.92144278 -44563.6706978  -39958.59314241]
 [-35347.94785075 -35347.94785075 

array([15063.77 , 15063.77 , 17028.184, 15268.151], dtype=float32)