# Generative Flow Network Demo
From https://colab.research.google.com/drive/1fUMwgu2OhYpQagpzU5mhe9_Esib3Q2VR


In [None]:
import os
import sys

import matplotlib.pyplot as matplotlib_pyplot
import numpy as numpy
import torch
import tqdm
from torch.distributions.categorical import Categorical

# Add the project root directory to Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.append(project_root)

import gflownet as gflownet

In [None]:
def plot_faces(faces):
    f, ax = matplotlib_pyplot.subplots(1, len(faces))
    for i in range(len(faces)):
        if len(faces) > 1:
            matplotlib_pyplot.sca(ax[i])
        gflownet.Face.draw_face(faces[i])


smiling_face = gflownet.Face(['smile', 'left_eb_down', 'right_eb_down'])
frowning_face = gflownet.Face(['frown', 'left_eb_up', 'right_eb_up'])
plot_faces([gflownet.Face(['left_eb_up', 'right_eb_up']), gflownet.Face(['frown']), gflownet.Face(['smile'])])
plot_faces([gflownet.Face(['left_eb_up', 'left_eb_down']), gflownet.Face(['right_eb_up', 'right_eb_down']),
            gflownet.Face(['left_eb_up', 'left_eb_down', 'right_eb_up', 'right_eb_down']),
            gflownet.Face(['smile', 'frown'])])
plot_faces([frowning_face, smiling_face])

![Differing representations of invalid faces, frowning and smiling face](./images/faces_1.png)
<br>
![Differing representations of invalid faces, frowning and smiling face](./images/faces_2.png)
<br>
![Differing representations of invalid faces, frowning and smiling face](./images/faces_3.png)

In [None]:
enumerated_states, transitions = gflownet.Face.enumerate_states_transitions(gflownet.Face.sorted_keys)
unique_states = []
for face in enumerated_states:
    if set(face.patches) not in [set(u.patches) for u in unique_states]:
        unique_states.append(face)

In [None]:
gflownet.Network.plot(unique_states, transitions)

![All possible faces - state space](./images/state_space.png)

# Flow Matching
https://arxiv.org/abs/2106.04399

In [None]:
def face_parents(state):
    parent_states = []  # states that are parents of state
    parent_actions = []  # actions that lead from those parents to state
    for face_part in state:
        # For each face part, there is a parent without that part
        parent_states.append([i for i in state if i != face_part])
        # The action to get there is the corresponding index of that face part
        parent_actions.append(gflownet.Face.sorted_keys.index(face_part))
    return parent_states, parent_actions

In [None]:
# Instantiate model and optimizer
flow_matching_model = gflownet.FlowModel(512)
opt = torch.optim.Adam(flow_matching_model.parameters(), 3e-4)

# Let's keep track of the losses and the faces we sample
losses = []
sampled_faces = []
# To not complicate the code, I'll just accumulate losses here and take a
# gradient step every `update_freq` episode.
minibatch_loss = 0
update_freq = 4
for episode in tqdm.tqdm(range(50000), ncols=40):
    # Each episode starts with an "empty state"
    state = []
    # Predict F(s, a)
    edge_flow_prediction = flow_matching_model(gflownet.Face(state).face_to_tensor())
    for t in range(3):
        # The policy is just normalizing, and gives us the probability of each action
        policy = edge_flow_prediction / edge_flow_prediction.sum()
        # Sample the action
        action = Categorical(probs=policy).sample()
        # "Go" to the next state
        new_state = state + [gflownet.Face.sorted_keys[action]]

        # Now we want to compute the loss, we'll first enumerate the parents
        parent_states, parent_actions = face_parents(new_state)
        # And compute the edge flows F(s, a) of each parent
        px = torch.stack([gflownet.Face(p).face_to_tensor() for p in parent_states])
        pa = torch.tensor(parent_actions).long()
        parent_edge_flow_preds = flow_matching_model(px)[torch.arange(len(parent_states)), pa]
        # Now we need to compute the reward and F(s, a) of the current state,
        # which is currently `new_state`
        if t == 2:
            # If we've built a complete face, we're done, so the reward is > 0
            # (unless the face is invalid)
            reward = gflownet.Face(new_state).face_reward()
            # and since there are no children to this state F(s,a) = 0 \forall a
            edge_flow_prediction = torch.zeros(6)
        else:
            # Otherwise we keep going, and compute F(s, a)
            reward = 0
            edge_flow_prediction = flow_matching_model(gflownet.Face(new_state).face_to_tensor())

        # The loss as per the equation above
        flow_mismatch = (parent_edge_flow_preds.sum() - edge_flow_prediction.sum() - reward).pow(2)
        minibatch_loss += flow_mismatch  # Accumulate
        # Continue iterating
        state = new_state

    # We're done with the episode, add the face to the list, and if we are at an
    # update episode, take a gradient step.
    sampled_faces.append(gflownet.Face(state))
    if episode % update_freq == 0:
        losses.append(minibatch_loss.item())
        minibatch_loss.backward()
        opt.step()
        opt.zero_grad()
        minibatch_loss = 0

