In [118]:
from nesymres.architectures.model import Model
from nesymres.utils import load_metadata_hdf5
from nesymres.dclasses import FitParams, NNEquation, BFGSParams
from pathlib import Path
from functools import partial
import torch
import torch.nn.functional as F
import numpy as np
from sympy import lambdify
import sympy
import json
import omegaconf
from typing import Literal, Callable, Tuple, Union, Dict
import os
from dataclasses import dataclass

device = "cpu" # NOTE: change to cuda if your GPU can handle it

In [119]:
RES_DIR = "../res/"

# load model config
with open(os.path.join(RES_DIR, "100m_eq_cfg.json"), "r") as json_file:
  eq_setting = json.load(json_file)

cfg = omegaconf.OmegaConf.load(os.path.join(RES_DIR, "100m_cfg.yaml"))

In [120]:
bfgs = BFGSParams(
    activated= cfg.inference.bfgs.activated,
    n_restarts=cfg.inference.bfgs.n_restarts,
    add_coefficients_if_not_existing=cfg.inference.bfgs.add_coefficients_if_not_existing,
    normalization_o=cfg.inference.bfgs.normalization_o,
    idx_remove=cfg.inference.bfgs.idx_remove,
    normalization_type=cfg.inference.bfgs.normalization_type,
    stop_time=cfg.inference.bfgs.stop_time,
)

params_fit = FitParams(word2id=eq_setting["word2id"], 
    id2word={int(k): v for k,v in eq_setting["id2word"].items()}, 
    una_ops=eq_setting["una_ops"], 
    bin_ops=eq_setting["bin_ops"], 
    total_variables=list(eq_setting["total_variables"]),  
    total_coefficients=list(eq_setting["total_coefficients"]),
    rewrite_functions=list(eq_setting["rewrite_functions"]),
    bfgs=bfgs,
    beam_size=cfg.inference.beam_size #This parameter is a tradeoff between accuracy and fitting time
)

# load model
model = Model.load_from_checkpoint(os.path.join(RES_DIR, "100m.ckpt"), cfg=cfg.architecture).to(device)
model.eval()

fitfunc = partial(model.fitfunc, cfg_params=params_fit)

