In [80]:
import time
import torch
import torch.nn as nn
from pysmt.shortcuts import Symbol, And, Equals, Real, GT, Max, Plus, get_model
from pysmt.typing import REAL

In [100]:
class Model(nn.Module):
    def __init__(self, embedding_size):
        super(Model, self).__init__()             
        self.linear0 = nn.Linear(3, embedding_size)
        self.linear1 = nn.Linear(embedding_size, 1)
        self.linear_layers = [self.linear0, self.linear1]
        self.relu = nn.ReLU()
        
    def forward(self, x):   
        x = self.linear0(x)
        x = self.relu(x)  
        
        x = self.linear1(x)
        x = self.relu(x) 
        
        return x[0]

In [101]:
def train_backpropagation(model, all_data, labels, criterion, optimizer, num_epochs):
    model.train()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        num_correct = 0
    
        for (data, label) in zip(all_data, labels): 
            optimizer.zero_grad()
        
            output = model(data)
            prediction = 1 if output > 0 else 0

            loss = criterion(output, label)
            num_correct += (label == prediction)
        
            running_loss += loss.item()
            loss.backward()
            optimizer.step()

        loss = running_loss/len(labels)
        accuracy = num_correct/len(labels)
        print("epoch: " + str(epoch) + ", accuracy: " + str(accuracy.item()) + ", loss: " + str(loss))
        
        if (accuracy == 1.0):
            return

In [102]:
def train_smt(model, all_data, labels):
    
    layers_weights = []
    num_layers = len(model.linear_layers)
    for i in range(num_layers):
        num_rows = len(model.linear_layers[i].weight.data)
        num_cols = len(model.linear_layers[i].weight.data[0])
        layer_weights = []
        for j in range(num_rows):
            weights = []
            for k in range(num_cols):
                weights.append(Symbol("weight_" + str(i) + "_" + str(j) + "_" + str(k), REAL))
            weights.append(Symbol(str("bias_" + str(i) + "_" + str(j)), REAL))
            layer_weights.append(weights)
        layers_weights.append(layer_weights)

    equations = []
    for (data, label) in zip(all_data, labels):
        x = []
        for value in data:
            x.append(Real(value.item()))
            
        for i in range(num_layers):
            num_rows = len(model.linear_layers[i].weight.data)
            num_cols = len(model.linear_layers[i].weight.data[0])
            new_x = []
            for j in range(num_rows):
                calculations = []
                for k in range(num_cols):
                    calculation = layers_weights[i][j][k] * x[k]
                    calculations.append(calculation)
                calculations.append(layers_weights[i][j][num_cols])
                calculation = Plus(calculations)
                relu = Max(calculation, Real(0.0))
                new_x.append(relu)
            x = new_x
    
        if label.item() == 0.0:
            equation = Equals(x[0], Real(0.0))
        else:
            equation = GT(x[0], Real(0.0))
            
        equations.append(equation)

    formula = And(equations)
    solution = get_model(formula)
    
    if solution:
        for i in range(num_layers):
            num_rows = len(model.linear_layers[i].weight.data)
            num_cols = len(model.linear_layers[i].weight.data[0])
            for j in range(num_rows):
                for k in range(num_cols):
                    weight = layers_weights[i][j][k]
                    weight_solution = solution[weight].constant_value()
                    model.linear_layers[i].weight.data[j][k] = float(weight_solution)
    
                
                bias = layers_weights[i][j][num_cols]
                bias_solution = solution[bias].constant_value()
                model.linear_layers[i].bias.data[j] = float(bias_solution)

    else:
        print("No solution found")

In [103]:
def test(model, all_data, labels, criterion): 
    running_loss = 0.0
    num_correct = 0
    
    for (data, label) in zip(all_data, labels): 
        output = model(data)
        prediction = 1 if output > 0 else 0

        loss = criterion(output, label)
        num_correct += (label == prediction)
        
        running_loss += loss.item()

    loss = running_loss/len(labels)
    accuracy = num_correct/len(labels)
    
    print("testing accuracy: " + str(accuracy.item()) + " and loss: " + str(loss))

In [129]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

points = torch.tensor([[-1.0, -1.0, -1.0], [1.0, -1.0, -1.0], [-1.0, 1.0, -1.0], [-1.0, -1.0, 1.0],
                       [1.0, 1.0, 1.0], [-1.0, 1.0, 1.0], [1.0, -1.0, 1.0], [1.0, 1.0, -1.0],])

labels = torch.tensor([0.0, 1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0], requires_grad=True)

criterion = nn.BCEWithLogitsLoss()
num_epochs = 1000

In [130]:
embedding_sizes = [10, 100, 1000]
smt_times = []
backpropagation_times = []

for embedding_size in embedding_sizes:

    model = Model(embedding_size)
    optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)
    model.to(device)

    start_time = time.time()
    train_backpropagation(model, points, labels, criterion, optimizer, num_epochs)
    backpropagation_times.append(time.time() - start_time)

    start_time = time.time()
    train_smt(model, points, labels)
    smt_times.append(time.time() - start_time)
    test(model, points, labels, criterion)

