In [1]:
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as VF

from jupylet.app import App
from jupylet.shadertoy import Shadertoy

from sdf import *
from render import *
from models import *

In [2]:
app = App(width=320, height=240)

In [3]:
model = Mat4Net()
model.load_state_dict(torch.load("model-snapshot.pt"))

<All keys matched successfully>

In [4]:
def tensor_to_str(t: torch.Tensor) -> str:
    return ", ".join(str(float(v)) for v in torch.flatten(t))

def generate_glsl(model: Mat4Net):
    code = ""
    max_features = 0
    for i, layer in enumerate(model.layers):
        max_features = max(max_features, max(layer.n_in, layer.n_out))
        for j, w in enumerate(layer.weights):
            code += f"const mat4 weight_{i}_{j} = mat4({tensor_to_str(w)});\n"
        for j, b in enumerate(layer.biases):
            code += f"const vec4 bias_{i}_{j} = vec4({tensor_to_str(b)});\n"
    code += "\n"
    code += "vec4 " + ", ".join(f"in_{i}" for i in range(max_features)) + ";\n";
    code += "vec4 " + ", ".join(f"out_{i}" for i in range(max_features)) + ";\n";

    for i, layer in enumerate(model.layers):
        code += f"\n// layer {i}\n\n"
        if i == 0:
            code += "in_0 = vec4(pos, pos.x);\n"
        else:
            code += "".join(f"in_{j} = out_{j};\n" for j in range(layer.n_in))
            
        code += " = ".join(f"out_{j}" for j in range(layer.n_out)) + " = vec4(0);\n"
        
        for j in range(max(layer.n_in, layer.n_out)):
            code += f"out_{j % layer.n_out} += in_{j % layer.n_in} * weight_{i}_{j};\n"
        
        for j in range(layer.n_out):
            if layer.act is None:
                code += f"out_{j} += bias_{i}_{j};\n"
            else:
                code += f"out_{j} = sin(out_{j} + bias_{i}_{j});\n"
                
    code += "\nreturn out_0.x;\n"
    code = """
float nn_dist(in vec3 pos) {
    %s
}
""" % "\n".join("    " + l for l in code.splitlines())
    return code

code = generate_glsl(model)
print(code)


float nn_dist(in vec3 pos) {
        const mat4 weight_0_0 = mat4(2.0616865158081055, 0.4601484537124634, -2.0115034580230713, -2.060988664627075, 0.3239569664001465, 5.250946998596191, 0.8862974643707275, -0.37656134366989136, 0.28273501992225647, 1.3478621244430542, -1.6889636516571045, -0.2962493300437927, -1.236717700958252, -2.0213096141815186, 1.5221377611160278, 1.2772239446640015);
    const mat4 weight_0_1 = mat4(-1.758532166481018, -3.1988749504089355, -0.00401068851351738, -1.8241500854492188, -2.2292885780334473, 0.07957334816455841, 0.01722952350974083, -3.753303050994873, -0.7611516118049622, -2.4813034534454346, 0.9524009823799133, -0.030554838478565216, -1.4161354303359985, 2.0103347301483154, -0.1559772491455078, 0.1319631189107895);
    const mat4 weight_0_2 = mat4(-0.563083827495575, -3.396986484527588, 1.0684194564819336, 0.4195560812950134, 3.5080628395080566, 0.5053042769432068, 0.08828899264335632, -1.4327408075332642, -2.1785879135131836, 1.9827799797058105, 0.

In [14]:
st = Shadertoy("""
    %(code)s

    #line 3
    float scene_dist(in vec3 pos) {
        //float d = length(pos) - 1.;
        float d = nn_dist(pos);
        return d;
    }
    
    vec3 scene_norm(in vec3 pos) {
        const vec2 e = vec2(0.0001, 0);
        return normalize(vec3(
            scene_dist(pos + e.xyy) - scene_dist(pos - e.xyy),
            scene_dist(pos + e.yxy) - scene_dist(pos - e.yxy),
            scene_dist(pos + e.yyx) - scene_dist(pos - e.yyx)
        ));
    }
    
    vec2 rotate(in vec2 x, in float a) {
        float si = sin(a), co = cos(a);
        return vec2(x.x * co - x.y * si, x.x * si + x.y * co);
    }
    
    vec3 rotate_y(in vec3 x, in float a) {
        float si = sin(a), co = cos(a);
        return vec3(x.x * co - x.z * si, x.y, x.x * si + x.z * co);
    }
    
    vec3 raymarch(in vec3 pos, in vec3 dir) {
       vec3 p = pos;
       for (int i=0; i<50; ++i) {
           bool visible = length(p) <= 2.;
           vec3 p_trans = rotate_y(p, iTime);
           float d = scene_dist(p_trans);
           if (visible) {
               if (abs(d) < 0.001) {
                   return scene_norm(p_trans) * .5 + .5;
               }
           }
           p += d * dir;
       }
       return vec3(0);
    }
    
    void mainImage(out vec4 fragColor, in vec2 fragCoord)
    {
        vec2 uv = (fragCoord - vec2(%(w)s, %(h)s) * .5) / %(h)s;
        
        vec3 dir = normalize(vec3(uv, 1));
        vec3 col = raymarch(vec3(0, .5, -2), dir);
        //vec3 col = vec3(nn_dist(vec3(uv, 0)))-3.;
        fragColor = vec4(col,1.0);
    }

""" % {"w": app.width, "h": app.height, "code": code})

@app.event
def render(ct, dt):
    app.window.clear()
    st.draw(ct, dt)

In [12]:
app.run()

Image(value=b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x00\x00\x01\x00\x01\x00\x00\xff\xdb\x00C\x00\x05\x03\x0…

In [None]:
import glm
w = glm.mat4(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
glm.vec4(1, 2, 3, 4) * w

In [None]:
w = torch.Tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12], [13,14,15,16]])
F.linear(torch.Tensor([[1,2,3,4]]), w)
torch.flatten(w)