# Johnson-Mehl tessellation picture

I'd like to draw a picture of the Johnson-Mehl tessellation and a Voronoi tessellation so the viewer can compare them.

One interesting fact which I hadn't realised until trying to draw these pictures: the boundaries between cells in the JM tessellation _aren't straight_.

I've used the algorithm described by Moulinec in ["A simple and fast algorithm for computing discrete Voronoi, Johnson-Mehl or Laguerre diagrams of points"](https://www.sciencedirect.com/science/article/pii/S0965997822000618). Mainly because it's simple, although being fast is also an advantage.

## To do:
1. Draw the "tessellation" when the radii are too small to cover the whole space, with uncovered areas. This is a pretty minor change.

## Outline
1. Sample the arrival times and locations, and prune them.
2. Assign each pixel in an image to its appropriate cell, based on which Johnson-Mehl seed grows to cover it first.
3. Compute adjacency: I want to know which cells border each other so I can colour them appropriately.
4. Colour the cells nicely. It's a planar map, so a four-colouring exists, but I think I'll be a bit simpler and use a greedy colouring, so no two neighbouring cells share a colour but maybe the number of colours isn't completely optimal.

In [None]:
import numpy as np
from scipy.spatial import KDTree
from unconstrained import sample_points, prune_arrivals
from tqdm import trange
import timeit

### Sampling arrival times and pruning

In [None]:
def get_arrival_times( rho, max_time=1.0, R=0 ):
    # PROBLEM (major-ish):
    # The "start again if there are more than Nmax arrivals" method
    # means our arrival times don't exactly have the distribution of homogeneous
    # Poisson arrivals. Instead I suppose I should generate some new samples.
    rate = rho*(1+2*R)**2
    Nmax = int(max_time*rate + 2*np.sqrt(max_time*rate)) # Two standard deviations above the mean
    interarrival_times = np.random.exponential(scale=1/rate,size=Nmax)
    arrival_times = np.cumsum(interarrival_times)
    too_late = np.searchsorted(arrival_times,max_time,side='right') # First index where the arrival time is at least max_time
    while too_late == Nmax: # This will be the case if we are unlucky and Nmax points arrived before time max_time. We'll just generate more points.
        interarrival_times = np.append(interarrival_times, np.random.exponential(scale=1/rate,size=Nmax))
        arrival_times = np.cumsum(interarrival_times)
        too_late = np.searchsorted(arrival_times,max_time,side='right') # First index where the arrival time is at least max_time
    return arrival_times[:too_late].copy()
    # N = np.random.poisson(lam=rho*max_time*(1+2*R)**2)
    # return np.sort(np.random.uniform(low=0.0, high=max_time, size=N))
# # Needed redefining because of my foolishly using a global variable (rng)
# # to define the version of this function in unconstrained.py.

In [None]:
rho = 10

times = get_arrival_times(rho)
seeds = sample_points(len(times))
arrived = prune_arrivals(times, seeds)
print(f'{len(arrived)} out of {len(times)} seeds germinated.')
times = times[arrived]
seeds = seeds[arrived]

If I try to take $\rho$ larger than around $5 \times 10^7$, the kernel dies (presumably from running out of memory).

Is it possible to add points in batches? It should be... We just need to merge two sorted lists (easy) and rearrange the list of locations to match.

In [None]:
def merge_jm_arrivals(t1,l1,t2,l2):
    """
    Given two sets of arrival times and locations from a time-homogeneous PPP,
    merges them into a single pair.

    The arguments are all numpy arrays, and both t1 and t2 should be sort
    """
    totallen = len(t1)+len(t2)
    outtimes = np.empty(totallen)
    outseeds = np.empty((totallen,2))
    i1 = 0
    i2 = 0
    while i1 < len(t1) and i2 < len(t2):
        if t1[i1] < t2[i2]:
            outtimes[i1+i2] = t1[i1]
            outseeds[i1+i2] = l1[i1]
            i1 += 1
        else:
            outtimes[i1+i2] = t2[i2]
            outseeds[i1+i2] = l2[i2]
            i2 += 1
    if i1 == len(t1):
        outtimes[i1+i2:] = t2[i2:]
        outseeds[i1+i2:] = l2[i2:]
    else:
        outtimes[i1+i2:] = t1[i1:]
        outseeds[i1+i2:] = l1[i1:]
    return outtimes, outseeds

