In [None]:
# default_exp forward

# Forward
> forward pass

In [None]:
# export
import jax
import jax.numpy as jnp
import numpy as np
from inverse_design.utils import conv2d, randn
from inverse_design.generator import (
    new_design,
    add_void_touch,
    take_free_void_touches,
    add_solid_touch,
    take_free_solid_touches,
)
from inverse_design.generator import (
    UNASSIGNED,
    VOID,
    SOLID,
    PIXEL_IMPOSSIBLE,
    PIXEL_EXISTING,
    PIXEL_POSSIBLE,
    PIXEL_REQUIRED,
    TOUCH_REQUIRED,
    TOUCH_INVALID,
    TOUCH_EXISTING,
    TOUCH_VALID,
    TOUCH_FREE,
    TOUCH_RESOLVING,
)

In [None]:
# hide
import matplotlib.pyplot as plt
from inverse_design.generator import circular_brush, notched_square_brush, show_mask
my_brush = notched_square_brush(5, 1)

## Latent Design
It's not very well explained in the paper what the latent design actually is. In this case we'll just assume it's an array of the same shape as the design, but with continuous values between 0 and 1.

In [None]:
def new_latent_design(shape, bias=0, r=None):
    arr = randn(shape, r=r)
    arr += bias
    return jnp.asarray(arr, dtype=float)

In [None]:
latent = new_latent_design((30,30), r=42)
plt.imshow(latent, vmin=-3, vmax=3, cmap="Greys")
plt.colorbar()
plt.show()

## Transform
The transform removes some of the noise from the latent design.  

In [None]:
@jax.jit
def transform(latent, brush, beta=5.0):
    return jnp.tanh(beta * conv2d(latent, jnp.asarray(brush, dtype=float)/brush.sum()))

In [None]:
latent_t = transform(latent, my_brush)
plt.imshow(latent_t, cmap="Greys", vmin=-1, vmax=1)
plt.colorbar()
plt.show()

## Generator

In [None]:
@jax.jit
def argmax2d(arr2d):
    m, n = arr2d.shape
    arr1d = arr2d.ravel()
    k = jnp.argmax(arr1d)
    return k//m, k%m

In [None]:
brush = my_brush
design = new_design(latent_t.shape)

In [None]:
while (design.design == UNASSIGNED).any():
    solid_touch_mask = design.solid_touches == TOUCH_VALID
    void_touch_mask = design.solid_touches == TOUCH_VALID
    free_mask = (design.solid_touches == TOUCH_FREE) | (
        design.void_touches == TOUCH_FREE
    )
    resolving_mask = (design.solid_touches == TOUCH_RESOLVING) | (
        design.void_touches == TOUCH_RESOLVING
    )

    if free_mask.any():
        print("taking free pixels...")
        if (design.solid_touches == TOUCH_FREE).any():
            design = take_free_solid_touches(design, brush)
        else:
            design = take_free_void_touches(design, brush)
    elif resolving_mask.any():
        print("resolving required pixels...")
        selector = jnp.abs(latent_t)
        selector = jnp.where(resolving_mask, selector, 0)
        i, j = argmax2d(selector)
        v = latent_t[i, j]
        if v > 0:
            design = add_solid_touch(design, brush, (i, j))
        else:
            design = add_void_touch(design, brush, (i, j))
    else:
        print("touching...", end=" ")
        selector = jnp.abs(latent_t)
        void_selector = jnp.where(void_touch_mask, selector, 0)
        solid_selector = jnp.where(solid_touch_mask, selector, 0)
        i_v, j_v = argmax2d(void_selector)
        v = latent_t[i_v, j_v]
        i_s, j_s = argmax2d(void_selector)
        print((int(i_v), int(j_v)), (int(i_s), int(j_s)))
        s = latent_t[i_s, j_s]
        if v > s:
            design = add_void_touch(design, brush, (i_v, j_v))
        else:
            design = add_solid_touch(design, brush, (i_s, j_s))

In [None]:
design

In [None]:
latent_t = transform(latent, my_brush)
plt.imshow(latent_t, cmap="Greys", vmin=-1, vmax=1)
plt.colorbar()
plt.show()