In [1]:
import numpy as np
import matplotlib.pyplot as plt
import random

In [2]:
clauses = []

num_literals = 0
num_clause = 0

with open('../UF250.1065.100/uf250-01.cnf') as file:
    for line in file:
        if(line[0] == 'p'):
            l = line.split()
            num_literals = int(l[2])
            num_clauses = int(l[3])
        if(line[0] != 'p' and line[0] != '%' and line[0] != 'c'):
            clause = [int(digit) for digit in line.split()]
            clauses.append(clause[:-1])

In [3]:
from numpy import linalg as LA

In [4]:
import math

In [5]:
from scipy.stats import norm

In [6]:
def rbf(xi, xj, h):
    return math.exp(-((LA.norm(xi - xj))**2) / h)

In [7]:
def rbf_derivative(xj, xi, h, k):
    return -(2 / h) * k * (xj - xi)

In [8]:
def sigmoid(x):
    return 1 / (1 + math.exp(-x))

In [9]:
sigmoid_vector = np.vectorize(sigmoid)

In [10]:
def sigmoid_derivative(x):
    s = sigmoid(x)
    return s * (1 - s)

In [11]:
#returns the likelihood of each clause being F for each sample
def likelihood_clauses(num_points, points):
    likelihoods = np.zeros((num_points, num_clauses))
    for i in range(num_points):
        for j in range(num_clauses):
            p = 1
            for literal in clauses[j]:
                if literal > 0:
                    p = p * (1 - sigmoid(points[i][literal-1]))
                else:
                    p = p * sigmoid(points[i][-literal-1])
            likelihoods[i][j] = p
    return likelihoods       

In [12]:
#return the derivative for the i-th sample
def derivative_likelihood(i, likelihoods, num_points, points):
    derivative = np.zeros(num_literals)
    for k in range(num_literals):
        derivative[k] -= 0.5 * norm.pdf(x = points[i][k], loc =3, scale =1) * (points[i][k] - 3) + 0.5 * norm.pdf(x = points[i][k], loc =-3, scale =1) * (points[i][k] + 3) 
        
    for c in range(num_clauses):
        likelihood = likelihoods[i][c]
        for literal in clauses[c]:
            if literal > 0:
                derivative[literal-1] += (likelihood  * sigmoid_derivative(points[i][literal-1]) / ( 1- sigmoid(points[i][literal-1]))) / (1 - likelihood)
            else:
                derivative[-literal-1] += - (likelihood * sigmoid_derivative(points[i][-literal-1]) / sigmoid(points[i][-literal-1])) / (1 - likelihood)
                
    return derivative
    

In [13]:
def satisfied(polarity):
    num_satisfied = 0
    for clause in clauses:
        for literal in clause:
            if literal * polarity[abs(literal) - 1] > 0:
                num_satisfied += 1
                break
    return num_satisfied

In [14]:
def determin_h(points):
    length = int(len(points))
    distance = np.zeros(int(length * (length - 1) / 2))
    m = 0
    for i in range(length):
        for j in range(length):
            if i < j:
                distance[m] = LA.norm(points[i] - points[j])
                m += 1
    return np.median(distance)**2 / math.log(length)

In [16]:
import random

In [18]:
def mixture_gaussian(w, h):
    samples = np.zeros((w, h))
    for i in range(w):
        for j in range(h):
            if random.uniform(0, 1) < 0.5:
                mu = -3
                sigma = 1
                samples[i][j] = np.random.normal(mu, sigma, 1)[0]
            else:
                mu = 3
                sigma = 1
                samples[i][j] = np.random.normal(mu, sigma, 1)[0]
    return samples

