In [1]:
VAMPIRE_PATH = '/home/apluska/.vampire/bin/vampire_z3_rel_static_casc2023_6749'
TPTP_PATH = '/home/apluska/TPTP-v8.2.0/'

In [2]:
PROBLEM = 'PUZ001-1.p'

In [3]:
import os

from foreduce.tptp.parser import read_file as read_tptp

symbols = set()

problem = read_tptp(TPTP_PATH + 'Problems/PUZ/' + PROBLEM, include_path=TPTP_PATH, max_size=10_000)
symbols.update(problem.function_symbols() | problem.predicate_symbols())

if not os.path.exists('problems/' + PROBLEM):
    with open('problems/' + PROBLEM, 'w') as f:
        f.write(problem.to_tptp())

In [4]:
from collections import defaultdict

from foreduce.transformer.tokenizer import TokenConfig

arity_dict = defaultdict(list)

for s in symbols:
    arity_dict[s.arity].append(s.name)

arity_list = list(arity_dict.values())
config = TokenConfig(num_functions=[len(s) for s in arity_list])
function_mapping = config.random_function_mapping(arity_list)


mapping = config.reserved_token_mapping | config.random_function_mapping(arity_list)

In [5]:
from tqdm.auto import tqdm

from foreduce.vampire.vampire import VampireAutomatic
from foreduce.data.data import ProofTokens


dataset = ProofTokens(config, seq_len=24)
for _ in tqdm(range(1)):
    vampire = VampireAutomatic(VAMPIRE_PATH, 'problems/' + PROBLEM, selection='10')
    vampire.run()
    if 'Refutation found' in vampire.proof:    
        dataset.add_proof(vampire.problem, vampire.tree, mapping)

  from .autonotebook import tqdm as notebook_tqdm
  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  2.68it/s]


In [9]:
from foreduce.transformer.embedding import FormulaEmbedding
from torch.utils.data import DataLoader

embedding = FormulaEmbedding(config, seq_len=32, dim=2048, n_layers=24, n_heads=16)
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

In [10]:
from lightning import Trainer
import wandb
from lightning.pytorch.loggers import WandbLogger
import torch

torch.set_float32_matmul_precision('medium')

wandb.init(project='foreduce')

trainer = Trainer(max_epochs=32, logger=WandbLogger(), accumulate_grad_batches=256, log_every_n_steps=1)

trainer.fit(embedding, data_loader)

