In [28]:
from model import load_model, greedy_predict, tokens_to_text
from hooks import HookPoint, register_decoder_hook, register_encoder_hook
from data import generate_dataset_pairs
import torch
import sympy

device = "cpu" # NOTE: change to cuda if your GPU can handle it

In [29]:
model = load_model(device=device)

/home/morris/miniconda3/envs/symreg/lib/python3.9/site-packages/pytorch_lightning/utilities/migration/migration.py:208: You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys that would end up colliding with each other after an upgrade, which means we can't differentiate which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint` callbacks will not be able to reload the state.
Lightning automatically upgraded your loaded checkpoint from v1.3.3 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../res/100m.ckpt`


In [30]:
# If we uncomment the code below, we set all decoder MLP outputs to random values using interventions.
#  As you'll see below, the model won't be able to fit the correct equation (:omg:).

"""
def test_hook(output, _hook: HookPoint):
    return torch.randn_like(output)

for layer in range(4):
    register_decoder_hook(model.model, test_hook, HookPoint(layer, "mlp"))
"""


'\ndef test_hook(output, _hook: HookPoint):\n    return torch.randn_like(output)\n\nfor layer in range(4):\n    register_decoder_hook(model.model, test_hook, HookPoint(layer, "mlp"))\n'

In [None]:
def test_hook(output, _hook: HookPoint):
    return torch.zeros_like(output)

# ablating all these model components with zeros will lead to a very incorrect output
for i in range(7):
    register_encoder_hook(model.model, test_hook, HookPoint(i, "mlp"), model.model_cfg)

    for j in range(6):
        register_encoder_hook(model.model, test_hook, HookPoint(j, ("self", j)), model.model_cfg)

In [32]:
complexity_dataset = generate_dataset_pairs("complexity-bias", 500, 10, model.model_cfg, model.eq_cfg, second_dataset_sample_rate=2)

X = complexity_dataset["X0"][0]
y = complexity_dataset["y0"][0]

print("Ground truth function")
sympy.sympify(complexity_dataset["equations"][0][0])

Ground truth function


x_1 - x_3/x_2

In [33]:
# initial token prediction, this initializes the sequence and caches the encoder embedding (saves computation time).
_, seq, enc_embed = greedy_predict(model.model, model.params_fit, X.unsqueeze(0), y.unsqueeze(0))

# repeatedly predict next token greedily
for _ in range(30):
    seq = greedy_predict(model.model, model.params_fit, enc_embed=enc_embed, sequence=seq)[1]

# this should result in (roughly) the correct equation
greedy_pred = tokens_to_text(seq, model.params_fit)

print("Greedy predicted equation:")
for eq in greedy_pred:
    try:
        display(sympy.sympify(eq))
    except:
        print(f"Error parsing expression: {eq}")



Greedy predicted equation:


c*x_1 + (c + x_1)*exp(-cos(c*x_2))

In [34]:
# fit model with beam search instead of greedy + constant fitting (takes a lot longer)
output = model.fitfunc(X, y) 
# here you can see the fitted equations
output

Memory footprint of the encoder: 4.096e-05GB 



  X = torch.tensor(X,device=self.device).unsqueeze(0)
  y = torch.tensor(y,device=self.device).unsqueeze(0)


Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...
Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...


{'all_bfgs_preds': ['(0.0543208490998112*x_1 + x_2/(x_1 + 2.8034590773846))*(x_1 + 2.73562357072212)',
  '(-0.00653201202497072*x_1 + (x_1 - x_2)/x_1)*(x_1 - 0.0374967344092601)'],
 'all_bfgs_loss': [102.47861, 79.118225],
 'best_bfgs_preds': ['(-0.00653201202497072*x_1 + (x_1 - x_2)/x_1)*(x_1 - 0.0374967344092601)'],
 'best_bfgs_loss': [79.118225]}

In [35]:
print("Best BFGS prediction:")
sympy.sympify(output["best_bfgs_preds"][0])

Best BFGS prediction:


(-0.00653201202497072*x_1 + (x_1 - x_2)/x_1)*(x_1 - 0.0374967344092601)