Skip to content

Commit

Permalink
TSL: Introduce ShaderCallNode & tslFn improvements (#26824)
Browse files Browse the repository at this point in the history
* Add ShaderCallNode

* cleanup

* Use tslFn as default
  • Loading branch information
sunag committed Sep 23, 2023
1 parent 45505ed commit 9d2d7eb
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 45 deletions.
4 changes: 2 additions & 2 deletions examples/jsm/nodes/procedural/CheckerNode.js
Expand Up @@ -25,9 +25,9 @@ class CheckerNode extends TempNode {

}

generate( builder ) {
construct() {

return checkerShaderNode( { uv: this.uvNode } ).build( builder );
return checkerShaderNode( { uv: this.uvNode } );

}

Expand Down
74 changes: 56 additions & 18 deletions examples/jsm/nodes/shadernode/ShaderNode.js
Expand Up @@ -196,42 +196,86 @@ const ShaderNodeImmutable = function ( NodeClass, ...params ) {

};

class ShaderNodeInternal extends Node {
class ShaderCallNodeInternal extends Node {

constructor( jsFunc ) {
constructor( shaderNode, inputNodes ) {

super();

this._jsFunc = jsFunc;
this.shaderNode = shaderNode;
this.inputNodes = inputNodes;

}

call( inputs, stack, builder ) {
getNodeType( builder ) {

inputs = nodeObjects( inputs );
const { outputNode } = builder.getNodeProperties( this );

return nodeObject( this._jsFunc( inputs, stack, builder ) );
return outputNode ? outputNode.getNodeType( builder ) : super.getNodeType( builder );

}

getNodeType( builder ) {
call( builder ) {

const { outputNode } = builder.getNodeProperties( this );
const { shaderNode, inputNodes } = this;

return outputNode ? outputNode.getNodeType( builder ) : super.getNodeType( builder );
const jsFunc = shaderNode.jsFunc;
const outputNode = inputNodes !== null ? jsFunc( nodeObjects( inputNodes ), builder.stack, builder ) : jsFunc( builder.stack, builder );

return nodeObject( outputNode );

}

construct( builder ) {

builder.addStack();

builder.stack.outputNode = nodeObject( this._jsFunc( builder.stack, builder ) );
builder.stack.outputNode = this.call( builder );

return builder.removeStack();

}

generate( builder, output ) {

const { outputNode } = builder.getNodeProperties( this );

if ( outputNode === null ) {

// TSL: It's recommended to use `tslFn` in construct() pass.

return this.call( builder ).build( builder, output );

}

return super.generate( builder, output );

}

}

class ShaderNodeInternal extends Node {

constructor( jsFunc ) {

super();

this.jsFunc = jsFunc;

}

call( inputs = null ) {

return nodeObject( new ShaderCallNodeInternal( this, inputs ) );

}

construct() {

return this.call();

}

}

const bools = [ false, true ];
Expand Down Expand Up @@ -349,15 +393,9 @@ export const shader = ( jsFunc ) => { // @deprecated, r154

export const tslFn = ( jsFunc ) => {

let shaderNode = null;
const shaderNode = new ShaderNode( jsFunc );

return ( ...params ) => {

if ( shaderNode === null ) shaderNode = new ShaderNode( jsFunc );

return shaderNode.call( ...params );

};
return ( inputs ) => shaderNode.call( inputs );

};

Expand Down
16 changes: 8 additions & 8 deletions examples/jsm/nodes/utils/LoopNode.js
@@ -1,7 +1,7 @@
import Node, { addNodeClass } from '../core/Node.js';
import { expression } from '../code/ExpressionNode.js';
import { bypass } from '../core/BypassNode.js';
import { context as contextNode } from '../core/ContextNode.js';
import { context } from '../core/ContextNode.js';
import { addNodeElement, nodeObject, nodeArray } from '../shadernode/ShaderNode.js';

class LoopNode extends Node {
Expand Down Expand Up @@ -65,13 +65,11 @@ class LoopNode extends Node {

const properties = this.getProperties( builder );

const context = { tempWrite: false };
const contextData = { tempWrite: false };

const params = this.params;
const stackNode = properties.stackNode;

const returnsSnippet = properties.returnsNode ? properties.returnsNode.build( builder ) : '';

for ( let i = 0, l = params.length - 1; i < l; i ++ ) {

const param = params[ i ];
Expand All @@ -82,7 +80,7 @@ class LoopNode extends Node {
if ( param.isNode ) {

start = '0';
end = param.generate( builder, 'int' );
end = param.build( builder, 'int' );
direction = 'forward';

} else {
Expand All @@ -92,10 +90,10 @@ class LoopNode extends Node {
direction = param.direction;

if ( typeof start === 'number' ) start = start.toString();
else if ( start && start.isNode ) start = start.generate( builder, 'int' );
else if ( start && start.isNode ) start = start.build( builder, 'int' );

if ( typeof end === 'number' ) end = end.toString();
else if ( end && end.isNode ) end = end.generate( builder, 'int' );
else if ( end && end.isNode ) end = end.build( builder, 'int' );

if ( start !== undefined && end === undefined ) {

Expand Down Expand Up @@ -159,7 +157,9 @@ class LoopNode extends Node {

}

const stackSnippet = contextNode( stackNode, context ).build( builder, 'void' );
const stackSnippet = context( stackNode, contextData ).build( builder, 'void' );

const returnsSnippet = properties.returnsNode ? properties.returnsNode.build( builder ) : '';

builder.removeFlowTab().addFlowCode( '\n' + builder.tab + stackSnippet );

Expand Down
7 changes: 3 additions & 4 deletions examples/webgpu_audio_processing.html
Expand Up @@ -31,7 +31,7 @@
<script type="module">

import * as THREE from 'three';
import { ShaderNode, uniform, storage, instanceIndex, float, texture, viewportTopLeft, color } from 'three/nodes';
import { tslFn, uniform, storage, instanceIndex, float, texture, viewportTopLeft, color } from 'three/nodes';

import { GUI } from 'three/addons/libs/lil-gui.module.min.js';

Expand Down Expand Up @@ -136,11 +136,10 @@

// compute (shader-node)

const computeShaderNode = new ShaderNode( ( stack ) => {
const computeShaderFn = tslFn( ( stack ) => {

const index = float( instanceIndex );


// pitch

const time = index.mul( pitch );
Expand Down Expand Up @@ -171,7 +170,7 @@

// compute

computeNode = computeShaderNode.compute( waveBuffer.length );
computeNode = computeShaderFn().compute( waveBuffer.length );


// gui
Expand Down
10 changes: 5 additions & 5 deletions examples/webgpu_compute.html
Expand Up @@ -26,7 +26,7 @@
<script type="module">

import * as THREE from 'three';
import { ShaderNode, uniform, storage, attribute, float, vec2, vec3, color, instanceIndex, PointsNodeMaterial } from 'three/nodes';
import { tslFn, uniform, storage, attribute, float, vec2, vec3, color, instanceIndex, PointsNodeMaterial } from 'three/nodes';

import { GUI } from 'three/addons/libs/lil-gui.module.min.js';

Expand Down Expand Up @@ -74,7 +74,7 @@

// create function

const computeShaderNode = new ShaderNode( ( stack ) => {
const computeShaderFn = tslFn( ( stack ) => {

const particle = particleBufferNode.element( instanceIndex );
const velocity = velocityBufferNode.element( instanceIndex );
Expand All @@ -98,10 +98,10 @@

// compute

computeNode = computeShaderNode.compute( particleNum );
computeNode = computeShaderFn().compute( particleNum );
computeNode.onInit = ( { renderer } ) => {

const precomputeShaderNode = new ShaderNode( ( stack ) => {
const precomputeShaderNode = tslFn( ( stack ) => {

const particleIndex = float( instanceIndex );

Expand All @@ -117,7 +117,7 @@

} );

renderer.compute( precomputeShaderNode.compute( particleNum ) );
renderer.compute( precomputeShaderNode().compute( particleNum ) );

};

Expand Down
14 changes: 7 additions & 7 deletions examples/webgpu_compute_particles.html
Expand Up @@ -26,7 +26,7 @@
<script type="module">

import * as THREE from 'three';
import { ShaderNode, uniform, texture, instanceIndex, float, vec3, storage, SpriteNodeMaterial } from 'three/nodes';
import { tslFn, uniform, texture, instanceIndex, float, vec3, storage, SpriteNodeMaterial } from 'three/nodes';

import WebGPU from 'three/addons/capabilities/WebGPU.js';
import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';
Expand Down Expand Up @@ -83,7 +83,7 @@

// compute

const computeInit = new ShaderNode( ( stack ) => {
const computeInit = tslFn( ( stack ) => {

const position = positionBuffer.element( instanceIndex );
const color = colorBuffer.element( instanceIndex );
Expand All @@ -98,11 +98,11 @@

stack.assign( color, vec3( randX, randY, randZ ) );

} ).compute( particleCount );
} )().compute( particleCount );

//

const computeUpdate = new ShaderNode( ( stack ) => {
const computeUpdate = tslFn( ( stack ) => {

const position = positionBuffer.element( instanceIndex );
const velocity = velocityBuffer.element( instanceIndex );
Expand All @@ -128,7 +128,7 @@

} );

computeParticles = computeUpdate.compute( particleCount );
computeParticles = computeUpdate().compute( particleCount );

// create nodes

Expand Down Expand Up @@ -179,7 +179,7 @@

// click event

const computeHit = new ShaderNode( ( stack ) => {
const computeHit = tslFn( ( stack ) => {

const position = positionBuffer.element( instanceIndex );
const velocity = velocityBuffer.element( instanceIndex );
Expand All @@ -193,7 +193,7 @@

stack.assign( velocity, velocity.add( direction.mul( relativePower ) ) );

} ).compute( particleCount );
} )().compute( particleCount );

//

Expand Down
2 changes: 1 addition & 1 deletion examples/webgpu_materials.html
Expand Up @@ -29,7 +29,7 @@
import * as THREE from 'three';
import * as Nodes from 'three/nodes';

import { tslFn, wgslFn, attribute, positionLocal, positionWorld, normalLocal, normalWorld, normalView, color, texture, uv, float, vec2, vec3, vec4, oscSine, triplanarTexture, viewportBottomLeft, js, string, global, loop, MeshBasicNodeMaterial, NodeObjectLoader } from 'three/nodes';
import { tslFn, wgslFn, positionLocal, positionWorld, normalLocal, normalWorld, normalView, color, texture, uv, float, vec2, vec3, vec4, oscSine, triplanarTexture, viewportBottomLeft, js, string, global, loop, MeshBasicNodeMaterial, NodeObjectLoader } from 'three/nodes';

import WebGPU from 'three/addons/capabilities/WebGPU.js';
import WebGPURenderer from 'three/addons/renderers/webgpu/WebGPURenderer.js';
Expand Down

0 comments on commit 9d2d7eb

Please sign in to comment.