# **1. Setup.**

In [None]:
%pip install torch
%pip install torchtext
%pip install tensorboard
%pip install pytorch-lightning

In [4]:
from argparse import ArgumentParser

from pytorch_lightning import Trainer, seed_everything

from module import TPN2FModule
from tpn2f.model import TPN2FModel

max_epochs = 60
data_path = '.data'
dataset = 'MathQA'
num_workers = 0

checkpoint_path = 'lightning_logs/version_0/checkpoints/epoch=0.ckpt'
formula_limit = 10

# **2. Train.**

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/

In [None]:
seed_everything(42)

parser = ArgumentParser()
parser = TPN2FModel.add_model_specific_args(parser)
parser = Trainer.add_argparse_args(parser)
parser.set_defaults(max_epochs=max_epochs, learning_rate=0.00115, num_workers=num_workers,
                    data_path=data_path, dataset=dataset)

hparams = parser.parse_args()
trainer = Trainer.from_argparse_args(hparams, deterministic=True)
model = TPN2FModule(hparams)
trainer.fit(model)

# **3. Infer.**

In [None]:
problem = 'On the coast there are 3 lighthouses . The first light shines for 3 seconds then goes off for 3 seconds . The second light shines for 4 seconds then goes off for 4 seconds . The third light shines for 5 seconds then goes off for 5 seconds . All three lights have just come on together . When is the first time all three lights will be off ?' #@param {type:"string"}
problem = model.dataset.encode_problem(problem)

tpn2f = TPN2FModule.load_from_checkpoint(checkpoint_path)
tpn2f.eval()
formula = tpn2f(problem, formula_limit)
relation_tuples = model.dataset.translate_formula(tpn2f.decode_formula(formula))
print('|'.join([f"{relation}({','.join(args)})" for relation, *args in relation_tuples]))