Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,18 @@ type ClientDriverExtensionAvailableSpace interface {
GetAvailableSpace(dirName string) (int64, error)
}

// AnswerCommand is a struct to answer a command to the client
type AnswerCommand struct {
Code int
Message string
}

// ClientDriverExtensionSite is an extension to implement if you want to handle SITE command
// yourself. You have to set DisableSite to false for this extension to be called
type ClientDriverExtensionSite interface {
Site(param string) *AnswerCommand
}

// ClientContext is implemented on the server side to provide some access to few data around the client
type ClientContext interface {
// Path provides the path of the current connection
Expand Down
7 changes: 7 additions & 0 deletions driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ func NewTestServerWithDriver(t *testing.T, driver MainDriver) *FtpServer {
return NewTestServerWithDriverAndLogger(t, driver, nil)
}

type authUserProvider func(user, pass string) (ClientDriver, error)

// TestServerDriver defines a minimal serverftp server driver
type TestServerDriver struct {
Debug bool // To display connection logs information
Expand All @@ -132,6 +134,7 @@ type TestServerDriver struct {
TLSVerificationReply tlsVerificationReply
errPassiveListener error
TLSRequirement TLSRequirement
AuthProvider authUserProvider
}

// TestClientDriver defines a minimal serverftp client driver
Expand Down Expand Up @@ -254,6 +257,10 @@ var errBadUserNameOrPassword = errors.New("bad username or password")

// AuthUser with authenticate users
func (driver *TestServerDriver) AuthUser(_ ClientContext, user, pass string) (ClientDriver, error) {
if driver.AuthProvider != nil {
return driver.AuthProvider(user, pass)
}

if user == authUser && pass == authPass {
clientdriver := NewTestClientDriver(driver)

Expand Down
11 changes: 11 additions & 0 deletions handle_misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@ func (c *clientHandler) handleSITE(param string) error {
return nil
}

// If the driver implements ClientDriverExtensionSite, we call its Site method
// If it returns ErrProceedWithDefaultBehavior, we proceed with the default behavior
// Otherwise, we return the error
if site, ok := c.driver.(ClientDriverExtensionSite); ok {
if answer := site.Site(param); answer != nil {
c.writeMessage(answer.Code, answer.Message)

return nil
}
}

spl := strings.SplitN(param, " ", 2)
cmd := strings.ToUpper(spl[0])
var params string
Expand Down
69 changes: 69 additions & 0 deletions handle_misc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,3 +462,72 @@ func TestREIN(t *testing.T) {
require.NoError(t, err)
require.Equal(t, StatusCommandNotImplemented, returnCode)
}

// Custom driver for testing ClientDriverExtensionSite
// Implements Site(param string) error
type customSiteDriver struct {
TestClientDriver
}

var _ ClientDriverExtensionSite = (*customSiteDriver)(nil)

func (d *customSiteDriver) Site(param string) *AnswerCommand {
switch param {
case "CUSTOMERR":
return &AnswerCommand{
Code: StatusSyntaxErrorNotRecognised,
Message: "custom site error",
}
case "PROCEED":
return nil
case "OK":
return &AnswerCommand{
Code: StatusOK,
Message: "OK",
}
default:
return nil
}
}

func TestClientDriverExtensionSite(t *testing.T) {
t.Parallel()

req := require.New(t)

server := NewTestServerWithTestDriver(t, &TestServerDriver{
Debug: false,
AuthProvider: func(_, _ string) (ClientDriver, error) {
return &customSiteDriver{}, nil
},
})
conf := goftp.Config{
User: authUser,
Password: authPass,
}

client, err := goftp.DialConfig(conf, server.Addr())
req.NoError(err, "Couldn't connect")
defer func() { panicOnError(client.Close()) }()
raw, err := client.OpenRawConn()
req.NoError(err, "Couldn't open raw connection")
defer func() { require.NoError(t, raw.Close()) }()

// Custom error from Site
returnCode, response, err := raw.SendCommand("SITE CUSTOMERR")
req.NoError(err)
req.Equal(StatusSyntaxErrorNotRecognised, returnCode)
req.Contains(response, "custom site error")

// Default behavior fallback (should get unknown subcommand)
returnCode, response, err = raw.SendCommand("SITE PROCEED")
require.NoError(t, err)
require.Equal(t, StatusSyntaxErrorNotRecognised, returnCode)
require.Contains(t, response, "Unknown SITE subcommand")

// Short-circuit: Site returns nil, so command is accepted (no error, no message)
returnCode, response, err = raw.SendCommand("SITE OK")
require.NoError(t, err)
require.Equal(t, StatusOK, returnCode)
require.Equal(t, "OK", response)
}
Loading