In [1]:
import torch
import torch.nn as nn
from pysmt.shortcuts import Symbol, And, Equals, Real, GT, Max, Plus, get_model
from pysmt.typing import REAL

In [2]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()             
        self.linear = nn.Linear(2, 1)
        self.relu = nn.ReLU()
        
    def forward(self, x):   
        x = self.linear(x)
        x = self.relu(x)            
        return x[0]

In [3]:
def train_backpropagation(model, all_data, labels, criterion, optimizer, num_epochs):
    model.train()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        num_correct = 0
    
        for (data, label) in zip(all_data, labels): 
            optimizer.zero_grad()
        
            output = model(data)
            prediction = 1 if output > 0 else 0

            loss = criterion(output, label)
            num_correct += (label == prediction)
        
            running_loss += loss.item()
            loss.backward()
            optimizer.step()

        loss = running_loss/len(labels)
        accuracy = num_correct/len(labels)
    
        print("epoch: " + str(epoch) + ", accuracy: " + str(accuracy.item()) + ", loss: " + str(loss))

In [4]:
def train_smt(model, all_data, labels):
    
    data_dim = len(all_data[0])
    
    weights = []
    for i in range(data_dim):
        weights.append(Symbol(str(i), REAL))
    weights.append(Symbol(str(i), REAL))

    equations = []
    for (data, label) in zip(all_data, labels):
        
        calculations = []
        for i in range(data_dim):
            x = data[i].item()
            calculation = weights[i] * Real(x)
            calculations.append(calculation)
        calculations.append(weights[data_dim])
        calculation = Plus(calculations)
        
        relu = Max(calculation, Real(0.0))
    
        if label.item() == 0.0:
            equation = Equals(relu, Real(0.0))
        else:
            equation = GT(relu, Real(0.0))
            
        equations.append(equation)

    formula = And(equations)
    solution = get_model(formula)
    
    if solution:
        for i in range(data_dim):
            weight = weights[i]
            weight_solution = solution[weight].constant_value()
            model.linear.weight.data[0][i] = float(weight_solution)
    
        bias = weights[data_dim]
        bias_solution = solution[bias].constant_value()
        model.linear.bias.data[0] = float(bias_solution)

    else:
        print("No solution found")

In [5]:
def test(model, all_data, labels, criterion): 
    running_loss = 0.0
    num_correct = 0
    
    for (data, label) in zip(all_data, labels): 
        output = model(data)
        prediction = 1 if output > 0 else 0

        loss = criterion(output, label)
        num_correct += (label == prediction)
        
        running_loss += loss.item()

    loss = running_loss/len(labels)
    accuracy = num_correct/len(labels)
    
    print("testing accuracy: " + str(accuracy.item()) + " and loss: " + str(loss))

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Model()
model.to(device)

points = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
labels = torch.tensor([0.0, 0.0, 1.0, 1.0], requires_grad=True)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)
num_epochs = 10

In [7]:
train_backpropagation(model, points, labels, criterion, optimizer, num_epochs)
test(model, points, labels, criterion)

epoch: 0, accuracy: 0.25, loss: 0.9644601494073868
epoch: 1, accuracy: 0.25, loss: 0.8604625016450882
epoch: 2, accuracy: 0.5, loss: 0.8021301701664925
epoch: 3, accuracy: 0.5, loss: 0.763194352388382
epoch: 4, accuracy: 0.5, loss: 0.7279316633939743
epoch: 5, accuracy: 0.5, loss: 0.6962470337748528
epoch: 6, accuracy: 0.75, loss: 0.6597120463848114
epoch: 7, accuracy: 0.75, loss: 0.5981166362762451
epoch: 8, accuracy: 0.75, loss: 0.5475847870111465
epoch: 9, accuracy: 1.0, loss: 0.5188740119338036
testing accuracy: 1.0 and loss: 0.5043608397245407


In [8]:
train_smt(model, points, labels)
test(model, points, labels, criterion)

testing accuracy: 1.0 and loss: 0.5836120843887329
