In [None]:
import numpy as np
import pylab as plt
import time

start_time = time.time()
starts = np.random.choice(1_000_000, 5000, replace=False)
chain = np.zeros((10_000, 3), dtype=np.int64)
counter = 0

# Loop over starting numbers
for num, start in enumerate(starts):
    print(".", end="", flush=True)
    if (num + 1) % 100 == 0:
        print("", flush=True)

    # Collatz iterations
    position = start
    while position > 1:
        if position % 2 == 0:
            position_new = position / 2
        else:
            position_new = position * 3 + 1

        # Check if we already had this transition before
        ind = (chain[:, 0] == position_new) & (chain[:, 1] == position)
        if np.sum(ind) == 0:
            # If not, add it with counter 1
            chain[counter] = [position_new, position, 1]
            counter += 1

            # If needed, increase the buffer
            if counter >= len(chain):
                chain = np.concatenate((chain, np.zeros((10_000, 3)).astype(int)))
        else:
            # If yes, increase the counter
            chain[ind, 2] += 1
        position = position_new

chain = chain[:counter]
ind = np.argsort(chain[:, 0])
chain = chain[ind]

print("")
print("Tree size: ", chain.shape)
print("")
print("First 10 edges:")
print(chain[:10])

plt.figure(figsize=(8, 4.5), layout="constrained", facecolor="black")
ax = plt.gca()
ax.set_facecolor("black")

chain_ext = np.zeros((chain.shape[0], 6)).astype(float)
chain_ext[:, :3] = chain

def draw(start, pos_beg, angle):
    ind = chain[:, 0] == start
    if np.sum(ind) == 0:
        return
    else:
        for i, link in enumerate(chain[ind]):
            if link[1] % 2 == 0:
                new_angle = angle + np.deg2rad(8.65)  # .15
            else:
                new_angle = angle - np.deg2rad(16)  # .28

            pos_end = pos_beg + [np.cos(new_angle), np.sin(new_angle)] / np.log(link[1])
            col = np.log1p(link[2]) / np.log1p(np.max(chain[:, 2]))
            plt.plot( [pos_beg[0], pos_end[0]], [pos_beg[1], pos_end[1]], lw=2 * col, c=plt.cm.spring(1 - col), zorder=col, )

            chain_ext[np.where(ind)[0][i], 3:] = [ pos_end[0], pos_end[1], np.log1p(link[2]) / np.log1p(np.max(chain[:, 2])), ]

            # Recursion
            draw(link[1], pos_end, new_angle)

draw(1, np.array([0, 0]), 0)

plt.axis("off")
plt.savefig("collatz-1mln-5000.png", dpi=300, facecolor="black")
plt.savefig("collatz.svg", format="svg", facecolor="black")
print(f"Elapsed time: {time.time() - start_time:.2f} seconds")

....................................................................................................
....................................................................................................
....................................................................................................
....................................................................................................
....................................................................................................
....................................................................................................
....................................................................................................
....................................................................................................
....................................................................................................
...........................................................................................