In [1]:
from foreduce.vampire.vampire import VampireInteractive
import os
from dotenv import load_dotenv
from torch_geometric.utils.convert import from_networkx
from torch_geometric.utils import index_to_mask
import torch

from foreduce.transformer.model import GraphModel
from foreduce.data.data import _type_mapping

load_dotenv()
VAMPIRE = os.getenv("VAMPIRE")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = GraphModel.load_from_checkpoint('./models/stage1.ckpt')
model.eval()

GraphModel(
  (gnn): GNN(
    (type_embedding): Embedding(8, 256)
    (arity_embedding): Embedding(10, 256, padding_idx=0)
    (position_embedding): Embedding(10, 256, padding_idx=0)
    (act): ReLU()
    (norms): ModuleList(
      (0-7): 8 x GraphNorm(256)
    )
    (conv_layers): ModuleList(
      (0-7): 8 x GCNConv(256, 256)
    )
  )
  (out): Sequential(
    (0): Linear(in_features=256, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=1, bias=True)
  )
)

In [None]:
interactive = VampireInteractive(VAMPIRE, f'./problems/SWV/SWV243-2.p')
interactive.__enter__()
graph, mapping, clauses = interactive.problem.to_graph(depth=8)
graph.depth()

In [16]:
data = from_networkx(graph)
data.type = torch.tensor([_type_mapping[t] for t in data.type], dtype=torch.int)
data.arity = torch.tensor([min(8 + 1, a + 1) if a is not None else 0 for a in data.arity], dtype=torch.int)
data.pos = torch.tensor([min(8 + 1, a + 1) if a is not None else 0 for a in data.pos], dtype=torch.int)
data.clauses = index_to_mask(torch.tensor(clauses), size=data.num_nodes)
score = model(data)

In [17]:
deduped = []
for clause in interactive.problem.clauses:
    if clause not in deduped:
        deduped.append(clause)
_mapping = {i: deduped.index(clause) for i, clause in enumerate(interactive.problem.clauses)}
vals = [(score[_mapping[i]].item(), i) for i in range(len(interactive.problem.clauses)) if not interactive.active[i]]
_, next_clause = max(vals)
interactive.step(next_clause)

In [18]:
graph, mapping, clauses = interactive.problem.extend_graph(graph, mapping, len(clauses), depth=8)

In [49]:
print(interactive.proof)

% Running in auto input_syntax mode. Trying TPTP
[SA] new: 1. goal_0 | c_Message_Oanalz(c_Message_Osynth(v_H)) != c_union(c_Message_Oanalz(v_H),c_Message_Osynth(v_H),tc_Message_Omsg) [input]
[SA] new: 2. c_Message_Oanalz(c_union(c_Message_Osynth(X0),X1,tc_Message_Omsg)) = c_union(c_Message_Oanalz(c_union(X0,X1,tc_Message_Omsg)),c_Message_Osynth(X0),tc_Message_Omsg) [input]
[SA] new: 3. c_union(X2,c_emptyset,X3) = X2 [input]
[SA] new: 4. ~goal_0 [input]
[SA] new: 5. c_Message_Oanalz(c_union(c_Message_Osynth(c_Message_Oanalz(c_union(X0,X1,tc_Message_Omsg))),c_Message_Osynth(X0),tc_Message_Omsg)) = c_union(c_Message_Oanalz(c_Message_Oanalz(c_union(c_Message_Osynth(X0),X1,tc_Message_Omsg))),c_Message_Osynth(c_Message_Oanalz(c_union(X0,X1,tc_Message_Omsg))),tc_Message_Omsg) [superposition 2,2]
[SA] new: 6. c_Message_Oanalz(c_union(c_Message_Osynth(c_Message_Oanalz(c_Message_Oanalz(c_union(c_Message_Osynth(X0),X1,tc_Message_Omsg)))),c_Message_Osynth(c_Message_Oanalz(c_union(X0,X1,tc_Message_

In [50]:
clauses

[2, 22, 40, 49, 52, 81, 123]

In [51]:
list(graph.neighbors(123))

[125]

In [52]:
list(graph.neighbors(125))

[1, 126, 146, 125, 123]

In [57]:
list(graph.neighbors(128))

[10, 129, 127]

In [58]:
for k, v in mapping.items():
    if v == 128:
        print(k)

In [62]:
interactive

0: goal_0 | c_Message_Oanalz(c_Message_Osynth(v_H)) = c_union(c_Message_Oanalz(v_H), c_Message_Osynth(v_H), tc_Message_Omsg)
[1m 1: c_Message_Oanalz(c_union(c_Message_Osynth(X0), X1, tc_Message_Omsg)) = c_union(c_Message_Oanalz(c_union(X0, X1, tc_Message_Omsg)), c_Message_Osynth(X0), tc_Message_Omsg)[0m
2: c_union(X2, c_emptyset, X3) = X2
[1m 3: ~goal_0[0m
[1m 4: c_Message_Oanalz(c_union(c_Message_Osynth(c_Message_Oanalz(c_union(X0, X1, tc_Message_Omsg))), c_Message_Osynth(X0), tc_Message_Omsg)) = c_union(c_Message_Oanalz(c_Message_Oanalz(c_union(c_Message_Osynth(X0), X1, tc_Message_Omsg))), c_Message_Osynth(c_Message_Oanalz(c_union(X0, X1, tc_Message_Omsg))), tc_Message_Omsg)[0m
5: c_Message_Oanalz(c_union(c_Message_Osynth(c_Message_Oanalz(c_Message_Oanalz(c_union(c_Message_Osynth(X0), X1, tc_Message_Omsg)))), c_Message_Osynth(c_Message_Oanalz(c_union(X0, X1, tc_Message_Omsg))), tc_Message_Omsg)) = c_union(c_Message_Oanalz(c_Message_Oanalz(c_union(c_Message_Osynth(c_Message_Oana

In [61]:
l = sorted([(v, k) for k, v in mapping.items()])
for v, k in l:
    print(v, k)

0 $false/0
1 eq/2
2 goal_0 | c_Message_Oanalz(c_Message_Osynth(v_H)) = c_union(c_Message_Oanalz(v_H), c_Message_Osynth(v_H), tc_Message_Omsg)
4 goal_0/0
5 goal_0goal_0 | c_Message_Oanalz(c_Message_Osynth(v_H)) = c_union(c_Message_Oanalz(v_H), c_Message_Osynth(v_H), tc_Message_Omsg)
7 c_Message_Oanalz(c_Message_Osynth(v_H)) = c_union(c_Message_Oanalz(v_H), c_Message_Osynth(v_H), tc_Message_Omsg)goal_0 | c_Message_Oanalz(c_Message_Osynth(v_H)) = c_union(c_Message_Oanalz(v_H), c_Message_Osynth(v_H), tc_Message_Omsg)
8 c_Message_Oanalz/1
9 c_Message_Oanalz(c_Message_Osynth(v_H))goal_0 | c_Message_Oanalz(c_Message_Osynth(v_H)) = c_union(c_Message_Oanalz(v_H), c_Message_Osynth(v_H), tc_Message_Omsg)
10 c_Message_Osynth/1
12 v_H/0
14 c_union/3
15 c_union(c_Message_Oanalz(v_H), c_Message_Osynth(v_H), tc_Message_Omsg)goal_0 | c_Message_Oanalz(c_Message_Osynth(v_H)) = c_union(c_Message_Oanalz(v_H), c_Message_Osynth(v_H), tc_Message_Omsg)
16 c_Message_Oanalz(v_H)goal_0 | c_Message_Oanalz(c_Messag