Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ParseMediaType tolerates unencoded 8bit characters #201

Merged
merged 5 commits into from
Jul 15, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
132 changes: 76 additions & 56 deletions header.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ func consumeParam(s string) (consumed, rest string) {
valueQuotedOriginally := false
valueQuoteAdded := false
valueQuoteNeeded := false
rfc2047Needed := false

var r rune
findValueStart:
Expand All @@ -369,7 +370,8 @@ findValueStart:
case '"':
valueQuotedOriginally = true
valueQuoteAdded = true
value.WriteRune(r)
valueQuoteNeeded = true
param.WriteRune(r)

break findValueStart

Expand All @@ -381,6 +383,10 @@ findValueStart:
break findValueStart

default:
if r > 127 {
rfc2047Needed = true
}

valueQuotedOriginally = false
valueQuoteAdded = false
value.WriteRune(r)
Expand All @@ -389,6 +395,20 @@ findValueStart:
}
}

quoteIfUnquoted := func() {
if !valueQuoteNeeded {
if !valueQuoteAdded {
param.WriteByte('"')

valueQuoteAdded = true
}

valueQuoteNeeded = true
}
}

hasRest := false
jhillyerd marked this conversation as resolved.
Show resolved Hide resolved

if len(s)-i < 1 {
// parameter value starts at the end of the string, make empty
// quoted string to play nice with mime.ParseMediaType
Expand All @@ -397,105 +417,105 @@ findValueStart:
} else {
// The beginning of the value is not at the end of the string

quoteIfUnquoted := func() {
if !valueQuoteNeeded {
if !valueQuoteAdded {
param.WriteByte('"')

valueQuoteAdded = true
}

valueQuoteNeeded = true
}
}

for _, v := range []byte{'(', ')', '<', '>', '@', ',', ':', '/', '[', ']', '?', '='} {
if s[0] == v {
quoteIfUnquoted()
break
}
}

s = s[i+1:]
escaped := false

findValueEnd:
for len(s) > 0 {
switch s[0] {
for i, r = range s {
if escaped {
value.WriteRune(r)
escaped = false
continue
}

switch r {
case ';', ' ', '\t':
if valueQuotedOriginally {
// We're in a quoted string, so whitespace is allowed.
value.WriteByte(s[0])
s = s[1:]
value.WriteRune(r)
break
}

// Otherwise, we've reached the end of an unquoted value.

param.WriteString(value.String())
value.Reset()

if valueQuoteNeeded {
param.WriteByte('"')
}

param.WriteByte(s[0])
s = s[1:]

hasRest = true
s = s[i:]
break findValueEnd

case '"':
if valueQuotedOriginally {
// We're in a quoted value. This is the end of that value.
param.WriteString(value.String())
value.Reset()

param.WriteByte(s[0])
s = s[1:]

hasRest = true
s = s[i:]
break findValueEnd
}

quoteIfUnquoted()

value.WriteByte('\\')
value.WriteByte(s[0])
s = s[1:]
value.WriteRune(r)

case '\\':
if len(s) > 1 {
value.WriteByte(s[0])
s = s[1:]

// Backslash escapes the next char. Consume that next char.
value.WriteByte(s[0])

if i < len(s)-1 {
// If next char is present, escape it with backslash
value.WriteRune(r)
escaped = true
quoteIfUnquoted()
}
// Else there is no next char to consume.
s = s[1:]

case '(', ')', '<', '>', '@', ',', ':', '/', '[', ']', '?', '=':
quoteIfUnquoted()

fallthrough

default:
value.WriteByte(s[0])
s = s[1:]
if r > 127 {
rfc2047Needed = true
}
value.WriteRune(r)
}
}
}

if !hasRest {
// Whole string was processed
s = ""
}

if value.Len() > 0 {
// There is a value that ends with the string. Capture it.
param.WriteString(value.String())

if valueQuotedOriginally || valueQuoteNeeded {
// If valueQuotedOriginally is true and we got here,
// that means there was no closing quote. So we'll add one.
// Otherwise, we're here because it was an unquoted value
// with a special char in it, and we had to quote it.
param.WriteByte('"')
// Convert whole value to RFC2047 if it contains forbidden characters (ASCII > 127)
val := value.String()
if rfc2047Needed {
val = mime.BEncoding.Encode("UTF-8", val)
jhillyerd marked this conversation as resolved.
Show resolved Hide resolved
// RFC 2047 must be quoted
quoteIfUnquoted()
}

// Write the value
param.WriteString(val)
}

// Add final quote if required
if valueQuoteNeeded {
param.WriteByte('"')
}

// Write last parsed char if any
if s != "" {
if s[0] != '"' {
// When last char is quote, valueQuotedOriginally is surely true and the quote was already written.
// Otherwise output the character (; for example)
param.WriteByte(s[0])
}

// Focus the rest of the string
s = s[1:]
}

return param.String(), s
Expand Down
30 changes: 30 additions & 0 deletions header_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,36 @@ func TestFixUnquotedSpecials(t *testing.T) {
input: `text/html;charset="`,
want: `text/html;charset=""`,
},
{
// Check unquoted 8bit is encoded
input: `application/msword;name=管理.doc`,
want: `application/msword;name="=?UTF-8?b?566h77+977+955CGLmRvYw==?="`,
},
{
// Check mix of ascii and unquoted 8bit is encoded
input: `application/msword;name=15管理.doc`,
want: `application/msword;name="=?UTF-8?b?MTXnrqHnkIYuZG9j?="`,
},
{
// Check quoted 8bit is encoded
input: `application/msword;name="15管理.doc"`,
want: `application/msword;name="=?UTF-8?b?MTXnrqHnkIYuZG9j?="`,
},
{
// Check quoted 8bit with missing closing quote is encoded
input: `application/msword;name="15管理.doc`,
want: `application/msword;name="=?UTF-8?b?MTXnrqHnkIYuZG9j?="`,
},
{
// Trailing quote without starting quote is considered as part of param text for simplicity
input: `application/msword;name=15管理.doc"`,
want: `application/msword;name="=?UTF-8?b?MTXnrqHnkIYuZG9jXCI=?="`,
},
{
// Invalid UTF-8 sequence does not cause any fatal error
input: "application/msword;name=\xe2\x28\xa1.doc",
want: `application/msword;name="=?UTF-8?b?77+9KO+/vS5kb2M=?="`,
},
}
for _, tc := range testCases {
t.Run(tc.input, func(t *testing.T) {
Expand Down