Skip to content

Commit

Permalink
Simplify AVX512 code to use 16x wide 32-bit pointers
Browse files Browse the repository at this point in the history
  • Loading branch information
fwessels committed May 8, 2020
1 parent cb281b5 commit 58af1df
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 45 deletions.
34 changes: 10 additions & 24 deletions block16_amd64.s
Expand Up @@ -5,12 +5,8 @@
// This is the AVX512 implementation of the MD5 block function (16-way parallel)

#define prep(index) \
KMOVB kmaskL, ktmp \
VPGATHERQD index*4(base)(ptrsLow*1), ktmp, ymemLow \
KMOVB kmaskH, ktmp \
VPGATHERQD index*4(base)(ptrsHigh*1), ktmp, ymemHigh \
VALIGND $8, memHigh, memHigh, memHigh \
VPORD memHigh, mem, mem
KMOVQ kmask, ktmp \
VPGATHERDD index*4(base)(ptrs*1), ktmp, mem

#define ROUND1(a, b, c, d, index, const, shift) \
VXORPS c, tmp, tmp \
Expand Down Expand Up @@ -62,13 +58,13 @@
VXORPS c, ones, tmp \
VPADDD b, a, a

TEXT ·block16(SB),4,$0-32
TEXT ·block16(SB),4,$0-40

MOVQ state+0(FP), BX
XORQ SI, SI // null out base pointer (using absolute 64-bit pointers)
MOVQ ptrs+8(FP), AX
KMOVQ mask+16(FP), K1
MOVQ n+24(FP), DX
MOVQ base+8(FP), SI
MOVQ ptrs+16(FP), AX
KMOVQ mask+24(FP), K1
MOVQ n+32(FP), DX
MOVQ ·avx512md5consts+0(SB), DI

#define a Z0
Expand All @@ -83,16 +79,11 @@ TEXT ·block16(SB),4,$0-32

#define tmp Z8
#define tmp2 Z9
#define ptrsLow Z10
#define ptrsHigh Z11
#define ptrs Z10
#define ones Z12
#define mem Z15
#define ymemLow Y15
#define memHigh Z14
#define ymemHigh Y14

#define kmaskL K1
#define kmaskH K2
#define kmask K1
#define ktmp K3

// ----------------------------------------------------------
Expand All @@ -112,12 +103,7 @@ TEXT ·block16(SB),4,$0-32
VMOVUPD 0xc0(dig), d

// load source pointers
VMOVUPD 0x00(AX), ptrsLow
VMOVUPD 0x40(AX), ptrsHigh

// setup masks
KMOVW kmaskL, kmaskH
KSHIFTRW $8, kmaskH, kmaskH
VMOVUPD 0x00(AX), ptrs

MOVQ $-1, AX
VPBROADCASTQ AX, ones
Expand Down
34 changes: 21 additions & 13 deletions block16_amd64_test.go
Expand Up @@ -59,14 +59,16 @@ func TestBlock16(t *testing.T) {
s.v0[i], s.v1[i], s.v2[i], s.v3[i] = init0, init1, init2, init3
}

ptrs := [16]int64{}
bufs := [16]int32{4, 4 + internalBlockSize, 4 + internalBlockSize*2, 4 + internalBlockSize*3, 4 + internalBlockSize*4, 4 + internalBlockSize*5, 4 + internalBlockSize*6, 4 + internalBlockSize*7,
4 + internalBlockSize*8, 4 + internalBlockSize*9, 4 + internalBlockSize*10, 4 + internalBlockSize*11, 4 + internalBlockSize*12, 4 + internalBlockSize*13, 4 + internalBlockSize*14, 4 + internalBlockSize*15}

