In [1]:
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as VF

from jupylet.app import App
from jupylet.shadertoy import Shadertoy

from sdf import *
from render import *
from models import *

In [2]:
app = App(width=320, height=240)

In [3]:
model = Mat4Net()
model.load_state_dict(torch.load("model-snapshot.pt"))

<All keys matched successfully>

In [4]:
def tensor_to_str(t: torch.Tensor) -> str:
    return ", ".join(str(float(v)) for v in torch.flatten(t))

def generate_glsl(model: Mat4Net):
    code = ""
    max_features = 0
    for i, layer in enumerate(model.layers):
        max_features = max(max_features, max(layer.n_in, layer.n_out))
        for j, w in enumerate(layer.weights):
            code += f"const mat4 weight_{i}_{j} = mat4({tensor_to_str(w)});\n"
        for j, b in enumerate(layer.biases):
            code += f"const vec4 bias_{i}_{j} = vec4({tensor_to_str(b)});\n"
    code += "\n"
    code += "vec4 " + ", ".join(f"in_{i}" for i in range(max_features)) + ";\n";
    code += "vec4 " + ", ".join(f"out_{i}" for i in range(max_features)) + ";\n";

    for i, layer in enumerate(model.layers):
        code += f"\n// layer {i}\n\n"
        if i == 0:
            code += "in_0 = vec4(pos, pos.x);\n"
        else:
            code += "".join(f"in_{j} = out_{j};\n" for j in range(layer.n_in))
            
        code += " = ".join(f"out_{j}" for j in range(layer.n_out)) + " = vec4(0);\n"
        
        for j in range(max(layer.n_in, layer.n_out)):
            code += f"out_{j % layer.n_out} += in_{j % layer.n_in} * weight_{i}_{j};\n"
        
        for j in range(layer.n_out):
            if layer.act is None:
                code += f"out_{j} += bias_{i}_{j};\n"
            else:
                code += f"out_{j} = sin(out_{j} + bias_{i}_{j});\n"
                
    code += "\nreturn out_0.x;\n"
    code = """
float nn_dist(in vec3 pos) {
    %s
}
""" % "\n".join("    " + l for l in code.splitlines())
    return code

code = generate_glsl(model)
print(code)


float nn_dist(in vec3 pos) {
        const mat4 weight_0_0 = mat4(-1.2277699708938599, 1.519213318824768, -1.1872285604476929, 1.2413098812103271, -1.108780860900879, -0.8853718042373657, 0.7356535792350769, 1.1389192342758179, 0.11244277656078339, 2.498598575592041, 0.0949987843632698, -0.09940953552722931, -0.232490673661232, 0.9679884314537048, 1.6390388011932373, 0.2377111315727234);
    const mat4 weight_0_1 = mat4(-1.5602127313613892, 1.812886357307434, -1.9229179620742798, 1.599096655845642, -2.0006299018859863, -0.2122923731803894, 0.04328341782093048, -0.0741896703839302, 0.1166880875825882, 1.475764513015747, 1.2103097438812256, -0.11480896174907684, 0.9728386998176575, 0.30226218700408936, 0.020459486171603203, 1.6759905815124512);
    const mat4 weight_0_2 = mat4(-0.6555724740028381, 1.5137975215911865, -0.17084719240665436, 2.217674493789673, 0.14028114080429077, 1.8049113750457764, 1.4405932426452637, 0.8297991156578064, 1.8818145990371704, -1.5221549272537231, 0.1980926

In [5]:
st = Shadertoy("""
    %(code)s

    #line 3
    float scene_dist(in vec3 pos) {
        //float d = length(pos) - 1.;
        float d = nn_dist(pos);
        return d;
    }
    
    vec3 scene_norm(in vec3 pos) {
        const vec2 e = vec2(0.0001, 0);
        return normalize(vec3(
            scene_dist(pos + e.xyy) - scene_dist(pos - e.xyy),
            scene_dist(pos + e.yxy) - scene_dist(pos - e.yxy),
            scene_dist(pos + e.yyx) - scene_dist(pos - e.yyx)
        ));
    }
    
    vec2 rotate(in vec2 x, in float a) {
        float si = sin(a), co = cos(a);
        return vec2(x.x * co - x.y * si, x.x * si + x.y * co);
    }
    
    vec3 rotate_y(in vec3 x, in float a) {
        float si = sin(a), co = cos(a);
        return vec3(x.x * co - x.z * si, x.y, x.x * si + x.z * co);
    }
    
    vec3 raymarch(in vec3 pos, in vec3 dir) {
       vec3 p = pos + vec3(0, 0, -1.+sin(iTime*.91));
       for (int i=0; i<100; ++i) {
           bool visible = length(p) <= 2.;
           vec3 p_trans = rotate_y(p, iTime);
           float d = scene_dist(p_trans);
           if (abs(d) < 0.001) {
               return scene_norm(p_trans) * .5 + .5;
           }
           p += d * dir;
       }
       return vec3(0);
    }
    
    void mainImage(out vec4 fragColor, in vec2 fragCoord)
    {
        vec2 uv = (fragCoord - vec2(%(w)s, %(h)s) * .5) / %(h)s;
        
        vec3 dir = normalize(vec3(uv, 1));
        vec3 col = raymarch(vec3(0, .5, -2), dir);
        //vec3 col = vec3(nn_dist(vec3(uv, 0)))-3.;
        fragColor = vec4(col,1.0);
    }

""" % {"w": app.width, "h": app.height, "code": code})

@app.event
def render(ct, dt):
    app.window.clear()
    st.draw(ct, dt)

In [6]:
app.run()

Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x08\x06\x0…

In [7]:
import glm
w = glm.mat4(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
glm.vec4(1, 2, 3, 4) * w

vec4( 30, 70, 110, 150 )

In [8]:
w = torch.Tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12], [13,14,15,16]])
F.linear(torch.Tensor([[1,2,3,4]]), w)
torch.flatten(w)

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16.])