diff --git a/driver.go b/driver.go index 4fe8860d..3b18f0ba 100644 --- a/driver.go +++ b/driver.go @@ -131,6 +131,18 @@ type ClientDriverExtensionAvailableSpace interface { GetAvailableSpace(dirName string) (int64, error) } +// AnswerCommand is a struct to answer a command to the client +type AnswerCommand struct { + Code int + Message string +} + +// ClientDriverExtensionSite is an extension to implement if you want to handle SITE command +// yourself. You have to set DisableSite to false for this extension to be called +type ClientDriverExtensionSite interface { + Site(param string) *AnswerCommand +} + // ClientContext is implemented on the server side to provide some access to few data around the client type ClientContext interface { // Path provides the path of the current connection diff --git a/driver_test.go b/driver_test.go index ed65cbf8..58110136 100644 --- a/driver_test.go +++ b/driver_test.go @@ -119,6 +119,8 @@ func NewTestServerWithDriver(t *testing.T, driver MainDriver) *FtpServer { return NewTestServerWithDriverAndLogger(t, driver, nil) } +type authUserProvider func(user, pass string) (ClientDriver, error) + // TestServerDriver defines a minimal serverftp server driver type TestServerDriver struct { Debug bool // To display connection logs information @@ -132,6 +134,7 @@ type TestServerDriver struct { TLSVerificationReply tlsVerificationReply errPassiveListener error TLSRequirement TLSRequirement + AuthProvider authUserProvider } // TestClientDriver defines a minimal serverftp client driver @@ -254,6 +257,10 @@ var errBadUserNameOrPassword = errors.New("bad username or password") // AuthUser with authenticate users func (driver *TestServerDriver) AuthUser(_ ClientContext, user, pass string) (ClientDriver, error) { + if driver.AuthProvider != nil { + return driver.AuthProvider(user, pass) + } + if user == authUser && pass == authPass { clientdriver := NewTestClientDriver(driver) diff --git a/handle_misc.go b/handle_misc.go index cb36324d..5444777d 100644 --- a/handle_misc.go +++ b/handle_misc.go @@ -68,6 +68,17 @@ func (c *clientHandler) handleSITE(param string) error { return nil } + // If the driver implements ClientDriverExtensionSite, we call its Site method + // If it returns ErrProceedWithDefaultBehavior, we proceed with the default behavior + // Otherwise, we return the error + if site, ok := c.driver.(ClientDriverExtensionSite); ok { + if answer := site.Site(param); answer != nil { + c.writeMessage(answer.Code, answer.Message) + + return nil + } + } + spl := strings.SplitN(param, " ", 2) cmd := strings.ToUpper(spl[0]) var params string diff --git a/handle_misc_test.go b/handle_misc_test.go index 437569f3..3a820738 100644 --- a/handle_misc_test.go +++ b/handle_misc_test.go @@ -462,3 +462,72 @@ func TestREIN(t *testing.T) { require.NoError(t, err) require.Equal(t, StatusCommandNotImplemented, returnCode) } + +// Custom driver for testing ClientDriverExtensionSite +// Implements Site(param string) error +type customSiteDriver struct { + TestClientDriver +} + +var _ ClientDriverExtensionSite = (*customSiteDriver)(nil) + +func (d *customSiteDriver) Site(param string) *AnswerCommand { + switch param { + case "CUSTOMERR": + return &AnswerCommand{ + Code: StatusSyntaxErrorNotRecognised, + Message: "custom site error", + } + case "PROCEED": + return nil + case "OK": + return &AnswerCommand{ + Code: StatusOK, + Message: "OK", + } + default: + return nil + } +} + +func TestClientDriverExtensionSite(t *testing.T) { + t.Parallel() + + req := require.New(t) + + server := NewTestServerWithTestDriver(t, &TestServerDriver{ + Debug: false, + AuthProvider: func(_, _ string) (ClientDriver, error) { + return &customSiteDriver{}, nil + }, + }) + conf := goftp.Config{ + User: authUser, + Password: authPass, + } + + client, err := goftp.DialConfig(conf, server.Addr()) + req.NoError(err, "Couldn't connect") + defer func() { panicOnError(client.Close()) }() + raw, err := client.OpenRawConn() + req.NoError(err, "Couldn't open raw connection") + defer func() { require.NoError(t, raw.Close()) }() + + // Custom error from Site + returnCode, response, err := raw.SendCommand("SITE CUSTOMERR") + req.NoError(err) + req.Equal(StatusSyntaxErrorNotRecognised, returnCode) + req.Contains(response, "custom site error") + + // Default behavior fallback (should get unknown subcommand) + returnCode, response, err = raw.SendCommand("SITE PROCEED") + require.NoError(t, err) + require.Equal(t, StatusSyntaxErrorNotRecognised, returnCode) + require.Contains(t, response, "Unknown SITE subcommand") + + // Short-circuit: Site returns nil, so command is accepted (no error, no message) + returnCode, response, err = raw.SendCommand("SITE OK") + require.NoError(t, err) + require.Equal(t, StatusOK, returnCode) + require.Equal(t, "OK", response) +}