wandb.finish()

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlexpk[0m. Use [1m`wandb login --relogin`[0m to force relogin


Trainer will use only 1 of 4 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=4)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/apluska/miniconda3/envs/foreduce/lib/python3.12/site-packages/lightning/pytorch/trainer/configuration_validator.py:70: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.
/home/apluska/miniconda3/envs/foreduce/lib/python3.12/site-packages/lightning/pytorch/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name       | Type       | Params | Mode 
-------

Epoch 31: 100%|██████████| 17/17 [00:05<00:00,  3.10it/s, v_num=m6ex, train_loss_step=0.00356, train_loss_epoch=0.00392]

`Trainer.fit` stopped: `max_epochs=32` reached.


Epoch 31: 100%|██████████| 17/17 [00:41<00:00,  0.41it/s, v_num=m6ex, train_loss_step=0.00356, train_loss_epoch=0.00392]


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train_loss_epoch,▅▄▅▃▅▇▇█▆▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁
train_loss_step,▄▆▄▃▄▅▄█▄▅▄▃▃▂▁▃▂▃▂▄▂▁▄▁▂▃▄▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████

0,1
epoch,31.0
train_loss_epoch,0.00392
train_loss_step,0.00356
trainer/global_step,31.0


In [11]:
embedding = embedding.cpu().eval()
__x = torch.zeros((len(vampire.problem.clauses), 32), dtype=torch.long)
dependencies = [set() for _ in vampire.problem.clauses]
for idx in range(len(vampire.problem.clauses)):
    if vampire.tree[idx]:
        dependencies[idx] = {idx} | set.union(*[dependencies[j] for j in vampire.tree[idx]])
    else:
        dependencies[idx] = {idx}
target = torch.zeros(len(vampire.problem.clauses), dtype=torch.float)

for i, clause in enumerate(vampire.problem.clauses):
    tokens = vampire.problem.clauses[i].tokenize(config, mapping)
    for j, token in enumerate(tokens[:32]):
        __x[i, j] = tokens[j]

_tokens, indices = torch.unique(__x, return_inverse=True, dim=0)
index_mapping = {idx : [] for idx in range(indices.max().item() + 1)}
for i, idx in enumerate(indices):
    index_mapping[idx.item()].append(i)

for i in range(len(vampire.problem.clauses)):
    idx, _ = min(
        (
            (idx, len(dependencies[idx] | dependencies[-1]) + 1 / (1 + len(dependencies[idx] & dependencies[-1]))) for idx in index_mapping[indices[i].item()]
        ), key=lambda x: x[1]
    )
    target[i] = len(dependencies[idx] & dependencies[-1]) / (len(dependencies[idx]) * len(dependencies[-1]))**0.5
        
_x = embedding(__x)
similarities = sorted([
    (
        i,
        target[i],
        torch.nn.functional.cosine_similarity(_x[i], _x[-1], dim=0).item()
    )
    for i in range(len(vampire.problem.clauses))
], key=lambda x: x[1], reverse=True)

for i, target, sim in similarities:
    print(f"{target:.2f}", f"{sim:.2f}", vampire.problem.clauses[i])

1.00 1.00 $false
0.78 0.84 ~killed(butler, agatha)
0.72 0.64 richer(butler, agatha)
0.66 0.64 richer(butler, agatha) | ~lives(butler)
0.63 0.56 killed(butler, agatha)
0.59 0.30 ~hates(butler, butler)
0.55 0.66 killed(butler, agatha) | ~hates(agatha, agatha)
0.47 0.29 killed(butler, agatha) | hates(charles, agatha)
0.47 0.24 ~hates(butler, butler) | ~hates(butler, agatha)
0.39 0.38 ~hates(butler, butler) | richer(charles, agatha) | ~lives(charles)
0.39 0.72 richer(butler, agatha) | ~killed(charles, agatha)
0.36 0.40 hates(butler, agatha)
0.36 -0.06 hates(butler, charles)
0.36 0.41 killed(charles, agatha) | killed(butler, agatha)
0.36 0.40 hates(butler, agatha)
0.36 0.64 richer(charles, agatha) | richer(butler, agatha)
0.35 0.48 ~hates(butler, butler) | richer(charles, agatha)
0.33 0.58 richer(charles, agatha) | richer(butler, agatha) | ~lives(butler)
0.28 0.35 ~hates(agatha, butler)
0.24 0.24 ~hates(agatha, butler) | ~hates(agatha, agatha)
0.24 0.26 ~hates(butler, butler) | ~hates(butle

In [12]:
for i, target, sim in similarities:
    if i in dependencies[-1]:
        print(f"{target/len(dependencies[i])**0.5:.2f}", f"{sim/len(dependencies[i])**0.5:.2f}", vampire.problem.clauses[i])

0.21 0.21 $false
0.21 0.23 ~killed(butler, agatha)
0.21 0.19 richer(butler, agatha)
0.21 0.20 richer(butler, agatha) | ~lives(butler)
0.21 0.19 killed(butler, agatha)
0.21 0.11 ~hates(butler, butler)
0.21 0.25 killed(butler, agatha) | ~hates(agatha, agatha)
0.21 0.13 killed(butler, agatha) | hates(charles, agatha)
0.21 0.11 ~hates(butler, butler) | ~hates(butler, agatha)
0.21 0.23 hates(butler, agatha)
0.21 -0.04 hates(butler, charles)
0.21 0.24 killed(charles, agatha) | killed(butler, agatha)
0.21 0.13 lives(butler)
0.21 0.51 ~richer(X0, X1) | ~killed(X0, X1)
0.21 0.30 ~hates(charles, X2) | ~hates(agatha, X2)
0.21 0.27 ~hates(X3, charles) | ~hates(X3, butler) | ~hates(X3, agatha)
0.21 0.53 hates(agatha, agatha)
0.21 0.26 hates(agatha, charles)
0.21 0.35 hates(X4, X5) | ~killed(X4, X5)
0.21 0.32 hates(butler, X6) | ~hates(agatha, X6)
0.21 0.34 hates(butler, X7) | richer(X7, agatha) | ~lives(X7)
0.21 0.33 goal_0 | killed(charles, agatha) | killed(butler, agatha)
0.21 0.75 ~goal_0


In [13]:
goal = torch.zeros(32, dtype=torch.long)
for i, token in enumerate(vampire.problem.clauses[-1].tokenize(config, mapping)):
    goal[i] = token

In [23]:
from sortedcontainers import SortedList

from foreduce.vampire.vampire import VampireInteractive

MAX_STEP = 100

goal_embedding = embedding(goal.unsqueeze(0))

with VampireInteractive(VAMPIRE_PATH, 'problems/' + PROBLEM) as interactive:
    seen = 0
    similarities = SortedList()
    premise_count = []
    
    while not interactive.finished and interactive.step_count < MAX_STEP:
        new_clauses = interactive.problem.clauses[seen:]
        if new_clauses:
            tokens = [clause.tokenize(config, mapping) for clause in new_clauses]
            __x = torch.zeros((len(new_clauses), 24), dtype=torch.long)
            for i, clause in enumerate(tokens):
                for j, token in enumerate(clause[:24]):
                    __x[i, j] = clause[j]
            with torch.no_grad():
                sim = torch.nn.functional.cosine_similarity(embedding(__x), goal_embedding, dim=-1)
            
            for i, (s, p) in enumerate(zip(sim, interactive.tree[seen:])):
                premise_count.append(1 + sum(premise_count[idx] for idx in p))
                similarities.add((s.item() / premise_count[-1]**0.5, seen + i))
                
            sim.detach()
            
            seen = len(interactive.problem.clauses)
        
        next_clause = similarities.pop(-1)[1]
        interactive.step(next_clause)

In [15]:
len(interactive.active), similarities[-10:], [(i, count) for i, count in enumerate(premise_count) if i in [sim[1] for sim in similarities[-10:]]]

(32, [(0.09212842647629772, 28)], [(28, 5)])

In [16]:
from itertools import chain

dependencies = []
for i, p in enumerate(vampire.tree):
    dependencies.append([i] + list(chain(*[dependencies[idx] for idx in p])))

for i in dependencies[-1]:
    print(f"{torch.nn.functional.cosine_similarity(_x[i], goal_embedding, dim=-1).item() / len(dependencies[i]) ** 0.5:.2f}", vampire.problem.clauses[i])

0.20 $false
0.22 ~killed(butler, agatha)
0.18 richer(butler, agatha)
0.19 richer(butler, agatha) | ~lives(butler)
0.10 ~hates(butler, butler)
0.11 ~hates(butler, butler) | ~hates(butler, agatha)
0.27 ~hates(X3, charles) | ~hates(X3, butler) | ~hates(X3, agatha)
-0.04 hates(butler, charles)
0.32 hates(butler, X6) | ~hates(agatha, X6)
0.26 hates(agatha, charles)
0.23 hates(butler, agatha)
0.32 hates(butler, X6) | ~hates(agatha, X6)
0.53 hates(agatha, agatha)
0.34 hates(butler, X7) | richer(X7, agatha) | ~lives(X7)
0.13 lives(butler)
0.51 ~richer(X0, X1) | ~killed(X0, X1)
0.19 killed(butler, agatha)
0.25 killed(butler, agatha) | ~hates(agatha, agatha)
0.13 killed(butler, agatha) | hates(charles, agatha)
0.24 killed(charles, agatha) | killed(butler, agatha)
0.33 goal_0 | killed(charles, agatha) | killed(butler, agatha)
0.75 ~goal_0
0.35 hates(X4, X5) | ~killed(X4, X5)
0.30 ~hates(charles, X2) | ~hates(agatha, X2)
0.53 hates(agatha, agatha)


In [24]:
print(interactive.proof)

% Running in auto input_syntax mode. Trying TPTP
[SA] new: 1. lives(agatha) [input]
[SA] new: 2. lives(butler) [input]
[SA] new: 3. lives(charles) [input]
[SA] new: 4. ~richer(X0,X1) | ~killed(X0,X1) [input]
[SA] new: 5. ~hates(charles,X2) | ~hates(agatha,X2) [input]
[SA] new: 6. ~hates(X3,charles) | ~hates(X3,butler) | ~hates(X3,agatha) [input]
[SA] new: 7. hates(agatha,agatha) [input]
[SA] new: 8. hates(agatha,charles) [input]
[SA] new: 9. hates(X4,X5) | ~killed(X4,X5) [input]
[SA] new: 10. hates(butler,X6) | ~hates(agatha,X6) [input]
[SA] new: 11. hates(butler,X7) | richer(X7,agatha) | ~lives(X7) [input]
[SA] new: 12. goal_0 | killed(charles,agatha) | killed(butler,agatha) [input]
[SA] new: 13. ~goal_0 [input]
[SA] new: 14. killed(charles,agatha) | killed(butler,agatha) [subsumption resolution 12,13]
[SA] new: 15. hates(butler,agatha) [resolution 10,7]
[SA] new: 16. hates(butler,charles) [resolution 8,10]
[SA] new: 17. killed(butler,agatha) | hates(charles,agatha) [resolution 14,9]
