In [None]:
# default_exp algorithm

# Algorithm
> Method by Google X

In [None]:
# exporti
from typing import NamedTuple

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from fastcore.basics import patch_to
from inverse_design.utils import batch_conv2d, dilute
from matplotlib.colors import ListedColormap

In [None]:
# export
UNASSIGNED = 0
VOID = 1
SOLID = 2
PIXEL_IMPOSSIBLE = 3
PIXEL_EXISTING = 4
PIXEL_POSSIBLE = 5
PIXEL_REQUIRED = 6
TOUCH_REQUIRED = 7
TOUCH_INVALID = 8
TOUCH_EXISTING = 9
TOUCH_VALID = 10
TOUCH_FREE = 11
TOUCH_RESOLVING = 12

In [None]:
# export
class Design(NamedTuple):
    design: jnp.ndarray
    void_pixels: jnp.ndarray
    solid_pixels: jnp.ndarray
    void_touches: jnp.ndarray
    solid_touches: jnp.ndarray
    
    @property
    def shape(self):
        return self.design.shape

In [None]:
# export
def new_design(shape):
    return Design(
        design=jnp.zeros(shape, dtype=jnp.uint8).at[:,:].set(UNASSIGNED),
        void_pixels=jnp.zeros(shape, dtype=jnp.uint8).at[:,:].set(PIXEL_POSSIBLE),
        solid_pixels=jnp.zeros(shape, dtype=jnp.uint8).at[:,:].set(PIXEL_POSSIBLE),
        void_touches=jnp.zeros(shape, dtype=jnp.uint8).at[:,:].set(TOUCH_VALID),
        solid_touches=jnp.zeros(shape, dtype=jnp.uint8).at[:,:].set(TOUCH_VALID),
    )

In [None]:
# export
def circular_brush(diameter):
    radius = diameter / 2
    X, Y = jnp.mgrid[-radius:radius:1j*diameter,-radius:radius:1j*diameter]
    _int = lambda x: jnp.array(x, dtype=int)
    brush = _int(X)**2 + _int(Y)**2 < radius**2
    return brush

In [None]:
# export
def notched_square_brush(diameter):
    if diameter != 5:
        raise NotImplementedError("Can only create notched square brush of size 5")
    radius = diameter / 2
    X, Y = jnp.mgrid[-radius:radius:1j*diameter,-radius:radius:1j*diameter]
    Z = jnp.ones_like(X)
    Z = Z.at[0,0].set(0)
    Z = Z.at[0,-1].set(0)
    Z = Z.at[-1,0].set(0)
    Z = Z.at[-1,-1].set(0)
    return Z > 0.5

In [None]:
# export
def show_mask(brush):
    nx, ny = brush.shape
    _cmap = ListedColormap(colors={0: "#ffffff", 1: "#929292"}.values())
    ax = plt.gca()
    ax.set_yticks(jnp.arange(nx)+0.5, ["" for i in range(nx)])
    ax.set_xticks(jnp.arange(ny)+0.5, ["" for i in range(ny)])
    ax.set_yticks(jnp.arange(nx), [f"{i}" for i in range(nx)], minor=True)
    ax.set_xticks(jnp.arange(ny), [f"{i}" for i in range(ny)], minor=True)
    plt.grid(True, color='k')
    plt.imshow(brush, cmap=_cmap)

In [None]:
show_mask(circular_brush(13));

In [None]:
show_mask(notched_square_brush(5))

In [None]:
# export
def visualize(design):
    _cmap = ListedColormap(colors={UNASSIGNED: "#929292", VOID: "#cbcbcb", SOLID: "#515151", PIXEL_IMPOSSIBLE: "#8dd3c7", PIXEL_EXISTING: "#ffffb3", PIXEL_POSSIBLE: "#bebada", PIXEL_REQUIRED: "#fb7f72", TOUCH_REQUIRED: "#00ff00", TOUCH_INVALID: "#7fb1d3", TOUCH_EXISTING: "#fdb462", TOUCH_VALID: "#b3de69", TOUCH_FREE: "#fccde5", TOUCH_RESOLVING: "#e0e0e0"}.values(), name="cmap")
    nx, ny = design.design.shape
    fig, axs = plt.subplots(1, 5, figsize=(15,3*nx/ny))
    for i, title in enumerate(design._fields):
        ax = axs[i]
        ax.set_title(title.replace("_", " "))
        ax.imshow(design[i], cmap=_cmap, vmin=UNASSIGNED, vmax=TOUCH_RESOLVING)
        ax.set_yticks(jnp.arange(nx)+0.5, ["" for i in range(nx)])
        ax.set_xticks(jnp.arange(ny)+0.5, ["" for i in range(ny)])
        ax.set_yticks(jnp.arange(nx), [f"{i}" for i in range(nx)], minor=True)
        ax.set_xticks(jnp.arange(ny), [f"{i}" for i in range(ny)], minor=True)
        ax.set_xlim(-0.5, ny-0.5)
        ax.set_ylim(nx-0.5, -0.5)
        ax.grid(visible=True, which="major", c="k")