In [None]:
matplotlib_pyplot.figure(figsize=(10, 3))
matplotlib_pyplot.plot(losses)
matplotlib_pyplot.yscale('log')

![Flow Matching Loss over Training](./images/flow_matching_loss.png)

In [None]:
f, ax = matplotlib_pyplot.subplots(8, 8, figsize=(4, 4))
print('Ratio of faces with a smile:', sum(['smile' in i.patches for i in sampled_faces[-128:]]) / 128)
print('Ratio of valid faces:', sum([i.face_reward() > 0 for i in sampled_faces[-128:]]) / 128)
for i, face in enumerate(sampled_faces[-64:]):
    matplotlib_pyplot.sca(ax[i // 8, i % 8])
    face.draw_face()

![Samples Faces for Flow Matching](./images/flow_matching_faces.png)

In [None]:
gflownet.Network.plot(unique_states, transitions, flow_matching_model)

![All Edge Flows from Starting Face](./images/flows.png)

In [None]:
flow_matching_model(gflownet.Face([]).face_to_tensor()).sum()

# Trajectory Balancing

https://arxiv.org/abs/2201.13259

In [None]:
# Instantiate model and optimizer
tb_model = gflownet.TBModel(512)
opt = torch.optim.Adam(tb_model.parameters(), 3e-4)

# Let's keep track of the losses and the faces we sample
tb_losses = []
tb_sampled_faces = []
# To not complicate the code, I'll just accumulate losses here and take a
# gradient step every `update_freq` episode.
minibatch_loss = 0
update_freq = 2

logZs = []
for episode in tqdm.tqdm(range(50000), ncols=40):
    # Each episode starts with an "empty state"
    state = []
    # Predict P_F, P_B
    prob_forward, prob_backward = tb_model(gflownet.Face(state).face_to_tensor())
    total_prob_forward = 0
    total_prob_backward = 0
    for t in range(3):
        # Here P_F is logits, so we want the Categorical to compute the softmax for us
        cat = Categorical(logits=prob_forward)
        action = cat.sample()
        # "Go" to the next state
        new_state = state + [gflownet.Face.sorted_keys[action]]
        # Accumulate the P_F sum
        total_prob_forward += cat.log_prob(action)

        if t == 2:
            # If we've built a complete face, we're done, so the reward is > 0
            # (unless the face is invalid)
            reward = torch.tensor(gflownet.Face(new_state).face_reward()).float()
        # We recompute P_F and P_B for new_state
        prob_forward, prob_backward = tb_model(gflownet.Face(new_state).face_to_tensor())
        # Here we accumulate P_B, going backwards from `new_state`. We're also just
        # going to use opposite semantics for the backward policy. I.e., for P_F action
        # `i` just added the face part `i`, for P_B we'll assume action `i` removes
        # face part `i`, this way we can just keep the same indices.
        total_prob_backward += Categorical(logits=prob_backward).log_prob(action)

        # Continue iterating
        state = new_state

    # We're done with the trajectory, let's compute its loss. Since the reward can
    # sometimes be zero, instead of log(0) we'll clip the log-reward to -20.
    loss = (tb_model.logZ + total_prob_forward - torch.log(reward).clip(-20) - total_prob_backward).pow(2)
    minibatch_loss += loss

    # Add the face to the list, and if we are at an
    # update episode, take a gradient step.
    tb_sampled_faces.append(state)
    if episode % update_freq == 0:
        tb_losses.append(minibatch_loss.item())
        minibatch_loss.backward()
        opt.step()
        opt.zero_grad()
        minibatch_loss = 0
        logZs.append(tb_model.logZ.item())

In [None]:
f, ax = matplotlib_pyplot.subplots(2, 1, figsize=(10, 6))
matplotlib_pyplot.sca(ax[0])
matplotlib_pyplot.plot(tb_losses)
matplotlib_pyplot.yscale('log')
matplotlib_pyplot.ylabel('loss')
matplotlib_pyplot.sca(ax[1])
matplotlib_pyplot.plot(numpy.exp(logZs))
matplotlib_pyplot.ylabel('estimated Z');

![Trajectory Balancing Loss and Z Estimation](./images/trajectory_balancing_loss.png)

In [None]:
f, ax = matplotlib_pyplot.subplots(8, 8, figsize=(4, 4))
print('Ratio of faces with a smile:', sum(['smile' in i for i in tb_sampled_faces[-128:]]) / 128)
print('Ratio of valid faces:', sum([gflownet.Face(i).face_reward() > 0 for i in tb_sampled_faces[-128:]]) / 128)
for i, face in enumerate(tb_sampled_faces[-64:]):
    matplotlib_pyplot.sca(ax[i // 8, i % 8])
    gflownet.Face(face).draw_face()

![Sampled Faces for Trajectory Balancing](./images/trajectory_balancing_faces.png)

In [None]:
tb_model.logZ.exp()