In [28]:
def stein_update(num_epochs, num_points):
    points = mixture_gaussian(num_points, num_literals)
    satisfied_clauses = np.zeros(num_epochs)
    for e in range(num_epochs):
        prob = sigmoid_vector(points)
        ave = np.sum(prob, axis=0) / num_points
        polarity = np.zeros(num_literals)
        for i in range(num_literals):
            if ave[i] > 0.5:
                polarity[i] = 1
            else:
                polarity[i] = -1
        print(satisfied(polarity))
        #print(ave[:5])
        satisfied_clauses[e] = satisfied(polarity)
        kernel_matrix = np.zeros((num_points, num_points))
        h = determin_h(points)
        for i in range(num_points):
            for j in range(num_points):
                if i <= j:
                    kernel_matrix[i][j] = rbf(points[i], points[j], h)
                else:
                    kernel_matrix[i][j] = kernel_matrix[j][i]  
        #new_points = np.zeros((num_points, num_literals))
        likelihoods = likelihood_clauses(num_points, points)
        derivatives_likelihood = np.zeros((num_points, num_literals))
        for i in range(num_points):
            derivatives_likelihood[i] = derivative_likelihood(i, likelihoods, num_points, points)
        #print(derivatives_likelihood[0][:5])
        phi_matrix = np.zeros((num_points, num_literals))
        for i in range(num_points):
            phi = 0
            for j in range(num_points):
                phi += kernel_matrix[i][j] * derivatives_likelihood[j] + rbf_derivative(j, i, h, kernel_matrix[i][j])
            phi = phi / num_points
            phi_matrix[i] = phi
        print(LA.norm(derivatives_likelihood))
        print()
        epsilon = 1
        points += epsilon * phi_matrix
    return satisfied_clauses

In [29]:
stein_update(5000, 20)

928
40.70384053784747

931
40.515648462041106

932
40.30387933089882

935
40.066199734447494

935
39.80086149586271

943
39.50687503183072

943
39.18416579901286

945
38.8335271778716

946
38.45653992937354

946
38.05537711333616

949
37.632593498726635

951
37.19092762626264

956
36.73311977332678

957
36.26188361885592

959
35.779827766990635

964
35.28930758404055

963
34.792485726714126

964
34.291264496524

964
33.787310717824695

965
33.28214314578342

969
32.77712109283531

970
32.27346328017418

972
31.772314506268494

972
31.274826641753194

975
30.782056310000527

976
30.295003974589516

980
29.81465200902836

981
29.341884059728493

984
28.87745487546857

987
28.421971901373496

988
27.97591587025588

989
27.5395422874686

990
27.113010773204675

991
26.696276720389623

991
26.289191705177892

993
25.891561566729745

993
25.503122819620224

993
25.123526992447157

996
24.752396662526305

997
24.3893059354724

997
24.033925753198133

997
23.68583909777702

997
23.344706940414

5.425968763148325

1034
5.415074231713472

1034
5.404172605084527

1034
5.393266030280828

1034
5.382356566416004

1034
5.371446175822831

1034
5.360536718366443

1034
5.3496299488828924

1034
5.3387275176049815

1034
5.327830957662376

1034
5.316941734100838

1034
5.306061226296059

1034
5.2951907351319285

1034
5.284331494402654

1034
5.273484684701238

1034
5.262651439811178

1034
5.251832848711613

1034
5.241029988677731

1034
5.230243939203798

1034
5.219475807188026

1034
5.208726723361799

1034
5.197997853444126

1034
5.18729041051662

1034
5.1766056565912315

1034
5.165944909441938

1034
5.155309556230289

1034
5.14470105245587

1034
5.134120921918507

1034
5.12357075455276

1034
5.113052202170242

1034
5.10256697218931

1034
5.092116826968248

1034
5.081703571265118

1034
5.07132902807465

1034
5.0609950332282745

1034
5.05070342166067

1034
5.040456012925699

1034
5.03025459626243

1034
5.020100915522352

1034
5.0099966542725705

1034
4.999943421386522

1034
4.989942745893322

3.1782261124363385

1044
3.1735920599989647

1044
3.168977982585081

1044
3.164382123500115

1044
3.159802738336299

1044
3.1552381066969297

1044
3.1506865436111857

1044
3.146146410530962

1044
3.141616125804196

1044
3.1370941630693143

1044
3.1325790259643616

1044
3.1280693626508382

1044
3.1235639128795256

1044
3.119061514503513

