In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random

In [2]:
clauses = []

num_literals = 0
num_clause = 0

with open('../uf20-91/uf20-01.cnf') as file:
    for line in file:
        if(line[0] == 'p'):
            l = line.split()
            num_literals = int(l[2])
            num_clauses = int(l[3])
        if(line[0] != 'p' and line[0] != '%' and line[0] != 'c'):
            clause = [int(digit) for digit in line.split()]
            clauses.append(clause[:-1])

In [3]:
from numpy import linalg as LA

In [4]:
import math

In [5]:
def rbf(xi, xj, h):
    return math.exp(-((LA.norm(xi - xj))**2) / h)

In [6]:
def rbf_derivative(xj, xi, h, k):
    return -(2 / h) * k * (xj - xi)

In [16]:
#returns the likelihood of each clause being F for each sample
def likelihood_clauses(num_points, points):
    likelihoods = np.zeros((num_points, num_clauses))
    for i in range(num_points):
        for j in range(num_clauses):
            p = 1
            for literal in clauses[j]:
                if literal > 0:
                    p = p * (1 - points[i][literal-1])
                else:
                    p = p * points[i][-literal-1]
            likelihoods[i][j] = p
    return likelihoods       

In [17]:
#return the derivative for the i-th sample
def derivative_likelihood(alpha, beta, i, likelihoods, num_points, points):
    derivative = np.zeros(num_literals)
    for k in range(num_literals):
        derivative[k] += ((alpha - 1) / points[i][k]) - ((beta - 1) / (1 - points[i][k]))
        
    for c in range(num_clauses):
        likelihood = likelihoods[i][c]
        for literal in clauses[c]:
            if literal > 0:
                derivative[literal-1] += (likelihood / ( 1- points[i][literal-1])) / (1 - likelihood)
            else:
                derivative[-literal-1] += - (likelihood / points[i][-literal-1]) / (1 - likelihood)
                
    return derivative
    

In [18]:
def satisfied(polarity):
    num_satisfied = 0
    for clause in clauses:
        for literal in clause:
            if literal * polarity[abs(literal) - 1] > 0:
                num_satisfied += 1
                break
    return num_satisfied

In [19]:
def determin_h(points):
    length = int(len(points))
    distance = np.zeros(int(length * (length - 1) / 2))
    m = 0
    for i in range(length):
        for j in range(length):
            if i < j:
                distance[m] = LA.norm(points[i] - points[j])
                m += 1
    return np.median(distance)**2 / math.log(length)

In [20]:
def legal_prob(m):
    for i in range(len(m)):
        for j in range(len(m[i])):
            if m[i][j] > 1 or m[i][j] < 0:
                return False
    return True

In [37]:
def stein_update(num_epochs, num_points):
    points = np.random.beta(2, 2, size = (num_points, num_literals))
    satisfied_clauses = np.zeros(num_epochs)
    for e in range(num_epochs):
        ave = np.sum(points, axis=0) / num_points
        polarity = np.zeros(num_literals)
        for i in range(num_literals):
            if ave[i] > 0.5:
                polarity[i] = 1
            else:
                polarity[i] = -1
        #print(satisfied(polarity))
        #print(ave[:5])
        satisfied_clauses[e] = satisfied(polarity)
        kernel_matrix = np.zeros((num_points, num_points))
        h = determin_h(points)
        for i in range(num_points):
            for j in range(num_points):
                if i <= j:
                    kernel_matrix[i][j] = rbf(points[i], points[j], h)
                else:
                    kernel_matrix[i][j] = kernel_matrix[j][i]  
        #new_points = np.zeros((num_points, num_literals))
        likelihoods = likelihood_clauses(num_points, points)
        derivatives_likelihood = np.zeros((num_points, num_literals))
        for i in range(num_points):
            derivatives_likelihood[i] = derivative_likelihood(2, 2, i, likelihoods, num_points, points)
        #print(derivatives_likelihood[0][:5])
        epsilon = 10 / math.log(e + 2)
        phi_matrix = np.zeros((num_points, num_literals))
        for i in range(num_points):
            phi = 0
            for j in range(num_points):
                phi += kernel_matrix[i][j] * derivatives_likelihood[j] + rbf_derivative(j, i, h, kernel_matrix[i][j])
                #phi += kernel_matrix[i][j] * derivatives_likelihood[j]
            phi = phi / num_points
            phi_matrix[i] = phi
            #print(phi)
            #new_points[i] = points[i] + epsilon * phi
        print(LA.norm(derivatives_likelihood))
        epsilon = 0.1
        while(not legal_prob(points + epsilon * phi_matrix)):
            epsilon = epsilon / 1.2
        #print(epsilon)
        points += epsilon * phi_matrix
            
        #for i in range(num_points):
            #for j in range(num_literals):
                #if points[i][j] < 0.001:
                    #points[i][j] = 0.001
                #elif points[i][j] > 0.999:
                    #points[i][j] = 0.999
    return satisfied_clauses

