In [14]:
from nesymres.architectures.model import Model
from nesymres.utils import load_metadata_hdf5
from nesymres.dclasses import FitParams, NNEquation, BFGSParams
from pathlib import Path
from functools import partial
import torch
from sympy import lambdify
import json
import omegaconf
from typing import Literal, Callable, Tuple, Union
import os
from dataclasses import dataclass

device = "cpu" # NOTE: change to cuda if your GPU can handle it

In [15]:
RES_DIR = "../res/"

# load model config
with open(os.path.join(RES_DIR, "100m_eq_cfg.json"), "r") as json_file:
  eq_setting = json.load(json_file)

cfg = omegaconf.OmegaConf.load(os.path.join(RES_DIR, "100m_cfg.yaml"))

In [16]:
bfgs = BFGSParams(
    activated= cfg.inference.bfgs.activated,
    n_restarts=cfg.inference.bfgs.n_restarts,
    add_coefficients_if_not_existing=cfg.inference.bfgs.add_coefficients_if_not_existing,
    normalization_o=cfg.inference.bfgs.normalization_o,
    idx_remove=cfg.inference.bfgs.idx_remove,
    normalization_type=cfg.inference.bfgs.normalization_type,
    stop_time=cfg.inference.bfgs.stop_time,
)

params_fit = FitParams(word2id=eq_setting["word2id"], 
    id2word={int(k): v for k,v in eq_setting["id2word"].items()}, 
    una_ops=eq_setting["una_ops"], 
    bin_ops=eq_setting["bin_ops"], 
    total_variables=list(eq_setting["total_variables"]),  
    total_coefficients=list(eq_setting["total_coefficients"]),
    rewrite_functions=list(eq_setting["rewrite_functions"]),
    bfgs=bfgs,
    beam_size=cfg.inference.beam_size #This parameter is a tradeoff between accuracy and fitting time
)

# load model
model = Model.load_from_checkpoint(os.path.join(RES_DIR, "100m.ckpt"), cfg=cfg.architecture).to(device)
model.eval()

fitfunc = partial(model.fitfunc, cfg_params=params_fit)

