Permalink
Browse files

encoding/base64: This change modifies Go to take strict option when d…

…ecoding base64

If strict option is enabled, when decoding, instead of skip the padding
bits, it will do strict check to enforce they are set to zero.

Fixes #15656

Change-Id: I869fb725a39cc9dde44dbc4ff0046446e7abc642
Reviewed-on: https://go-review.googlesource.com/24964
Reviewed-by: Russ Cox <rsc@golang.org>
Run-TryBot: Russ Cox <rsc@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
  • Loading branch information...
xuyangkang authored and rsc committed Jul 17, 2016
1 parent 56b5546 commit 87b1aaa37cefec8deacdf9c3c30d26015bdfb00b
Showing with 42 additions and 3 deletions.
  1. +21 −3 src/encoding/base64/base64.go
  2. +21 −0 src/encoding/base64/base64_test.go
@@ -23,6 +23,7 @@ type Encoding struct {
encode [64]byte
decodeMap [256]byte
padChar rune
strict bool
}

const (
@@ -62,6 +63,14 @@ func (enc Encoding) WithPadding(padding rune) *Encoding {
return &enc
}

// Strict creates a new encoding identical to enc except with
// strict decoding enabled. In this mode, the decoder requires that
// trailing padding bits are zero, as described in RFC 4648 section 3.5.
func (enc Encoding) Strict() *Encoding {
enc.strict = true
return &enc
}

// StdEncoding is the standard base64 encoding, as defined in
// RFC 4648.
var StdEncoding = NewEncoding(encodeStd)
@@ -311,15 +320,24 @@ func (enc *Encoding) decode(dst, src []byte) (n int, end bool, err error) {

// Convert 4x 6bit source bytes into 3 bytes
val := uint(dbuf[0])<<18 | uint(dbuf[1])<<12 | uint(dbuf[2])<<6 | uint(dbuf[3])
dbuf[2], dbuf[1], dbuf[0] = byte(val>>0), byte(val>>8), byte(val>>16)
switch dlen {
case 4:
dst[2] = byte(val >> 0)
dst[2] = dbuf[2]
dbuf[2] = 0
fallthrough
case 3:
dst[1] = byte(val >> 8)
dst[1] = dbuf[1]
if enc.strict && dbuf[2] != 0 {
return n, end, CorruptInputError(si - 1)
}
dbuf[1] = 0
fallthrough
case 2:
dst[0] = byte(val >> 16)
dst[0] = dbuf[0]
if enc.strict && (dbuf[1] != 0 || dbuf[2] != 0) {
return n, end, CorruptInputError(si - 2)
}
}
dst = dst[dinc:]
n += dlen - 1
@@ -85,6 +85,11 @@ var encodingTests = []encodingTest{
{RawStdEncoding, rawRef},
{RawURLEncoding, rawUrlRef},
{funnyEncoding, funnyRef},
{StdEncoding.Strict(), stdRef},
{URLEncoding.Strict(), urlRef},
{RawStdEncoding.Strict(), rawRef},
{RawURLEncoding.Strict(), rawUrlRef},
{funnyEncoding.Strict(), funnyRef},
}

var bigtest = testpair{
@@ -436,6 +441,22 @@ func TestDecoderIssue7733(t *testing.T) {
}
}

func TestDecoderIssue15656(t *testing.T) {
_, err := StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
want := CorruptInputError(22)
if !reflect.DeepEqual(want, err) {
t.Errorf("Error = %v; want CorruptInputError(22)", err)
}
_, err = StdEncoding.Strict().DecodeString("WvLTlMrX9NpYDQlEIFlnDA==")
if err != nil {
t.Errorf("Error = %v; want nil", err)
}
_, err = StdEncoding.DecodeString("WvLTlMrX9NpYDQlEIFlnDB==")
if err != nil {
t.Errorf("Error = %v; want nil", err)
}
}

func BenchmarkEncodeToString(b *testing.B) {
data := make([]byte, 8192)
b.SetBytes(int64(len(data)))

0 comments on commit 87b1aaa

Please sign in to comment.