In [None]:
import torch
import numpy as np
import cv2
import ipywidgets as widgets
import asyncio
import matplotlib.pyplot as plt
import skimage

torch.set_grad_enabled(False)

In [None]:
#screen_h, screen_w = 1440, 2560
#screen_h, screen_w = 1440//4, 2560//4
screen_h, screen_w = 200, 200


if torch.cuda.is_available():
    device = torch.device("cuda")
    print("GPU is available")
else:
    device = torch.device("cpu")
    print("GPU not available, using CPU instead")


def NormalizeEnergy(energy, width=3):
    mean = torch.mean(energy.flatten())
    std = torch.std(energy.flatten())
    h_cutoff = mean + std*width
    energy_normed = (torch.clip(energy, 0, h_cutoff))/(h_cutoff)
    return energy_normed

def NormalizeSine(arr):
    return (arr + 1.0)/2.0
    
def Colorize(arr):
    '''
    Assumes a tensor of shape (h, w, 1) and outputs 
    (h, w, (rgb))
    '''
    # First convert to CIELAB space
    lab = np.stack([60*np.ones_like(arr), 128*np.cos(arr), 128*np.sin(arr)], axis=-1)
    rgb = skimage.color.lab2rgb(lab)
    return rgb
    

In [None]:
# Instantiate convolution dict
convs = {}

size = 3
grow = 3
for i in range(5):
    conv = torch.nn.Conv2d(1, 1, size, bias=False, padding='same', padding_mode='circular', device=device)
    if(i == 0):
        k = torch.ones((size, size)).to(device)
        k = k /(torch.sum(k) - 1.0)
    else:
        k = torch.zeros((last_size, last_size)).to(device)
        k = torch.nn.functional.pad(k, (grow, grow, grow, grow), value=1.0)
        k = k /(torch.sum(k))
    
    
    k[k.shape[0]//2, k.shape[0]//2] = -1
    print(size, k.shape, conv.weight.shape, k.sum())
    conv.weight = torch.nn.Parameter(k[None, None, ...])
    convs[k.shape[0]] = conv
    last_size = size
    size += grow*2
    print(k)

In [None]:
# Init the grid
phase = (torch.rand(screen_h, screen_w)*10).to(device)
frequency = 0.1
t = 0

# Coupling constants for each kernel
coupling_constants = 0.1*(torch.rand(len(convs.keys())) - 0.5)

while True:
    # Calculate the oscillator's positions
    img_out = torch.sin(frequency*t + phase)
    t += 1
    # Calculate the errors
    error_sum = torch.zeros_like(img_out)
    for k_idx, k_key in enumerate(convs):
        error = torch.sin(convs[k_key](phase[None, None, ...]))
        
        error = torch.squeeze(error)
        
        
        error_sum -= error*coupling_constants[k_idx]
        print(t, torch.abs(error_sum).mean())
    phase += error_sum
    
    if((t % 10) == 0):
        img_out_waves = Colorize(img_out.cpu().detach().numpy())
        img_out_phases = Colorize(phase.cpu().detach().numpy())
        cv2.imshow('Waves', img_out_waves)
        cv2.imshow('Phases', img_out_phases)
        cv2.waitKey(1)
        
        
