In [1]:
VAMPIRE_PATH = '/home/apluska/.vampire/bin/vampire_z3_rel_static_casc2023_6749'
TPTP_PATH = '/home/apluska/TPTP-v8.2.0/'

In [2]:
PROBLEM = 'GRP001-1.p'

In [3]:
import os

from foreduce.tptp.parser import read_file as read_tptp

symbols = set()

problem = read_tptp(TPTP_PATH + 'Problems/GRP/' + PROBLEM, include_path=TPTP_PATH, max_size=10_000)
symbols.update(problem.function_symbols() | problem.predicate_symbols())

if not os.path.exists('problems/' + PROBLEM):
    with open('problems/' + PROBLEM, 'w') as f:
        f.write(problem.to_tptp())

In [4]:
from collections import defaultdict

from foreduce.transformer.tokenizer import TokenConfig

arity_dict = defaultdict(list)

for s in symbols:
    arity_dict[s.arity].append(s.name)

arity_list = list(arity_dict.values())
config = TokenConfig(num_functions=[len(s) for s in arity_list])
function_mapping = config.random_function_mapping(arity_list)


mapping = config.reserved_token_mapping | config.random_function_mapping(arity_list)

In [5]:
from tqdm.auto import tqdm

from foreduce.vampire.vampire import VampireAutomatic
from foreduce.data.data import ProofTokens


dataset = ProofTokens(config, seq_len=24)
for _ in tqdm(range(1)):
    vampire = VampireAutomatic(VAMPIRE_PATH, 'problems/' + PROBLEM, selection='10')
    vampire.run()
    if 'Refutation found' in vampire.proof:    
        dataset.add_proof(vampire.problem, vampire.tree, mapping)

  from .autonotebook import tqdm as notebook_tqdm
100%|██████████| 1/1 [00:02<00:00,  2.01s/it]


In [6]:
from foreduce.transformer.embedding import FormulaEmbedding
from torch.utils.data import DataLoader

BATCHES = 256

