In [1]:
import plotly.graph_objects as go
import numpy as np
import math

In [2]:
# Exactly n
def num_trees_of_size_n(n, k):
    return math.log2(math.comb(n*k, n) // (k-1)*(n+1))

def num_trees(n, k):
    return np.sum([num_trees_of_size_n(i, k) for i in range(1, n+1) if i >= k])

In [3]:
N_MIN = 1
K_MIN = 2
N_RANGE = 100
K_RANGE = 100

n_values = np.array(range(N_MIN, N_RANGE+1), dtype=object)
k_values = np.array(range(K_MIN, K_RANGE+1), dtype=object)
z_values = np.array([[num_trees(n, k) for n in n_values] for k in k_values])

fig = go.Figure(data=[go.Surface(
    x=n_values,
    y=k_values,
    z=z_values
)])

fig.update_layout(scene = dict(
                    xaxis_title="N",
                    yaxis_title="K",
                    zaxis_title="NUM_TREES"),
                    title="Number of trees", autosize=False,
                    width=800, height=800,
                    margin=dict(l=65, r=50, b=65, t=90))

fig.show()

In [4]:
# Find the optimal point
k_idx, n_idx = np.unravel_index(np.argmax(z_values), z_values.shape)
k = k_idx + K_MIN
n = n_idx + N_MIN

print(f"The optimal K overall is {k}")

print(f"Optimal K for each N:")
for n, k in [(n_idx+N_MIN, np.argmax(row)+K_MIN) for n_idx, row in enumerate(z_values.T)]:
    print(f"n={n} k={k}")

The optimal K overall is 32
Optimal K for each N:
n=1 k=2
n=2 k=2
n=3 k=2
n=4 k=2
n=5 k=3
n=6 k=3
n=7 k=4
n=8 k=4
n=9 k=4
n=10 k=5
n=11 k=5
n=12 k=5
n=13 k=6
n=14 k=6
n=15 k=6
n=16 k=7
n=17 k=7
n=18 k=8
n=19 k=8
n=20 k=8
n=21 k=9
n=22 k=9
n=23 k=9
n=24 k=10
n=25 k=10
n=26 k=10
n=27 k=10
n=28 k=11
n=29 k=11
n=30 k=11
n=31 k=12
n=32 k=12
n=33 k=12
n=34 k=13
n=35 k=13
n=36 k=13
n=37 k=14
n=38 k=14
n=39 k=14
n=40 k=15
n=41 k=15
n=42 k=15
n=43 k=16
n=44 k=16
n=45 k=16
n=46 k=16
n=47 k=17
n=48 k=17
n=49 k=17
n=50 k=18
n=51 k=18
n=52 k=18
n=53 k=19
n=54 k=19
n=55 k=19
n=56 k=19
n=57 k=20
n=58 k=20
n=59 k=20
n=60 k=21
n=61 k=21
n=62 k=21
n=63 k=22
n=64 k=22
n=65 k=22
n=66 k=22
n=67 k=23
n=68 k=23
n=69 k=23
n=70 k=24
n=71 k=24
n=72 k=24
n=73 k=25
n=74 k=25
n=75 k=25
n=76 k=25
n=77 k=26
n=78 k=26
n=79 k=26
n=80 k=27
n=81 k=27
n=82 k=27
n=83 k=27
n=84 k=28
n=85 k=28
n=86 k=28
n=87 k=29
n=88 k=29
n=89 k=29
n=90 k=30
n=91 k=30
n=92 k=30
n=93 k=30
n=94 k=31
n=95 k=31
n=96 k=31
n=97 k=32
n=98 k=32
n=