In [17]:
import torch

features_and_labels = torch.load("../scripts/test_output.pt")

non_trivial_indices = torch.einsum("ijkl -> i", features_and_labels["features"] == 1) == 5
features_and_labels = {
    "features": features_and_labels["features"][non_trivial_indices],
    "labels": features_and_labels["labels"][non_trivial_indices]
}

In [18]:
features = features_and_labels["features"]
labels = features_and_labels["labels"]

In [19]:
print(labels.shape)
print(labels.sum())

torch.Size([4467])
tensor(4348)


In [20]:
from torch.utils.data import random_split
from torch.utils.data import TensorDataset

validation_len = 100
test_len = 3000
train_len = labels.shape[0] - validation_len - test_len
train, validation, test = random_split(
    TensorDataset(features, labels),
    [train_len, validation_len, test_len]
)

In [21]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation, batch_size=32)
test_loader = DataLoader(test, batch_size=32)
data_loaders = [train_loader, validation_loader, test_loader]

In [22]:
from torch.nn import Linear, Sequential, Sigmoid, Module, ReLU, Dropout
from torch import Tensor
from neural_semigroups.constants import CURRENT_DEVICE

class SATClassifier(Module):
    def __init__(self, cardinality: int):
        super().__init__()
        self.cardinality = cardinality
        self.layers = Sequential(
            Linear(cardinality ** 3, cardinality ** 3),
            ReLU(),
            Linear(cardinality ** 3, cardinality ** 3),
            ReLU(),
            Linear(cardinality ** 3, cardinality ** 2),
            ReLU(),
            Linear(cardinality ** 2, cardinality),
            ReLU(),
            Linear(cardinality, 2),
            Sigmoid()
        ).to(CURRENT_DEVICE)

    def forward(self, x: Tensor):
        return self.layers(x.view(-1, self.cardinality ** 3))

model = SATClassifier(13)

In [23]:
from torch.nn import CrossEntropyLoss

loss = CrossEntropyLoss()

In [24]:
!rm -rf runs

In [25]:
from neural_semigroups.training_helpers import learning_pipeline
from ignite.metrics import Loss
from ignite.contrib.metrics import ROC_AUC

learning_pipeline(
    params={"learning_rate": 0.0001, "epochs": 1000},
    model=model,
    loss=loss,
    metrics={
        "loss": Loss(loss),
        "ROC_AUC": ROC_AUC(output_transform=lambda output: (output[0][:, 1], output[1]))
    },
    data_loaders=data_loaders,
)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))

In [26]:
torch.save(model.state_dict(), "sat_classifier.pt")