1044
3.114561109011342

1044
3.110061746048801

1044
3.105562586912913

1044
3.1010629070167033

1044
3.0965620973381216

1044
3.0920596883227707

1044
3.0875553088933816

1044
3.0830486944484465

1044
3.0785396924071518

1044
3.0740282592365014

1044
3.0695144827427736

1044
3.064998548402579

1044
3.0604807137175056

1044
3.055961327507361

1044
3.0514408204774046

1044
3.0469196421826465

1044
3.0423983839164417

1044
3.037877704744042

1044
3.033358325435505

1044
3.02884102244587

1044
3.024326622001842

1044
3.019815994346973

1044
3.015310048189877

1044
3.010809725392576

1044
3.006315995928798

1044
3.0018298531350376

1044
2.9973523092705965



2.427555017244842

1043
2.4249115394454286

1043
2.4222502774341725

1043
2.419571752112922

1043
2.4168765130692145

1043
2.414165134872353

1043
2.4114382133755305

1043
2.408696362051743

1043
2.405940208389585

1043
2.403170387810668

1043
2.4003874289091796

1043
2.3975919830676746

1043
2.3947847022548236

1043
2.391966236058712

1043
2.38913722889856

1043
2.3862983174283467

1043
2.38345012814413

1043
2.3805932846748386

1043
2.3777283916130125

1043
2.374856032674271

1043
2.37197677405993

1043
2.3690911632007503

1043
2.3661997277407965

1043
2.3633029747618775

1043
2.36040139024748

1043
2.3574954387835216

1043
2.3545855634916744

1043
2.351672186189474

1043
2.348755707769896

1043
2.345836508791579

1043
2.3429149502694555

1043
2.339991374654161

1043
2.337066123500238

1043
2.33413951439284

1043
2.331211840770049

1043
2.328283383051904

1043
2.3253544103474932

1043
2.3224251822580073

1043
2.319495950757607

1043
2.316566962133642

1043
2.313638474188739

1043
2.3

1.7579929166019257

1045
1.7564900918294897

1045
1.7549933280588825

1045
1.753502707830478

1045
1.7520183142065262

1045
1.750540230751515

1045
1.749068541501285

1045
1.7476033309208356

1045
1.7461446838508021

1045
1.7446926854426443

1045
1.7432474210826618

1045
1.7418089763050026

1045
1.7403774366938969

1045
1.7389528877754359

1045
1.7375354148992368

1045
1.7361251031104488

1045
1.7347220370125704

1045
1.7333263006216253

1045
1.731937977212295

1045
1.7305571491566292

1045
1.7291838977560237

1045
1.7278183030671508

1045
1.726460443722574

1045
1.7251103967467885

1045
1.7237682373684189

1045
1.7224340388293287

1045
1.7211078721913682

1045
1.7197898061414825

1045
1.7184799067958683

1045
1.717178237503854

1045
1.7158848586521296

1045
1.7145998274699217

1045
1.7133231978356762

1045
1.7120550200857445

1045
1.7107953408255563

1045
1.7095442027436856

1045
1.7083016444291992

1045
1.7070677001926104

1045
1.705842399890742

1045
1.704625768755743

1045
1.703417

1.4521631734229898

1044
1.4513112643650052

1044
1.4504708422305348

1044
1.4496424910935466

1044
1.4488267836607154

1044
1.4480242818346594

1044
1.4472355372865449

1044
1.446461092027849

1044
1.4457014789711913

1044
1.444957222470288

1044
1.4442288388292668

1044
1.4435168367718003

1044
1.4428217178607536

1044
1.4421439768593092

1044
1.4414841020248415

1044
1.4408425753271068

1044
1.4402198725826716

1044
1.43961646349785

1044
1.439032811612805

1044
1.4384693741398547

1044
1.4379266016894643

1044
1.4374049378778184

1044
1.4369048188103495

1044
1.4364266724360848

1044
1.4359709177681799

1044
1.4355379639665706

1044
1.4351282092792532

1044
1.434742039839327

1044
1.434379828315619

