### Importing libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

### Defining the environment
The particle starts at $(0, 0)$. At each timestep, the particle will move up-right, right, or down-right. The sequence terminates after 5 timesteps, and the reward is equal to the absolute $y$ value of the ending position. The goal is to sample trajectories from $p(\tau)$ proportional to the reward $R(\tau)$, where $\tau$ represents a trajectory or a sequence of points $(s_0 \to s_1 \to \ldots \to s_5)$.

In [63]:
class GridEnvironment:
    def __init__(self, max_steps):
        self.max_steps = max_steps
        self.reset()

    def reset(self):
        self.current_state = (0, 0)
        self.step_count = 0
        return self.current_state

    def step(self, action):
        x, y = self.current_state
        x += 1
        y += action - 1  # -1, 0, or 1
        self.current_state = (x, y)
        self.step_count += 1
        done = self.step_count >= self.max_steps
        reward = max(abs(y), 0.1) if done else 0
        return self.current_state, reward, done

In [80]:
class GFlowNet(nn.Module):
    def __init__(self, hidden_size):
        super(GFlowNet, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(2, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
        )
        self.fwp = nn.Linear(hidden_size, 3)
        self.bwp = nn.Linear(hidden_size, 3)
        self.log_Z = nn.Parameter(torch.zeros(1))

    def forward(self, state):
        state = self.network(state)
        return F.softplus(self.fwp(state)), F.softplus(self.bwp(state))

def generate_trajectory(env, model, max_steps):
    state = env.reset()
    trajectory = [state]
    forward_probs = []
    actions = []

    for _ in range(max_steps):
        state_tensor = torch.FloatTensor(state)
        # print(state_tensor)
        flow_values, _ = model(state_tensor)
        probs = flow_values / flow_values.sum()
        action = torch.multinomial(probs, 1).item()
        # print(flow_values)
        # action = torch.distributions.categorical.Categorical(logits=flow_values).sample()
        # print(action)
        next_state, reward, done = env.step(action)
        state = next_state

        forward_probs.append(probs[action])
        actions.append(action)
        trajectory.append(next_state)

        if done:
            break

    return trajectory, forward_probs, actions, reward

def calculate_backward_probs(model, trajectory, actions):
    backward_probs = []
    for t in range(1, len(trajectory)):
        prev_state = trajectory[t-1]
        curr_state = trajectory[t]

        _, backward_flow_values = model(torch.FloatTensor(curr_state))
        if curr_state[0] == curr_state[1]:
            backward_flow_values[0] = 0
            backward_flow_values[1] = 0
        if curr_state[0] == curr_state[1] + 1:
            backward_flow_values[0] = 0
        if curr_state[0] == -curr_state[1]:
            backward_flow_values[1] = 0
            backward_flow_values[2] = 0
        if curr_state[0] == -curr_state[1] + 1:
            backward_flow_values[2] = 0
        prev_probs = backward_flow_values / backward_flow_values.sum()

        # Calculate the action that would lead from prev_state to curr_state
        dx = curr_state[0] - prev_state[0]  # Should always be 1
        dy = curr_state[1] - prev_state[1]
        backward_action = dy + 1  # Map (-1, 0, 1) to (0, 1, 2)

        backward_probs.append(prev_probs[backward_action])

    return backward_probs

def trajectory_balance_loss(model, trajectory, forward_probs, backward_probs, reward):
    forward_log_prob = torch.sum(torch.log(torch.stack(forward_probs)))
    backward_log_prob = torch.sum(torch.log(torch.stack(backward_probs))) if backward_probs else torch.tensor(0.0)

    log_ratio = model.log_Z + forward_log_prob - torch.log(torch.tensor(reward + 1e-10)) - backward_log_prob

    loss = log_ratio ** 2
    return loss

def train_gflownet(num_episodes, max_steps, hidden_size, learning_rate):
    env = GridEnvironment(max_steps)
    model = GFlowNet(hidden_size)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_episodes)

    for episode in range(num_episodes):
        trajectory, forward_probs, actions, reward = generate_trajectory(env, model, max_steps)
        backward_probs = calculate_backward_probs(model, trajectory, actions)
        loss = trajectory_balance_loss(model, trajectory, forward_probs, backward_probs, reward)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        if (episode + 1) % 1000 == 0:
            print(f"Episode {episode + 1}, Loss: {loss.item():.4f}, Reward: {reward}, LR: {optimizer.param_groups[0]['lr']:.6f}")

    return model

### Training the model

In [81]:
# Training parameters
num_episodes = 20000  # Increased number of episodes
max_steps = 5
hidden_size = 32
learning_rate = 0.001

trained_model = train_gflownet(num_episodes, max_steps, hidden_size, learning_rate)

# Generate sample trajectories using the trained model
env = GridEnvironment(max_steps)
num_samples = 1000  # Increased number of samples
sample_trajectories = [generate_trajectory(env, trained_model, max_steps)[0] for _ in range(num_samples)]

# Analyze the distribution of end states
end_states = [trajectory[-1][1] for trajectory in sample_trajectories]
unique, counts = np.unique(end_states, return_counts=True)
for y, count in zip(unique, counts):
    print(f"End state y={y}: {count/num_samples:.2%}")

Episode 1000, Loss: 2.5397, Reward: 4, LR: 0.000994
Episode 2000, Loss: 0.1439, Reward: 3, LR: 0.000976
Episode 3000, Loss: 0.1668, Reward: 5, LR: 0.000946
Episode 4000, Loss: 0.3631, Reward: 3, LR: 0.000905
Episode 5000, Loss: 0.4042, Reward: 5, LR: 0.000854
Episode 6000, Loss: 0.0863, Reward: 2, LR: 0.000794
Episode 7000, Loss: 0.0001, Reward: 5, LR: 0.000727
Episode 8000, Loss: 0.0002, Reward: 4, LR: 0.000655
Episode 9000, Loss: 0.0041, Reward: 5, LR: 0.000578
Episode 10000, Loss: 0.0071, Reward: 3, LR: 0.000500
Episode 11000, Loss: 0.0939, Reward: 2, LR: 0.000422
Episode 12000, Loss: 0.0059, Reward: 5, LR: 0.000345
Episode 13000, Loss: 0.0034, Reward: 5, LR: 0.000273
Episode 14000, Loss: 0.0206, Reward: 2, LR: 0.000206
Episode 15000, Loss: 0.0000, Reward: 4, LR: 0.000146
Episode 16000, Loss: 0.0000, Reward: 4, LR: 0.000095
Episode 17000, Loss: 0.0054, Reward: 4, LR: 0.000054
Episode 18000, Loss: 0.0183, Reward: 2, LR: 0.000024
Episode 19000, Loss: 0.0002, Reward: 5, LR: 0.000006
Ep