From 58af1df4077b639bcf57a1c5abb68cea5f8b8be9 Mon Sep 17 00:00:00 2001 From: frankw Date: Thu, 7 May 2020 19:41:01 -0700 Subject: [PATCH] Simplify AVX512 code to use 16x wide 32-bit pointers --- block16_amd64.s | 34 ++++++++++------------------------ block16_amd64_test.go | 34 +++++++++++++++++++++------------- block_amd64.go | 23 +++++++++++++++-------- 3 files changed, 46 insertions(+), 45 deletions(-) diff --git a/block16_amd64.s b/block16_amd64.s index 5bd4e47..d32c122 100644 --- a/block16_amd64.s +++ b/block16_amd64.s @@ -5,12 +5,8 @@ // This is the AVX512 implementation of the MD5 block function (16-way parallel) #define prep(index) \ - KMOVB kmaskL, ktmp \ - VPGATHERQD index*4(base)(ptrsLow*1), ktmp, ymemLow \ - KMOVB kmaskH, ktmp \ - VPGATHERQD index*4(base)(ptrsHigh*1), ktmp, ymemHigh \ - VALIGND $8, memHigh, memHigh, memHigh \ - VPORD memHigh, mem, mem + KMOVQ kmask, ktmp \ + VPGATHERDD index*4(base)(ptrs*1), ktmp, mem #define ROUND1(a, b, c, d, index, const, shift) \ VXORPS c, tmp, tmp \ @@ -62,13 +58,13 @@ VXORPS c, ones, tmp \ VPADDD b, a, a -TEXT ·block16(SB),4,$0-32 +TEXT ·block16(SB),4,$0-40 MOVQ state+0(FP), BX - XORQ SI, SI // null out base pointer (using absolute 64-bit pointers) - MOVQ ptrs+8(FP), AX - KMOVQ mask+16(FP), K1 - MOVQ n+24(FP), DX + MOVQ base+8(FP), SI + MOVQ ptrs+16(FP), AX + KMOVQ mask+24(FP), K1 + MOVQ n+32(FP), DX MOVQ ·avx512md5consts+0(SB), DI #define a Z0 @@ -83,16 +79,11 @@ TEXT ·block16(SB),4,$0-32 #define tmp Z8 #define tmp2 Z9 -#define ptrsLow Z10 -#define ptrsHigh Z11 +#define ptrs Z10 #define ones Z12 #define mem Z15 -#define ymemLow Y15 -#define memHigh Z14 -#define ymemHigh Y14 -#define kmaskL K1 -#define kmaskH K2 +#define kmask K1 #define ktmp K3 // ---------------------------------------------------------- @@ -112,12 +103,7 @@ TEXT ·block16(SB),4,$0-32 VMOVUPD 0xc0(dig), d // load source pointers - VMOVUPD 0x00(AX), ptrsLow - VMOVUPD 0x40(AX), ptrsHigh - - // setup masks - KMOVW kmaskL, kmaskH - KSHIFTRW $8, kmaskH, kmaskH + VMOVUPD 0x00(AX), ptrs MOVQ $-1, AX VPBROADCASTQ AX, ones diff --git a/block16_amd64_test.go b/block16_amd64_test.go index 56a9f13..2c3b6be 100644 --- a/block16_amd64_test.go +++ b/block16_amd64_test.go @@ -59,14 +59,16 @@ func TestBlock16(t *testing.T) { s.v0[i], s.v1[i], s.v2[i], s.v3[i] = init0, init1, init2, init3 } - ptrs := [16]int64{} + bufs := [16]int32{4, 4 + internalBlockSize, 4 + internalBlockSize*2, 4 + internalBlockSize*3, 4 + internalBlockSize*4, 4 + internalBlockSize*5, 4 + internalBlockSize*6, 4 + internalBlockSize*7, + 4 + internalBlockSize*8, 4 + internalBlockSize*9, 4 + internalBlockSize*10, 4 + internalBlockSize*11, 4 + internalBlockSize*12, 4 + internalBlockSize*13, 4 + internalBlockSize*14, 4 + internalBlockSize*15} - for i := range ptrs { - ptrs[i] = int64(uintptr(unsafe.Pointer(&(input[i][0])))) - // fmt.Printf("%016x\n", ptrs[i]) + base := make([]byte, 4+16*internalBlockSize) + + for i := 0; i < len(input); i++ { + copy(base[bufs[i]:], input[i]) } - block16(&s.v0[0], &ptrs[0], 0xffff, 64) + block16(&s.v0[0], uintptr(unsafe.Pointer(&(base[0]))), &bufs[0], 0xffff, 64) want := `00000000 82 3c 09 52 b9 77 11 2a 65 ee 4c 82 f9 ad 4d 28 |.<.R.w.*e.L...M(| @@ -124,15 +126,18 @@ func TestBlock16Masked(t *testing.T) { s.v0[i], s.v1[i], s.v2[i], s.v3[i] = init0, init1, init2, init3 } - ptrs := [16]int64{} + bufs := [16]int32{4, 4 + internalBlockSize, 4 + internalBlockSize*2, 4 + internalBlockSize*3, 4 + internalBlockSize*4, 4 + internalBlockSize*5, 4 + internalBlockSize*6, 4 + internalBlockSize*7, + 4 + internalBlockSize*8, 4 + internalBlockSize*9, 4 + internalBlockSize*10, 4 + internalBlockSize*11, 4 + internalBlockSize*12, 4 + internalBlockSize*13, 4 + internalBlockSize*14, 4 + internalBlockSize*15} + + base := make([]byte, 4+16*internalBlockSize) - for i := range ptrs { + for i := 0; i < len(input); i++ { if input[i] != nil { - ptrs[i] = int64(uintptr(unsafe.Pointer(&(input[i][0])))) + copy(base[bufs[i]:], input[i]) } } - block16(&s.v0[0], &ptrs[0], mask, 64) + block16(&s.v0[0], uintptr(unsafe.Pointer(&(base[0]))), &bufs[0], mask, 64) want := `00000000 82 3c 09 52 ac 1d 1f 03 65 ee 4c 82 ac 1d 1f 03 |.<.R....e.L.....| @@ -223,10 +228,13 @@ func BenchmarkBlock16(b *testing.B) { s.v0[i], s.v1[i], s.v2[i], s.v3[i] = init0, init1, init2, init3 } - ptrs := [16]int64{} + bufs := [16]int32{4, 4 + internalBlockSize, 4 + internalBlockSize*2, 4 + internalBlockSize*3, 4 + internalBlockSize*4, 4 + internalBlockSize*5, 4 + internalBlockSize*6, 4 + internalBlockSize*7, + 4 + internalBlockSize*8, 4 + internalBlockSize*9, 4 + internalBlockSize*10, 4 + internalBlockSize*11, 4 + internalBlockSize*12, 4 + internalBlockSize*13, 4 + internalBlockSize*14, 4 + internalBlockSize*15} + + base := make([]byte, 4+16*internalBlockSize) - for i := range ptrs { - ptrs[i] = int64(uintptr(unsafe.Pointer(&(input[i][0])))) + for i := 0; i < len(input); i++ { + copy(base[bufs[i]:], input[i]) } b.SetBytes(int64(size * 16)) @@ -234,6 +242,6 @@ func BenchmarkBlock16(b *testing.B) { b.ResetTimer() for j := 0; j < b.N; j++ { - block16(&s.v0[0], &ptrs[0], 0xffff, size) + block16(&s.v0[0], uintptr(unsafe.Pointer(&(base[0]))), &bufs[0], 0xffff, size) } } diff --git a/block_amd64.go b/block_amd64.go index 330b9a8..27d6ce0 100644 --- a/block_amd64.go +++ b/block_amd64.go @@ -21,7 +21,7 @@ var hasAVX512 bool func block8(state *uint32, base uintptr, bufs *int32, cache *byte, n int) //go:noescape -func block16(state *uint32, ptrs *int64, mask uint64, n int) +func block16(state *uint32, base uintptr, ptrs *int32, mask uint64, n int) // 8-way 4x uint32 digests in 4 ymm registers // (ymm0, ymm1, ymm2, ymm3) @@ -89,7 +89,7 @@ func init() { // Interface function to assembly code func (s *md5Server) blockMd5_x16(d *digest16, input [16][]byte, half bool) { if hasAVX512 { - blockMd5_avx512(d, input, &s.maskRounds16) + blockMd5_avx512(d, input, s.allBufs, &s.maskRounds16) } else { d8a, d8b := digest8{}, digest8{} for i := range d8a.v0 { @@ -125,14 +125,21 @@ func (s *md5Server) blockMd5_x16(d *digest16, input [16][]byte, half bool) { } // Interface function to AVX512 assembly code -func blockMd5_avx512(s *digest16, input [16][]byte, maskRounds *[16]maskRounds) { - ptrs := [16]int64{} +func blockMd5_avx512(s *digest16, input [16][]byte, base []byte, maskRounds *[16]maskRounds) { + baseMin := uint64(uintptr(unsafe.Pointer(&(base[0])))) + ptrs := [16]int32{} + for i := range ptrs { - if input[i] != nil { + if len(input[i]) > 0 { if len(input[i]) > internalBlockSize { panic(fmt.Sprintf("Sanity check fails for lane %d: maximum input length cannot exceed internalBlockSize", i)) } - ptrs[i] = int64(uintptr(unsafe.Pointer(&(input[i][0])))) + + off := uint64(uintptr(unsafe.Pointer(&(input[i][0])))) - baseMin + if off > math.MaxUint32 { + panic(fmt.Sprintf("invalid buffer sent with offset %x", off)) + } + ptrs[i] = int32(off) } } @@ -143,10 +150,10 @@ func blockMd5_avx512(s *digest16, input [16][]byte, maskRounds *[16]maskRounds) for r := 0; r < rounds; r++ { m := maskRounds[r] - block16(&sdup.v0[0], &ptrs[0], m.mask, int(64*m.rounds)) + block16(&sdup.v0[0], uintptr(baseMin), &ptrs[0], m.mask, int(64*m.rounds)) for j := 0; j < len(ptrs); j++ { - ptrs[j] += int64(64 * m.rounds) // update pointers for next round + ptrs[j] += int32(64 * m.rounds) // update pointers for next round if m.mask&(1<