1044
1.4340419324154163

1044
1.4337286932386422

1044
1.4334404334836386

1044
1.4331774555056556

1044
1.4329400392301455

1044
1.4327284399240467

1044
1.4325428858294258

1044
1.432383575665118

1044
1.432250676003396

1044
1.4321443185301752

1044
1.432064597198859

1044
1.43201156

1.438943407534156

1045
1.4385683606545323

1045
1.4381273572423303

1045
1.4376222676110457

1045
1.4370549955229848

1045
1.4364274660326695

1045
1.435741614433835

1045
1.4349993763425088

1045
1.4342026789256042

1045
1.43335343326364

1045
1.4324535278179569

1045
1.4315048229571903

1045
1.430509146484913

1045
1.4294682901001592

1045
1.4283840067149736

1045
1.427258008547944

1045
1.426091965909769

1045
1.424887506595974

1045
1.4236462158027183

1045
1.4223696364839713

1045
1.4210592700718991

1045
1.419716577486875

1045
1.4183429803688572

1045
1.4169398624677347

1045
1.415508571136477

1045
1.4140504188772884

1045
1.412566684897382

1045
1.4110586166372814

1045
1.4095274312406234

1045
1.4079743169402095

1045
1.40640043434046

1045
1.4048069175813989

1045
1.4031948753738461

1045
1.4015653918995479

1045
1.3999195275735918

1045
1.3982583196695728

1045
1.396582782810658

1045
1.394893909331948

1045
1.393192669521361

1045
1.3914800117477306

1045
1.38975686248590

1.0671562397285745

1045
1.0674079788762791

1045
1.067672950645027

1045
1.067950885590661

1045
1.0682415067623718

1045
1.0685445302801766

1045
1.06885966594237

1045
1.0691866178604825

1045
1.0695250851189608

1045
1.0698747624565133

1045
1.0702353409657697

1045
1.0706065088076546

1045
1.0709879519366081

1045
1.071379354832582

1045
1.0717804012355252

1045
1.0721907748779007

1045
1.0726101602106433

1045
1.0730382431178385

1045
1.073474711615349

1045
1.0739192565285627

1045
1.0743715721444482

1045
1.074831356833137

1045
1.075298313634348

1045
1.0757721508040816

1045
1.07625258231719

1045
1.0767393283216296

1045
1.0772321155404583

1045
1.07773067761792

1045
1.0782347554062899

1045
1.0787440971905022

1045
1.0792584588479894

1045
1.0797776039415548

1045
1.0803013037435707

1045
1.0808293371902349

1045
1.081361490765117

1045
1.0818975583117079

1045
1.0824373407751957

1045
1.082980645874195

1045
1.083527287703668

1045
1.0840770862707683

1045
1.0846298614920

1.0182276485506712

1045
1.0178420031756383

1045
1.0174426860967951

1045
1.0170296758059034

1045
1.0166029710543465

1045
1.01616259080762

1045
1.0157085741402838

1045
1.0152409800717692

1045
1.0147598873437516

1045
1.0142653941401212

1045
1.0137576177508862

1045
1.0132366941816442

1045
1.0127027777105309

1045
1.0121560403948342

1045
1.0115966715296931

1045
1.0110248770615422

1045
1.0104408789591584

1045
1.0098449145453512

1045
1.0092372357924957

1045
1.0086181085852373

1045
1.0079878119538104

1045
1.0073466372814799

1045
1.0066948874896906

1045
1.0060328762045134

1045
1.0053609269080084

1045
1.0046793720780856

1045
1.0039885523204115

1045
1.00328881549585

1045
1.0025805158468284

1045
1.0018640131259406

1045
1.0011396717299554

1045
1.0004078598422885

1045
0.9996689485868306

1045
0.9989233111958857

1045
0.9981713221947877

1045
0.9974133566056083

1045
0.9966497891721785

1045
0.9958809936084646

1045
0.9951073418721607

1045
0.99432920346517

1045
0.9935

0.8507717189162496

1045
0.8498882112002957

1045
0.8490015977694326

