In [2]:
from typing import Callable
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as VF

from jupylet.app import App
from jupylet.shadertoy import Shadertoy

from sdf import *
from render import *
from models import *

In [3]:
app = App(width=320, height=240)

In [4]:
model = Mat4Net()
model.load_state_dict(torch.load("model-snapshot.pt"))

<All keys matched successfully>

In [6]:
def tensor_to_str(t: torch.Tensor) -> str:
    return ", ".join(str(float(v)) for v in torch.flatten(t))

def generate_glsl(model: Mat4Net):
    code = ""
    max_features = 0
    for i, layer in enumerate(model.layers):
        max_features = max(max_features, max(layer.n_in, layer.n_out))
        for j in range(layer.n_out):
            for k in range(layer.n_in):
                w = layer.weight[j*4: j*4+4, k*4: k*4+4]
                code += f"const mat4 weight_{i}_{j}_{k} = mat4({tensor_to_str(w)});\n"
        for j in range(layer.n_out):
            b = layer.bias[j*4: j*4+4]
            code += f"const vec4 bias_{i}_{j} = vec4({tensor_to_str(b)});\n"
    code += "\n"
    code += "vec4 " + ", ".join(f"in_{i}" for i in range(max_features)) + ";\n";
    code += "vec4 " + ", ".join(f"out_{i}" for i in range(max_features)) + ";\n";
    
    for i, layer in enumerate(model.layers):
        code += f"\n// layer {i}\n\n"
        if i == 0:
            code += "in_0 = vec4(pos, pos.x);\n"
        else:
            code += "".join(f"in_{j} = out_{j};\n" for j in range(layer.n_in))
            
        code += " = ".join(f"out_{j}" for j in range(layer.n_out)) + " = vec4(0);\n"
        
        for j in range(layer.n_out):
            for k in range(layer.n_in):
                code += f"out_{j} += in_{k} * weight_{i}_{j}_{k};\n"
        
        for j in range(layer.n_out):
            if layer.act is None:
                code += f"out_{j} += bias_{i}_{j};\n"
            else:
                code += f"out_{j} = sin(out_{j} + bias_{i}_{j});\n"
                
    code += "\nreturn out_0.x;\n"
    code = """
float nn_dist(in vec3 pos) {
    %s
}
""" % "\n".join("    " + l for l in code.splitlines())
    return code

code = generate_glsl(model)
print(code)

const mat4 weight_0_0_0 = mat4(1.2021186351776123, -1.0139050483703613, 2.302286148071289, -1.0851049423217773, -0.6602197885513306, 0.9151622653007507, -0.7561661005020142, -0.8238990902900696, 1.532228708267212, -1.2937939167022705, -1.57724130153656, -1.966712236404419, -0.644068717956543, 1.0166679620742798, 1.4305334091186523, 0.7981642484664917);
const mat4 weight_0_1_0 = mat4(1.963820219039917, 1.3149985074996948, -1.1309620141983032, -1.49079430103302, -0.059448711574077606, 0.5998917818069458, -1.7208845615386963, -0.5969899892807007, 0.7533749341964722, 1.5925933122634888, 1.1643946170806885, -1.6383633613586426, -0.25616568326950073, -1.49824857711792, -1.7035496234893799, -1.1360278129577637);
const mat4 weight_0_2_0 = mat4(-0.7063148021697998, 1.0059007406234741, -0.9115124344825745, -1.309776782989502, -0.7764936089515686, 0.045450545847415924, 1.143848180770874, -0.8240973353385925, -0.5465065836906433, -1.3238753080368042, 0.5442268252372742, -1.3461506366729736, -0.996

In [None]:
st = Shadertoy("""
    %(code)s

    #line 3
    float scene_dist(in vec3 pos) {
        //float d = length(pos) - 1.;
        float d = nn_dist(pos);
        return d;
    }
    
    vec3 scene_norm(in vec3 pos) {
        const vec2 e = vec2(0.0001, 0);
        return normalize(vec3(
            scene_dist(pos + e.xyy) - scene_dist(pos - e.xyy),
            scene_dist(pos + e.yxy) - scene_dist(pos - e.yxy),
            scene_dist(pos + e.yyx) - scene_dist(pos - e.yyx)
        ));
    }
    
    vec2 rotate(in vec2 x, in float a) {
        float si = sin(a), co = cos(a);
        return vec2(x.x * co - x.y * si, x.x * si + x.y * co);
    }
    
    vec3 rotate_y(in vec3 x, in float a) {
        float si = sin(a), co = cos(a);
        return vec3(x.x * co - x.z * si, x.y, x.x * si + x.z * co);
    }
    
    vec3 raymarch(in vec3 pos, in vec3 dir) {
       vec3 p = pos + vec3(0, 0, -1.+sin(iTime*.91));
       for (int i=0; i<100; ++i) {
           bool visible = length(p) <= 2.;
           vec3 p_trans = rotate_y(p, iTime);
           float d = scene_dist(p_trans);
           if (abs(d) < 0.001) {
               return scene_norm(p_trans) * .5 + .5;
           }
           p += d * dir;
       }
       return vec3(0);
    }
    
    void mainImage(out vec4 fragColor, in vec2 fragCoord)
    {
        vec2 uv = (fragCoord - vec2(%(w)s, %(h)s) * .5) / %(h)s;
        
        vec3 dir = normalize(vec3(uv, 1));
        vec3 col = raymarch(vec3(0, .5, -2), dir);
        //vec3 col = vec3(nn_dist(vec3(uv, 0)))-3.;
        fragColor = vec4(col,1.0);
    }

""" % {"w": app.width, "h": app.height, "code": code})

@app.event
def render(ct, dt):
    app.window.clear()
    st.draw(ct, dt)

In [None]:
app.run()

In [None]:
import glm
w = glm.mat4(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16)
glm.vec4(1, 2, 3, 4) * w

In [None]:
w = torch.Tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12], [13,14,15,16]])
F.linear(torch.Tensor([[1,2,3,4]]), w)
torch.flatten(w)