In [None]:
# import plotly.graph_objects as go
# from ipywidgets import interact, IntSlider

# plt.style.use("dark_background")

# @interact(t=IntSlider(min=0, max=len(zs)-1, step=1, value=0))
# def update(t):
#     plt.figure(figsize = (10,10))
#     plt.imshow(zs[t], extent=[-1, 1, -1, 1])


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import imageio
import ipywidgets

def draw_gif(zs):
    # rescale each image to 0-255 and convert to uint8
    images_scaled = []
    for image in zs:
        image = (image - np.min(image)) / (np.max(image) - np.min(image)) * 255
        images_scaled.append(image.astype(np.uint8))

    # make the images bigger
    stretch = 1
    images_scaled = [np.repeat(np.repeat(image, stretch, axis=0), stretch, axis=1) for image in images_scaled]

    imageio.mimsave('movie.gif', images_scaled, format='gif', fps=6)

    with open("movie.gif", "rb") as file:
        # read file as string into `image` 
        image = file.read()

    return ipywidgets.Image(
        value=image,
        format='gif',
    )

In [None]:
size = 150
# create 2D numpy mesh
x_range = np.linspace(-1, 1, size)
y_range = np.linspace(-1, 1, size)
X, Y = np.meshgrid(x_range, y_range)

Z = np.random.rand(size, size)

In [None]:
alpha = 4
wavelet_list = [
    # (0, 30, 0.1),
    # (15, 26, 0.1),
    # (-15, 26, 0.1),
    (0, 80, 0.1),
    (80, 0, 0.1),
]

masks = []
for a, b, deviation in wavelet_list:
    cos_mask = np.cos((X) * a + (Y) * b)
    gauss_mask = np.exp(-((X) ** 2 + (Y) ** 2) / (2 * deviation ** 2))
    regularized_gauss_mask = gauss_mask / np.sum(gauss_mask)
    masks.append(cos_mask * regularized_gauss_mask)

masks_tensor = np.zeros([len(masks), Z.shape[0], Z.shape[1], Z.shape[0], Z.shape[1]])
for i in range(len(x_range)):
    for j in range(len(y_range)):
        for k in range(len(masks)):
            masks_tensor[k, i, j, :, :] = np.roll(np.roll(masks[k], i, axis=0), j, axis=1)

def update_z(Z, alpha):
    updates = np.zeros_like(Z)
    forces = np.tensordot(masks_tensor, Z, axes=2)
    updates = np.sum(forces ** 5, axis=0)
    updates = np.abs(updates) ** (1/5) # * np.sign(update)
    
    # regularize updates
    updates = updates / np.std(updates)
    return Z + updates * alpha

zs = [Z]
for i in range(20):
    new_z = update_z(zs[-1], alpha)
    zs.append(new_z)

draw_gif(zs)