1045
0.8481121227863035

1045
0.8472200321223542

1045
0.8463255731064102

1045
0.8454289942799889

1045
0.8445305451597621

1045
0.8436304760075259

1045
0.8427290376080017

1045
0.8418264810547521

1045
0.8409230575444528

1045
0.8400190181797234

1045
0.8391146137806884

1045
0.8382100947053958

1045
0.837305710679196

1045
0.8364017106331437

1045
0.8354983425514576

1045
0.8345958533280471

1045
0.8336944886320815

1045
0.8327944927825541

1045
0.831896108631774

1045
0.8309995774576868

1045
0.8301051388649159

1045
0.8292130306943856

1045
0.8283234889413814

1045
0.827436747681878

1045
0.8265530390069562

1045
0.8256725929651156

1045
0.8247956375122799

1045
0.8239223984692796

1045
0.8230530994865914

1045
0.822187962016102

1045
0.8213272052896632

1045
0.8204710463041944

1045
0.8196196998130878

1045
0.81877337832367

1045
0.8179322921004641

1045
0.817096649174007

1045
0.8162666553549596

1045
0.81544

0.7682311527566797

1045
0.7678809009251057

1045
0.7675418253504888

1045
0.7672140580759477

1045
0.7668977272448098

1045
0.7665929570036681

1045
0.7662998674064194

1045
0.7660185743192273

1045
0.765749189326348

1045
0.7654918196367795

1045
0.7652465679916732

1045
0.7650135325724781

1045
0.7647928069097718

1045
0.764584479792759

1045
0.7643886351794114

1045
0.7642053521072422

1045
0.764034704604713

1045
0.7638767616032843

1045
0.7637315868501324

1045
0.7635992388215697

1045
0.7634797706372121

1045
0.7633732299749648

1045
0.7632796589868984

1045
0.7631990942161098

1045
0.7631315665146823

1045
0.7630771009628675

1045
0.763035716789636

1045
0.7630074272947573

1045
0.7629922397725928

1045
0.7629901554377932

1045
0.7630011693531279

1045
0.7630252703596754

1045
0.7630624412869326

1045
0.7631126611298187

1045
0.7631758995855261

1045
0.7632521198823518

1045
0.7633412787257956

1045
0.7634433262488334

1045
0.7635582059667257

1045
0.7636858547367404

1045
0.76

0.9527854856421744

1046
0.953764919916217

1046
0.9547379968578782

1046
0.9557033516165694

1046
0.9566595489782161

1046
0.9576050857042616

1046
0.958538393469507

1046
0.959457842420045

1046
0.9603617453676688

1046
0.9612483626315597

1046
0.9621159075317886

1046
0.9629625525322639

1046
0.9637864360232534

1046
0.9645856697255983

1046
0.9653583466902909

1046
0.9661025498583526

1046
0.9668163611370105

1046
0.9674978709392301

1046
0.9681451881248162

1046
0.9687564502728111

1046
0.9693298342068576

1046
0.969863566687872

1046
0.9703559351818722

1046
0.9708052986053595

1046
0.9712100979464304

1046
0.9715688666569083

1046
0.9718802407094123

1046
0.9721429682134919

1046
0.9723559184868261

1046
0.9725180904810611

1046
0.9726286204671147

1046
0.9726867888916902

1046
0.9726920263251995

1046
0.972643918431202

1046
0.9725422098986354

1046
0.9723868072903501

1046
0.9721777807745599

1046
0.9719153647194971

1046
0.9715999571455856

1046
0.9712321180435108

1046
0.970

0.729537153300509

1045
0.7281091666374006

1045
0.7266927883833292

1045
0.7252880733985111

1045
0.7238950667558856

1045
0.722513804282428

1045
0.7211443130814871

1045
0.7197866120361615

1045
0.7184407122938022

1045
0.717106617731787

1045
0.7157843254047697

1045
0.7144738259736517

1045
0.713175104116571

1045
0.7118881389222316

1045
0.710612904265936

1045
0.7093493691687071

1045
0.7080974981399011

