Train a neural network to fit an earth texture on a unit sphere and export GLSL code.

In [65]:
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch import nn, tensor
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets
from torchvision.transforms import transforms, ToTensor
import torchvision.utils as vutils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Load the earth texture. Average two UV images.

In [66]:
import requests

def load_image(url):
    req = requests.get(url, headers={
        'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.51 Safari/537.36',
        'accept': '*/*',
        'accept-language': 'en-US,en;q=0.5',
        'cache-control': 'no-cache',
        'pragma': 'no-cache'
    })
    print(req.status_code)
    with open("temp", 'wb') as fp:
        fp.write(req.content)
    image = Image.open("temp").convert("RGB")
    image = np.array(image.resize((512, 256))) / 255.0
    return image.astype(np.float32)

earth_img_1 = load_image("https://www.solarsystemscope.com/textures/download/2k_earth_daymap.jpg")
earth_img_2 = load_image("https://eoimages.gsfc.nasa.gov/images/imagerecords/73000/73801/world.topo.bathy.200409.3x5400x2700.jpg")
earth_img = 0.5 * (earth_img_1 + earth_img_2)
plt.imshow(earth_img)
plt.show()

In [67]:
class PixelDataSet(Dataset):
    def __init__(self, image):
        self.shape = image.shape[:2]
        self.image = image
        # generate points on an unit sphere
        h, w = self.shape
        u = 2.0*np.pi*((np.arange(w)+0.5)/w-0.5)
        v = -np.pi*((np.arange(h)+0.5)/h-0.5)
        x = np.outer(np.cos(v), np.cos(u))
        y = np.outer(np.cos(v), np.sin(u))
        z = np.outer(np.sin(v), np.ones(len(u)))
        w = np.outer(np.cos(v), np.ones(len(u)))  # arial element
        coords = np.einsum("nab->abn", [x, y, z])
        w = np.einsum("nab->abn", [w, w, w])
        self.coords = coords.astype(np.float32)
        self.weights = w.astype(np.float32)
        #print(self.shape, self.image.shape, self.coords.shape, self.weights.shape)

    def __len__(self):
        return np.prod(self.image.shape[:2])

    def __getitem__(self, i):
        i, j = i // self.shape[1], i % self.shape[1]
        return [self.image[i][j],
                self.coords[i][j],
                self.weights[i][j]]

for pixel, coord, weight in DataLoader(
    PixelDataSet(earth_img),
    batch_size=16,
    shuffle=True
):
    print(pixel.dtype, coord.dtype, weight.dtype)
    print(pixel.shape, coord.shape, weight.shape)
    break

Define model.

Hidden layers: sine activation

Output layer: sigmoid activation

In [68]:
class Siren(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, x):
        return torch.sin(x)

class Model(nn.Module):
    def __init__(self, hidden_layers):
        super().__init__()
        layers = [3] + hidden_layers + [3]
        sequence = []
        for i in range(len(layers)-2):
            sequence += [
                nn.Linear(layers[i], layers[i+1]),
                Siren()
            ]
        sequence += [
            nn.Linear(layers[-2], layers[-1]),
            nn.Sigmoid()
        ]
        self.main = nn.Sequential(*sequence)
    
    def forward(self, x):
        return self.main(x)


model = Model([12, 12, 12, 8]).to(device)

Loss function

In [69]:
def lossFun(output, expected_output, weight):
    diff = expected_output - output
    return torch.sum(weight*diff**2)/torch.sum(weight)

Training - gradient descent

In [70]:
dataloader = list(DataLoader(
    PixelDataSet(earth_img),
    batch_size=64,
    shuffle=True
))

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=0.005, betas=(0.9, 0.999))

print("ADAM")
count = 0
for epoch in range(1, 10+1):
    print("Epoch", epoch)
    for pixel, coord, weight in dataloader:
        pixel = pixel.to(device)
        coord = coord.to(device)
        weight = weight.to(device)
        optimizer.zero_grad()
        output = model(coord)
        loss = lossFun(output, pixel, weight)
        loss.backward()
        optimizer.step()
        
        count += 1
        if count % 1024 == 0:
            print("Iteration {} - loss = {}".format(count, loss.item()))

Train - BFGS

In [71]:
dataloader = list(DataLoader(
    PixelDataSet(earth_img),
    batch_size=int(np.prod(earth_img.shape[:2])),
    shuffle=False
))

