In [92]:
from nesymres.architectures.model import Model
from nesymres.utils import load_metadata_hdf5
from nesymres.dclasses import FitParams, NNEquation, BFGSParams
from pathlib import Path
from functools import partial
import torch
import torch.nn.functional as F
from sympy import lambdify
import json
import omegaconf
from typing import Literal, Callable, Tuple, Union
import os
from dataclasses import dataclass

device = "cpu" # NOTE: change to cuda if your GPU can handle it

In [93]:
RES_DIR = "../res/"

# load model config
with open(os.path.join(RES_DIR, "100m_eq_cfg.json"), "r") as json_file:
  eq_setting = json.load(json_file)

cfg = omegaconf.OmegaConf.load(os.path.join(RES_DIR, "100m_cfg.yaml"))

In [94]:
bfgs = BFGSParams(
    activated= cfg.inference.bfgs.activated,
    n_restarts=cfg.inference.bfgs.n_restarts,
    add_coefficients_if_not_existing=cfg.inference.bfgs.add_coefficients_if_not_existing,
    normalization_o=cfg.inference.bfgs.normalization_o,
    idx_remove=cfg.inference.bfgs.idx_remove,
    normalization_type=cfg.inference.bfgs.normalization_type,
    stop_time=cfg.inference.bfgs.stop_time,
)

params_fit = FitParams(word2id=eq_setting["word2id"], 
    id2word={int(k): v for k,v in eq_setting["id2word"].items()}, 
    una_ops=eq_setting["una_ops"], 
    bin_ops=eq_setting["bin_ops"], 
    total_variables=list(eq_setting["total_variables"]),  
    total_coefficients=list(eq_setting["total_coefficients"]),
    rewrite_functions=list(eq_setting["rewrite_functions"]),
    bfgs=bfgs,
    beam_size=cfg.inference.beam_size #This parameter is a tradeoff between accuracy and fitting time
)

# load model
model = Model.load_from_checkpoint(os.path.join(RES_DIR, "100m.ckpt"), cfg=cfg.architecture).to(device)
model.eval()

fitfunc = partial(model.fitfunc, cfg_params=params_fit)