for i := range ptrs {
ptrs[i] = int64(uintptr(unsafe.Pointer(&(input[i][0]))))
// fmt.Printf("%016x\n", ptrs[i])
base := make([]byte, 4+16*internalBlockSize)

for i := 0; i < len(input); i++ {
copy(base[bufs[i]:], input[i])
}

block16(&s.v0[0], &ptrs[0], 0xffff, 64)
block16(&s.v0[0], uintptr(unsafe.Pointer(&(base[0]))), &bufs[0], 0xffff, 64)

want :=
`00000000 82 3c 09 52 b9 77 11 2a 65 ee 4c 82 f9 ad 4d 28 |.<.R.w.*e.L...M(|
Expand Down Expand Up @@ -124,15 +126,18 @@ func TestBlock16Masked(t *testing.T) {
s.v0[i], s.v1[i], s.v2[i], s.v3[i] = init0, init1, init2, init3
}

ptrs := [16]int64{}
bufs := [16]int32{4, 4 + internalBlockSize, 4 + internalBlockSize*2, 4 + internalBlockSize*3, 4 + internalBlockSize*4, 4 + internalBlockSize*5, 4 + internalBlockSize*6, 4 + internalBlockSize*7,
4 + internalBlockSize*8, 4 + internalBlockSize*9, 4 + internalBlockSize*10, 4 + internalBlockSize*11, 4 + internalBlockSize*12, 4 + internalBlockSize*13, 4 + internalBlockSize*14, 4 + internalBlockSize*15}

base := make([]byte, 4+16*internalBlockSize)

for i := range ptrs {
for i := 0; i < len(input); i++ {
if input[i] != nil {
ptrs[i] = int64(uintptr(unsafe.Pointer(&(input[i][0]))))
copy(base[bufs[i]:], input[i])
}
}

block16(&s.v0[0], &ptrs[0], mask, 64)
block16(&s.v0[0], uintptr(unsafe.Pointer(&(base[0]))), &bufs[0], mask, 64)

want :=
`00000000 82 3c 09 52 ac 1d 1f 03 65 ee 4c 82 ac 1d 1f 03 |.<.R....e.L.....|
Expand Down Expand Up @@ -223,17 +228,20 @@ func BenchmarkBlock16(b *testing.B) {
s.v0[i], s.v1[i], s.v2[i], s.v3[i] = init0, init1, init2, init3
}

ptrs := [16]int64{}
bufs := [16]int32{4, 4 + internalBlockSize, 4 + internalBlockSize*2, 4 + internalBlockSize*3, 4 + internalBlockSize*4, 4 + internalBlockSize*5, 4 + internalBlockSize*6, 4 + internalBlockSize*7,
4 + internalBlockSize*8, 4 + internalBlockSize*9, 4 + internalBlockSize*10, 4 + internalBlockSize*11, 4 + internalBlockSize*12, 4 + internalBlockSize*13, 4 + internalBlockSize*14, 4 + internalBlockSize*15}

base := make([]byte, 4+16*internalBlockSize)

for i := range ptrs {
ptrs[i] = int64(uintptr(unsafe.Pointer(&(input[i][0]))))
for i := 0; i < len(input); i++ {
copy(base[bufs[i]:], input[i])
}

b.SetBytes(int64(size * 16))
b.ReportAllocs()
b.ResetTimer()

for j := 0; j < b.N; j++ {
block16(&s.v0[0], &ptrs[0], 0xffff, size)
block16(&s.v0[0], uintptr(unsafe.Pointer(&(base[0]))), &bufs[0], 0xffff, size)
}
}
23 changes: 15 additions & 8 deletions block_amd64.go
Expand Up @@ -21,7 +21,7 @@ var hasAVX512 bool
func block8(state *uint32, base uintptr, bufs *int32, cache *byte, n int)

//go:noescape
func block16(state *uint32, ptrs *int64, mask uint64, n int)
func block16(state *uint32, base uintptr, ptrs *int32, mask uint64, n int)

// 8-way 4x uint32 digests in 4 ymm registers
// (ymm0, ymm1, ymm2, ymm3)
Expand Down Expand Up @@ -89,7 +89,7 @@ func init() {
// Interface function to assembly code
func (s *md5Server) blockMd5_x16(d *digest16, input [16][]byte, half bool) {
if hasAVX512 {
blockMd5_avx512(d, input, &s.maskRounds16)
blockMd5_avx512(d, input, s.allBufs, &s.maskRounds16)
} else {
d8a, d8b := digest8{}, digest8{}
for i := range d8a.v0 {
Expand Down Expand Up @@ -125,14 +125,21 @@ func (s *md5Server) blockMd5_x16(d *digest16, input [16][]byte, half bool) {
}

// Interface function to AVX512 assembly code
func blockMd5_avx512(s *digest16, input [16][]byte, maskRounds *[16]maskRounds) {
ptrs := [16]int64{}
func blockMd5_avx512(s *digest16, input [16][]byte, base []byte, maskRounds *[16]maskRounds) {
baseMin := uint64(uintptr(unsafe.Pointer(&(base[0]))))
ptrs := [16]int32{}

for i := range ptrs {
if input[i] != nil {
if len(input[i]) > 0 {
if len(input[i]) > internalBlockSize {
panic(fmt.Sprintf("Sanity check fails for lane %d: maximum input length cannot exceed internalBlockSize", i))
}
ptrs[i] = int64(uintptr(unsafe.Pointer(&(input[i][0]))))

off := uint64(uintptr(unsafe.Pointer(&(input[i][0])))) - baseMin
if off > math.MaxUint32 {
panic(fmt.Sprintf("invalid buffer sent with offset %x", off))
}
ptrs[i] = int32(off)
}
}

Expand All @@ -143,10 +150,10 @@ func blockMd5_avx512(s *digest16, input [16][]byte, maskRounds *[16]maskRounds)
for r := 0; r < rounds; r++ {
m := maskRounds[r]

block16(&sdup.v0[0], &ptrs[0], m.mask, int(64*m.rounds))
block16(&sdup.v0[0], uintptr(baseMin), &ptrs[0], m.mask, int(64*m.rounds))

for j := 0; j < len(ptrs); j++ {
ptrs[j] += int64(64 * m.rounds) // update pointers for next round
ptrs[j] += int32(64 * m.rounds) // update pointers for next round
if m.mask&(1<<j) != 0 { // update digest if still masked as active
(*s).v0[j], (*s).v1[j], (*s).v2[j], (*s).v3[j] = sdup.v0[j], sdup.v1[j], sdup.v2[j], sdup.v3[j]
}
Expand Down

0 comments on commit 58af1df

Please sign in to comment.