## GFlowNet As a BSR Sampler: Demo & Tutorial

In this brief tutorial, we will demonstrate the power of GFlowNet (short for GFN) -- a new generative neural network structure proposed by [Bengio et el.](https://yoshuabengio.org/2022/03/05/generative-flow-networks/) -- in Bayesian symbolic regression problems. While it's recommended to have some prior knowledge on GFN, especially on its specification (the concepts of "flows" and "transitions"), through the original paper or the official tutorial,  we will give a clear walk through over the setup of the model from scratch.

The problem that the GFN in this tutorial seeks to solve is a simplified version of Symbolic Regression (SR), in which we are given a known (or confident) expression tree structure and all we search over is the specific features/operators that go into each tree node. In comparison with previous methods like BMS, this approach reduces the needs of finding structure (S), and we make such simplification since (1) we might be able to use other routines to figure out **S** (such as enumerating tree structures) and (2) this is a MVP model and we hope to inspire more complicated upgrades in the near future.


### Step 1: Expression Tree Representation

The problem begins with a given expression tree structure **S**. A tree structure is an uninitialized binary tree (i.e. simply tell us which node is leaf, which is binary without specifying the operator/feature on those nodes). For example, we might have a structure:

In [1]:
from binarytree import Node

demo_tree = Node("-", Node("-", Node("-")), Node("-"))
print(demo_tree)


    -
   / \
  -   -
 /
-



Expressions like $sin(x) + y$ or $\frac{log(y)}{x}$ all fall under this structure, as expressed by

In [2]:
demo_tree2 = Node("+", Node("sin", Node("x")), Node("y"))
demo_tree3 = Node("/", Node("log", Node("y")), Node("x"))
print(demo_tree2, demo_tree3)


     _+
    /  \
  sin   y
 /
x
 
     _/
    /  \
  log   x
 /
y



However, in order to better work with neural networks, we consider an alternative vector-based representation of a tree structure. Consider a tree structure **S** of depth **D**, its vector representation lives in space $\mathbb{R}^n$ where n is the number of nodes in the full binary tree of depth **D**. We construct such vector by doing a level-by-level check on **S**, and appending a placeholder (such as 1) if there's a node and 0 otherwise. For example, the above tree structure translates to

In [3]:
[1, 1, 1, 1, 0, 0, 0]

[1, 1, 1, 1, 0, 0, 0]

Since it only has the leftmost leaf node in the third level (the dimension of its vector is $7 = 2^3 - 1$, which is the number of nodes in a full binary tree of depth 3).

On the other hand, every valid expression tree of structure **S**, denoted as **T**, is a fully instantiated tree. A fully instantiated tree can also be encoded in a vector format, but the values of that vector need to depend on the **operator space**. Consider this simple setup:

In [21]:
import torch

NUM_FEATURES = 2  # e.g. the dataset has two covariates x0 and x1
BIN_OPS = [
    torch.add,
    torch.mul,
    torch.div,
    torch.sub
]  # 4 binary operators

UN_OPS = [
    torch.sin,
    torch.cos,
    torch.exp,
    torch.abs
]  # 4 unary operators

In the vector representation, we then use 0 to denote an empty node (e.g. the right child of an unary node), 1-2 to denote the features (for leaf node), 3-6 for the 4 available binary operators, and 7-11 for the 5 unary operators. The expression $sin(x) + y$ is thus $[3, 7, 2, 1, 0, 0, 0]$ if our features are $[x, y]$. We actually have a helper function to do such conversion.

In [5]:
from tree_utils import encoding_to_tree
ecd = torch.Tensor([3, 7, 2, 1, 0, 0, 0])
print(encoding_to_tree(ecd))


      _+
     /  \
  _sin   X1
 /
X0



### Step 2: Action and State Space

Once we've decided the above representations of the tree structures, we can formally define the GFN problem that we hope to solve. Between an uninstantiated (empty) structure and a valid expression tree, we can have many intermediate **states**. An example state looks like

In [6]:
ecd2 = torch.Tensor([3, 7, 2, 0, 0, 0, 0])
print(encoding_to_tree(ecd2))


   _+
  /  \
sin   X1



which is an invalid expression tree due to the lack of operand under `sin`, but once we define the **actions** on the **states**, we have the capacity of turning it into the aforementioned expression.

In this simple problem setup, for a given state **X**, we simply compare its vector representation with that of the given tree structure **S**; the comparison should give us the next uninstantiated node and its type (leaf, unary, or binary). Based on these information, our action space is then the corresponding features or operators to be used for instantiating the new node. For the above example, since the leftmost leaf node is the next (and final) uninstantiated node, our action space is simply $A = \{X_0, X_1\}$. This formulation of the action space looks simple but restrictive -- due to the way we construct the vector representation, this method essentially builds up a valid expression tree from left to right, level by level, from an empty tree.

A forward policy $\pi$ in GFN is a stochastic mapping from the state space to the action space. This stochastic policy is usually represented as a neural network. In our case, we use a 4-layer fully connected NN (with 32 intermediate nodes, see `FTForwardPolicy` class) with `LeakyRELU` activations in between and a `sigmoid` output layer. The stochastic policy is achieved through outputting the probabilities for different actions, and we choose the action by sampling from a categorical distribution with these probabilities. Despite the stochastic policy, actions in GFN are determinstic, which means that we are guaranteed to reach a next state $s'$ with a fixed action $a'$.

GFN also requires specifying a `backward policy` network, which estimates the sources of incoming flows for a given state (the `forward policy` can be thought as estimating the destinations of outgoing flows). In our setup, it's obvious that there's only one source state for a given state (you can pause and think about why; maybe revisit the action specification above?), the backward policy is trivial.

<u>>Side note for anyone interested in the code</u>

Since our policy is represented by a NN, the output features (# of actions) is usually fixed; but how do we account for the different available actions for different states (e.g. if the next node to instantiate is a leaf node, then we should only allow features, i.e. the action 1 & 2)? The solution is to add a `mask` function that reduces the probabilities of all "unavailable" actions of a given state to zero. We can thus rescale the other probabilities to make sure that the action we sample is 100% valid.

### Step 3: Rewards and Training

In the context of symbolic regression, we grant rewards to a valid expression tree based on how this expression "fits in" the data we have. There are many ways of doing so, and the naive one we use is
$$
max(1, 100 * (1 - MSE(\hat y, y)/TSS(y)))
$$

Where $TSS(y) =  ||y - \bar{y}||^2$ is the total sum of squares for the original data (same as the one in linear regression problems) that gives the performance of a baseline predictor (i.e. use the data mean). This reward function transforms the ratio between our predictor's MSE and the baseline onto a scale between 1 and 100, with smaller MSE giving higher rewards.

Once we have the rewards, we can formulate the **trajactory balance loss** (TBL) in the original paper. The details for this loss is omitted here, but the gist is that we will minimize this loss by optimizing our forward policy NN (through backpropagation with `Adam` optimizer) so that the incoming flows and outgoing flows of the states match.

There are two ways to train the GFN: on-policy and off-policy trainings -- borrowing terminologies from Reinforcement Learning (RL). The difference is that on-policy training keeps using the trajectories (paths going from a scratch tree to a valid expression tree that receives rewards) generated from the under-training GFN itself; while in off-policy training we can use some other routines to generate these paths independent of the GFN under updates. We only implement on-policy training for simplicity purpose.

### Step 4: Training and results

The following function trains our GFN for a groundtruth expression $y = (x_1 + x_0) * x_0$ under given `batch_size` and `num_epochs`.

In [18]:
from fixed_tree import FixedTree, FTForwardPolicy, FTBackwardPolicy
from gflownet.gflownet import GFlowNet
from torch.optim import Adam
from tqdm import tqdm
from gflownet.utils import trajectory_balance_loss

def train_fixed_tree(batch_size, num_epochs):
    X = torch.vstack([torch.empty(20).uniform_(-1, 1) for _ in range(3)]).T
    y = (torch.sin(X[:, 1]) + X[:, 0]) * (torch.cos(X[:, 1]) + X[:, 2])
    temp = torch.tensor([1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0])
    env = FixedTree(temp, X, y)
    forward_policy = FTForwardPolicy(env.state_dim, 32, env.num_actions)
    backward_policy = FTBackwardPolicy(env.state_dim, num_actions=env.num_actions)
    model = GFlowNet(forward_policy, backward_policy, env)
    opt = Adam(model.parameters(), lr=5e-3)

    for i in (p := tqdm(range(num_epochs))):
        s0 = torch.zeros(batch_size, env.state_dim)
        s, log = model.sample_states(s0, return_log=True)
        log.back_probs.fill_(1.0)
        loss = trajectory_balance_loss(log.total_flow,
                                       log.rewards,
                                       log.fwd_probs,
                                       log.back_probs)
        loss.backward()
        opt.step()
        opt.zero_grad()
        if i % 10 == 0:
            p.set_description(f"{loss.item():.3f}")

    # s0 = one_hot(torch.zeros(10).long(), env.state_dim).float()
    # s = model.sample_states(s0, return_log=False)
    return model, env

We simply train with `batch_size = 32` and `num_epochs = 20000`. After the training we take 20 samples for testing.

In [None]:
model, env = train_fixed_tree(64, 20000)

6.585:  73%|███████▎  | 14563/20000 [00:53<00:19, 275.65it/s]

In [9]:
s0 = torch.zeros(20, env.state_dim)
s = model.sample_states(s0, return_log = False)

The training should be fairly fast (around 1 minute) and the sampling should be immediate (this is why GFN is also known as an amortized MCMC -- it uses longer training time to exchange ultra fast sampling process). You should also see the loss reduced hugely.

In [10]:
s

tensor([[4., 3., 1., 1., 2., 0., 0.],
        [4., 3., 1., 1., 2., 0., 0.],
        [3., 4., 2., 1., 1., 0., 0.],
        [3., 4., 1., 1., 1., 0., 0.],
        [3., 6., 1., 1., 1., 0., 0.],
        [3., 5., 1., 1., 1., 0., 0.],
        [3., 5., 1., 2., 2., 0., 0.],
        [5., 5., 2., 2., 2., 0., 0.],
        [3., 3., 1., 2., 1., 0., 0.],
        [4., 3., 1., 1., 2., 0., 0.],
        [3., 6., 1., 2., 2., 0., 0.],
        [4., 3., 1., 1., 2., 0., 0.],
        [3., 6., 1., 2., 2., 0., 0.],
        [3., 4., 2., 1., 1., 0., 0.],
        [4., 4., 1., 1., 1., 0., 0.],
        [3., 6., 1., 2., 2., 0., 0.],
        [4., 5., 1., 1., 1., 0., 0.],
        [5., 6., 2., 2., 2., 0., 0.],
        [6., 4., 1., 1., 2., 0., 0.],
        [4., 3., 1., 1., 2., 0., 0.]])

Out of the 20 samples, many of them are [4, 3, 1, 2, 1, 0, 0], which is the exact groud-truth solution.

In [22]:
print(encoding_to_tree(torch.Tensor([4, 3, 1, 2, 1, 0, 0])))


     ___*
    /    \
  _+      X0
 /  \
X1   X0



With around 50000 epoch of training (3-5 minutes), the samples are mostly the correct answer. To sum up, this simple demo already shows the potential of GFN in fulfilling SR tasks.


In [20]:
model2, env2 = train_fixed_tree(32, 20000)

nan:   0%|          | 1/20000 [00:00<04:36, 72.39it/s]


ValueError: Expected parameter probs (Tensor of shape (32, 12)) of distribution Categorical(probs: torch.Size([32, 12])) to satisfy the constraint Simplex(), but found invalid values:
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
        [nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
       grad_fn=<DivBackward0>)

### Step 5: Potential Next Steps

There are many ways that this MVP GFN model can be improved. To name a few priority items:

- encoder: we propose a vector-based encoder schema for an expression tree object. Under this setup the vector will contain lots of sparsity (especially when the tree grows deeper). Other alternative encoders include Tree-RNN, VAE, or a stripped version of the current vector representation.
- action: our current actions assign features/operators according to a given template (structure) from left to right, level by level. There might be other ways to formulate the action space so it fits in a natural construction process better.
- reward function: there can be other reward functions (e.g. likelihood-based, utility theory related, etc.) other than the current one based on MSE.
- structure proposal: simultaneously training a structure proposal GFN at the same time?
- ...: let's come together and think about more.