/home/morris/miniconda3/envs/symreg/lib/python3.9/site-packages/pytorch_lightning/utilities/migration/migration.py:208: You have multiple `ModelCheckpoint` callback states in this checkpoint, but we found state keys that would end up colliding with each other after an upgrade, which means we can't differentiate which of your checkpoint callbacks needs which states. At least one of your `ModelCheckpoint` callbacks will not be able to reload the state.
Lightning automatically upgraded your loaded checkpoint from v1.3.3 to v2.5.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint ../res/100m.ckpt`


In [121]:
@dataclass
class HookPoint:
    layer: int
    # Which part of the decoder layer to hook into. Can either be the MLP, a self-attention or cross-attention head.
    component: Union[Literal["mlp"], Tuple[Literal["self", "cross"], int]]

def register_decoder_hook(model: Model, hook_fn: Callable, hook: HookPoint) -> torch.utils.hooks.RemovableHandle:
    """
    Hooks a function into the decoder part of the model. This allows for reading or manulipulating the output of a specific component.

    NOTE: To remove the hook, the returned `RemovableHandle` must be called with `.remove()`.

    # Args
    * `model`: The model to hook into.
    * `hook_fn`: Callable that takes the output of the hooked component as a`torch.Tensor` and the hooked location as a `HookPoint`.
        The function should return an updated output.
    * `hook`: Description of the component to hook into.
    """

    def hook_wrapper(_module, _input, output):
        if hook.component == "mlp":
            output[0] = hook_fn(output[0], hook)
        elif hook.component[0] == "self" or hook.component[0] == "cross":
            head_idx = hook.component[1]

            # multihead and self-attention layer have same number of heads
            num_head = model.decoder_transfomer.layers[hook.layer].multihead_attn.num_heads
            
            # view data in terms of [seq_len x batch_size x num_head x head_dim])
            seq_len, bsz, _ = output[0].size()
            output_heads = output[0].view(seq_len, bsz, num_head, -1)

            # hook output of specified head
            output_heads[:, :, head_idx, :] = hook_fn(output_heads[:, :, head_idx, :], hook)
        else:
            raise ValueError(f"Unknown hook component: {hook.component}")
        
        return output
    
    if hook.component == "mlp":
        # hook into 2nd linear layer of MLP
        return model.decoder_transfomer.layers[hook.layer].linear2.register_forward_hook(hook_wrapper)
    elif hook.component[0] == "self":
        # hook into self-attention layer
        return model.decoder_transfomer.layers[hook.layer].self_attn.register_forward_hook(hook_wrapper)
    elif hook.component[0] == "cross":
        # hook into cross-attention layer
        return model.decoder_transfomer.layers[hook.layer].multihead_attn.register_forward_hook(hook_wrapper)



# If we uncomment the code below, we set all decoder MLP outputs to random values using interventions.
#  As you'll see below, the model won't be able to fit the correct equation (:omg:).
"""
def test_hook(output, _hook: HookPoint):
    return torch.randn_like(output)
for layer in range(4):
    register_decoder_hook(model, test_hook, HookPoint(layer, "mlp"))
"""

'\ndef test_hook(output, _hook: HookPoint):\n    return torch.randn_like(output)\nfor layer in range(4):\n    register_decoder_hook(model, test_hook, HookPoint(layer, "mlp"))\n'

In [122]:
def sample_equation(
    eq: str, vars_used: list[str], vars_model: list[str],
    point_count: int,
    min_support: float, max_support: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Uniformly samples an equation using `sympy.lambdify`.

    # Args
    * `vars_used`: The names of the variables in the equation.
    * `vars_model`: The total list of names of the variables the model expects. If the equation has less variables,
        the input will be padded with zeros.
    * `min_support` & `max_support`: The range of the domain to sample the equation from.

    # Returns
    A tuple containing:
    * `X`: A tensor of shape [`point_count` x len(vars_model)] containing the sampled independent variables.
    * `y`: A tensor of shape [`point_count`] containing the evaluated equation values.
    """

    # of shape [N x D]
    X = torch.zeros(point_count, len(vars_model))
    X_dict = {}

    # set used variables to random values
    for idx, var in enumerate(vars_model):
        if var in vars_used:
            # sample uniformly and scale to the supported range
            X[:, idx] = torch.rand(point_count) * (max_support - min_support) + min_support

        X_dict[var] = X[:, idx]

    # evaluate equation
    y = sympy.lambdify(",".join(vars_model), eq)(**X_dict)

    return X, y
    
def sample_equation_from_config(
    eq: str, vars_used: list[str], point_count: int,
    model_cfg: omegaconf.DictConfig,
    eq_cfg: Dict,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Wrapper around `sample_equation` that uses the model and equation configuration to sample an equation.
    """

    return sample_equation(
        eq=eq,
        vars_used=vars_used,
        vars_model=eq_cfg["total_variables"],
        point_count=point_count,
        min_support=model_cfg.dataset_train.fun_support["min"],
        max_support=model_cfg.dataset_train.fun_support["max"],
    )

#X, y = sample_equation_from_config("exp(cos(x_1))", ["x_1"], 500, cfg, eq_setting)
X, y = sample_equation_from_config("sin(abs(x_1))+cos(abs(sin(x_2)))", ["x_1", "x_2"], 500, cfg, eq_setting)

In [None]:
def generate_function(vars: list[str], identity_prob=0.5, nest_prob=0.3, max_depth=3) -> Tuple[str, list[str]]:
    """ 
    Generates a potentially nested function assuming a uniform function distribution over operators.
    This could, for example, be a combination of powers and trigonometric functions,
    but no addition, substraction, multiplication or division.

    # Args
    * `identity_prob`: The probability that the identity function is used (i.e., a naked variable).
    * `nest_prob`: In case the function is not an identity function, the probability of nesting a function.
    * `max_depth`: Maximum nesting depth.

    # Returns
    A tuple containing:
    * `equation`: The equation.
    * `vars_used`: The independent variables in the function.
    """

    # NOTE: functions that have incomplete domains are ignored (log, tan, etc.)
    funcs = ["abs", "cos", "exp", "sin"]
    
    var = np.random.choice(vars)

    if np.random.sample() < identity_prob:
        # return naked variable
        return var, [var]
    else:
        func = np.random.choice(funcs) + "("

        # nest function
        depth = 1
        while np.random.sample() < nest_prob and depth < max_depth:
            func += np.random.choice(funcs) + "("
            depth += 1

        # add variable
        func += var
        func += ")" * depth

        return func, [var]


def generate_dataset_pairs(
    strategy: Literal["sign-bias", "complexity-bias"], point_count: int, num_eq: int,
    model_cfg: omegaconf.DictConfig, eq_cfg: Dict,
    second_dataset_sample_rate: int=None    
) -> Tuple[torch.Tensor, torch.Tensor, list[str]]:
    """
    Generates dataset pairs based on a given strategy. Depending on the strategy, random equations pairs
    are generated, which are then sampled to form dataset pairs.
    
    # Args
    * `strategy`:
        1. If equal to `"sign-bias"`, equations are of the following form: `function_1(variable_1) ± function_2(variable_2)`.
        The first equation in the pair will apply the `+` operator and the second the `-` operator.
        2. If equal to `"complexity-bias"`, equations are sampled randomly. Each dataset in a pair samples the same
        equation, but the second dataset with a different sample rate determined by `second_dataset_sample_rate`.
    * `point_count`: The number of points per dataset.
    * `num_eq`: The number of equation pairs to generate.

    # Returns
    A tuple containing:
    * `X`: A tensor of sampled equation input variables of size [2 x Ne x Np x D], where 2 denotes each pair,
        Ne the number of equations, Np the number of points, and D the dimensionality.
    * `y`: A tensor of shape [2 x Ne x Np] containing the evaluated equation values.
    * `equations`: A list of size [Ne x 2] containing the generated equations.
    """
    X = torch.empty((2, num_eq, point_count, len(eq_cfg["total_variables"])))
    y = torch.empty((2, num_eq, point_count))
    equations = []

    if strategy == "sign-bias":
        for i in range(num_eq):
            # generate equation
            eq_part0, vars0 = generate_function(eq_cfg["total_variables"])
            eq_part1, vars1 = generate_function(eq_cfg["total_variables"])
            vars_used = vars0 + vars1

            eq_plus = eq_part0 + "+" + eq_part1
            eq_min = eq_part0 + "-" + eq_part1

            # sample equation
            Xe0, ye0 = sample_equation_from_config(eq_plus, vars_used, point_count, model_cfg, eq_cfg)
            Xe1, ye1 = sample_equation_from_config(eq_min, vars_used, point_count, model_cfg, eq_cfg)

            # store data
            X[0, i], y[0, i] = Xe0, ye0
            X[1, i], y[1, i] = Xe1, ye1
            equations.append((eq_plus, eq_min))

    elif strategy == "complexity-bias":
        assert second_dataset_sample_rate != None, f"second_dataset_sample_rate not set, but should be for strategy: {strategy}"

        for i in range(num_eq):
            # generate equation
            eq, vars_used = ... # TODO: generate equation

            # sample equation
            Xe0, ye0 = sample_equation_from_config(eq, vars_used, point_count, model_cfg, eq_cfg)
            Xe1, ye1 = Xe0[::second_dataset_sample_rate], ye0[::second_dataset_sample_rate]

            # store data
            X[0, i], y[0, i] = Xe0, ye0
            X[1, i], y[1, i] = Xe1, ye1
            equations.append((eq, eq))
    else:
        raise ValueError(f"Unknown strategy: {strategy}")
    
    return X, y, equations


generate_dataset_pairs("sign-bias", 10, 10, cfg, eq_setting)[2]

[('x_2+x_2', 'x_2-x_2'),
 ('x_2+cos(x_1)', 'x_2-cos(x_1)'),
 ('abs(x_2)+x_1', 'abs(x_2)-x_1'),
 ('x_3+cos(exp(sin(x_3)))', 'x_3-cos(exp(sin(x_3)))'),
 ('sin(sin(exp(x_2)))+x_1', 'sin(sin(exp(x_2)))-x_1'),
 ('sin(x_3)+abs(x_1)', 'sin(x_3)-abs(x_1)'),
 ('x_3+exp(x_3)', 'x_3-exp(x_3)'),
 ('x_3+x_3', 'x_3-x_3'),
 ('x_2+x_1', 'x_2-x_1'),
 ('x_1+x_1', 'x_1-x_1')]

In [124]:
def greedy_predict(
        model: Model, params: FitParams,
        X: torch.Tensor = None, y: torch.Tensor = None, sequence: torch.Tensor = None, enc_embed: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Greedily predicts the next token in the sequence. Can be used in two ways:
    1. If `X` and `y` are provided, the model will use them to compute the encoder embedding and then predict the next token.
    2. If `enc_embed` is provided, the model will use this embedding directly to predict the next token.

    # Args
    * `X`: Function domain of shape [B, N, D], where B is the batch size, N is the number of samples,
        and D the input dimensionality (no more than 3).
    * `y`: Function image corresponding to `X` of shape [B, N].
    * `sequence`: The initial tokens to predict the next token from of shape [B, S], where S is the maximum sequence length.
        All samples in the batch are expected to be at the same current sequence length.
    * `enc_embed`: Encoder embedding. Can be reused if the same `X` and `y` are used repeatedly for prediction.

    # Returns
    A tuple containing:
    * `next_token`: The predicted token IDs for each sample in the batch.
    * `sequence`: The sequence of tokens generated so far, which is updated with the predicted token.
    * `enc_embed`: The encoder embedding, which may be resued.
    """

    if enc_embed is None:
        # compute encoder embedding
        enc_input = torch.cat((X, y.unsqueeze(-1)), dim=-1).to(model.device)
        enc_embed = model.enc(enc_input)

    batch_size = enc_embed.size(0)

    if sequence is None:
        # initialize sequence with start token
        sequence = torch.zeros((batch_size, model.cfg.length_eq), dtype=torch.long, device=model.device)
        sequence[:, 0] = params.word2id["S"]

    cur_len = (sequence != 0).sum(dim=1).max().item()

    # generate decoder masks
    mask1, mask2 = model.make_trg_mask(
        sequence[:, :cur_len]
    )

    # compute positional embeddings
    pos = model.pos_embedding(
        torch.arange(0, cur_len)
            .unsqueeze(0)
            .repeat(sequence.shape[0], 1)
            .type_as(sequence)
    )

    # embed tokens
    seq_embed = model.tok_embedding(sequence[:, :cur_len])
    seq_embed = model.dropout(seq_embed + pos)

    # run decoder
    output = model.decoder_transfomer(
        seq_embed.permute(1, 0, 2),
        enc_embed.permute(1, 0, 2),
        mask2.float(),
        tgt_key_padding_mask=mask1.bool(),
    )
    output = model.fc_out(output)
    output = output.permute(1, 0, 2).contiguous()

    # add next token
    # NOTE: softmax not really necessary here, but may come in handy later
    token_probs = F.softmax(output[:, -1:, :], dim=-1).squeeze(1)
    next_token = torch.argmax(token_probs, dim=-1)
    sequence[:, cur_len] = next_token

    return next_token, sequence, enc_embed

def tokens_to_text(tokens: torch.Tensor, params: FitParams) -> list[str]:
    """
    Converts a batches of token IDs to their corresponding text representations.

    # Args
    * `tokens`: Of shape [B, S], where B is the batch size and S is the maximum sequence length.
    """

    decoded = []
    for batch in tokens:
        text = []

        for token in batch:
            if token.item() == params.word2id["F"] or token.item() == 0:
                break
            text.append(params.id2word[token.item()])
        decoded.append(" ".join(text[1:])) # skip start token
    return decoded


In [125]:
# initial token prediction, this initializes the sequence and caches the encoder embedding (saves computation time).
tok, seq, enc_embed = greedy_predict(model, params_fit, X.unsqueeze(0), y.unsqueeze(0))

# repeatedly predict next token greedily
for i in range(30):
    seq = greedy_predict(model, params_fit, enc_embed=enc_embed, sequence=seq)[1]

# this should result in (roughly) the correct equation
tokens_to_text(seq, params_fit)



['add cos cos x_1 cos cos x_1']

In [126]:
# fit model with beam search instead of greedy + constant fitting (takes a lot longer)
output = fitfunc(X, y) 

  X = torch.tensor(X,device=self.device).unsqueeze(0)
  y = torch.tensor(y,device=self.device).unsqueeze(0)


Memory footprint of the encoder: 4.096e-05GB 

Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...
Constructing BFGS loss...
Flag idx remove ON, Removing indeces with high values...
checking input values range...
Loss constructed, starting new BFGS optmization...


In [127]:
# here you can see the fitted equations
output

{'all_bfgs_preds': ['((cos(cos(x_1)))+(cos(cos(x_1))))',
  '((cos((cos(x_1))**(2)))+(cos(x_2)))'],
 'all_bfgs_loss': [0.85989344, 1.0511613],
 'best_bfgs_preds': ['((cos(cos(x_1)))+(cos(cos(x_1))))'],
 'best_bfgs_loss': [0.85989344]}