| 
 | 1 | +package websocket  | 
 | 2 | + | 
 | 3 | +import (  | 
 | 4 | +	"bytes"  | 
 | 5 | +	"crypto/rand"  | 
 | 6 | +	"encoding/binary"  | 
 | 7 | +	"math/big"  | 
 | 8 | +	"math/bits"  | 
 | 9 | +	"testing"  | 
 | 10 | + | 
 | 11 | +	"nhooyr.io/websocket/internal/test/assert"  | 
 | 12 | +)  | 
 | 13 | + | 
 | 14 | +func basicMask(b []byte, key uint32) uint32 {  | 
 | 15 | +	for i := range b {  | 
 | 16 | +		b[i] ^= byte(key)  | 
 | 17 | +		key = bits.RotateLeft32(key, -8)  | 
 | 18 | +	}  | 
 | 19 | +	return key  | 
 | 20 | +}  | 
 | 21 | + | 
 | 22 | +func basicMask2(b []byte, key uint32) uint32 {  | 
 | 23 | +	keyb := binary.LittleEndian.AppendUint32(nil, key)  | 
 | 24 | +	pos := 0  | 
 | 25 | +	for i := range b {  | 
 | 26 | +		b[i] ^= keyb[pos&3]  | 
 | 27 | +		pos++  | 
 | 28 | +	}  | 
 | 29 | +	return bits.RotateLeft32(key, (pos&3)*-8)  | 
 | 30 | +}  | 
 | 31 | + | 
 | 32 | +func TestMask(t *testing.T) {  | 
 | 33 | +	t.Parallel()  | 
 | 34 | + | 
 | 35 | +	testMask(t, "basicMask", basicMask)  | 
 | 36 | +	testMask(t, "maskGo", maskGo)  | 
 | 37 | +	testMask(t, "basicMask2", basicMask2)  | 
 | 38 | +}  | 
 | 39 | + | 
 | 40 | +func testMask(t *testing.T, name string, fn func(b []byte, key uint32) uint32) {  | 
 | 41 | +	t.Run(name, func(t *testing.T) {  | 
 | 42 | +		t.Parallel()  | 
 | 43 | +	for i := 0; i < 9999; i++ {  | 
 | 44 | +		keyb := make([]byte, 4)  | 
 | 45 | +		_, err := rand.Read(keyb)  | 
 | 46 | +		assert.Success(t, err)  | 
 | 47 | +		key := binary.LittleEndian.Uint32(keyb)  | 
 | 48 | + | 
 | 49 | +		n, err := rand.Int(rand.Reader, big.NewInt(1<<16))  | 
 | 50 | +		assert.Success(t, err)  | 
 | 51 | + | 
 | 52 | +		b := make([]byte, 1+n.Int64())  | 
 | 53 | +		_, err = rand.Read(b)  | 
 | 54 | +		assert.Success(t, err)  | 
 | 55 | + | 
 | 56 | +		b2 := make([]byte, len(b))  | 
 | 57 | +		copy(b2, b)  | 
 | 58 | +		b3 := make([]byte, len(b))  | 
 | 59 | +		copy(b3, b)  | 
 | 60 | + | 
 | 61 | +		key2 := basicMask(b2, key)  | 
 | 62 | +		key3 := fn(b3, key)  | 
 | 63 | + | 
 | 64 | +		if key2 != key3 {  | 
 | 65 | +			t.Errorf("expected key %X but got %X", key2, key3)  | 
 | 66 | +		}  | 
 | 67 | +		if !bytes.Equal(b2, b3) {  | 
 | 68 | +			t.Error("bad bytes")  | 
 | 69 | +			return  | 
 | 70 | +		}  | 
 | 71 | +	}  | 
 | 72 | +	})  | 
 | 73 | +}  | 
0 commit comments