In [2]:
import torch
import torch.nn as nn
from pysmt.shortcuts import Symbol, And, Equals, Real, GT, LT, get_model
from pysmt.typing import REAL

In [3]:
class Node(nn.Module):
    def __init__(self):
        super(Node, self).__init__()             
        self.linear = nn.Linear(2, 1)
        self.relu = nn.ReLU()

    def forward(self, x):     
        x = self.linear(x)
        x = self.relu(x)
        
        #option1 = torch.tensor(1.0, requires_grad=True)
        #option2 = torch.tensor(0.0, requires_grad=True)
        #Replace 0 with self.threshold.weight
        #x = torch.where(x > 0, option1, option2, requires_grad=True)
        
        return x

In [4]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()             
        self.node1 = Node()
        
    def forward(self, x):   
        x = self.node1(x)
        return x[0]

In [5]:
def train(model, all_data, labels, criterion, optimizer, num_epochs):
    model.train()
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        num_correct = 0
    
        for (data, label) in zip(all_data, labels): 
            optimizer.zero_grad()
        
            output = model(data)
            prediction = 1 if output > 0 else 0

            loss = criterion(output, label)
            num_correct += (output == prediction)
        
            running_loss += loss.item()
            loss.backward()
            optimizer.step()

        loss = running_loss/len(labels)
        accuracy = num_correct/len(labels)
    
        print("epoch: " + str(epoch) + ", accuracy: " + str(accuracy.item()) + ", loss: " + str(loss))

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Model()
model.to(device)

points = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
labels = torch.tensor([0.0, 0.0, 1.0, 1.0], requires_grad=True)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.1)
num_epochs = 10

In [7]:
train(model, points, labels, criterion, optimizer, num_epochs)

epoch: 0, accuracy: 0.25, loss: 0.9669813737273216
epoch: 1, accuracy: 0.25, loss: 0.8610153347253799
epoch: 2, accuracy: 0.5, loss: 0.7879737168550491
epoch: 3, accuracy: 0.5, loss: 0.7507683336734772
epoch: 4, accuracy: 0.25, loss: 0.7141583561897278
epoch: 5, accuracy: 0.25, loss: 0.6447553113102913
epoch: 6, accuracy: 0.25, loss: 0.5875185877084732
epoch: 7, accuracy: 0.25, loss: 0.5402806177735329
epoch: 8, accuracy: 0.5, loss: 0.5080373510718346
epoch: 9, accuracy: 0.5, loss: 0.49350449442863464


In [18]:
w1 = Symbol("w1", REAL)
w2 = Symbol("w2", REAL)

equations = []

for (point, label) in zip(points, labels):
    x = point[0].item()
    y = point[1].item()
    calculation = w1 * Real(x) + w2 * Real(y)
    zero = Real(0.0)
    
    if label.item() == 0.0:
        equation = LT(calculation, zero)
    else:
        equation = GT(calculation, zero)
    equations.append(equation)

formula = And(equations)
print(formula)

model = get_model(formula)
if model:
    print(model)
    print(model[w1])
    print(model[w2])

else:
    print("No solution found")

((((w1 * -1.0) + (w2 * -1.0)) < 0.0) & (((w1 * -1.0) + (w2 * 1.0)) < 0.0) & (0.0 < ((w1 * 1.0) + (w2 * -1.0))) & (0.0 < ((w1 * 1.0) + (w2 * 1.0))))
w1 := 1/2
w2 := 0.0
1/2
0.0
