Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf(zstd): Improve 'matchLen' performance by vector instructions. #823

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
107 changes: 107 additions & 0 deletions zstd/_generate/gen_matchlen.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
package main

//go:generate go run gen_matchlen.go -out ../matchlen_amd64.s -pkg=zstd

import (
"flag"

. "github.com/mmcloughlin/avo/build"
"github.com/mmcloughlin/avo/buildtags"
. "github.com/mmcloughlin/avo/operand"
)

func main() {
flag.Parse()

Constraint(buildtags.Not("appengine").ToConstraint())
Constraint(buildtags.Not("noasm").ToConstraint())
Constraint(buildtags.Term("gc").ToConstraint())
Constraint(buildtags.Not("noasm").ToConstraint())
generateMatchLen()
Generate()
}

func generateMatchLen() {
Package("github.com/klauspost/compress/zstd")
TEXT("matchLen", NOSPLIT, "func (a, b []byte) int")
Pragma("noescape")
Comment("load param")
zzzzwc marked this conversation as resolved.
Show resolved Hide resolved
aptr := Load(Param("a").Base(), GP64())
alen := Load(Param("a").Len(), GP64())
bptr := Load(Param("b").Base(), GP64())
equalMaskBits := GP64()
ret := GP64()
XORQ(ret, ret)

Comment("find the minimum length slice")
remainLen := alen

Label("loop")
{
CMPQ(remainLen, U8(32))
JB(LabelRef("last_loop"))
Comment("load 32 bytes into YMM registers")
adata := YMM()
bdata := YMM()
equalMaskBytes := YMM()
VMOVDQU(Mem{Base: aptr}, adata)
VMOVDQU(Mem{Base: bptr}, bdata)
Comment("compare bytes in adata and bdata, like 'bytewise XNOR'",
"if the byte is the same in adata and bdata, VPCMPEQB will store 0xFF in the same position in equalMaskBytes")
VPCMPEQB(adata, bdata, equalMaskBytes)
Comment("like convert byte to bit, store equalMaskBytes into general reg")
VPMOVMSKB(equalMaskBytes, equalMaskBits.As32())
CMPL(equalMaskBits.As32(), U32(0xffffffff))
JNE(LabelRef("cal_prefix"))
ADDQ(U8(32), aptr)
ADDQ(U8(32), bptr)
SUBQ(U8(32), remainLen)
ADDQ(U8(32), ret)
JMP(LabelRef("loop"))
}

Label("last_loop")
{
TESTQ(remainLen, remainLen)
JZ(LabelRef("ret"))
adata := YMM()
bdata := YMM()
equalMaskBytes := YMM()
VMOVDQU(Mem{Base: aptr}, adata)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell you are over-reading. This is not feasible. You cannot read longer than provided slices.

VMOVDQU(Mem{Base: bptr}, bdata)
VPCMPEQB(adata, bdata, equalMaskBytes)
VPMOVMSKB(equalMaskBytes, equalMaskBits.As32())
CMPL(equalMaskBits.As32(), U32(0xffffffff))
JNE(LabelRef("cal_last_prefix"))
Comment("if last bytes are all equal, just add remaining len on ret and return")
ADDQ(remainLen, ret)
JMP(LabelRef("ret"))
}

Label("cal_last_prefix")
{
matchedLen := GP64()
NOTQ(equalMaskBits)
Comment("store first not equal position into matchedLen")
TZCNTQ(equalMaskBits, matchedLen)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check for BMI

Comment("if matched len > remaining len, just add remaining on ret")
CMPQ(remainLen, matchedLen)
CMOVQLT(remainLen, matchedLen)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fairly predictable, so I suspect a branch will be faster.

If you keep it, check CMOV - not entirely sure this is always present on AMD64.

ADDQ(matchedLen, ret)
JMP(LabelRef("ret"))
}

Label("cal_prefix")
{
matchedLen := GP64()
NOTQ(equalMaskBits)
TZCNTQ(equalMaskBits, matchedLen)
ADDQ(matchedLen, ret)
}

Label("ret")
{
Store(ret, ReturnIndex(0))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VZEROUPPER missing.

RET()
}
}
1 change: 0 additions & 1 deletion zstd/_generate/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ require (
github.com/klauspost/compress v1.15.15
github.com/mmcloughlin/avo v0.5.0
golang.org/x/tools v0.6.0 // indirect
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect
)

