@@ -970,6 +970,99 @@ func TestAgent_SCP(t *testing.T) {
970970 require .NoError (t , err )
971971}
972972
973+ func TestAgent_FileTransferBlocked (t * testing.T ) {
974+ t .Parallel ()
975+
976+ assertFileTransferBlocked := func (t * testing.T , errorMessage string ) {
977+ // NOTE: Checking content of the error message is flaky. Most likely there is a race condition, which results
978+ // in stopping the client in different phases, and returning different errors:
979+ // - client read the full error message: File transfer has been disabled.
980+ // - client's stream was terminated before reading the error message: EOF
981+ // - client just read the error code (Windows): Process exited with status 65
982+ isErr := strings .Contains (errorMessage , agentssh .BlockedFileTransferErrorMessage ) ||
983+ strings .Contains (errorMessage , "EOF" ) ||
984+ strings .Contains (errorMessage , "Process exited with status 65" )
985+ require .True (t , isErr , fmt .Sprintf ("Message: " + errorMessage ))
986+ }
987+
988+ t .Run ("SFTP" , func (t * testing.T ) {
989+ t .Parallel ()
990+
991+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
992+ defer cancel ()
993+
994+ //nolint:dogsled
995+ conn , _ , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 , func (_ * agenttest.Client , o * agent.Options ) {
996+ o .BlockFileTransfer = true
997+ })
998+ sshClient , err := conn .SSHClient (ctx )
999+ require .NoError (t , err )
1000+ defer sshClient .Close ()
1001+ _ , err = sftp .NewClient (sshClient )
1002+ require .Error (t , err )
1003+ assertFileTransferBlocked (t , err .Error ())
1004+ })
1005+
1006+ t .Run ("SCP with go-scp package" , func (t * testing.T ) {
1007+ t .Parallel ()
1008+
1009+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
1010+ defer cancel ()
1011+
1012+ //nolint:dogsled
1013+ conn , _ , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 , func (_ * agenttest.Client , o * agent.Options ) {
1014+ o .BlockFileTransfer = true
1015+ })
1016+ sshClient , err := conn .SSHClient (ctx )
1017+ require .NoError (t , err )
1018+ defer sshClient .Close ()
1019+ scpClient , err := scp .NewClientBySSH (sshClient )
1020+ require .NoError (t , err )
1021+ defer scpClient .Close ()
1022+ tempFile := filepath .Join (t .TempDir (), "scp" )
1023+ err = scpClient .CopyFile (context .Background (), strings .NewReader ("hello world" ), tempFile , "0755" )
1024+ require .Error (t , err )
1025+ assertFileTransferBlocked (t , err .Error ())
1026+ })
1027+
1028+ t .Run ("Forbidden commands" , func (t * testing.T ) {
1029+ t .Parallel ()
1030+
1031+ for _ , c := range agentssh .BlockedFileTransferCommands {
1032+ t .Run (c , func (t * testing.T ) {
1033+ t .Parallel ()
1034+
1035+ ctx , cancel := context .WithTimeout (context .Background (), testutil .WaitLong )
1036+ defer cancel ()
1037+
1038+ //nolint:dogsled
1039+ conn , _ , _ , _ , _ := setupAgent (t , agentsdk.Manifest {}, 0 , func (_ * agenttest.Client , o * agent.Options ) {
1040+ o .BlockFileTransfer = true
1041+ })
1042+ sshClient , err := conn .SSHClient (ctx )
1043+ require .NoError (t , err )
1044+ defer sshClient .Close ()
1045+
1046+ session , err := sshClient .NewSession ()
1047+ require .NoError (t , err )
1048+ defer session .Close ()
1049+
1050+ stdout , err := session .StdoutPipe ()
1051+ require .NoError (t , err )
1052+
1053+ //nolint:govet // we don't need `c := c` in Go 1.22
1054+ err = session .Start (c )
1055+ require .NoError (t , err )
1056+ defer session .Close ()
1057+
1058+ msg , err := io .ReadAll (stdout )
1059+ require .NoError (t , err )
1060+ assertFileTransferBlocked (t , string (msg ))
1061+ })
1062+ }
1063+ })
1064+ }
1065+
9731066func TestAgent_EnvironmentVariables (t * testing.T ) {
9741067 t .Parallel ()
9751068 key := "EXAMPLE"
0 commit comments