In [None]:
batch_rho = 1.0e5
n_batches = 100
rho = batch_rho * n_batches
max_time = 1.5*( (2*np.log(rho) + 4*np.log(np.log(rho))) / (np.pi*rho) )**(1/3)
display(f'Running until max time {max_time:.5f}.')
times = get_arrival_times(batch_rho,max_time=max_time)
seeds = sample_points(len(times))
arrived = prune_arrivals(times, seeds)
times = times[arrived]
seeds = seeds[arrived]
progress = trange(n_batches-1)
for i in progress:
    progress.set_description("Finding new arrivals")
    new_times = get_arrival_times(batch_rho,max_time=max_time)
    new_seeds = sample_points(len(new_times))
    progress.set_description("Pruning new arrivals")
    arrived = prune_arrivals(new_times, new_seeds)
    new_times = new_times[arrived]
    new_seeds = new_seeds[arrived]
    progress.set_description("Merging all arrivals")
    times, seeds = merge_jm_arrivals(times,seeds,new_times,new_seeds) # Is it faster to just stick all the arrays together and sort them at the end? (Since we don't prune in the middle any more.)
    progress.set_description("Merged. Weird pause.")
    # if len(times) >= 1000000:
    #     progress.set_description("We have a huge list, pruning as an intermediate step...")
    #     arrived = prune_arrivals(times,seeds)
    #     times = times[arrived]
    #     seeds = seeds[arrived]
print("Arrivals all generated, now for the last pruning...") # There's a weird pause after the loop but before this message is printed. Not sure why.
arrived = prune_arrivals(times,seeds)
times = times[arrived]
seeds = seeds[arrived]
print(f'We have a total of {len(seeds)} arrivals with rate {rho} (that\'s {rho:.0e}).')

#### Idea to speed this up a bit more:
Currently we merge and prune an increasingly large list.
A better idea might be a sort of binary recursive structure.
Merge generation 1 arrival processes to get the generation 2 processes,
then merge the generation 2 processes, etc.

In [None]:
def recursive_sampling(batch_rho, generations, prune_limit=1000000):
    pass

### Assigning pixels to their cells

This is the bit using Moulinec's method. Moulinec has two separate steps: first assigning the pixels which are covered by time $T$, then the pixels which were not covered by time $T$. He chooses $T$ to optimise the speed of the algorithm. We can simplify the algorithm by choosing $T$ to be the coverage time, then there is no second step.

---

The algorithm works as follows: we start with an array $\mathcal{D}$ of "running minimum coverage times" and an array $\mathcal{I}$ of assignments, both the same shape as the output image. We intialise $\mathcal{D}$ to be full of $\infty$. We order the seeds $x_1, \dots, x_N$ with corresponding arrival times $t_1, \dots, t_N$.

Then for each $i = 1, \dots, N$ in turn: for every pixel $y$ in the ball centred at $x_i$ of radius $T-t_i$, this pixel was first reached by seed $i$ at time $\| x_i - y \| + t_i$. If $\| x_i - y \| + t_i < \mathcal{D}(y)$, then we set $\mathcal{I}(y) = i$ (overwriting its previous value if it had one) and set the new running minimum $\mathcal{D}(y) = \| x_i - y \| + t_i$.

Once we have done this for all $N$ seeds, every pixel is correctly assigned.

In [None]:
def get_ball_pixels(centre, radius, img_size):
    """
    Returns the indices of the pixels in the picture
    corresponding to a ball centred at a point in [0,1]^2
    of a given radius.
    Also saves the corresponding (squared) distances.
    
    I suspect a numpy-ish method would be faster:
    create a 2d array containing the (squared) distance between each point in [min_i,max_i]x[min_j,max_j]
    and v, then turn that into an array of bools which we can return along with the distances.
    We might need to also then return (min_i, min_j) so the bool array can be aligned within the image.    
    """
    if radius <= 0:
        return [], []
    v = (img_size-1)*centre
    x,y = v[0], v[1]
    r = (img_size-1)*radius
    r2 = r*r
    min_i = max( 0, int(x-r) )
    max_i = min( img_size-1, int(x+r)+1 )
    min_j = max( 0, int(y-r) )
    max_j = min( img_size-1, int(y+r)+1 )
    in_ball = []
    sq_distances = []
    for i in range(min_i, max_i+1):
        dx2 = (x-i)*(x-i)
        if dx2 > r2:
            continue
        w = np.sqrt( r2 - dx2 )
        for j in range(max(int(y-w),min_j), min(int(y+w)+2,max_j+1)):
            d2 = dx2 + (y-j)**2
            if d2 <= r2:
                in_ball.append((i,j))
                sq_distances.append(d2)
    return in_ball, sq_distances

