In [1]:
from torch.nn import Linear, Sequential, Sigmoid, Module, ReLU, Dropout
from torch import Tensor
from neural_semigroups.constants import CURRENT_DEVICE

class SATClassifier(Module):
    def __init__(self, cardinality: int):
        super().__init__()
        self.cardinality = cardinality
        self.layers = Sequential(
            Linear(cardinality ** 3, cardinality ** 3),
            ReLU(),
            Linear(cardinality ** 3, cardinality ** 3),
            ReLU(),
            Linear(cardinality ** 3, cardinality ** 2),
            ReLU(),
            Linear(cardinality ** 2, cardinality),
            ReLU(),
            Linear(cardinality, 2),
            Sigmoid()
        ).to(CURRENT_DEVICE)

    def forward(self, x: Tensor):
        return self.layers(x.view(-1, self.cardinality ** 3))

model = SATClassifier(13)

In [2]:
import torch

model.load_state_dict(torch.load("sat_classifier.pt"))
model.to(CURRENT_DEVICE)
model.eval()

SATClassifier(
  (layers): Sequential(
    (0): Linear(in_features=2197, out_features=2197, bias=True)
    (1): ReLU()
    (2): Linear(in_features=2197, out_features=2197, bias=True)
    (3): ReLU()
    (4): Linear(in_features=2197, out_features=169, bias=True)
    (5): ReLU()
    (6): Linear(in_features=169, out_features=13, bias=True)
    (7): ReLU()
    (8): Linear(in_features=13, out_features=2, bias=True)
    (9): Sigmoid()
  )
)

In [3]:
from scripts.parse_mace4_output import get_cube_from_output

cube = get_cube_from_output("""
INPUT
0 * 1 = 1.
1 * 2 = 2.
0 * 2 = 1.
end of input
""", 13).view(1, 13, 13, 13).to(CURRENT_DEVICE)
model(cube)

tensor([[3.2252e-10, 9.9964e-01]], device='cuda:0', grad_fn=<SigmoidBackward>)

In [63]:
features_and_labels = torch.load("../scripts/test_output.pt")

In [64]:
non_trivial_indices = torch.einsum("ijkl -> i", features_and_labels["features"] == 1) == 13
features_and_labels = {
    "features": features_and_labels["features"][non_trivial_indices],
    "labels": features_and_labels["labels"][non_trivial_indices]
}

In [65]:
from torch.utils.data import TensorDataset, DataLoader

In [66]:
loader = DataLoader(TensorDataset(features_and_labels["features"], features_and_labels["labels"]), batch_size = 1024)

In [67]:
from tqdm.notebook import tqdm

result = list()
for batch in tqdm(loader):
    x, y = batch
    result.append(model(x.to(CURRENT_DEVICE)).detach().cpu())

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [68]:
from sklearn.metrics import roc_auc_score, f1_score, classification_report

y_true = features_and_labels["labels"]
y_score = torch.cat(result)[:, 1]
y_pred = y_score > 0.5
print(
    roc_auc_score(y_true, y_score),
    f1_score(y_true, y_pred)
)

0.4760547320410491 0.05189189189189188


In [69]:
print(classification_report(y_true, y_pred))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00      4385
           1       0.03      1.00      0.05       120

    accuracy                           0.03      4505
   macro avg       0.01      0.50      0.03      4505
weighted avg       0.00      0.03      0.00      4505



  _warn_prf(average, modifier, msg_start, len(result))


In [11]:
for i in range(10):
    cutoff = 0.0 + i / 10
    y_pred = y_score > cutoff
    print(cutoff, f1_score(y_true, y_pred))

0.0 0.9865002836074873
0.1 0.9865002836074873
0.2 0.9865002836074873
0.3 0.9865002836074873
0.4 0.9865002836074873
0.5 0.9865002836074873
0.6 0.9865002836074873
0.7 0.9865002836074873
0.8 0.9865002836074873
0.9 0.9865002836074873
