diff --git a/soaptrip.go b/soaptrip.go index c418e06..b86e9fd 100644 --- a/soaptrip.go +++ b/soaptrip.go @@ -5,7 +5,7 @@ import ( "bytes" "encoding/xml" "fmt" - "io/ioutil" + "io" "net/http" "strings" ) @@ -61,22 +61,13 @@ func (sf SoapFault) Error() string { // ParseFault attempts to parse a Soap Fault from an http.Response. If a fault is found, it will return an error // of type SoapFault, otherwise it will return nil func ParseFault(resp *http.Response) error { - // read the response, but don't close it - b, err := ioutil.ReadAll(resp.Body) - if err != nil { - return err - } - - // replace the read closer that we just used - // TODO: theres a more elegant way, maybe a MultiReader? - resp.Body = ioutil.NopCloser(bytes.NewBuffer(b)) - - reader := bytes.NewReader(b) - d := xml.NewDecoder(reader) + var buf bytes.Buffer + d := xml.NewDecoder(io.TeeReader(resp.Body, &buf)) var start xml.StartElement - fault := &SoapFault{} + fault := &SoapFault{Response: resp} found := false + depth := 0 // iterate through the tokens for { @@ -89,21 +80,31 @@ func ParseFault(resp *http.Response) error { switch t := tok.(type) { case xml.StartElement: start = t.Copy() + depth++ + if depth > 2 { // don't descend beyond Envelope>Body>Fault + break + } case xml.EndElement: start = xml.StartElement{} + depth-- case xml.CharData: - key := strings.ToLower(start.Name.Local) // fault was found, capture the values and mark as found - if key == "faultcode" { + switch strings.ToLower(start.Name.Local) { + case "faultcode": found = true fault.FaultCode = string(t) - } else if key == "faultstring" { + case "faultstring": found = true fault.FaultString = string(t) } } } + resp.Body = struct { + io.Reader + io.Closer + }{io.MultiReader(bytes.NewReader(buf.Bytes()), resp.Body), resp.Body} + if found { fault.Response = resp return fault