/home/morris/miniconda3/envs/symreg/lib/python3.9/site-packages/pytorch_lightning/utilities/migration/migration.py:208: You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys that would end up colliding with each other after an upgrade, which means we can't differentiate which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint` callbacks will not be able to reload the state.
Lightning automatically upgraded your loaded checkpoint from v1.3.3 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../res/100m.ckpt`


In [95]:
@dataclass
class HookPoint:
    layer: int
    # Which part of the decoder layer to hook into. Can either be the MLP, a self-attention or cross-attention head.
    component: Union[Literal["mlp"], Tuple[Literal["self", "cross"], int]]

def register_decoder_hook(model: Model, hook_fn: Callable, hook: HookPoint) -> torch.utils.hooks.RemovableHandle:
    """
    Hooks a function into the decoder part of the model. This allows for reading or manulipulating the output of a specific component.

    NOTE: To remove the hook, the returned `RemovableHandle` must be called with `.remove()`.

    # Args
    * `model`: The model to hook into.
    * `hook_fn`: Callable that takes the output of the hooked component as a`torch.Tensor` and the hooked location as a `HookPoint`.
        The function should return an updated output.
    * `hook`: Description of the component to hook into.
    """

    def hook_wrapper(_module, _input, output):
        if hook.component == "mlp":
            output[0] = hook_fn(output[0], hook)
        elif hook.component[0] == "self" or hook.component[0] == "cross":
            head_idx = hook.component[1]

            # multihead and self-attention layer have same number of heads
            num_head = model.decoder_transfomer.layers[hook.layer].multihead_attn.num_heads
            
            # view data in terms of [seq_len x batch_size x num_head x head_dim])
            seq_len, bsz, _ = output[0].size()
            output_heads = output[0].view(seq_len, bsz, num_head, -1)

            # hook output of specified head
            output_heads[:, :, head_idx, :] = hook_fn(output_heads[:, :, head_idx, :], hook)
        else:
            raise ValueError(f"Unknown hook component: {hook.component}")
        
        return output
    
    if hook.component == "mlp":
        # hook into 2nd linear layer of MLP
        return model.decoder_transfomer.layers[hook.layer].linear2.register_forward_hook(hook_wrapper)
    elif hook.component[0] == "self":
        # hook into self-attention layer
        return model.decoder_transfomer.layers[hook.layer].self_attn.register_forward_hook(hook_wrapper)
    elif hook.component[0] == "cross":
        # hook into cross-attention layer
        return model.decoder_transfomer.layers[hook.layer].multihead_attn.register_forward_hook(hook_wrapper)



# If we uncomment the code bellow, we set all decoder MLP outputs to random values using interventions.
#  As you'll see below, the model won't be able to fit the correct equation (:omg:).
"""
def test_hook(output, _hook: HookPoint):
    return torch.randn_like(output)
for layer in range(4):
    register_decoder_hook(model, test_hook, HookPoint(layer, "mlp"))
"""

'\ndef test_hook(output, _hook: HookPoint):\n    return torch.randn_like(output)\nfor layer in range(4):\n    register_decoder_hook(model, test_hook, HookPoint(layer, "mlp"))\n'

In [96]:
# create dummy data
number_of_points = 500
n_variables = 1

# to get best results make sure that your support inside the max and mix support
max_supp = cfg.dataset_train.fun_support["max"] 
min_supp = cfg.dataset_train.fun_support["min"]
X = torch.rand(number_of_points,len(list(eq_setting["total_variables"])))*(max_supp-min_supp)+min_supp
X[:,n_variables:] = 0
#target_eq = "x_1*sin(x_1)" #Use x_1,x_2 and x_3 as independent variables
target_eq = "exp(cos(x_1))" #Use x_1,x_2 and x_3 as independent variables
X_dict = {x:X[:,idx].cpu() for idx, x in enumerate(eq_setting["total_variables"])} 
y = lambdify(",".join(eq_setting["total_variables"]), target_eq)(**X_dict)

print("X shape: ", X.shape)
print("y shape: ", y.shape)

X shape:  torch.Size([500, 3])
y shape:  torch.Size([500])


In [97]:
def greedy_predict(
        model: Model, params: FitParams,
        X: torch.Tensor = None, y: torch.Tensor = None, sequence: torch.Tensor = None, enc_embed: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Greedily predicts the next token in the sequence. Can be used in two ways:
    1. If `X` and `y` are provided, the model will use them to compute the encoder embedding and then predict the next token.
    2. If `enc_embed` is provided, the model will use this embedding directly to predict the next token.

    # Args
    * `X`: Function domain of shape [B, N, D], where B is the batch size, N is the number of samples,
        and D the input dimensionality (no more than 3).
    * `y`: Function image corresponding to `X` of shape [B, N].
    * `sequence`: The initial tokens to predict the next token from of shape [B, S], where S is the maximum sequence length.
        All samples in the batch are expected to be at the same current sequence length.
    * `enc_embed`: Encoder embedding. Can be reused if the same `X` and `y` are used repeatedly for prediction.

    # Returns
    A tuple containing:
    * `torch.Tensor`: The predicted token IDs for each sample in the batch.
    * `torch.Tensor`: The sequence of tokens generated so far, which is updated with the predicted token.
    * `torch.Tensor`: The encoder embedding, which may be resued.
    """

    if enc_embed is None:
        # compute encoder embedding
        enc_input = torch.cat((X, y.unsqueeze(-1)), dim=-1).to(model.device)
        enc_embed = model.enc(enc_input)

    batch_size = enc_embed.size(0)

    if sequence is None:
        # initialize sequence with start token
        sequence = torch.zeros((batch_size, model.cfg.length_eq), dtype=torch.long, device=model.device)
        sequence[:, 0] = params.word2id["S"]

    cur_len = (sequence != 0).sum(dim=1).max().item()

    # generate decoder masks
    mask1, mask2 = model.make_trg_mask(
        sequence[:, :cur_len]
    )

    # compute positional embeddings
    pos = model.pos_embedding(
        torch.arange(0, cur_len)
            .unsqueeze(0)
            .repeat(sequence.shape[0], 1)
            .type_as(sequence)
    )

    # embed tokens
    seq_embed = model.tok_embedding(sequence[:, :cur_len])
    seq_embed = model.dropout(seq_embed + pos)

    # run decoder
    output = model.decoder_transfomer(
        seq_embed.permute(1, 0, 2),
        enc_embed.permute(1, 0, 2),
        mask2.float(),
        tgt_key_padding_mask=mask1.bool(),
    )
    output = model.fc_out(output)
    output = output.permute(1, 0, 2).contiguous()

    # add next token
    # NOTE: softmax not really necessary here, but may come in handy later
    token_probs = F.softmax(output[:, -1:, :], dim=-1).squeeze(1)
    next_token = torch.argmax(token_probs, dim=-1)
    sequence[:, cur_len] = next_token

    return next_token, sequence, enc_embed

def tokens_to_text(tokens: torch.Tensor, params: FitParams) -> list[str]:
    """
    Converts a batches of token IDs to their corresponding text representations.

    # Args
    * `tokens`: Of shape [B, S], where B is the batch size and S is the maximum sequence length.
    """

    decoded = []
    for batch in tokens:
        text = []

        for token in batch:
            if token.item() == params.word2id["F"] or token.item() == 0:
                break
            text.append(params.id2word[token.item()])
        decoded.append(" ".join(text[1:])) # skip start token
    return decoded


In [None]:
# initial token prediction, this initializes the sequence and caches the encoder embedding (saves computation time).
tok, seq, enc_embed = greedy_predict(model, params_fit, X.unsqueeze(0), y.unsqueeze(0))

# repeatedly predict next token greedily
for i in range(30):
    seq = greedy_predict(model, params_fit, enc_embed=enc_embed, sequence=seq)[1]

# this should result in (roughly) the correct equation
tokens_to_text(seq, params_fit)

['exp cos x_1']

In [None]:
# fit model with beam search instead of greedy + constant fitting (takes a lot longer)
output = fitfunc(X, y) 

  X = torch.tensor(X,device=self.device).unsqueeze(0)
  y = torch.tensor(y,device=self.device).unsqueeze(0)


Memory footprint of the encoder: 4.096e-05GB 

Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...
Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...


In [None]:
# here you can see the fitted equations
output

{'all_bfgs_preds': ['(exp(cos(x_1)))', 'exp(1.00000003956864*cos(x_1))'],
 'all_bfgs_loss': [0.0, 0.0],
 'best_bfgs_preds': ['(exp(cos(x_1)))'],
 'best_bfgs_loss': [0.0]}