epoch: 0, accuracy: 0.5, loss: 0.6935068517923355
epoch: 1, accuracy: 0.625, loss: 0.6809237599372864
epoch: 2, accuracy: 0.625, loss: 0.6781993582844734
epoch: 3, accuracy: 0.625, loss: 0.6737440377473831
epoch: 4, accuracy: 0.625, loss: 0.6708152890205383
epoch: 5, accuracy: 0.875, loss: 0.6575960889458656
epoch: 6, accuracy: 0.625, loss: 0.6570715457201004
epoch: 7, accuracy: 0.875, loss: 0.6261524856090546
epoch: 8, accuracy: 0.75, loss: 0.604952160269022
epoch: 9, accuracy: 0.875, loss: 0.5835750326514244
epoch: 10, accuracy: 0.875, loss: 0.5647474899888039
epoch: 11, accuracy: 0.875, loss: 0.5461789630353451
epoch: 12, accuracy: 1.0, loss: 0.5236970223486423
testing accuracy: 1.0 and loss: 0.6628731191158295
epoch: 0, accuracy: 0.5, loss: 0.7105482034385204
epoch: 1, accuracy: 0.5, loss: 0.6842015944421291
epoch: 2, accuracy: 0.5, loss: 0.657867893576622
epoch: 3, accuracy: 0.5, loss: 0.6216005813330412
epoch: 4, accuracy: 0.5, loss: 0.5865693334490061
epoch: 5, accuracy: 0.5, lo

epoch: 210, accuracy: 0.5, loss: 0.6931471824645996
epoch: 211, accuracy: 0.5, loss: 0.6931471824645996
epoch: 212, accuracy: 0.5, loss: 0.6931471824645996
epoch: 213, accuracy: 0.5, loss: 0.6931471824645996
epoch: 214, accuracy: 0.5, loss: 0.6931471824645996
epoch: 215, accuracy: 0.5, loss: 0.6931471824645996
epoch: 216, accuracy: 0.5, loss: 0.6931471824645996
epoch: 217, accuracy: 0.5, loss: 0.6931471824645996
epoch: 218, accuracy: 0.5, loss: 0.6931471824645996
epoch: 219, accuracy: 0.5, loss: 0.6931471824645996
epoch: 220, accuracy: 0.5, loss: 0.6931471824645996
epoch: 221, accuracy: 0.5, loss: 0.6931471824645996
epoch: 222, accuracy: 0.5, loss: 0.6931471824645996
epoch: 223, accuracy: 0.5, loss: 0.6931471824645996
epoch: 224, accuracy: 0.5, loss: 0.6931471824645996
epoch: 225, accuracy: 0.5, loss: 0.6931471824645996
epoch: 226, accuracy: 0.5, loss: 0.6931471824645996
epoch: 227, accuracy: 0.5, loss: 0.6931471824645996
epoch: 228, accuracy: 0.5, loss: 0.6931471824645996
epoch: 229, 

epoch: 498, accuracy: 0.5, loss: 0.6931471824645996
epoch: 499, accuracy: 0.5, loss: 0.6931471824645996
epoch: 500, accuracy: 0.5, loss: 0.6931471824645996
epoch: 501, accuracy: 0.5, loss: 0.6931471824645996
epoch: 502, accuracy: 0.5, loss: 0.6931471824645996
epoch: 503, accuracy: 0.5, loss: 0.6931471824645996
epoch: 504, accuracy: 0.5, loss: 0.6931471824645996
epoch: 505, accuracy: 0.5, loss: 0.6931471824645996
epoch: 506, accuracy: 0.5, loss: 0.6931471824645996
epoch: 507, accuracy: 0.5, loss: 0.6931471824645996
epoch: 508, accuracy: 0.5, loss: 0.6931471824645996
epoch: 509, accuracy: 0.5, loss: 0.6931471824645996
epoch: 510, accuracy: 0.5, loss: 0.6931471824645996
epoch: 511, accuracy: 0.5, loss: 0.6931471824645996
epoch: 512, accuracy: 0.5, loss: 0.6931471824645996
epoch: 513, accuracy: 0.5, loss: 0.6931471824645996
epoch: 514, accuracy: 0.5, loss: 0.6931471824645996
epoch: 515, accuracy: 0.5, loss: 0.6931471824645996
epoch: 516, accuracy: 0.5, loss: 0.6931471824645996
epoch: 517, 

epoch: 792, accuracy: 0.5, loss: 0.6931471824645996
epoch: 793, accuracy: 0.5, loss: 0.6931471824645996
epoch: 794, accuracy: 0.5, loss: 0.6931471824645996
epoch: 795, accuracy: 0.5, loss: 0.6931471824645996
epoch: 796, accuracy: 0.5, loss: 0.6931471824645996
epoch: 797, accuracy: 0.5, loss: 0.6931471824645996
epoch: 798, accuracy: 0.5, loss: 0.6931471824645996
epoch: 799, accuracy: 0.5, loss: 0.6931471824645996
epoch: 800, accuracy: 0.5, loss: 0.6931471824645996
epoch: 801, accuracy: 0.5, loss: 0.6931471824645996
epoch: 802, accuracy: 0.5, loss: 0.6931471824645996
epoch: 803, accuracy: 0.5, loss: 0.6931471824645996
epoch: 804, accuracy: 0.5, loss: 0.6931471824645996
epoch: 805, accuracy: 0.5, loss: 0.6931471824645996
epoch: 806, accuracy: 0.5, loss: 0.6931471824645996
epoch: 807, accuracy: 0.5, loss: 0.6931471824645996
epoch: 808, accuracy: 0.5, loss: 0.6931471824645996
epoch: 809, accuracy: 0.5, loss: 0.6931471824645996
epoch: 810, accuracy: 0.5, loss: 0.6931471824645996
epoch: 811, 