#embedding = FormulaEmbedding(config, seq_len=24, dim=1536, n_layers=24, n_heads=16)
embedding = FormulaEmbedding(config, seq_len=24, dim=512, n_layers=64, n_heads=16)
data_loader = DataLoader(dataset, batch_size=len(dataset) // BATCHES, shuffle=True, num_workers=4, persistent_workers=True, drop_last=True)

In [7]:
from lightning import Trainer
import wandb
from lightning.pytorch.loggers import WandbLogger
import torch

torch.set_float32_matmul_precision('medium')

wandb.init(project='foreduce')

trainer = Trainer(max_epochs=256, logger=WandbLogger(), accumulate_grad_batches=BATCHES, log_every_n_steps=1, accelerator="auto", devices="auto")

trainer.fit(embedding, data_loader)

wandb.finish()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlexpk[0m. Use [1m`wandb login --relogin`[0m to force relogin


Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/apluska/miniconda3/envs/foreduce/lib/python3.12/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/home/apluska/miniconda3/envs/foreduce/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name       | Type       | Params | Mode 
-------

Epoch 11:  83%|████████▎ | 213/256 [01:39<00:20,  2.14it/s, v_num=m3es, train_loss_step=0.787, train_loss_epoch=0.992]

In [8]:
embedding = embedding.cpu().eval()
__x = torch.zeros((len(vampire.problem.clauses), 24), dtype=torch.long)
dependencies = [set() for _ in vampire.problem.clauses]
for idx in range(len(vampire.problem.clauses)):
    if vampire.tree[idx]:
        dependencies[idx] = {idx} | set.union(*[dependencies[j] for j in vampire.tree[idx]])
    else:
        dependencies[idx] = {idx}
target = torch.zeros(len(vampire.problem.clauses), dtype=torch.float)

for i, clause in enumerate(vampire.problem.clauses):
    tokens = vampire.problem.clauses[i].tokenize(config, mapping)
    for j, token in enumerate(tokens[:24]):
        __x[i, j] = tokens[j]

_tokens, indices = torch.unique(__x, return_inverse=True, dim=0)
index_mapping = {idx : [] for idx in range(indices.max().item() + 1)}
for i, idx in enumerate(indices):
    index_mapping[idx.item()].append(i)

for i in range(len(vampire.problem.clauses)):
    idx, _ = min(
        (
            (idx, len(dependencies[idx] | dependencies[-1]) + 1 / (1 + len(dependencies[idx] & dependencies[-1]))) for idx in index_mapping[indices[i].item()]
        ), key=lambda x: x[1]
    )
    target[i] = len(dependencies[idx] & dependencies[-1]) / (len(dependencies[idx]) * len(dependencies[-1]))**0.5
        
_x = embedding(__x)
similarities = sorted([
    (
        i,
        target[i],
        torch.nn.functional.cosine_similarity(_x[i], _x[-1], dim=0).item()
    )
    for i in range(len(vampire.problem.clauses))
], key=lambda x: x[1], reverse=True)

for i, target, sim in similarities:
    print(f"{target:.2f}", f"{sim:.2f}", vampire.problem.clauses[i])

1.00 1.00 $false
0.78 0.71 ~killed(butler, agatha)
0.72 0.69 richer(butler, agatha)
0.66 0.66 richer(butler, agatha) | ~lives(butler)
0.63 0.55 killed(butler, agatha)
0.59 0.59 ~hates(butler, butler)
0.55 0.55 killed(butler, agatha) | ~hates(agatha, agatha)
0.47 0.47 killed(butler, agatha) | hates(charles, agatha)
0.47 0.47 ~hates(butler, butler) | ~hates(butler, agatha)
0.39 0.40 ~hates(butler, butler) | richer(charles, agatha) | ~lives(charles)
0.39 0.40 richer(butler, agatha) | ~killed(charles, agatha)
0.36 0.38 hates(butler, agatha)
0.36 0.37 hates(butler, charles)
0.36 0.36 killed(charles, agatha) | killed(butler, agatha)
0.36 0.38 hates(butler, agatha)
0.36 0.38 richer(charles, agatha) | richer(butler, agatha)
0.35 0.36 ~hates(butler, butler) | richer(charles, agatha)
0.33 0.34 richer(charles, agatha) | richer(butler, agatha) | ~lives(butler)
0.28 0.28 ~hates(agatha, butler)
0.24 0.25 ~hates(agatha, butler) | ~hates(agatha, agatha)
0.24 0.24 ~hates(butler, butler) | ~hates(butler

In [9]:
print(vampire.proof)

% Running in auto input_syntax mode. Trying TPTP
[SA] new: 1. lives(agatha) [input]
[SA] new: 2. lives(butler) [input]
[SA] new: 3. lives(charles) [input]
[SA] new: 4. ~richer(X0,X1) | ~killed(X0,X1) [input]
[SA] new: 5. ~hates(charles,X2) | ~hates(agatha,X2) [input]
[SA] new: 6. ~hates(X3,charles) | ~hates(X3,butler) | ~hates(X3,agatha) [input]
[SA] new: 7. hates(agatha,agatha) [input]
[SA] new: 8. hates(agatha,charles) [input]
[SA] new: 9. hates(X4,X5) | ~killed(X4,X5) [input]
[SA] new: 10. hates(butler,X6) | ~hates(agatha,X6) [input]
[SA] new: 11. hates(butler,X7) | richer(X7,agatha) | ~lives(X7) [input]
[SA] new: 12. goal_0 | killed(charles,agatha) | killed(butler,agatha) [input]
[SA] new: 13. ~goal_0 [input]
[SA] new: 14. hates(butler,agatha) [resolution 10,7]
[SA] new: 15. hates(butler,charles) [resolution 10,8]
[SA] new: 16. killed(charles,agatha) | killed(butler,agatha) [subsumption resolution 12,13]
[SA] new: 17. killed(butler,agatha) | hates(charles,agatha) [resolution 16,9]


In [10]:
for i, target, sim in similarities:
    if i in dependencies[-1]:
        print(f"{target/len(dependencies[i])**0.5:.2f}", f"{sim/len(dependencies[i])**0.5:.2f}", vampire.problem.clauses[i])

0.21 0.21 $false
0.21 0.19 ~killed(butler, agatha)
0.21 0.20 richer(butler, agatha)
0.21 0.21 richer(butler, agatha) | ~lives(butler)
0.21 0.18 killed(butler, agatha)
0.21 0.21 ~hates(butler, butler)
0.21 0.21 killed(butler, agatha) | ~hates(agatha, agatha)
0.21 0.21 killed(butler, agatha) | hates(charles, agatha)
0.21 0.21 ~hates(butler, butler) | ~hates(butler, agatha)
0.21 0.22 hates(butler, agatha)
0.21 0.21 hates(butler, charles)
0.21 0.21 killed(charles, agatha) | killed(butler, agatha)
0.21 0.21 lives(butler)
0.21 0.20 ~richer(X0, X1) | ~killed(X0, X1)
0.21 0.22 ~hates(charles, X2) | ~hates(agatha, X2)
0.21 0.21 ~hates(X3, charles) | ~hates(X3, butler) | ~hates(X3, agatha)
0.21 0.23 hates(agatha, agatha)
0.21 0.21 hates(agatha, charles)
0.21 0.21 hates(X4, X5) | ~killed(X4, X5)
0.21 0.22 hates(butler, X6) | ~hates(agatha, X6)
0.21 0.21 hates(butler, X7) | richer(X7, agatha) | ~lives(X7)
0.21 0.21 goal_0 | killed(charles, agatha) | killed(butler, agatha)
0.21 0.21 ~goal_0


In [11]:
goal = torch.zeros(24, dtype=torch.long)
for i, token in enumerate(vampire.problem.clauses[-1].tokenize(config, mapping)):
    goal[i] = token

In [12]:
from sortedcontainers import SortedList

from foreduce.vampire.vampire import VampireInteractive

MAX_STEP = 200

goal_embedding = embedding(goal.unsqueeze(0))

with VampireInteractive(VAMPIRE_PATH, 'problems/' + PROBLEM) as interactive:
    seen = 0
    similarities = SortedList()
    premise_count = []
    
    while not interactive.finished and interactive.step_count < MAX_STEP:
        new_clauses = interactive.problem.clauses[seen:]
        if new_clauses:
            tokens = [clause.tokenize(config, mapping) for clause in new_clauses]
            __x = torch.zeros((len(new_clauses), 24), dtype=torch.long)
            for i, clause in enumerate(tokens):
                for j, token in enumerate(clause[:24]):
                    __x[i, j] = clause[j]
            with torch.no_grad():
                sim = torch.nn.functional.cosine_similarity(embedding(__x), goal_embedding, dim=-1)
            
            for i, (s, p) in enumerate(zip(sim, interactive.tree[seen:])):
                premise_count.append(1 + sum(premise_count[idx] for idx in p))
                similarities.add((s.item() / premise_count[-1]**0.5, seen + i))
                
            sim.detach()
            
            seen = len(interactive.problem.clauses)
        
        next_clause = similarities.pop(-1)[1]
        interactive.step(next_clause)

In [13]:
len(interactive.active), similarities[-10:], [(i, count) for i, count in enumerate(premise_count) if i in [sim[1] for sim in similarities[-10:]]]

(29,
 [(0.009308381006121635, 2),
  (0.009402859024703503, 0),
  (0.11529046385817768, 28),
  (0.12591255692398315, 19),
  (0.13941754951860794, 18),
  (0.1428321984179353, 16),
  (0.15166795635407074, 21)],
 [(0, 1), (2, 1), (16, 3), (18, 3), (19, 5), (21, 7), (28, 11)])

In [14]:
from itertools import chain

dependencies = []
for i, p in enumerate(vampire.tree):
    dependencies.append([i] + list(chain(*[dependencies[idx] for idx in p])))

for i in dependencies[-1]:
    print(f"{torch.nn.functional.cosine_similarity(_x[i], goal_embedding, dim=-1).item() / len(dependencies[i]) ** 0.5:.2f}", vampire.problem.clauses[i])

0.20 $false
0.18 ~killed(butler, agatha)
0.19 richer(butler, agatha)
0.20 richer(butler, agatha) | ~lives(butler)
0.20 ~hates(butler, butler)
0.21 ~hates(butler, butler) | ~hates(butler, agatha)
0.21 ~hates(X3, charles) | ~hates(X3, butler) | ~hates(X3, agatha)
0.21 hates(butler, charles)
0.22 hates(butler, X6) | ~hates(agatha, X6)
0.21 hates(agatha, charles)
0.22 hates(butler, agatha)
0.22 hates(butler, X6) | ~hates(agatha, X6)
0.23 hates(agatha, agatha)
0.21 hates(butler, X7) | richer(X7, agatha) | ~lives(X7)
0.21 lives(butler)
0.20 ~richer(X0, X1) | ~killed(X0, X1)
0.18 killed(butler, agatha)
0.21 killed(butler, agatha) | ~hates(agatha, agatha)
0.21 killed(butler, agatha) | hates(charles, agatha)
0.21 killed(charles, agatha) | killed(butler, agatha)
0.21 goal_0 | killed(charles, agatha) | killed(butler, agatha)
0.21 ~goal_0
0.21 hates(X4, X5) | ~killed(X4, X5)
0.22 ~hates(charles, X2) | ~hates(agatha, X2)
0.23 hates(agatha, agatha)


In [15]:
interactive

0: lives(agatha)
[1m 1: lives(butler)[0m
2: lives(charles)
[1m 3: ~richer(X0, X1) | ~killed(X0, X1)[0m
[1m 4: ~hates(charles, X2) | ~hates(agatha, X2)[0m
[1m 5: ~hates(X3, charles) | ~hates(X3, butler) | ~hates(X3, agatha)[0m
[1m 6: hates(agatha, agatha)[0m
[1m 7: hates(agatha, charles)[0m
[1m 8: hates(X4, X5) | ~killed(X4, X5)[0m
[1m 9: hates(butler, X6) | ~hates(agatha, X6)[0m
[1m 10: hates(butler, X7) | richer(X7, agatha) | ~lives(X7)[0m
[1m 11: goal_0 | killed(charles, agatha) | killed(butler, agatha)[0m
[1m 12: ~goal_0[0m
[1m 13: hates(butler, agatha)[0m
[1m 14: hates(butler, charles)[0m
[1m 15: killed(charles, agatha) | killed(butler, agatha)[0m
16: ~hates(agatha, butler) | ~hates(agatha, agatha)
[1m 17: ~hates(butler, butler) | ~hates(butler, agatha)[0m
18: ~hates(butler, butler) | ~hates(butler, agatha) | richer(charles, agatha) | ~lives(charles)
19: ~hates(agatha, butler)
[1m 20: ~hates(butler, butler)[0m
21: ~hates(butler, butler) | richer(char

In [16]:
print(interactive.proof)

% Running in auto input_syntax mode. Trying TPTP
[SA] new: 1. lives(agatha) [input]
[SA] new: 2. lives(butler) [input]
[SA] new: 3. lives(charles) [input]
[SA] new: 4. ~richer(X0,X1) | ~killed(X0,X1) [input]
[SA] new: 5. ~hates(charles,X2) | ~hates(agatha,X2) [input]
[SA] new: 6. ~hates(X3,charles) | ~hates(X3,butler) | ~hates(X3,agatha) [input]
[SA] new: 7. hates(agatha,agatha) [input]
[SA] new: 8. hates(agatha,charles) [input]
[SA] new: 9. hates(X4,X5) | ~killed(X4,X5) [input]
[SA] new: 10. hates(butler,X6) | ~hates(agatha,X6) [input]
[SA] new: 11. hates(butler,X7) | richer(X7,agatha) | ~lives(X7) [input]
[SA] new: 12. goal_0 | killed(charles,agatha) | killed(butler,agatha) [input]
[SA] new: 13. ~goal_0 [input]
[SA] new: 14. hates(butler,agatha) [resolution 10,7]
[SA] new: 15. hates(butler,charles) [resolution 8,10]
[SA] new: 16. killed(charles,agatha) | killed(butler,agatha) [subsumption resolution 12,13]
[SA] new: 17. ~hates(agatha,butler) | ~hates(agatha,agatha) [resolution 6,8]
[

In [17]:
print(vampire.proof)

% Running in auto input_syntax mode. Trying TPTP
[SA] new: 1. lives(agatha) [input]
[SA] new: 2. lives(butler) [input]
[SA] new: 3. lives(charles) [input]
[SA] new: 4. ~richer(X0,X1) | ~killed(X0,X1) [input]
[SA] new: 5. ~hates(charles,X2) | ~hates(agatha,X2) [input]
[SA] new: 6. ~hates(X3,charles) | ~hates(X3,butler) | ~hates(X3,agatha) [input]
[SA] new: 7. hates(agatha,agatha) [input]
[SA] new: 8. hates(agatha,charles) [input]
[SA] new: 9. hates(X4,X5) | ~killed(X4,X5) [input]
[SA] new: 10. hates(butler,X6) | ~hates(agatha,X6) [input]
[SA] new: 11. hates(butler,X7) | richer(X7,agatha) | ~lives(X7) [input]
[SA] new: 12. goal_0 | killed(charles,agatha) | killed(butler,agatha) [input]
[SA] new: 13. ~goal_0 [input]
[SA] new: 14. hates(butler,agatha) [resolution 10,7]
[SA] new: 15. hates(butler,charles) [resolution 10,8]
[SA] new: 16. killed(charles,agatha) | killed(butler,agatha) [subsumption resolution 12,13]
[SA] new: 17. killed(butler,agatha) | hates(charles,agatha) [resolution 16,9]
