Skip to content

Commit

Permalink
Merge pull request #167 from gabriel-vasile/clean_tests
Browse files Browse the repository at this point in the history
Fix deadlock for Extend and Detect
  • Loading branch information
gabriel-vasile committed Jul 16, 2021
2 parents 05e50ec + 61d7c3e commit eb51b6a
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 256 deletions.
16 changes: 9 additions & 7 deletions internal/magic/signature.go
Expand Up @@ -6,13 +6,15 @@ import (
"fmt"
)

type Detector func(raw []byte, limit uint32) bool
type xmlSig struct {
// the local name of the root tag
localName []byte
// the namespace of the XML document
xmlns []byte
}
type (
Detector func(raw []byte, limit uint32) bool
xmlSig struct {
// the local name of the root tag
localName []byte
// the namespace of the XML document
xmlns []byte
}
)

// prefix creates a Detector which returns true if any of the provided signatures
// is the prefix of the raw input.
Expand Down
2 changes: 1 addition & 1 deletion internal/magic/text.go
Expand Up @@ -149,7 +149,7 @@ func Json(raw []byte, limit uint32) bool {
return err == nil
}

return parsed == len(raw)
return parsed == len(raw) && len(raw) > 0
}

// GeoJson matches a RFC 7946 GeoJSON file.
Expand Down
19 changes: 2 additions & 17 deletions mime.go
Expand Up @@ -2,7 +2,6 @@ package mimetype

import (
"mime"
"sync"

"github.com/gabriel-vasile/mimetype/internal/charset"
"github.com/gabriel-vasile/mimetype/internal/magic"
Expand All @@ -19,8 +18,6 @@ type MIME struct {
detector magic.Detector
children []*MIME
parent *MIME

mu sync.RWMutex
}

// String returns the string representation of the MIME type, e.g., "application/zip".
Expand Down Expand Up @@ -60,8 +57,6 @@ func (m *MIME) Is(expectedMIME string) bool {
return true
}

m.mu.RLock()
defer m.mu.RUnlock()
for _, alias := range m.aliases {
if alias == expectedMIME {
return true
Expand Down Expand Up @@ -91,17 +86,13 @@ func newMIME(
}

func (m *MIME) alias(aliases ...string) *MIME {
m.mu.Lock()
m.aliases = aliases
m.mu.Unlock()
return m
}

// match does a depth-first search on the signature tree. It returns the deepest
// successful node for which all the children detection functions fail.
func (m *MIME) match(in []byte, readLimit uint32) *MIME {
m.mu.RLock()
defer m.mu.RUnlock()
for _, c := range m.children {
if c.detector(in, readLimit) {
return c.match(in, readLimit)
Expand All @@ -127,8 +118,6 @@ func (m *MIME) match(in []byte, readLimit uint32) *MIME {
// flatten transforms an hierarchy of MIMEs into a slice of MIMEs.
func (m *MIME) flatten() []*MIME {
out := []*MIME{m}
m.mu.RLock()
defer m.mu.RUnlock()
for _, c := range m.children {
out = append(out, c.flatten()...)
}
Expand All @@ -143,8 +132,6 @@ func (m *MIME) clone(ps map[string]string) *MIME {
clonedMIME = mime.FormatMediaType(m.mime, ps)
}

m.mu.RLock()
defer m.mu.RUnlock()
return &MIME{
mime: clonedMIME,
aliases: m.aliases,
Expand All @@ -167,8 +154,6 @@ func (m *MIME) cloneHierarchy(ps map[string]string) *MIME {
}

func (m *MIME) lookup(mime string) *MIME {
m.mu.RLock()
defer m.mu.RUnlock()
for _, n := range append(m.aliases, m.mime) {
if n == mime {
return m
Expand Down Expand Up @@ -196,7 +181,7 @@ func (m *MIME) Extend(detector func(raw []byte, limit uint32) bool, mime, extens
aliases: aliases,
}

m.mu.Lock()
mu.Lock()
m.children = append([]*MIME{c}, m.children...)
m.mu.Unlock()
mu.Unlock()
}
19 changes: 14 additions & 5 deletions mimetype.go
Expand Up @@ -21,10 +21,13 @@ var readLimit uint32 = 3072
// The result is always a valid MIME type, with application/octet-stream
// returned when identification failed.
func Detect(in []byte) *MIME {
// Using atomic because readLimit can be written at the same time in other goroutine.
l := atomic.LoadUint32(&readLimit)
if l > 0 && len(in) > int(l) {
in = in[:l]
}
mu.RLock()
defer mu.RUnlock()
return root.match(in, l)
}

Expand All @@ -41,24 +44,27 @@ func DetectReader(r io.Reader) (*MIME, error) {
var in []byte
var err error

// Using atomic because readLimit can be written at the same time in other goroutine.
l := atomic.LoadUint32(&readLimit)
if l == 0 {
in, err = ioutil.ReadAll(r)
if err != nil {
return root, err
return errMIME, err
}
} else {
// io.UnexpectedEOF means len(r) < len(in). It is not an error in this case,
// it just means the input file is smaller than the allocated bytes slice.
n := 0
in = make([]byte, l)
// io.UnexpectedEOF means len(r) < len(in). It is not an error in this case,
// it just means the input file is smaller than the allocated bytes slice.
n, err = io.ReadFull(r, in)
if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
return root, err
return errMIME, err
}
in = in[:n]
}

mu.RLock()
defer mu.RUnlock()
return root.match(in, l), nil
}

Expand All @@ -70,7 +76,7 @@ func DetectReader(r io.Reader) (*MIME, error) {
func DetectFile(path string) (*MIME, error) {
f, err := os.Open(path)
if err != nil {
return root, err
return errMIME, err
}
defer f.Close()

Expand Down Expand Up @@ -98,6 +104,7 @@ func EqualsAny(s string, mimes ...string) bool {
// their magical numbers towards the end of the file: docx, pptx, xlsx, etc.
// A limit of 0 means the whole input file will be used.
func SetLimit(limit uint32) {
// Using atomic because readLimit can be read at the same time in other goroutine.
atomic.StoreUint32(&readLimit, limit)
}

Expand All @@ -110,5 +117,7 @@ func Extend(detector func(raw []byte, limit uint32) bool, mime, extension string
// Lookup finds a MIME object by its string representation.
// The representation can be the main mime type, or any of its aliases.
func Lookup(mime string) *MIME {
mu.RLock()
defer mu.RUnlock()
return root.lookup(mime)
}

0 comments on commit eb51b6a

Please sign in to comment.