def assign_cells( seeds, times, img_size, T=1.0 ):
    """
    Assigns all the pixels in an img_size x img_size picture
    to their respective Johnson-Mehl cells.
    T should be a decent upper bound on the coverage time - smaller T
    means we check fewer points.
    This is a modified version of Moulinec's algorithm,
    in which we assign things which were covered by time T,
    and leave the rest unassigned.
    """
    min_cov_times = np.full((img_size,img_size),np.inf) # running minimum coverage times
    assignments = np.full((img_size,img_size),-1,dtype=int) # everything uncovered is assigned to a separate class.

    for i in trange(len(times)):
        xi = seeds[i]
        ti = times[i]
        indices, d2s = get_ball_pixels(xi, T-ti, img_size)
        for k, ij_pair in enumerate(indices):
            cov_time = np.sqrt(d2s[k])/img_size + ti
            if cov_time < min_cov_times[ij_pair]:
                assignments[ij_pair] = i
                min_cov_times[ij_pair] = cov_time
    return assignments

In [None]:
img_size = 1920

max_time = 1.5*( (2*np.log(rho) + 4*np.log(np.log(rho))) / (np.pi*rho) )**(1/3)
I = assign_cells(seeds, times, img_size, T=max_time)
# print(I)

### Computing adjacency

The method is easy: for each pixel check if its cell differs from the one below and the one to the right. If they differ, then record the pair of cell IDs in the adjacency matrix. This might be a little slow, but this is miles faster than assigning the pixels in the first place. As long as the resolution is high enough this will, with high probability, give us the correct adjacency structure.

In [None]:
import networkx # Contains a Graph object which we'll use to store the cell structure.
def get_adjacency(cell_assignments, blanklabel=-1):
    G = networkx.Graph()
    #G.add_nodes_from(range(cell_assignments.max()+1)) # Uncomment this to include cells with zero pixels
    N = cell_assignments.shape[0]
    for i in range(N-1): # All columns except the last
        for j in range(N-1): # All rows except the last
            G.add_edge(cell_assignments[i,j], cell_assignments[i+1,j])
            G.add_edge(cell_assignments[i,j], cell_assignments[i,j+1])
        G.add_edge(cell_assignments[i,N-1],cell_assignments[i+1,N-1])
    for j in range(N-1):
        G.add_edge(cell_assignments[N-1,j],cell_assignments[N-1,j+1])
    G.remove_edges_from(networkx.selfloop_edges(G)) # Not necessary for the colouring but if we want to look at the graph structure it makes it a bit cleaner.
    G.remove_nodes_from([blanklabel]) # If there are uncovered cells, remove them from the adjacency graph.
    return G

In [None]:
cell_structure = get_adjacency(I)
# networkx.draw(cell_structure, with_labels=False, node_size=20)
print(cell_structure)

In [None]:
def colour_graph(G):
    """
    Uses a greedy algorithm to colour G.
    The "colours" are just integers, which can be replaced
    with a suitable set of colours when drawing the picture later.
    Even with a few thousand cells I've never seen it use more than
    7 colours.

    Returns a dictionary indexed by the elements of G.nodes
    """
    cells = list(G.nodes).copy()
    colours = dict.fromkeys(G.nodes)
    
    np.random.shuffle(cells)
    for cell in cells:
        new_colour = 0
        while new_colour in [colours[v] for v in G.neighbors(cell)]:
            new_colour += 1
        colours[cell] = new_colour
    return colours

In [None]:
# # Normally it's possible to find a 5-colouring in a few thousand tries, which is pretty quick.
# # There is a non-zero (but rather small) probability that you'll get a 4-colouring, if you're feeling patient.
colours = colour_graph(cell_structure)

# i=1
# while max(colours.values())+1 > 6:
#     colours = colour_graph(cell_structure)
#     i+=1
# print(f'{i} attempts to get a {max(colours.values())+1}-colouring.')

print(f'We have a {len(set(colours.values()))}-colouring of the cells.')
# networkx.draw(cell_structure, node_size=50, node_color=list(colours.values()))

Next we pick suitable colours.

I might two independent colourings of the cells,
so we have colours of the same luminosity and change the brightnesses.
This means it will be a colourful diagram on the screen but will still have a valid colouring when printed in greyscale.

In [None]:
import colorspace
c = colorspace.hcl_palettes().get_palette(name="Reds 2")

In [None]:
from PIL import Image, ImageColor

hex_colours = c(max(colours.values())+1) # The last colour is for unassigned regions.
rgb_colours = [ImageColor.getcolor(col,"RGB") for col in hex_colours]
bg_colour = (252, 15,192)

data = np.full((img_size, img_size, 3),0, dtype=np.uint8)
N = I.shape[0]
for i in range(N):
    for j in range(N):
        if I[i,j] >= 0:
            data[i,j,:] = rgb_colours[colours[I[i,j]]]
        else:
            data[i,j,:] = bg_colour

image = Image.fromarray(data)
# image.show() # opens in system image viewer
display(image)

In [None]:
image.show()