Skip to content

Commit

Permalink
Fix redirects ignoring AllowURLRevisit=false
Browse files Browse the repository at this point in the history
This commit introduces a breaking change: ErrAlreadyVisited is replaced
with AlreadyVisitedError, which allows the user to know the redirect
destination, which might not match the URL passed to Visit when multiple
redirects are followed.

See #405
  • Loading branch information
WGH- committed Mar 10, 2022
1 parent b151a08 commit 0be3b71
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 10 deletions.
66 changes: 57 additions & 9 deletions colly.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,26 @@ type ScrapedCallback func(*Response)
// ProxyFunc is a type alias for proxy setter functions.
type ProxyFunc func(*http.Request) (*url.URL, error)

// AlreadyVisitedError is the error type for already visited URLs.
//
// It's returned synchronously by Visit when the URL passed to Visit
// is already visited.
//
// When already visited URL is encountered after following
// redirects, this error appears in OnError callback, and if Async
// mode is not enabled, is also returned by Visit.
type AlreadyVisitedError struct {
// Destination is the URL that was attempted to be visited.
// It might not match the URL passed to Visit if redirect
// was followed.
Destination *url.URL
}

// Error implements error interface.
func (e *AlreadyVisitedError) Error() string {
return fmt.Sprintf("%q already visited", e.Destination)
}

type htmlCallbackContainer struct {
Selector string
Function HTMLCallback
Expand Down Expand Up @@ -196,8 +216,6 @@ var (
// ErrNoURLFiltersMatch is the error thrown if visiting
// a URL which is not allowed by URLFilters
ErrNoURLFiltersMatch = errors.New("No URLFilters match")
// ErrAlreadyVisited is the error type for already visited URLs
ErrAlreadyVisited = errors.New("URL already visited")
// ErrRobotsTxtBlocked is the error type for robots.txt errors
ErrRobotsTxtBlocked = errors.New("URL blocked by robots.txt")
// ErrNoCookieJar is the error type for missing cookie jar
Expand Down Expand Up @@ -603,7 +621,7 @@ func (c *Collector) scrape(u, method string, depth int, requestData io.Reader, c
// note: once 1.13 is minimum supported Go version,
// replace this with http.NewRequestWithContext
req = req.WithContext(c.Context)
if err := c.requestCheck(u, parsedURL, method, req.GetBody, depth, checkRevisit); err != nil {
if err := c.requestCheck(parsedURL, method, req.GetBody, depth, checkRevisit); err != nil {
return err
}
u = parsedURL.String()
Expand Down Expand Up @@ -694,10 +712,8 @@ func (c *Collector) fetch(u, method string, depth int, requestData io.Reader, ct
return err
}

func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, getBody func() (io.ReadCloser, error), depth int, checkRevisit bool) error {
if u == "" {
return ErrMissingURL
}
func (c *Collector) requestCheck(parsedURL *url.URL, method string, getBody func() (io.ReadCloser, error), depth int, checkRevisit bool) error {
u := parsedURL.String()
if c.MaxDepth > 0 && c.MaxDepth < depth {
return ErrMaxDepth
}
Expand Down Expand Up @@ -732,7 +748,7 @@ func (c *Collector) requestCheck(u string, parsedURL *url.URL, method string, ge
return err
}
if visited {
return ErrAlreadyVisited
return &AlreadyVisitedError{parsedURL}
}
return c.store.Visited(uHash)
}
Expand Down Expand Up @@ -1292,6 +1308,31 @@ func (c *Collector) checkRedirectFunc() func(req *http.Request, via []*http.Requ
if err := c.checkFilters(req.URL.String(), req.URL.Hostname()); err != nil {
return fmt.Errorf("Not following redirect to %q: %w", req.URL, err)
}

if !c.AllowURLRevisit {
var body io.ReadCloser
if req.GetBody != nil {
var err error
body, err = req.GetBody()
if err != nil {
return err
}
defer body.Close()
}
uHash := requestHash(req.URL.String(), body)
visited, err := c.store.IsVisited(uHash)
if err != nil {
return err
}
if visited {
return &AlreadyVisitedError{req.URL}
}
err = c.store.Visited(uHash)
if err != nil {
return err
}
}

if c.redirectHandler != nil {
return c.redirectHandler(req, via)
}
Expand Down Expand Up @@ -1442,7 +1483,14 @@ func isMatchingFilter(fs []*regexp.Regexp, d []byte) bool {

func requestHash(url string, body io.Reader) uint64 {
h := fnv.New64a()
h.Write([]byte(url))
// reparse the url to fix ambiguities such as
// "http://example.com" vs "http://example.com/"
parsedWhatwgURL, err := whatwgUrl.Parse(url)
if err == nil {
h.Write([]byte(parsedWhatwgURL.String()))
} else {
h.Write([]byte(url))
}
if body != nil {
io.Copy(h, body)
}
Expand Down
38 changes: 37 additions & 1 deletion colly_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ func newTestServer() *httptest.Server {
})

mux.Handle("/redirect", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/redirected/", http.StatusSeeOther)
destination := "/redirected/"
if d := r.URL.Query().Get("d"); d != "" {
destination = d
}
http.Redirect(w, r, destination, http.StatusSeeOther)

}))

Expand Down Expand Up @@ -674,6 +678,38 @@ func TestCollectorURLRevisitCheck(t *testing.T) {
if visited != true {
t.Error("Expected URL to have been visited")
}

errorTestCases := []struct {
Path string
DestinationError string
}{
{"/", "/"},
{"/redirect?d=/", "/"},
// now that /redirect?d=/ itself is recorded as visited,
// it's now returned in error
{"/redirect?d=/", "/redirect?d=/"},
{"/redirect?d=/redirect%3Fd%3D/", "/redirect?d=/"},
{"/redirect?d=/redirect%3Fd%3D/", "/redirect?d=/redirect%3Fd%3D/"},
{"/redirect?d=/redirect%3Fd%3D/&foo=bar", "/redirect?d=/"},
}

for i, testCase := range errorTestCases {
err := c.Visit(ts.URL + testCase.Path)
if testCase.DestinationError == "" {
if err != nil {
t.Errorf("got unexpected error in test %d: %q", i, err)
}
} else {
var ave *AlreadyVisitedError
if !errors.As(err, &ave) {
t.Errorf("err=%q returned when trying to revisit, expected AlreadyVisitedError", err)
} else {
if got, want := ave.Destination.String(), ts.URL+testCase.DestinationError; got != want {
t.Errorf("wrong destination in AlreadyVisitedError in test %d, got=%q want=%q", i, got, want)
}
}
}
}
}

func TestCollectorPostURLRevisitCheck(t *testing.T) {
Expand Down

0 comments on commit 0be3b71

Please sign in to comment.