optimizer = torch.optim.LBFGS(
    model.parameters(),
    max_iter=20)

print("L-BFGS")

count = 0
count_t = 0
iteration = 0
def closure():
    global count, count_t
    optimizer.zero_grad()
    for pixel, coord, weight in dataloader:
        pixel = pixel.to(device)
        coord = coord.to(device)
        weight = weight.to(device)
        output = model(coord)
        loss = lossFun(output, pixel, weight)
        loss.backward()
    count += 1
    count_t += 1
    if count_t == 1 and iteration % 10 == 0:
        print("Evaluation {} - loss = {}".format(count, loss.item()))
    return loss

while iteration < 100:
    count_t = 0
    iteration += 1
    optimizer.step(closure)

Test - generate an image

In [72]:
h, w = (128, 256)
u = 2.0*np.pi*((np.arange(w)+0.5)/w-0.5)
v = -np.pi*((np.arange(h)+0.5)/h-0.5)
x = np.outer(np.cos(v), np.cos(u))
y = np.outer(np.cos(v), np.sin(u))
z = np.outer(np.sin(v), np.ones(len(u)))
coords = np.einsum("kab->abk", [x, y, z])
coords = torch.tensor(coords, dtype=torch.float).to(device)
colors = model(coords).detach().cpu().numpy()
plt.imshow(colors)
plt.show()

Export GLSL code

In [73]:
def num2str(x, d=3):
    s = "{:.{prec}f}".format(x, prec=d)
    while s[0] == '0':
        s = s[1:]
    while s[0] == '-' and s[1] == '0':
        s = '-' + s[2:]
    while len(s) > 0 and s[-1] in ['0', '.']:
        s = s[0:len(s)-1]
    if s in ['', '-']:
        s = '0'
    return s

def vec2str(v, d=3):
    return f'vec{len(v)}(' + ','.join([num2str(x, d=d) for x in v]) + ')'

def mat2str(m, d=3):
    v = m.flatten()
    return 'mat4(' + ','.join([num2str(x, d=d) for x in v]) + ')'

digits = 2

def print_input_layer(l: int, weight, bias):
    assert weight.shape[1] == 3
    for i in range(0, len(bias), 4):
        w = weight[i:i+4].T
        b = bias[i:i+4]
        s = '+'.join([
            vec2str(w[0], digits)+'*p.x',
            vec2str(w[1], digits)+'*p.y',
            vec2str(w[2], digits)+'*p.z',
            vec2str(b, digits)
        ])
        print(f"  vec4 v{l}{i//4} = sin({s});")

def print_hidden_layer(l: int, weight, bias):
    for i in range(0, len(bias), 4):
        w = weight[i:i+4].T
        b = bias[i:i+4]
        terms = []
        for j in range(0, len(w), 4):
            s = mat2str(w[j:j+4], digits)
            s += f"*v{l-1}{j//4}"
            terms.append(s)
        terms.append(vec2str(b, digits))
        s = '+'.join(terms)
        print(f"  vec4 v{l}{i//4} = sin({s});")

def print_output_layer(l: int, weight, bias):
    assert weight.shape[0] == 3 and len(bias) == 3
    for i in range(3):
        w = weight[i]
        b = bias[i]
        terms = []
        for j in range(0, len(w), 4):
            s = vec2str(w[j:j+4], digits)
            s = f"dot({s},v{l-1}{j//4})"
            terms.append(s)
        terms.append(num2str(b, digits))
        s = '+'.join(terms)
        print(f"  float v{l}{i} = sigmoid({s});")

layers = []

for layer in model.main:
    if hasattr(layer, 'weight'):
        assert hasattr(layer, 'bias')
        weight = layer.weight.detach().cpu().numpy()
        bias = layer.bias.detach().cpu().numpy()
        layers.append((weight, bias))

for i in range(len(layers)):
    weight, bias = layers[i]
    assert weight.shape[0] == len(bias)
    assert len(bias) == 3 or len(bias) % 4 == 0
    print('  //', weight.shape, bias.shape)
    if weight.shape[1] == 3:
        print_input_layer(i, weight, bias)
    elif weight.shape[0] == 3:
        print_output_layer(i, weight, bias)
    else:
        print_hidden_layer(i, weight, bias)
print(f"  return vec3(v{i}0, v{i}1, v{i}2);")