1045
0.7068572515037451

1045
0.7056285857102338

1045
0.7044114536308398

1045
0.7032058048394976

1045
0.7020115858793275

1045
0.700828740515567

1045
0.6996572099751815

1045
0.6984969331736284

1045
0.6973478469292325

1045
0.6962098861656497

1045
0.6950829841028661

1045
0.6939670724371929

1045
0.6928620815107015

1045
0.6917679404705327

1045
0.6906845774185134

1045
0.6896119195514988

1045
0.6885498932928474

1045
0.6874984244154345

1045
0.6864574381565856

1045
0.6854268593253123

1045
0.6844066124022204

1045
0.6833966216324392

1045
0.6823968111119229

1045
0.6814

0.6184904912277043

1045
0.6181662183244754

1045
0.6178349290631917

1045
0.6174968165162268

1045
0.6171520786175282

1045
0.6168009177187902

1045
0.6164435401412998

1045
0.6160801557250462

1045
0.6157109773766609

1045
0.6153362206177287

1045
0.6149561031349644

1045
0.6145708443337115

1045
0.6141806648961659

1045
0.613785786345665

1045
0.6133864306183262

1045
0.6129828196432461

1045
0.612575174932402

1045
0.6121637171813236

1045
0.6117486658815153

1045
0.6113302389455469

1045
0.6109086523456274

1045
0.6104841197664097

1045
0.6100568522726854

1045
0.6096270579925394

1045
0.6091949418164657

1045
0.6087607051128496

1045
0.6083245454601496

1045
0.6078866563960347

1045
0.6074472271836554

1045
0.6070064425951541

1045
0.60656448271245

1045
0.6061215227452741

1045
0.6056777328663522

1045
0.6052332780635964

1045
0.6047883180090861

1045
0.6043430069445874

1045
0.6038974935833008

1045
0.6034519210274889

1045
0.6030064267015911

1045
0.6025611423004036

1045
0.60

0.48900000577537045

1045
0.4887610683002758

1045
0.4885243479651048

1045
0.4882898180669724

1045
0.4880574516172221

1045
0.48782722136812934

1045
0.4875990998398553

1045
0.48737305934764186

1045
0.4871490720292351

1045
0.48692710987253535

1045
0.4867071447434584

1045
0.4864891484139922

1045
0.4862730925904449

1045
0.48605894894186186

1045
0.4858466891286

1045
0.4856362848310403

1045
0.48542770777842303

1045
0.4852209297777827

1045
0.4850159227429659

1045
0.4848126587237046

1045
0.4846111099347286

1045
0.4844112487848847

1045
0.4842130479062438

1045
0.48401648018316185

1045
0.4838215187812734

1045
0.48362813717638115

1045
0.4834363091832147

1045
0.48324600898402825

1045
0.4830572111569973

1045
0.48286989070438785

1045
0.4826840230804594

1045
0.4824995842190645

1045
0.482316550560913

1045
0.48213489908045654

1045
0.48195460731236217

1045
0.4817756533775306

1045
0.48159801600862356

1045
0.48142167457505736

1045
0.4812466091074251

1045
0.4810728072039

0.4660732484354806

1045
0.46591492449328814

1045
0.4657563886999588

1045
0.46559769074393154

1045
0.46543888008735756

1045
0.4652800059482449

1045
0.4651211172832205

1045
0.4649622627709013

1045
0.46480349079587424

1045
0.4646448494332722

1045
0.4644863864339435

1045
0.46432814921020277

1045
0.4641701848221591

1045
0.4640125399646073

1045
0.4638552609544762

1045
0.46369839371882304

1045
0.4635419837833627

1045
0.46338607626151984

1045
0.4632307158439963

1045
0.463075946788837

1045
0.4629218129119868

1045
0.46276835757832374



array([ 928.,  931.,  932., ..., 1045., 1045., 1045.])

In [51]:
x = np.asarray([1, 2])

In [54]:
LA.norm(x)

2.23606797749979

In [72]:
np.random.normal(1, 2, 1)

array([3.03731959])