In [38]:
stein_update(1000, 20)

138.46456544230952
169.64814653211647
182.9909630499796
88.14801289860875
84.1518584285069
79.9556283923259
78.30472633100572
77.14543014385612
75.92711622457661
76.35918113548543
76.99680902888713
75.51710317628086
80.92162682970557
84.40536441984085
317.76006681551513
122.67726902001696
77.87815086994331
104.10339955808705
98.8627702191389
140.6275428368602
84.70809746025309
149.7706690311532
100.00693247795675
164.36094675021334
143.53611485838834
152.12317316338132
239.94999385711017
184.872960538013
90.56244251655023
97.98807653805882
297.4169247894202
239.1672793745816
91.4312479570763
92.02827554680678
130.44151454342835
104.54406578793288
136.3194510367862
73.01374087038904
152.32521534224472
139.24305630945156
209.56879057389213
78.16673664608261
123.13708506543486
80.00359307395533
143.62241907509238
170.97364711182274
184.2795078675265
137.75148567955213
125.98451487857112
93.25990644423605
277.40850993977534
165.45302064464394
92.03243905790134
133.05487548429699
141.192405

128.26011916681335
298.5174594148033
128.4984579159576
94.96694442122448
120.20901362416315
644.1911300481427
132.83648960864673
86.699573724347
131.31730096850882
116.25558861781427
197.97287773300118
107.7874019477059
160.08491339394334
124.9867278471848
152.72122875644476
116.37325797096725
178.65111967740526
143.61984818284986
145.21122408988924
5908.9516579887395
108.0779357930028
92.60790225004945
125.34991751053273
229.44358403176977
939.10260911418
108.09110273547988
123.6265242231834
95.07930566464009
104.18259480608701
103.51733608541569
156.40807252199454
3054.2695573644683
114.76145061580625
91.42277737681198
761.6294187982755
100.41175944798204
75.29477265505764
112.74177178057764
142.25051741643108
207.69896212840183
78.51594650560126
145.23319037429764
248.37151601399015
84.81851263404583
402.6583490583834
97.65537052022947
84.15268919270642
112.73095439765991
475.3872810321173
160.40106444478445
92.7427874078963
115.5768570963146
131.41919143570564
232.4105729018961
112

87.621477190072
151.72315227964023
104.90467412604814
172.92285810338288
106.20665750762221
121.92568277265568
80.09181963457154
134.16634204847136
270.80201457930343
502.8800622680887
95.20389777470803
87.30253188783581
143.60436668775446
132.5491399173611
97.14366579384854
131.6744120368247
101.40495963162167
147.01270231972376
72.61746694752794
227.14541752055172
148.0775702548452
192.7460801562418
501.5276226970983
86.23983243640056
83.04361641784453
222.12132478607424
111.19467979903752
118.2032914697376
98.31398085642624
184.63864043121697
106.9082227061141
92.41431075729922
111.51468876062576
184.79722974022877
172.23913841830714
92.13454728184013
145.40978656523527
114.24852586967256
157.10093289201865
167.9543080508369
408.6049655379071
100.70410724353513
74.16442867061778
166.59434772555616
97.47574450721655
194.00023384474417
90.67391541473859
53190.308192838056
364.52363366474265
141.66983903078386
179.07706725970542
100.62021754893634
1562.003833303142
113.00614459577267
8

array([76., 78., 80., 84., 85., 86., 86., 86., 87., 87., 87., 87., 87.,
       87., 87., 89., 88., 89., 87., 87., 87., 87., 88., 88., 87., 87.,
       87., 89., 89., 88., 87., 87., 86., 88., 87., 87., 87., 87., 87.,
       87., 86., 87., 87., 87., 87., 87., 88., 87., 87., 87., 87., 87.,
       88., 88., 86., 87., 87., 87., 88., 88., 87., 87., 86., 87., 87.,
       88., 89., 87., 86., 86., 86., 88., 87., 84., 86., 87., 87., 87.,
       88., 86., 87., 87., 88., 89., 87., 87., 87., 89., 89., 87., 87.,
       87., 87., 87., 86., 87., 87., 87., 87., 88., 88., 87., 87., 87.,
       89., 88., 88., 88., 87., 84., 87., 87., 87., 88., 89., 89., 87.,
       86., 87., 87., 87., 87., 86., 87., 88., 87., 87., 87., 87., 87.,
       87., 87., 88., 87., 87., 88., 87., 87., 87., 88., 88., 87., 87.,
       87., 87., 88., 87., 87., 87., 87., 87., 88., 87., 86., 87., 88.,
       88., 87., 87., 87., 87., 87., 88., 87., 87., 87., 87., 87., 88.,
       88., 88., 87., 87., 87., 87., 87., 88., 87., 86., 86., 86

In [51]:
x = np.asarray([1, 2])

In [54]:
LA.norm(x)

2.23606797749979