/home/morris/miniconda3/envs/symreg/lib/python3.9/site-packages/pytorch_lightning/utilities/migration/migration.py:208: You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys that would end up colliding with each other after an upgrade, which means we can't differentiate which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint` callbacks will not be able to reload the state.
Lightning automatically upgraded your loaded checkpoint from v1.3.3 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../res/100m.ckpt`


In [None]:
@dataclass
class HookPoint:
    layer: int
    # Which part of the decoder layer to hook into. Can either be the MLP, a self-attention or cross-attention head.
    component: Union[Literal["mlp"], Tuple[Literal["self", "cross"], int]]

def register_decoder_hook(model: Model, hook_fn: Callable, hook: HookPoint) -> torch.utils.hooks.RemovableHandle:
    """
    Hooks a function into the decoder part of the model. This allows for reading or manulipulating the output of a specific component.

    NOTE: To remove the hook, the returned `RemovableHandle` must be called with `.remove()`.

    # Args
    * `model`: The model to hook into.
    * `hook_fn`: Callable that takes the output of the hooked component as a`torch.Tensor` and the hooked location as a `HookPoint`.
        The function should return an updated output.
    * `hook`: Description of the component to hook into.
    """

    def hook_wrapper(_module, _input, output):
        if hook.component == "mlp":
            output[0] = hook_fn(output[0], hook)
        elif hook.component[0] == "self" or hook.component[0] == "cross":
            head_idx = hook.component[1]

            # multihead and self-attention layer have same number of heads
            num_head = model.decoder_transfomer.layers[hook.layer].multihead_attn.num_heads
            
            # view data in terms of [seq_len x batch_size x num_head x head_dim])
            seq_len, bsz, _ = output[0].size()
            output_heads = output[0].view(seq_len, bsz, num_head, -1)

            # hook output of specified head
            output_heads[:, :, head_idx, :] = hook_fn(output_heads[:, :, head_idx, :], hook)
        else:
            raise ValueError(f"Unknown hook component: {hook.component}")
        
        return output
    
    if hook.component == "mlp":
        # hook into 2nd linear layer of MLP
        return model.decoder_transfomer.layers[hook.layer].linear2.register_forward_hook(hook_wrapper)
    elif hook.component[0] == "self":
        # hook into self-attention layer
        return model.decoder_transfomer.layers[hook.layer].self_attn.register_forward_hook(hook_wrapper)
    elif hook.component[0] == "cross":
        # hook into cross-attention layer
        return model.decoder_transfomer.layers[hook.layer].multihead_attn.register_forward_hook(hook_wrapper)

def test_hook(output, _hook: HookPoint):
    return torch.randn_like(output)

# For this intervention etst, we set all decoder MLP outputs to random values.
#  As you'll see below, the model won't be able to fit the correct equation (:omg:).
for layer in range(4):
    register_decoder_hook(model, test_hook, HookPoint(layer, "mlp"))

In [18]:
# create dummy data
number_of_points = 500
n_variables = 1

# to get best results make sure that your support inside the max and mix support
max_supp = cfg.dataset_train.fun_support["max"] 
min_supp = cfg.dataset_train.fun_support["min"]
X = torch.rand(number_of_points,len(list(eq_setting["total_variables"])))*(max_supp-min_supp)+min_supp
X[:,n_variables:] = 0
#target_eq = "x_1*sin(x_1)" #Use x_1,x_2 and x_3 as independent variables
target_eq = "exp(cos(x_1))" #Use x_1,x_2 and x_3 as independent variables
X_dict = {x:X[:,idx].cpu() for idx, x in enumerate(eq_setting["total_variables"])} 
y = lambdify(",".join(eq_setting["total_variables"]), target_eq)(**X_dict)

print("X shape: ", X.shape)
print("y shape: ", y.shape)

X shape:  torch.Size([500, 3])
y shape:  torch.Size([500])


In [19]:
# fit model
output = fitfunc(X, y) 

Memory footprint of the encoder: 4.096e-05GB 



  X = torch.tensor(X,device=self.device).unsqueeze(0)
  y = torch.tensor(y,device=self.device).unsqueeze(0)


Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...


  final_loss = np.mean(np.square(y_found-y.cpu()).numpy())


All-Nan slice encountered
Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...


  return 0.0147309266766406*(0.367878993469832*(-0.0565576553344727)**(c0**((0.003198768376933*c1 + 1)**c2)) - 1)**2 + 0.014540075415087*(0.367867253829069*(-0.127516746520996)**(c0**((0.0162605206433*c1 + 1)**c2)) - 1)**2 + 0.0143659278663763*(0.367842271068122*(-0.168389320373535)**(c0**((0.0283549632158611*c1 + 1)**c2)) - 1)**2 + 0.0139982582818476*(0.36774180679865*(-0.233368873596191)**(c0**((0.0544610311635552*c1 + 1)**c2)) - 1)**2 + 0.0138811305323402*(0.367695354064706*(-0.250890731811523)**(c0**((0.0629461593089218*c1 + 1)**c2)) - 1)**2 + 0.0134043654028375*(0.367427214007051*(-0.313640594482422)**(c0**((0.098370422507287*c1 + 1)**c2)) - 1)**2 + 0.0133570260984125*(0.367393234237256*(-0.319324493408203)**(c0**((0.101968132090406*c1 + 1)**c2)) - 1)**2 + 0.0127964100915989*(0.366879112307059*(-0.38176441192627)**(c0**((0.14574406621341*c1 + 1)**c2)) - 1)**2 + 0.0125060589501346*(0.366524359963729*(-0.411472320556641)**(c0**((0.169309470584267*c1 + 1)**c2)) - 1)**2 + 0.0123687725

All-Nan slice encountered


  return 0.0147309266766406*(0.367878993469832*(-0.0565576553344727)**(c0**((0.003198768376933*c1 + 1)**c2)) - 1)**2 + 0.014540075415087*(0.367867253829069*(-0.127516746520996)**(c0**((0.0162605206433*c1 + 1)**c2)) - 1)**2 + 0.0143659278663763*(0.367842271068122*(-0.168389320373535)**(c0**((0.0283549632158611*c1 + 1)**c2)) - 1)**2 + 0.0139982582818476*(0.36774180679865*(-0.233368873596191)**(c0**((0.0544610311635552*c1 + 1)**c2)) - 1)**2 + 0.0138811305323402*(0.367695354064706*(-0.250890731811523)**(c0**((0.0629461593089218*c1 + 1)**c2)) - 1)**2 + 0.0134043654028375*(0.367427214007051*(-0.313640594482422)**(c0**((0.098370422507287*c1 + 1)**c2)) - 1)**2 + 0.0133570260984125*(0.367393234237256*(-0.319324493408203)**(c0**((0.101968132090406*c1 + 1)**c2)) - 1)**2 + 0.0127964100915989*(0.366879112307059*(-0.38176441192627)**(c0**((0.14574406621341*c1 + 1)**c2)) - 1)**2 + 0.0125060589501346*(0.366524359963729*(-0.411472320556641)**(c0**((0.169309470584267*c1 + 1)**c2)) - 1)**2 + 0.0123687725

In [20]:
output

{'all_bfgs_preds': ['(((x_1)**(((x_1)**(-2))**((-3)/(4))))*(cos(cos(x_1))))',
  'x_1**(0.31891064740011**((0.0402562403114958*x_1**2 + 1)**0.984259719019639))*cos(x_1)'],
 'all_bfgs_loss': nan,
 'best_bfgs_preds': [],
 'best_bfgs_loss': None}