@patch_to(Design)
def _repr_html_(self):
    visualize(self)
    return ""
    

In [None]:
design = new_design((6, 8)) 
design

In [None]:
brush = notched_square_brush(5)
show_mask(brush)

In [None]:
# export

def _apply_free_touches(void_touches_mask, void_pixels_mask):
    r = jnp.zeros_like(void_touches_mask, dtype=bool)
    m, n = r.shape
    i, j = jnp.arange(m), jnp.arange(n)
    I, J = [idxs.ravel() for idxs in jnp.meshgrid(i, j)]
    K = jnp.arange(m*n)
    R = jnp.broadcast_to(r[None,:,:], (m*n, m, n)).at[K,I,J].set(True)
    Rb = batch_conv2d(R, brush[None]) | void_pixels_mask
    free_idxs = (Rb == void_pixels_mask).all((1,2))
    return R[free_idxs].sum(0, dtype=bool)

def add_void_touch(design, brush, pos, apply_free_touches=True):
    void_touches_mask = design.void_touches.at[pos].set(TOUCH_EXISTING) == TOUCH_EXISTING
    mask = dilute(void_touches_mask, brush)
    diluted_mask = dilute(mask, brush)
    if apply_free_touches:
        free_touches_mask = _apply_free_touches(void_touches_mask, mask)
    else:
        free_touches_mask = void_touches_mask
    return Design(
        design=design.design.at[mask].set(VOID),
        void_pixels=design.void_pixels.at[mask].set(PIXEL_EXISTING),
        solid_pixels=design.solid_pixels.at[mask].set(PIXEL_IMPOSSIBLE),
        void_touches=design.void_touches.at[free_touches_mask].set(TOUCH_EXISTING),
        solid_touches=design.solid_touches.at[diluted_mask].set(TOUCH_INVALID),
    )

In [None]:
design = new_design((6, 8)) 
design = add_void_touch(design, brush, (0, 6))
design

In [None]:
def add_solid_touch(design, brush, pos, apply_free_touches=True):
    solid_touches_mask=design.solid_touches.at[pos].set(TOUCH_EXISTING) == TOUCH_EXISTING
    mask = dilute(solid_touches_mask, brush)
    diluted_mask = dilute(mask, brush)
    if apply_free_touches:
        free_touches_mask = _apply_free_touches(solid_touches_mask, mask)
    else:
        free_touches_mask = solid_touches_mask
    return Design(
        design=design.design.at[mask].set(SOLID),
        void_pixels=design.void_pixels.at[mask].set(PIXEL_IMPOSSIBLE),
        solid_pixels=design.solid_pixels.at[mask].set(PIXEL_EXISTING),
        void_touches=design.void_touches.at[diluted_mask].set(TOUCH_INVALID),
        solid_touches=design.solid_touches.at[free_touches_mask].set(TOUCH_EXISTING),
    )

In [None]:
design = new_design((6, 8)) 
design = add_void_touch(design, brush, (0, 6))
design = add_solid_touch(design, brush, (0, 0))
design

In [None]:
design = new_design((6, 8)) 
design = add_void_touch(design, brush, (0, 6))
design = add_solid_touch(design, brush, (0, 0))
design = add_void_touch(design, brush, (4,6), apply_free_touches=False)
design

In [None]:
r = dilute(design.void_touches == TOUCH_EXISTING, brush)
show_mask(r)

In [None]:
design

In [None]:
R = dilute(r, brush)
show_mask(R)