In [None]:
import networkx as nx
import matplotlib.pyplot as plt

# Parameters
S0 = 4  # initial stock price
u = 2  # up factor
d = 0.5  # down factor
T = 3  # number of periods


# Function to compute children nodes given a parent node (time, S, m)
def next_nodes(node):
    t, S, m = node
    next_time = t + 1
    # Up branch: new stock price and updated maximum
    S_up = S * u
    m_up = max(m, S_up)
    # Down branch: new stock price and updated maximum
    S_down = S * d
    m_down = max(m, S_down)
    return [(next_time, S_up, m_up), (next_time, S_down, m_down)]


# Build the tree: nodes stored as (time, S, m) and list edges
nodes = {}  # dictionary to store node labels
edges = []  # list to store parent-child relations

# Start at time 0
root = (0, S0, S0)
nodes[root] = f"({S0:.1f},{S0:.1f})"
current_level = [root]

# Generate tree for T periods
while current_level:
    next_level = []
    for node in current_level:
        t, S, m = node
        if t < T:
            for child in next_nodes(node):
                edges.append((node, child))
                if child not in nodes:
                    # Label nodes as (S, m)
                    nodes[child] = f"({child[1]:.1f},{child[2]:.1f})"
                next_level.append(child)
    current_level = next_level

# Create a directed graph with NetworkX
G = nx.DiGraph()
for node, label in nodes.items():
    # For plotting, use time as x-coordinate and negative S as y-coordinate (to visualize tree downward)
    G.add_node(node, label=label, pos=(node[0], -node[1]))
G.add_edges_from(edges)

# Extract positions and labels from node attributes
pos = nx.get_node_attributes(G, "pos")
labels = nx.get_node_attributes(G, "label")

# Plot the binomial tree
plt.figure(figsize=(8, 6))
nx.draw(
    G,
    pos,
    with_labels=True,
    labels=labels,
    node_size=800,
    node_color="lightblue",
    arrows=True,
)
plt.title("Three-Period Binomial Tree")
plt.xlabel("Time")
plt.ylabel("Stock Price (inverted for visualization)")
plt.show()