diff --git a/.github/workflows/lint-witness.yaml b/.github/workflows/lint-witness.yaml new file mode 100644 index 0000000..84737bc --- /dev/null +++ b/.github/workflows/lint-witness.yaml @@ -0,0 +1,36 @@ +name: lint-witness + +on: + push: + branches: + - master + pull_request: + +jobs: + lint: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v4 + with: + go-version: 1.20.4 + - name: lint witness + uses: golangci/golangci-lint-action@v3 + with: + version: v1.52.2 + working-directory: witness + - name: lint witness/wazero + uses: golangci/golangci-lint-action@v3 + with: + version: v1.52.2 + working-directory: witness/wazero + - name: lint witness/wasmer + uses: golangci/golangci-lint-action@v3 + with: + version: v1.52.2 + working-directory: witness/wasmer + - name: lint witness/test_wasm_impls + uses: golangci/golangci-lint-action@v3 + with: + version: v1.52.2 + working-directory: witness/test_wasm_impls diff --git a/.github/workflows/test-witness.yaml b/.github/workflows/test-witness.yaml new file mode 100644 index 0000000..ccd60c1 --- /dev/null +++ b/.github/workflows/test-witness.yaml @@ -0,0 +1,32 @@ +name: test-witness + +on: + push: + branches: + - master + pull_request: + +jobs: + test: + strategy: + matrix: + containers: + - 1.18.10-bullseye + - 1.19.9-bullseye + - 1.20.4-bullseye + runs-on: ubuntu-20.04 + container: golang:${{matrix.containers}} + steps: + - uses: actions/checkout@v3 + - uses: actions/cache@v3 + with: + path: | + ~/.cache/go-build + /go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + - run: cd witness && go test -race -timeout=60s -v ./... + - run: cd witness/wazero && go test -race -timeout=60s -v ./... + - run: cd witness/wasmer && go test -race -timeout=60s -v ./... + - run: cd witness/test_wasm_impls && go test -race -timeout=300s -v ./... diff --git a/witness/README.md b/witness/README.md index aced525..3561c8e 100644 --- a/witness/README.md +++ b/witness/README.md @@ -5,7 +5,7 @@ Calculates witness, that can be passed to a prover ([snarkjs](https://github.com ## Installation ``` -go get github.com/iden3/go-rapidsnark/witness +go get github.com/iden3/go-rapidsnark/witness/v2 ``` ## Dependencies diff --git a/witness/circom2witnesscalc_test.go b/witness/circom2witnesscalc_test.go deleted file mode 100644 index 23fa091..0000000 --- a/witness/circom2witnesscalc_test.go +++ /dev/null @@ -1,129 +0,0 @@ -package witness - -import ( - "os" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestCircom2CalculateWitness(t *testing.T) { - wasmBytes, err := os.ReadFile("testdata/circom2/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("testdata/circom2/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WitnessCalculator(wasmBytes, true) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - witness, err := calc.CalculateWitness(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, witness) -} - -func TestCircom2CalculateBinWitness(t *testing.T) { - wasmBytes, err := os.ReadFile("testdata/circom2/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("testdata/circom2/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WitnessCalculator(wasmBytes, true) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - witnessBytes, err := calc.CalculateBinWitness(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, witnessBytes) -} - -func TestCircom2CalculateWTNSBin(t *testing.T) { - wasmBytes, err := os.ReadFile("testdata/circom2/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("testdata/circom2/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WitnessCalculator(wasmBytes, true) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - wtnsBytes, err := calc.CalculateWTNSBin(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, wtnsBytes) - - //_ = os.WriteFile("testdata/circom2/witness.wtns", wtnsBytes, fs.FileMode(defaultFileMode)) -} - -// TestCircom2CalculateWitness210 tests the calculation of the witness for the circom 2.1.0 -func TestCircom2CalculateWitness210(t *testing.T) { - wasmBytes, err := os.ReadFile("testdata/circom2_1_0/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("testdata/circom2_1_0/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WitnessCalculator(wasmBytes, true) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - witness, err := calc.CalculateWitness(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, witness) -} - -// TestCircom2CalculateBinWitness210 tests the calculation of the witness for the circom 2.1.0 -func TestCircom2CalculateBinWitness210(t *testing.T) { - wasmBytes, err := os.ReadFile("testdata/circom2_1_0/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("testdata/circom2_1_0/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WitnessCalculator(wasmBytes, true) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - witnessBytes, err := calc.CalculateBinWitness(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, witnessBytes) -} - -// TestCircom2CalculateWTNSBin210 tests the calculation of the witness for the circom 2.1.0 -func TestCircom2CalculateWTNSBin210(t *testing.T) { - wasmBytes, err := os.ReadFile("testdata/circom2_1_0/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("testdata/circom2_1_0/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WitnessCalculator(wasmBytes, true) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - wtnsBytes, err := calc.CalculateWTNSBin(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, wtnsBytes) - - //_ = os.WriteFile("testdata/circom2_1_0/witness.wtns", wtnsBytes, fs.FileMode(defaultFileMode)) -} diff --git a/witness/circom2witnesscalc_wazero_test.go b/witness/circom2witnesscalc_wazero_test.go deleted file mode 100644 index 46bb8be..0000000 --- a/witness/circom2witnesscalc_wazero_test.go +++ /dev/null @@ -1,189 +0,0 @@ -package witness - -import ( - "crypto/md5" - "encoding/hex" - "math/big" - "os" - "testing" - - "github.com/stretchr/testify/require" -) - -func TestWZCircom2CalculateWitness(t *testing.T) { - wasmBytes, err := os.ReadFile("test_files/circom2/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("test_files/circom2/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WZWitnessCalculator(wasmBytes) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - witness, err := calc.CalculateWitness(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, witness) - require.Equal(t, "c1780821352c069392e9d0fab4330531", hashInts(witness)) -} - -func TestWZCircom2CalculateBinWitness(t *testing.T) { - wasmBytes, err := os.ReadFile("test_files/circom2/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("test_files/circom2/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WZWitnessCalculator(wasmBytes) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - witnessBytes, err := calc.CalculateBinWitness(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, witnessBytes) - require.Equal(t, "d2c0486d7fd6f0715d04d535765f028b", - hashBytes(witnessBytes)) -} - -func TestWZCircom2CalculateWTNSBin(t *testing.T) { - wasmBytes, err := os.ReadFile("test_files/circom2/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("test_files/circom2/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WZWitnessCalculator(wasmBytes) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - wtnsBytes, err := calc.CalculateWTNSBin(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, wtnsBytes) - require.Equal(t, "1709fbda942dabed641044f39b466e94", - hashBytes(wtnsBytes)) - -} - -// TestWZCircom2CalculateWitness210 tests the calculation of the witness for the circom 2.1.0 -func TestWZCircom2CalculateWitness210(t *testing.T) { - wasmBytes, err := os.ReadFile("test_files/circom2_1_0/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("test_files/circom2_1_0/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WZWitnessCalculator(wasmBytes) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - witness, err := calc.CalculateWitness(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, witness) - require.Equal(t, "c0a2b43f5a333310c2bb8d357db46d3b", hashInts(witness)) -} - -// TestWZCircom2CalculateBinWitness210 tests the calculation of the witness -// for the circom 2.1.0 -func TestWZCircom2CalculateBinWitness210(t *testing.T) { - wasmBytes, err := os.ReadFile("test_files/circom2_1_0/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("test_files/circom2_1_0/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WZWitnessCalculator(wasmBytes) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - witnessBytes, err := calc.CalculateBinWitness(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, witnessBytes) - require.Equal(t, "2b38b66035d8e923eacc028ea0f1dad2", - hashBytes(witnessBytes)) -} - -// TestWZCircom2CalculateWTNSBin210 tests the calculation of the witness -// for the circom 2.1.0 -func TestWZCircom2CalculateWTNSBin210(t *testing.T) { - wasmBytes, err := os.ReadFile("test_files/circom2_1_0/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("test_files/circom2_1_0/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WZWitnessCalculator(wasmBytes) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - - wtnsBytes, err := calc.CalculateWTNSBin(inputs, true) - require.NoError(t, err) - require.NotEmpty(t, wtnsBytes) - require.Equal(t, "75c5682a7195c20868b59d6580852fce", - hashBytes(wtnsBytes)) -} - -// TestWZCircom2CalculateWTNSBin210 tests the calculation of the witness -// for the circom 2.1.0 -func TestWZCircom2CalculateWTNSBin210_Error(t *testing.T) { - wasmBytes, err := os.ReadFile("test_files/circom2_1_0/circuit.wasm") - require.NoError(t, err) - - inputBytes, err := os.ReadFile("test_files/circom2_1_0/input.json") - require.NoError(t, err) - - calc, err := NewCircom2WZWitnessCalculator(wasmBytes) - require.NoError(t, err) - require.NotEmpty(t, calc) - - inputs, err := ParseInputs(inputBytes) - require.NoError(t, err) - wrongSmtRoot, ok := big.NewInt(0).SetString( - "23891407091237035626910338386637210028103224489833886255774452947213913989795", - 10) - require.True(t, ok) - inputs["globalSmtRoot"] = wrongSmtRoot - - _, err = calc.CalculateWTNSBin(inputs, true) - require.EqualError(t, err, `error code: 4: Assert Failed. -Error in template ForceEqualIfEnabled_234 line: 56 -Error in template SMTVerifier_235 line: 134 -Error in template AuthV2_347 line: 93`) -} - -func hashInts(in []*big.Int) string { - h := md5.New() - for _, i := range in { - h.Write(i.Bytes()) - } - return hex.EncodeToString(h.Sum(nil)) -} - -func hashBytes(in []byte) string { - h := md5.New() - n, err := h.Write(in) - if err != nil { - panic(err) - } - if n != len(in) { - panic("incorrect size") - } - return hex.EncodeToString(h.Sum(nil)) -} diff --git a/witness/go.mod b/witness/go.mod index f0d4ebf..bc7e1d8 100644 --- a/witness/go.mod +++ b/witness/go.mod @@ -1,16 +1,15 @@ -module github.com/iden3/go-rapidsnark/witness +module github.com/iden3/go-rapidsnark/witness/v2 go 1.18 require ( github.com/iden3/go-iden3-crypto v0.0.15 - github.com/iden3/wasmer-go v0.0.1 github.com/stretchr/testify v1.8.2 - github.com/tetratelabs/wazero v1.1.0 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/sys v0.6.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/witness/go.sum b/witness/go.sum index 3097dab..da8815d 100644 --- a/witness/go.sum +++ b/witness/go.sum @@ -3,8 +3,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/iden3/go-iden3-crypto v0.0.15 h1:4MJYlrot1l31Fzlo2sF56u7EVFeHHJkxGXXZCtESgK4= github.com/iden3/go-iden3-crypto v0.0.15/go.mod h1:dLpM4vEPJ3nDHzhWFXDjzkn1qHoBeOT/3UEhXsEsP3E= -github.com/iden3/wasmer-go v0.0.1 h1:TZKh8Se8B/73PvWrcu+FTU9L1k5XYAmtFbioj7l0Uog= -github.com/iden3/wasmer-go v0.0.1/go.mod h1:ZnZBAO012M7o+Q1INXLRIxKQgEcH2FuwL0Iga8A4ufg= +github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -14,8 +13,8 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/tetratelabs/wazero v1.1.0 h1:EByoAhC+QcYpwSZJSs/aV0uokxPwBgKxfiokSUwAknQ= -github.com/tetratelabs/wazero v1.1.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/witness/test_wasm_impls/go.mod b/witness/test_wasm_impls/go.mod new file mode 100644 index 0000000..723f778 --- /dev/null +++ b/witness/test_wasm_impls/go.mod @@ -0,0 +1,27 @@ +module github.com/iden3/go-rapidsnark/witness/test-wasm-impls + +go 1.18 + +require ( + github.com/iden3/go-rapidsnark/witness/v2 v2.0.0-20230523125954-fcfab2575c4d + github.com/iden3/go-rapidsnark/witness/wasmer v0.0.0 + github.com/iden3/go-rapidsnark/witness/wazero v0.0.0 + github.com/stretchr/testify v1.8.2 + +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/iden3/go-iden3-crypto v0.0.15 // indirect + github.com/iden3/wasmer-go v0.0.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/tetratelabs/wazero v1.1.0 // indirect + golang.org/x/sys v0.6.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + +replace ( + github.com/iden3/go-rapidsnark/witness/v2 => ../ + github.com/iden3/go-rapidsnark/witness/wasmer => ../wasmer + github.com/iden3/go-rapidsnark/witness/wazero => ../wazero +) diff --git a/witness/test_wasm_impls/go.sum b/witness/test_wasm_impls/go.sum new file mode 100644 index 0000000..8bbbede --- /dev/null +++ b/witness/test_wasm_impls/go.sum @@ -0,0 +1,26 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/iden3/go-iden3-crypto v0.0.15 h1:4MJYlrot1l31Fzlo2sF56u7EVFeHHJkxGXXZCtESgK4= +github.com/iden3/go-iden3-crypto v0.0.15/go.mod h1:dLpM4vEPJ3nDHzhWFXDjzkn1qHoBeOT/3UEhXsEsP3E= +github.com/iden3/wasmer-go v0.0.1 h1:TZKh8Se8B/73PvWrcu+FTU9L1k5XYAmtFbioj7l0Uog= +github.com/iden3/wasmer-go v0.0.1/go.mod h1:ZnZBAO012M7o+Q1INXLRIxKQgEcH2FuwL0Iga8A4ufg= +github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/tetratelabs/wazero v1.1.0 h1:EByoAhC+QcYpwSZJSs/aV0uokxPwBgKxfiokSUwAknQ= +github.com/tetratelabs/wazero v1.1.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/witness/testdata/circom2/circuit.wasm b/witness/test_wasm_impls/testdata/circom2/circuit.wasm similarity index 100% rename from witness/testdata/circom2/circuit.wasm rename to witness/test_wasm_impls/testdata/circom2/circuit.wasm diff --git a/witness/testdata/circom2/input.json b/witness/test_wasm_impls/testdata/circom2/input.json similarity index 100% rename from witness/testdata/circom2/input.json rename to witness/test_wasm_impls/testdata/circom2/input.json diff --git a/witness/testdata/circom2_1_0/circuit.wasm b/witness/test_wasm_impls/testdata/circom2_1_0/circuit.wasm similarity index 100% rename from witness/testdata/circom2_1_0/circuit.wasm rename to witness/test_wasm_impls/testdata/circom2_1_0/circuit.wasm diff --git a/witness/testdata/circom2_1_0/input.json b/witness/test_wasm_impls/testdata/circom2_1_0/input.json similarity index 100% rename from witness/testdata/circom2_1_0/input.json rename to witness/test_wasm_impls/testdata/circom2_1_0/input.json diff --git a/witness/test_wasm_impls/witness_test.go b/witness/test_wasm_impls/witness_test.go new file mode 100644 index 0000000..df8944e --- /dev/null +++ b/witness/test_wasm_impls/witness_test.go @@ -0,0 +1,130 @@ +package witness + +import ( + "crypto/md5" + "encoding/hex" + "math/big" + "os" + "testing" + + "github.com/iden3/go-rapidsnark/witness/v2" + "github.com/iden3/go-rapidsnark/witness/wasmer" + "github.com/iden3/go-rapidsnark/witness/wazero" + "github.com/stretchr/testify/require" +) + +func TestEngines(t *testing.T) { + engineTestCases := []struct { + title string + engine func(code []byte) (witness.CalculatorImpl, error) + wantErr string + }{ + { + title: "Wazero", + engine: wazero.NewCircom2WZWitnessCalculator, + }, + { + title: "Wasmer", + engine: wasmer.NewCircom2WitnessCalculator, + }, + { + title: "empty", + wantErr: "witness calculator wasm engine not set", + }, + } + + circomTestCases := []struct { + wasmFile string + inputs string + wantWtnsHex string + wantBinWtnsHex string + wantWTNSBinHex string + }{ + { + wasmFile: "testdata/circom2/circuit.wasm", + inputs: "testdata/circom2/input.json", + wantWtnsHex: "c1780821352c069392e9d0fab4330531", + wantBinWtnsHex: "d2c0486d7fd6f0715d04d535765f028b", + wantWTNSBinHex: "1709fbda942dabed641044f39b466e94", + }, + { + wasmFile: "testdata/circom2_1_0/circuit.wasm", + inputs: "testdata/circom2_1_0/input.json", + wantWtnsHex: "c0a2b43f5a333310c2bb8d357db46d3b", + wantBinWtnsHex: "2b38b66035d8e923eacc028ea0f1dad2", + wantWTNSBinHex: "75c5682a7195c20868b59d6580852fce", + }, + } + + for i := range engineTestCases { + engTC := engineTestCases[i] + t.Run(engTC.title, func(t *testing.T) { + for _, circomTC := range circomTestCases { + t.Run(circomTC.wasmFile, func(t *testing.T) { + wasmBytes, err := os.ReadFile(circomTC.wasmFile) + require.NoError(t, err) + inputBytes, err := os.ReadFile(circomTC.inputs) + require.NoError(t, err) + + var ops []witness.Option + if engTC.engine != nil { + ops = append(ops, witness.WithWasmEngine(engTC.engine)) + } + calc, err := witness.NewCalculator(wasmBytes, ops...) + if engTC.wantErr != "" { + require.EqualError(t, err, engTC.wantErr) + return + } + + require.NoError(t, err) + + inputs, err := witness.ParseInputs(inputBytes) + require.NoError(t, err) + + t.Run("CalculateWitness", func(t *testing.T) { + wtns, err2 := calc.CalculateWitness(inputs, true) + require.NoError(t, err2) + require.NotEmpty(t, wtns) + require.Equal(t, circomTC.wantWtnsHex, hashInts(wtns)) + }) + + t.Run("CalculateBinWitness", func(t *testing.T) { + wtns, err2 := calc.CalculateBinWitness(inputs, true) + require.NoError(t, err2) + require.NotEmpty(t, wtns) + require.Equal(t, circomTC.wantBinWtnsHex, + hashBytes(wtns)) + }) + + t.Run("CalculateWTNSBin", func(t *testing.T) { + wtns, err2 := calc.CalculateWTNSBin(inputs, true) + require.NoError(t, err2) + require.NotEmpty(t, wtns) + require.Equal(t, circomTC.wantWTNSBinHex, + hashBytes(wtns)) + }) + }) + } + }) + } +} + +func hashInts(in []*big.Int) string { + h := md5.New() + for _, i := range in { + h.Write(i.Bytes()) + } + return hex.EncodeToString(h.Sum(nil)) +} + +func hashBytes(in []byte) string { + h := md5.New() + n, err := h.Write(in) + if err != nil { + panic(err) + } + if n != len(in) { + panic("incorrect size") + } + return hex.EncodeToString(h.Sum(nil)) +} diff --git a/witness/utils.go b/witness/utils.go index b156dce..b1eaf45 100644 --- a/witness/utils.go +++ b/witness/utils.go @@ -3,19 +3,18 @@ package witness import ( "encoding/json" "fmt" - "hash/fnv" "math/big" "reflect" ) -// parseInput is a recurisve helper function for ParseInputs +// parseInput is a recursive helper function for ParseInputs func parseInput(v interface{}) (interface{}, error) { rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.String: n, ok := new(big.Int).SetString(v.(string), 0) if !ok { - return nil, fmt.Errorf("Error parsing input %v", v) + return nil, fmt.Errorf("error parsing input %v", v) } return n, nil case reflect.Float64: @@ -26,12 +25,12 @@ func parseInput(v interface{}) (interface{}, error) { var err error res[i], err = parseInput(rv.Index(i).Interface()) if err != nil { - return nil, fmt.Errorf("Error parsing input %v: %w", v, err) + return nil, fmt.Errorf("error parsing input %v: %w", v, err) } } return res, nil default: - return nil, fmt.Errorf("Unexpected type for input %v: %T", v, v) + return nil, fmt.Errorf("unexpected type for input %v: %T", v, v) } } @@ -74,11 +73,3 @@ func flatSlice(v interface{}) []*big.Int { _flatSlice(&res, v) return res } - -// fnvHash returns the 64 bit FNV-1a hash split into two 32 bit values: (MSB, LSB) -func fnvHash(s string) (int32, int32) { - hash := fnv.New64a() - hash.Write([]byte(s)) - h := hash.Sum64() - return int32(h >> 32), int32(h & 0xffffffff) -} diff --git a/witness/circom2witnesscalc.go b/witness/wasmer/circom2witnesscalc.go similarity index 79% rename from witness/circom2witnesscalc.go rename to witness/wasmer/circom2witnesscalc.go index 683447f..6df24d0 100644 --- a/witness/circom2witnesscalc.go +++ b/witness/wasmer/circom2witnesscalc.go @@ -1,12 +1,16 @@ -package witness +package wasmer import ( "bytes" "encoding/binary" "errors" "fmt" + "hash/fnv" "math/big" + "reflect" + "github.com/iden3/go-iden3-crypto/utils" + "github.com/iden3/go-rapidsnark/witness/v2" "github.com/iden3/wasmer-go/wasmer" ) @@ -17,7 +21,6 @@ type Circom2WitnessCalculator struct { module *wasmer.Module instance *wasmer.Instance store *wasmer.Store - sanityCheck bool n32 int32 version int32 witnessSize int32 @@ -38,11 +41,12 @@ type Circom2WitnessCalculator struct { msgStr bytes.Buffer } -// NewCircom2WitnessCalculator creates a new WitnessCalculator from the WitnessCalc +// NewCircom2WitnessCalculator creates a new CalculatorImpl from the WitnessCalc // loaded WASM module in the runtime. -func NewCircom2WitnessCalculator(wasmBytes []byte, sanityCheck bool) (*Circom2WitnessCalculator, error) { +func NewCircom2WitnessCalculator( + wasmBytes []byte) (witness.CalculatorImpl, error) { + wc := Circom2WitnessCalculator{} - wc.sanityCheck = sanityCheck wc.engine = wasmer.NewEngine() wc.store = wasmer.NewStore(wc.engine) @@ -199,139 +203,8 @@ func NewCircom2WitnessCalculator(wasmBytes []byte, sanityCheck bool) (*Circom2Wi } // CalculateWitness calculates the witness given the inputs. -func (wc *Circom2WitnessCalculator) CalculateWitness(inputs map[string]interface{}, sanityCheck bool) ([]*big.Int, error) { - - w := make([]*big.Int, wc.witnessSize) - - err := wc.doCalculateWitness(inputs, sanityCheck) - if err != nil { - return nil, err - } - - for i := 0; i < int(wc.witnessSize); i++ { - _, err := wc.getWitness(i) - if err != nil { - return nil, err - } - arr := make([]uint32, wc.n32) - for j := 0; j < int(wc.n32); j++ { - val, err := wc.readSharedRWMemory(int32(j)) - if err != nil { - return nil, err - } - arr[int(wc.n32)-1-j] = uint32(val.(int32)) - } - w[i] = fromArray32(arr) - } - - return w, nil -} - -// CalculateBinWitness calculates the witness in binary given the inputs. -func (wc *Circom2WitnessCalculator) CalculateBinWitness(inputs map[string]interface{}, sanityCheck bool) ([]byte, error) { - buff := new(bytes.Buffer) - - err := wc.doCalculateWitness(inputs, sanityCheck) - if err != nil { - return nil, err - } - - for i := 0; i < int(wc.witnessSize); i++ { - _, err := wc.getWitness(i) - if err != nil { - return nil, err - } - - for j := 0; j < int(wc.n32); j++ { - val, err := wc.readSharedRWMemory(j) - if err != nil { - return nil, err - } - _ = binary.Write(buff, binary.LittleEndian, uint32(val.(int32))) - } - } - - return buff.Bytes(), nil -} - -// CalculateWTNSBin calculates the witness in binary given the inputs. -func (wc *Circom2WitnessCalculator) CalculateWTNSBin(inputs map[string]interface{}, sanityCheck bool) ([]byte, error) { - buff := new(bytes.Buffer) - - err := wc.doCalculateWitness(inputs, sanityCheck) - if err != nil { - return nil, err - } - - buff.Grow(int(wc.witnessSize*wc.n32 + wc.n32 + 11)) - - // wtns - _ = buff.WriteByte('w') - _ = buff.WriteByte('t') - _ = buff.WriteByte('n') - _ = buff.WriteByte('s') - - //version 2 - _ = binary.Write(buff, binary.LittleEndian, uint32(2)) - - //number of sections: 2 - _ = binary.Write(buff, binary.LittleEndian, uint32(2)) - - //id section 1 - _ = binary.Write(buff, binary.LittleEndian, uint32(1)) - - n8 := wc.n32 * 4 - //id section 1 length in 64bytes - idSection1length := 8 + n8 - _ = binary.Write(buff, binary.LittleEndian, uint64(idSection1length)) - - //this.n32 - _ = binary.Write(buff, binary.LittleEndian, uint32(n8)) - - //prime number - _, err = wc.getRawPrime() - if err != nil { - return nil, err - } - - for j := 0; j < int(wc.n32); j++ { - val, err := wc.readSharedRWMemory(int32(j)) - if err != nil { - return nil, err - } - _ = binary.Write(buff, binary.LittleEndian, uint32(val.(int32))) - } - - // witness size - _ = binary.Write(buff, binary.LittleEndian, uint32(wc.witnessSize)) - - //id section 2 - _ = binary.Write(buff, binary.LittleEndian, uint32(2)) - - // section 2 length - idSection2length := n8 * wc.witnessSize - _ = binary.Write(buff, binary.LittleEndian, uint64(idSection2length)) - - for i := 0; i < int(wc.witnessSize); i++ { - _, err := wc.getWitness(i) - if err != nil { - return nil, err - } - - for j := 0; j < int(wc.n32); j++ { - val, err := wc.readSharedRWMemory(j) - if err != nil { - return nil, err - } - _ = binary.Write(buff, binary.LittleEndian, uint32(val.(int32))) - } - } - - return buff.Bytes(), nil -} - -// CalculateWitness calculates the witness given the inputs. -func (wc *Circom2WitnessCalculator) doCalculateWitness(inputs map[string]interface{}, sanityCheck bool) (funcErr error) { +func (wc *Circom2WitnessCalculator) doCalculateWitness(inputs map[string]interface{}, + sanityCheck bool) (funcErr error) { //input is assumed to be a map from signals to arrays of bigInts sanityCheckVal := int32(0) if sanityCheck { @@ -584,3 +457,79 @@ func fromArray32(arr []uint32) *big.Int { } return res } + +// fnvHash returns the 64 bit FNV-1a hash split into two 32 bit values: (MSB, LSB) +func fnvHash(s string) (int32, int32) { + hash := fnv.New64a() + hash.Write([]byte(s)) + h := hash.Sum64() + return int32(h >> 32), int32(h & 0xffffffff) +} + +// _flatSlice is a recursive helper function for flatSlice. +func _flatSlice(acc *[]*big.Int, v interface{}) { + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Slice: + for i := 0; i < rv.Len(); i++ { + _flatSlice(acc, rv.Index(i).Interface()) + } + default: + *acc = append(*acc, v.(*big.Int)) + } +} + +// flatSlice takes a structure that contains a recursive combination of slices +// and *big.Int and flattens it into a single slice. +func flatSlice(v interface{}) []*big.Int { + res := make([]*big.Int, 0) + _flatSlice(&res, v) + return res +} + +func (wc *Circom2WitnessCalculator) Calculate(inputs map[string]interface{}, + sanityCheck bool) (wtns witness.Witness, err error) { + + err = wc.doCalculateWitness(inputs, sanityCheck) + if err != nil { + return wtns, err + } + + //prime number + _, err = wc.getRawPrime() + if err != nil { + return wtns, err + } + + n8 := wc.n32 * 4 + bigIntBuf := make([]byte, n8) + for j := 0; j < int(wc.n32); j++ { + val, err := wc.readSharedRWMemory(int32(j)) + if err != nil { + return wtns, err + } + binary.LittleEndian.PutUint32(bigIntBuf[j*4:], uint32(val.(int32))) + } + + wtns.Prime = new(big.Int).SetBytes(utils.SwapEndianness(bigIntBuf)) + wtns.N32 = int(wc.n32) + + wtns.Witness = make([]*big.Int, wc.witnessSize) + for i := 0; i < int(wc.witnessSize); i++ { + _, err := wc.getWitness(i) + if err != nil { + return wtns, err + } + + for j := 0; j < int(wc.n32); j++ { + val, err := wc.readSharedRWMemory(j) + if err != nil { + return wtns, err + } + binary.LittleEndian.PutUint32(bigIntBuf[j*4:], uint32(val.(int32))) + } + wtns.Witness[i] = new(big.Int).SetBytes(utils.SwapEndianness(bigIntBuf)) + } + + return wtns, nil +} diff --git a/witness/wasmer/go.mod b/witness/wasmer/go.mod new file mode 100644 index 0000000..1a05969 --- /dev/null +++ b/witness/wasmer/go.mod @@ -0,0 +1,13 @@ +module github.com/iden3/go-rapidsnark/witness/wasmer + +go 1.18 + +require ( + github.com/iden3/go-iden3-crypto v0.0.15 + github.com/iden3/go-rapidsnark/witness/v2 v2.0.0-20230523125954-fcfab2575c4d + github.com/iden3/wasmer-go v0.0.1 +) + +require golang.org/x/sys v0.6.0 // indirect + +replace github.com/iden3/go-rapidsnark/witness/v2 => ../ diff --git a/witness/wasmer/go.sum b/witness/wasmer/go.sum new file mode 100644 index 0000000..df3411f --- /dev/null +++ b/witness/wasmer/go.sum @@ -0,0 +1,11 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/iden3/go-iden3-crypto v0.0.15 h1:4MJYlrot1l31Fzlo2sF56u7EVFeHHJkxGXXZCtESgK4= +github.com/iden3/go-iden3-crypto v0.0.15/go.mod h1:dLpM4vEPJ3nDHzhWFXDjzkn1qHoBeOT/3UEhXsEsP3E= +github.com/iden3/wasmer-go v0.0.1 h1:TZKh8Se8B/73PvWrcu+FTU9L1k5XYAmtFbioj7l0Uog= +github.com/iden3/wasmer-go v0.0.1/go.mod h1:ZnZBAO012M7o+Q1INXLRIxKQgEcH2FuwL0Iga8A4ufg= +github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/witness/circom2witnesscalc_wazero.go b/witness/wazero/circom2witnesscalc_wazero.go similarity index 68% rename from witness/circom2witnesscalc_wazero.go rename to witness/wazero/circom2witnesscalc_wazero.go index 6c0dd55..cea8514 100644 --- a/witness/circom2witnesscalc_wazero.go +++ b/witness/wazero/circom2witnesscalc_wazero.go @@ -1,31 +1,32 @@ -package witness +package wazero import ( "bytes" "context" - "encoding/binary" "errors" "fmt" + "hash/fnv" "log" "math" "math/big" "strings" "github.com/iden3/go-iden3-crypto/constants" - "github.com/tetratelabs/wazero" + "github.com/iden3/go-rapidsnark/witness/v2" + wz "github.com/tetratelabs/wazero" "github.com/tetratelabs/wazero/api" ) type Circom2WZWitnessCalculator struct { - runtime wazero.Runtime + runtime wz.Runtime modRuntime api.Module - compiledModule wazero.CompiledModule + compiledModule wz.CompiledModule } func NewCircom2WZWitnessCalculator( - wasmBytes []byte) (*Circom2WZWitnessCalculator, error) { + wasmBytes []byte) (witness.CalculatorImpl, error) { - runtime := wazero.NewRuntime(context.Background()) + runtime := wz.NewRuntime(context.Background()) ctx := context.Background() modRuntime, err := runtime.NewHostModuleBuilder("runtime"). @@ -84,211 +85,8 @@ func (w *Circom2WZWitnessCalculator) Close() error { return err } -// CalculateWitness calculates the witness given the inputs. -func (wc *Circom2WZWitnessCalculator) CalculateWitness(inputs map[string]interface{}, - sanityCheck bool) (wtns []*big.Int, err error) { - - wCtxState := &witnessCtxState{} - ctx := withWtnsCtx(context.Background(), wCtxState) - - cfg := wazero.NewModuleConfig() - var instance api.Module - instance, err = wc.runtime.InstantiateModule(ctx, wc.compiledModule, cfg) - if err != nil { - return nil, err - } - defer closeWithErrOrLog(ctx, instance, &err) - - var wCtx witnessCtx - wCtx, err = calculateWtnsCtx(ctx, instance) - if err != nil { - return nil, err - } - - err = wc.doCalculateWitness(ctx, instance, wCtx, inputs, sanityCheck) - if err != nil { - return nil, err - } - - wtns = make([]*big.Int, wCtx.witnessSize) - - for i := 0; i < int(wCtx.witnessSize); i++ { - err = wCtx.getWitness(ctx, int32(i)) - if err != nil { - return nil, err - } - arr := make([]uint32, wCtx.n32) - for j := 0; j < int(wCtx.n32); j++ { - var val int32 - val, err = wCtx.readSharedRWMemory(ctx, int32(j)) - if err != nil { - return nil, err - } - arr[int(wCtx.n32)-1-j] = uint32(val) - } - wtns[i] = fromArray32(arr) - } - - return wtns, wCtxState.err() -} - -// CalculateBinWitness calculates the witness in binary given the inputs. -func (wc *Circom2WZWitnessCalculator) CalculateBinWitness(inputs map[string]interface{}, - sanityCheck bool) (wtns []byte, err error) { - - wCtxState := &witnessCtxState{} - ctx := withWtnsCtx(context.Background(), wCtxState) - - cfg := wazero.NewModuleConfig() - var instance api.Module - instance, err = wc.runtime.InstantiateModule(ctx, wc.compiledModule, cfg) - if err != nil { - return nil, err - } - defer closeWithErrOrLog(ctx, instance, &err) - - var wCtx witnessCtx - // wCtx is closure around ctx, do not return it, use only in this function. - wCtx, err = calculateWtnsCtx(ctx, instance) - if err != nil { - return nil, err - } - - err = wc.doCalculateWitness(ctx, instance, wCtx, inputs, sanityCheck) - if err != nil { - return nil, err - } - - buff := new(bytes.Buffer) - - for i := 0; i < int(wCtx.witnessSize); i++ { - err = wCtx.getWitness(ctx, int32(i)) - if err != nil { - return nil, err - } - - for j := 0; j < int(wCtx.n32); j++ { - val, err := wCtx.readSharedRWMemory(ctx, int32(j)) - if err != nil { - return nil, err - } - _ = binary.Write(buff, binary.LittleEndian, uint32(val)) - } - } - - return buff.Bytes(), wCtxState.err() -} - -// CalculateWTNSBin calculates the witness in binary given the inputs. -func (wc *Circom2WZWitnessCalculator) CalculateWTNSBin(inputs map[string]interface{}, - sanityCheck bool) (wtns []byte, err error) { - - wCtxState := &witnessCtxState{} - ctx := withWtnsCtx(context.Background(), wCtxState) - - cfg := wazero.NewModuleConfig() - var instance api.Module - instance, err = wc.runtime.InstantiateModule(ctx, wc.compiledModule, cfg) - if err != nil { - return nil, err - } - defer closeWithErrOrLog(ctx, instance, &err) - - var wCtx witnessCtx - // wCtx is closure around ctx, do not return it, use only in this function. - wCtx, err = calculateWtnsCtx(ctx, instance) - if err != nil { - return nil, err - } - - err = wc.doCalculateWitness(ctx, instance, wCtx, inputs, sanityCheck) - if err != nil { - return nil, err - } - - buff := new(bytes.Buffer) - - var wResult []uint64 - wResult, err = instance.ExportedFunction("getWitnessSize").Call(ctx) - if err != nil { - return nil, err - } - witnessSize := api.DecodeI32(wResult[0]) - - buff.Grow(int(witnessSize*wCtx.n32 + wCtx.n32 + 11)) - - // wtns - _ = buff.WriteByte('w') - _ = buff.WriteByte('t') - _ = buff.WriteByte('n') - _ = buff.WriteByte('s') - - //version 2 - _ = binary.Write(buff, binary.LittleEndian, uint32(2)) - - //number of sections: 2 - _ = binary.Write(buff, binary.LittleEndian, uint32(2)) - - //id section 1 - _ = binary.Write(buff, binary.LittleEndian, uint32(1)) - - n8 := wCtx.n32 * 4 - //id section 1 length in 64bytes - idSection1length := 8 + n8 - _ = binary.Write(buff, binary.LittleEndian, uint64(idSection1length)) - - //this.n32 - _ = binary.Write(buff, binary.LittleEndian, uint32(n8)) - - //prime number - _, err = instance.ExportedFunction("getRawPrime").Call(ctx) - if err != nil { - return nil, err - } - - for j := 0; j < int(wCtx.n32); j++ { - data, err := wCtx.readSharedRWMemory(ctx, int32(j)) - if err != nil { - return nil, err - } - _ = binary.Write(buff, binary.LittleEndian, uint32(data)) - } - - // witness size - _ = binary.Write(buff, binary.LittleEndian, uint32(witnessSize)) - - //id section 2 - _ = binary.Write(buff, binary.LittleEndian, uint32(2)) - - // section 2 length - idSection2length := n8 * witnessSize - _ = binary.Write(buff, binary.LittleEndian, uint64(idSection2length)) - - getWitness := instance.ExportedFunction("getWitness") - for i := 0; i < int(witnessSize); i++ { - _, err = getWitness.Call(ctx, api.EncodeI32(int32(i))) - if err != nil { - return nil, err - } - - for j := 0; j < int(wCtx.n32); j++ { - var data int32 - data, err = wCtx.readSharedRWMemory(ctx, int32(j)) - if err != nil { - return nil, err - } - _ = binary.Write(buff, binary.LittleEndian, uint32(data)) - } - } - - return buff.Bytes(), wCtxState.err() -} - func (w *Circom2WZWitnessCalculator) doCalculateWitness(ctx context.Context, - instance api.Module, - wCtx witnessCtx, - inputs map[string]any, - sanityCheck bool) (err error) { + wCtx witnessCtx, inputs map[string]any, sanityCheck bool) (err error) { if err = wCtx.init(ctx, sanityCheck); err != nil { return err @@ -348,6 +146,30 @@ type witnessCtx struct { setInputSignal func(ctx context.Context, hMSB, hLSB, z int32) error readSharedRWMemory func(ctx context.Context, i int32) (int32, error) getWitness func(ctx context.Context, i int32) error + getRawPrime func(ctx context.Context) error +} + +func (wCtx *witnessCtx) prime(ctx context.Context) (*big.Int, error) { + + err := wCtx.getRawPrime(ctx) + if err != nil { + return nil, err + } + + return wCtx.readInt(ctx) +} + +func (wCtx *witnessCtx) readInt(ctx context.Context) (*big.Int, error) { + arr := make([]uint32, wCtx.n32) + for j := 0; j < int(wCtx.n32); j++ { + val, err := wCtx.readSharedRWMemory(ctx, int32(j)) + if err != nil { + return nil, err + } + arr[int(wCtx.n32)-1-j] = uint32(val) + } + + return fromArray32(arr), nil } func calculateWtnsCtx(ctx context.Context, @@ -431,6 +253,12 @@ func calculateWtnsCtx(ctx context.Context, return err2 } + _getRawPrime := instance.ExportedFunction("getRawPrime") + wCtx.getRawPrime = func(ctx context.Context) error { + _, err2 := _getRawPrime.Call(ctx) + return err2 + } + return wCtx, nil } @@ -640,3 +468,67 @@ func printSharedRWMemory(ctx context.Context, m api.Module) { wtnsCtx.msgStrs = append(wtnsCtx.msgStrs, fromArray32(arr).Text(10)) } + +func fromArray32(arr []uint32) *big.Int { + res := new(big.Int) + radix := big.NewInt(0x100000000) + for i := 0; i < len(arr); i++ { + res.Mul(res, radix) + res.Add(res, big.NewInt(int64(arr[i]))) + } + return res +} + +// fnvHash returns the 64 bit FNV-1a hash split into two 32 bit values: (MSB, LSB) +func fnvHash(s string) (int32, int32) { + hash := fnv.New64a() + hash.Write([]byte(s)) + h := hash.Sum64() + return int32(h >> 32), int32(h & 0xffffffff) +} + +// Calculate calculates the witness given the inputs. +func (wc *Circom2WZWitnessCalculator) Calculate(inputs map[string]interface{}, + sanityCheck bool) (wtns witness.Witness, err error) { + + wCtxState := &witnessCtxState{} + ctx := withWtnsCtx(context.Background(), wCtxState) + + cfg := wz.NewModuleConfig() + var instance api.Module + instance, err = wc.runtime.InstantiateModule(ctx, wc.compiledModule, cfg) + if err != nil { + return wtns, err + } + defer closeWithErrOrLog(ctx, instance, &err) + + var wCtx witnessCtx + wCtx, err = calculateWtnsCtx(ctx, instance) + if err != nil { + return wtns, err + } + + wtns.N32 = int(wCtx.n32) + + err = wc.doCalculateWitness(ctx, wCtx, inputs, sanityCheck) + if err != nil { + return wtns, err + } + + wtns.Witness = make([]*big.Int, wCtx.witnessSize) + + for i := 0; i < int(wCtx.witnessSize); i++ { + err = wCtx.getWitness(ctx, int32(i)) + if err != nil { + return wtns, err + } + wtns.Witness[i], err = wCtx.readInt(ctx) + } + + wtns.Prime, err = wCtx.prime(ctx) + if err != nil { + return wtns, err + } + + return wtns, wCtxState.err() +} diff --git a/witness/wazero/go.mod b/witness/wazero/go.mod new file mode 100644 index 0000000..73ad2e2 --- /dev/null +++ b/witness/wazero/go.mod @@ -0,0 +1,13 @@ +module github.com/iden3/go-rapidsnark/witness/wazero + +go 1.18 + +require ( + github.com/iden3/go-iden3-crypto v0.0.15 + github.com/iden3/go-rapidsnark/witness/v2 v2.0.0-20230523125954-fcfab2575c4d + github.com/tetratelabs/wazero v1.1.0 +) + +require golang.org/x/sys v0.6.0 // indirect + +replace github.com/iden3/go-rapidsnark/witness/v2 => ../ diff --git a/witness/wazero/go.sum b/witness/wazero/go.sum new file mode 100644 index 0000000..ee7bf7b --- /dev/null +++ b/witness/wazero/go.sum @@ -0,0 +1,11 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/iden3/go-iden3-crypto v0.0.15 h1:4MJYlrot1l31Fzlo2sF56u7EVFeHHJkxGXXZCtESgK4= +github.com/iden3/go-iden3-crypto v0.0.15/go.mod h1:dLpM4vEPJ3nDHzhWFXDjzkn1qHoBeOT/3UEhXsEsP3E= +github.com/leanovate/gopter v0.2.9 h1:fQjYxZaynp97ozCzfOyOuAGOU4aU/z37zf/tOujFk7c= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8= +github.com/tetratelabs/wazero v1.1.0 h1:EByoAhC+QcYpwSZJSs/aV0uokxPwBgKxfiokSUwAknQ= +github.com/tetratelabs/wazero v1.1.0/go.mod h1:wYx2gNRg8/WihJfSDxA1TIL8H+GkfLYm+bIfbblu9VQ= +golang.org/x/sys v0.6.0 h1:MVltZSvRTcU2ljQOhs94SXPftV6DCNnZViHeQps87pQ= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/witness/witness.go b/witness/witness.go new file mode 100644 index 0000000..69ad56b --- /dev/null +++ b/witness/witness.go @@ -0,0 +1,168 @@ +package witness + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "math/big" + + "github.com/iden3/go-iden3-crypto/utils" +) + +type Option func(cfg *calcConfig) + +func WithWasmEngine(calculator func([]byte) (CalculatorImpl, error)) Option { + return func(cfg *calcConfig) { + cfg.wasmEngine = calculator + } +} + +type CalculatorImpl interface { + Calculate(inputs map[string]interface{}, + sanityCheck bool) (wtns Witness, err error) +} + +type Calculator interface { + CalculateWitness(inputs map[string]interface{}, + sanityCheck bool) ([]*big.Int, error) + CalculateBinWitness(inputs map[string]interface{}, + sanityCheck bool) ([]byte, error) + CalculateWTNSBin(inputs map[string]interface{}, + sanityCheck bool) ([]byte, error) +} + +type calcConfig struct { + wasmEngine func([]byte) (CalculatorImpl, error) +} + +type calc struct { + wc CalculatorImpl +} + +func (c *calc) CalculateWitness(inputs map[string]interface{}, + sanityCheck bool) ([]*big.Int, error) { + + wtns, err := c.wc.Calculate(inputs, sanityCheck) + if err != nil { + return nil, err + } + return wtns.Witness, nil +} + +func (c *calc) CalculateBinWitness(inputs map[string]interface{}, + sanityCheck bool) ([]byte, error) { + + wtns, err := c.wc.Calculate(inputs, sanityCheck) + if err != nil { + return nil, err + } + + var b bytes.Buffer + b.Grow(wtns.N32 * 4 * len(wtns.Witness)) + for _, i := range wtns.Witness { + bs := utils.SwapEndianness(i.Bytes()) + b.Write(bs) + if len(bs) < wtns.N32*4 { + for j := 0; j < (wtns.N32*4)-len(bs); j++ { + b.WriteByte(0) + } + } + } + + return b.Bytes(), nil +} + +func (c *calc) CalculateWTNSBin(inputs map[string]interface{}, + sanityCheck bool) ([]byte, error) { + + wtns, err := c.wc.Calculate(inputs, sanityCheck) + if err != nil { + return nil, err + } + + buff := new(bytes.Buffer) + + n8 := wtns.N32 * 4 + idSection2length := n8 * len(wtns.Witness) + + totalLn := 4 + 4 + 4 + 4 + 8 + 4 + n8 + 4 + 4 + 8 + idSection2length + buff.Grow(totalLn) + + // wtns + _, _ = buff.Write([]byte("wtns")) + + //version 2 + _ = binary.Write(buff, binary.LittleEndian, uint32(2)) + + //number of sections: 2 + _ = binary.Write(buff, binary.LittleEndian, uint32(2)) + + //id section 1 + _ = binary.Write(buff, binary.LittleEndian, uint32(1)) + + //id section 1 length in 64bytes + idSection1length := 8 + n8 + _ = binary.Write(buff, binary.LittleEndian, uint64(idSection1length)) + + //this.n32 + _ = binary.Write(buff, binary.LittleEndian, uint32(n8)) + + err = writeInt(buff, wtns.Prime, n8) + if err != nil { + return nil, err + } + + // witness size + _ = binary.Write(buff, binary.LittleEndian, uint32(len(wtns.Witness))) + + //id section 2 + _ = binary.Write(buff, binary.LittleEndian, uint32(2)) + + // section 2 length + _ = binary.Write(buff, binary.LittleEndian, uint64(idSection2length)) + + for _, i := range wtns.Witness { + err = writeInt(buff, i, n8) + if err != nil { + return nil, err + } + } + + return buff.Bytes(), nil +} + +func writeInt(out io.Writer, i *big.Int, bytesLn int) error { + bs := utils.SwapEndianness(i.Bytes()) + _, err := out.Write(bs) + if err != nil { + return err + } + if len(bs) < bytesLn { + _, err = out.Write(make([]byte, bytesLn-len(bs))) + } + + return err +} + +func NewCalculator(wasm []byte, ops ...Option) (Calculator, error) { + var config calcConfig + for _, op := range ops { + op(&config) + } + if config.wasmEngine == nil { + return nil, errors.New("witness calculator wasm engine not set") + } + wc, err := config.wasmEngine(wasm) + if err != nil { + return nil, err + } + return &calc{wc: wc}, nil +} + +type Witness struct { + // number of int32 values required to represent the *big.Int + N32 int + Prime *big.Int + Witness []*big.Int +}