In [None]:
import plotly.graph_objects as go
import numpy as np
import math

In [None]:
# Exactly n
def num_trees_of_size_n(n, k):
    return math.log2(math.comb(n*k, n) // (k-1)*(n+1))

def num_trees(n, k):
    return np.sum([num_trees_of_size_n(i, k) for i in range(1, n+1) if i >= k])

In [None]:
N_MIN = 1
K_MIN = 2
N_RANGE = 100
K_RANGE = 100

n_values = np.array(range(N_MIN, N_RANGE+1), dtype=object)
k_values = np.array(range(K_MIN, K_RANGE+1), dtype=object)
z_values = np.array([[num_trees(n, k) for n in n_values] for k in k_values])

fig = go.Figure(data=[go.Surface(
    x=n_values,
    y=k_values,
    z=z_values
)])

fig.update_layout(scene = dict(
                    xaxis_title="N",
                    yaxis_title="K",
                    zaxis_title="NUM_TREES"),
                    title="Number of trees", autosize=False,
                    width=800, height=800,
                    margin=dict(l=65, r=50, b=65, t=90))

fig.show()

In [None]:
# Find the optimal point
k_idx, n_idx = np.unravel_index(np.argmax(z_values), z_values.shape)
k = k_idx + K_MIN
n = n_idx + N_MIN

print(f"The optimal K overall is {k}")

print(f"Optimal K for each N:")
for n, k in [(n_idx+N_MIN, np.argmax(row)+K_MIN) for n_idx, row in enumerate(z_values.T)]:
    print(f"n={n} k={k}")