In [1]:
from torch.nn import Linear, Sequential, Sigmoid, Module, ReLU, Dropout
from torch import Tensor
from neural_semigroups.constants import CURRENT_DEVICE

class SATClassifier(Module):
    def __init__(self, cardinality: int):
        super().__init__()
        self.cardinality = cardinality
        self.layers = Sequential(
            Linear(cardinality ** 3, cardinality ** 3),
            ReLU(),
            Linear(cardinality ** 3, cardinality ** 3),
            ReLU(),
            Linear(cardinality ** 3, cardinality ** 2),
            ReLU(),
            Linear(cardinality ** 2, cardinality),
            ReLU(),
            Linear(cardinality, 2),
            Sigmoid()
        ).to(CURRENT_DEVICE)

    def forward(self, x: Tensor):
        return self.layers(x.view(-1, self.cardinality ** 3))

model = SATClassifier(13)

In [2]:
import torch

model.load_state_dict(torch.load("sat_classifier.pt"))
model.to(CURRENT_DEVICE)
model.eval()

SATClassifier(
  (layers): Sequential(
    (0): Linear(in_features=2197, out_features=2197, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2197, out_features=2197, bias=True)
    (3): ReLU()
    (4): Linear(in_features=2197, out_features=169, bias=True)
    (5): ReLU()
    (6): Linear(in_features=169, out_features=13, bias=True)
    (7): ReLU()
    (8): Linear(in_features=13, out_features=2, bias=True)
    (9): Sigmoid()
  )
)

In [28]:
from scripts.parse_mace4_output import get_cube_from_output

cube = get_cube_from_output("""
INPUT
0 * 1 = 1.
1 * 2 = 2.
0 * 2 = 1.
end of input
""", 13).view(1, 13, 13, 13).to(CURRENT_DEVICE)
model(cube)

tensor([[6.2399e-05, 9.9999e-01]], device='cuda:0', grad_fn=<SigmoidBackward>)

In [4]:
data = torch.load("../scripts/some.pt")

In [5]:
from torch.utils.data import TensorDataset, DataLoader

In [6]:
loader = DataLoader(TensorDataset(data["features"], data["labels"]), batch_size = 1024)

In [9]:
from tqdm.notebook import tqdm

result = list()
for batch in tqdm(loader):
    x, y = batch
    result.append(model(x.to(CURRENT_DEVICE)).detach().cpu())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=977.0), HTML(value='')))




In [72]:
from sklearn.metrics import roc_auc_score, f1_score, classification_report

y_true = data["labels"]
y_score = torch.cat(result)[:, 1]
y_pred = y_score > 0.99998
print(
    roc_auc_score(y_true, y_score),
    f1_score(y_true, y_pred)
)

0.9945757831981856 0.8298340082773351


In [73]:
# for i in range(10):
#     cutoff = 0.9999 + i / 100000
#     y_pred = y_score > cutoff
#     print(cutoff, f1_score(y_true, y_pred))

In [74]:
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.99      0.99      0.99    945746
           1       0.80      0.86      0.83     54241

    accuracy                           0.98    999987
   macro avg       0.90      0.92      0.91    999987
weighted avg       0.98      0.98      0.98    999987

