In [None]:
# default_exp forward

# Forward
> forward pass

In [None]:
# export
import jax
import jax.numpy as jnp
import numpy as np
from inverse_design.utils import conv2d, randn
from inverse_design.generator import (
    new_design,
    add_void_touch,
    take_free_void_touches,
    add_solid_touch,
    take_free_solid_touches,
)
from itertools import count
from inverse_design.generator import (
    UNASSIGNED,
    VOID,
    SOLID,
    PIXEL_IMPOSSIBLE,
    PIXEL_EXISTING,
    PIXEL_POSSIBLE,
    PIXEL_REQUIRED,
    TOUCH_REQUIRED,
    TOUCH_INVALID,
    TOUCH_EXISTING,
    TOUCH_VALID,
    TOUCH_FREE,
    TOUCH_RESOLVING,
)

In [None]:
# hide
import matplotlib.pyplot as plt
from inverse_design.generator import circular_brush, notched_square_brush, show_mask
my_brush = notched_square_brush(5, 1)
my_brush = circular_brush(3)
show_mask(my_brush)

## Latent Design
It's not very well explained in the paper what the latent design actually is. In this case we'll just assume it's an array of the same shape as the design, but with continuous values between 0 and 1.

In [None]:
def new_latent_design(shape, bias=0, r=None):
    arr = randn(shape, r=r)
    arr += bias
    return jnp.asarray(arr, dtype=float)

In [None]:
latent = new_latent_design((30,30), r=42)
plt.imshow(latent, vmin=-3, vmax=3, cmap="Greys")
plt.colorbar()
plt.show()

## Transform
The transform removes some of the noise from the latent design.  

In [None]:
@jax.jit
def transform(latent, brush, beta=5.0):
    return jnp.tanh(beta * conv2d(latent, jnp.asarray(brush, dtype=float)/brush.sum()))

In [None]:
latent_t = transform(latent, my_brush)
plt.imshow(latent_t, cmap="Greys", vmin=-1, vmax=1)
plt.colorbar()
plt.show()

## Generator

In [None]:
@jax.jit
def argmax2d(arr2d):
    m, n = arr2d.shape
    arr1d = arr2d.ravel()
    k = jnp.argmax(arr1d)
    return k//m, k%m

In [None]:
@jax.jit
def argmin2d(arr2d):
    m, n = arr2d.shape
    arr1d = arr2d.ravel()
    k = jnp.argmin(arr1d)
    return k//m, k%m

In [None]:
brush = my_brush
design = new_design(latent_t.shape)

In [None]:
I = 0 

In [None]:
for I in count():
    void_touch_mask = design.void_touches == TOUCH_VALID
    solid_touch_mask = design.solid_touches == TOUCH_VALID
    touch_mask = void_touch_mask | solid_touch_mask
    
    void_free_mask = design.void_touches == TOUCH_FREE
    solid_free_mask = design.solid_touches == TOUCH_FREE
    free_mask = void_free_mask | solid_free_mask
    
    void_resolving_mask = design.void_touches == TOUCH_RESOLVING
    solid_resolving_mask = design.solid_touches == TOUCH_RESOLVING
    resolving_mask = void_resolving_mask | solid_resolving_mask

    if I == -1:
        break
    if free_mask.any():
        void_selector = jnp.where(void_free_mask, latent_t, 0)
        solid_selector = jnp.where(solid_free_mask, latent_t, 0)
        if abs(void_selector.sum()) > abs(solid_selector.sum()):
            design = take_free_void_touches(design, brush)
            print(f"{I} take free void...")
        else:
            design = take_free_solid_touches(design, brush)
            print(f"{I} take free solid...")
    elif resolving_mask.any():
        void_needs_resolving = void_resolving_mask.any()
        solid_needs_resolving = solid_resolving_mask.any()
        void_selector = jnp.where(void_resolving_mask, latent_t, np.inf)
        solid_selector = jnp.where(solid_resolving_mask, latent_t, -np.inf)
        
        if void_needs_resolving and (not solid_needs_resolving):
            i_v, j_v = argmin2d(void_selector)
            design = add_void_touch(design, brush, (i_v, j_v))
            print(f"{I} resolve void {int(i_v), int(j_v)}...")
        elif (not void_needs_resolving) and solid_needs_resolving:
            i_s, j_s = argmax2d(solid_selector)
            design = add_solid_touch(design, brush, (i_s, j_s))
            print(f"{I} resolve solid {int(i_s), int(j_s)}...")
        else: # both need resolving. TODO: figure out if we actually need this case...
            i_v, j_v = argmin2d(void_selector)
            v = latent_t[i_v, j_v]
            i_s, j_s = argmax2d(solid_selector)
            s = latent_t[i_s, j_s]
            if abs(v) > abs(s):
                design = add_void_touch(design, brush, (i_v, j_v))
                print(f"{I} resolve void {int(i_v), int(j_v)}...")
            else:
                design = add_solid_touch(design, brush, (i_s, j_s))
                print(f"{I} resolve solid {int(i_s), int(j_s)}...")
                
    elif touch_mask.any():
        void_selector = jnp.where(void_touch_mask, latent_t, np.inf)
        solid_selector = jnp.where(solid_touch_mask, latent_t, -np.inf)
        i_v, j_v = argmin2d(void_selector)
        v = latent_t[i_v, j_v]
        i_s, j_s = argmax2d(solid_selector)
        s = latent_t[i_s, j_s]
        if abs(v) > abs(s):
            design = add_void_touch(design, brush, (i_v, j_v))
            print(f"{I} touch void  {int(i_v), int(j_v)}...")
        else:
            design = add_solid_touch(design, brush, (i_s, j_s))
            print(f"{I} touch solid  {int(i_s), int(j_s)}...")
    else:
        break

In [None]:
design

In [None]:
latent_t = transform(latent, my_brush)
plt.imshow(latent_t, cmap="Greys", vmin=-1, vmax=1)
plt.colorbar()
plt.show()