replace github.com/klauspost/compress => ../..
24 changes: 1 addition & 23 deletions zstd/_generate/go.sum
Original file line number Diff line number Diff line change
@@ -1,40 +1,26 @@
github.com/mmcloughlin/avo v0.4.0 h1:jeHDRktVD+578ULxWpQHkilor6pkdLF7u7EiTzDbfcU=
github.com/mmcloughlin/avo v0.4.0/go.mod h1:RW9BfYA3TgO9uCdNrKU2h6J8cPD8ZLznvfgHAeszb1s=
github.com/mmcloughlin/avo v0.5.0 h1:nAco9/aI9Lg2kiuROBY6BhCI/z0t5jEvJfjWbL8qXLU=
github.com/mmcloughlin/avo v0.5.0/go.mod h1:ChHFdoV7ql95Wi7vuq2YT1bwCJqiWdZrQ1im3VujLYM=
github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.1.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/mod v0.4.2 h1:Gz96sIWK3OalVv/I/qNygP42zyoKp3xptRVCWRFEBvo=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI=
golang.org/x/mod v0.8.0 h1:LUYupSeNrTNCGzR/hVBk2NHZO4hXcVaW1k4Qx7rjPx8=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021 h1:giLT+HuUP/gXYrG2Plg9WTjj4qhfgaW424ZIFog3rlk=
golang.org/x/sys v0.0.0-20211030160813-b3129d9d1021/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
Expand All @@ -46,22 +32,14 @@ golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.1.7 h1:6j8CgantCy3yc8JGBqkDLMKWqZ0RDU2g1HVgacojGWQ=
golang.org/x/tools v0.1.7/go.mod h1:LGqMHiF4EqQNHR1JncWGqT5BVaXmza+X+BDGol+dOxo=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA=
golang.org/x/tools v0.6.0 h1:BOw41kyTf3PuCW1pVQf8+Cyg8pMlkYB1oo9iJ6D/lKM=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 h1:H2TDz8ibqkAF6YGhCdN3jS9O0/s90v0rJh3X/OLHEUk=
golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2/go.mod h1:K8+ghG5WaK9qNqU5K3HdILfMLy1f3aNYFI/wnl100a8=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
10 changes: 10 additions & 0 deletions zstd/matchlen_amd64.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//go:build amd64 && !appengine && !noasm && gc
// +build amd64,!appengine,!noasm,gc

package zstd

// matchLen returns the maximum common prefix length of a and b.
// a must be the shortest of the two.
//
//go:noescape
func matchLen(a, b []byte) int
72 changes: 72 additions & 0 deletions zstd/matchlen_amd64.s
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Code generated by command: go run gen_matchlen.go -out ../matchlen_amd64.s -pkg=zstd. DO NOT EDIT.

//go:build !appengine && !noasm && gc && !noasm

#include "textflag.h"

// func matchLen(a []byte, b []byte) int
// Requires: AVX, AVX2, BMI, CMOV
TEXT ·matchLen(SB), NOSPLIT, $0-56
// load param
MOVQ a_base+0(FP), AX
MOVQ a_len+8(FP), CX
MOVQ b_base+24(FP), DX
XORQ SI, SI

// find the minimum length slice
loop:
CMPQ CX, $0x20
JB last_loop

// load 32 bytes into YMM registers
VMOVDQU (AX), Y0
VMOVDQU (DX), Y1

// compare bytes in adata and bdata, like 'bytewise XNOR'
// if the byte is the same in adata and bdata, VPCMPEQB will store 0xFF in the same position in equalMaskBytes
VPCMPEQB Y0, Y1, Y0

// like convert byte to bit, store equalMaskBytes into general reg
VPMOVMSKB Y0, BX
CMPL BX, $0xffffffff
JNE cal_prefix
ADDQ $0x20, AX
ADDQ $0x20, DX
SUBQ $0x20, CX
ADDQ $0x20, SI
JMP loop

last_loop:
TESTQ CX, CX
JZ ret
VMOVDQU (AX), Y0
VMOVDQU (DX), Y1
VPCMPEQB Y0, Y1, Y0
VPMOVMSKB Y0, BX
CMPL BX, $0xffffffff
JNE cal_last_prefix

// if last bytes are all equal, just add remaining len on ret and return
ADDQ CX, SI
JMP ret

cal_last_prefix:
NOTQ BX

// store first not equal position into matchedLen
TZCNTQ BX, AX

// if matched len > remaining len, just add remaining on ret
CMPQ CX, AX
CMOVQLT CX, AX
ADDQ AX, SI
JMP ret

cal_prefix:
NOTQ BX
TZCNTQ BX, AX
ADDQ AX, SI

ret:
MOVQ SI, ret+48(FP)
RET
30 changes: 30 additions & 0 deletions zstd/matchlen_generic.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
//go:build !amd64 || appengine || !gc || noasm
// +build !amd64 appengine !gc noasm

package zstd

import (
"encoding/binary"
"math/bits"
)

// matchLen returns the maximum common prefix length of a and b.
// a must be the shortest of the two.
func matchLen(a, b []byte) (n int) {
for ; len(a) >= 8 && len(b) >= 8; a, b = a[8:], b[8:] {
diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b)
if diff != 0 {
return n + bits.TrailingZeros64(diff)>>3
}
n += 8
}

for i := range a {
if a[i] != b[i] {
break
}
n++
}
return n

}
22 changes: 0 additions & 22 deletions zstd/zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"errors"
"log"
"math"
"math/bits"
)

// enable debug printing
Expand Down Expand Up @@ -106,27 +105,6 @@ func printf(format string, a ...interface{}) {
}
}

// matchLen returns the maximum common prefix length of a and b.
// a must be the shortest of the two.
func matchLen(a, b []byte) (n int) {
for ; len(a) >= 8 && len(b) >= 8; a, b = a[8:], b[8:] {
diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b)
if diff != 0 {
return n + bits.TrailingZeros64(diff)>>3
}
n += 8
}

for i := range a {
if a[i] != b[i] {
break
}
n++
}
return n

}

func load3232(b []byte, i int32) uint32 {
return binary.LittleEndian.Uint32(b[:len(b):len(b)][i:])
}
Expand Down