Skip to content

Commit

Permalink
Merge pull request #9 from lpabon/sshscp
Browse files Browse the repository at this point in the history
ssh: SCP PUT support
  • Loading branch information
lpabon authored Feb 22, 2017
2 parents 0e37244 + ae4e860 commit 1a408a6
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 32 deletions.
6 changes: 4 additions & 2 deletions glide.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

59 changes: 59 additions & 0 deletions ssh/cmd/scpdemo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package main

import (
"fmt"
"os"

"github.com/heketi/utils/ssh"
)

func runDemo(s ssh.SshExecutor) {
// scp scpdemo
fmt.Print("Copying scpdemo to server...")
host := "127.0.0.1:22"
err := s.CopyPath("scpdemo.go", host, "/tmp/scpdemo-copy.go")
if err != nil {
fmt.Println(err)
os.Exit(3)
}
fmt.Println("Done")

// run a few commands
fmt.Println("Running commands...")
commands := []string{
"date",
"echo \"HELLO\" > /tmp/file",
"cat /tmp/file",
"ls -al",
"rm /tmp/file",
"rm /tmp/scpdemo-copy.go",
}

out, err := s.Exec(host, commands, 10, false)
if err != nil {
fmt.Println(err)
os.Exit(2)
}

fmt.Printf("%+v\n", out)
}

func main() {
fmt.Println("- Real Demo -")
s, err := ssh.NewSshExecWithAuth(os.Getenv("USER"))
if err != nil {
fmt.Println(err)
os.Exit(1)
}
runDemo(s)

fmt.Println("- Mock Demo -")
// Now run with a mock demo
m := ssh.NewMockSshExecutor()
m.MockExec = func(host string, commands []string, timeoutMinutes int, useSudo bool) ([]string, error) {
return []string{
"In Mock function",
}, nil
}
runDemo(m)
}
74 changes: 74 additions & 0 deletions ssh/mock.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
//
// Copyright (c) 2017 The heketi Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

// Package to mock ssh interface for unit tests
// Create a mock object and override the functions to test
// functions which return expected values
package ssh

import (
"io"
"os"
)

type MockSshExec struct {
MockCopy func(size int64, mode os.FileMode, fileName string, contents io.Reader, host, destinationPath string) error
MockCopyPath func(sourcePath, host, destinationPath string) error
MockExec func(host string, commands []string, timeoutMinutes int, useSudo bool) ([]string, error)
}

func NewMockSshExecutor() *MockSshExec {
m := &MockSshExec{}
m.MockCopy = func(size int64,
mode os.FileMode,
fileName string,
contents io.Reader,
host, destinationPath string) error {
return nil
}

m.MockCopyPath = func(sourcePath, host, destinationPath string) error {
return nil
}

m.MockExec = func(host string,
commands []string,
timeoutMinutes int,
useSudo bool) ([]string, error) {
return []string{""}, nil
}

return m
}

func (m *MockSshExec) Copy(size int64,
mode os.FileMode,
fileName string,
contents io.Reader,
host, destinationPath string) error {
return m.MockCopy(size, mode, fileName, contents, host, destinationPath)
}

func (m *MockSshExec) CopyPath(sourcePath, host, destinationPath string) error {
return m.MockCopyPath(sourcePath, host, destinationPath)
}

func (m *MockSshExec) Exec(host string,
commands []string,
timeoutMinutes int,
useSudo bool) ([]string, error) {
return m.MockExec(host, commands, timeoutMinutes, useSudo)
}
102 changes: 72 additions & 30 deletions ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,20 @@ package ssh

import (
"bytes"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"time"

"github.com/heketi/utils"
"github.com/tmc/scp"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
)

type SshExec struct {
clientConfig *ssh.ClientConfig
logger *utils.Logger
}

func getKeyFile(file string) (key ssh.Signer, err error) {
Expand All @@ -49,50 +47,44 @@ func getKeyFile(file string) (key ssh.Signer, err error) {
return
}

func NewSshExecWithAuth(logger *utils.Logger, user string) *SshExec {
func NewSshExecWithAuth(user string) (SshExecutor, error) {

sshexec := &SshExec{}
sshexec.logger = logger

authSocket := os.Getenv("SSH_AUTH_SOCK")
if authSocket == "" {
log.Fatal("SSH_AUTH_SOCK required, check that your ssh agent is running")
return nil
return nil, fmt.Errorf("SSH_AUTH_SOCK not set")
}

agentUnixSock, err := net.Dial("unix", authSocket)
if err != nil {
log.Fatal(err)
return nil
return nil, fmt.Errorf("Cannot connect to SSH_AUTH_SOCK")
}

agent := agent.NewClient(agentUnixSock)
signers, err := agent.Signers()
if err != nil {
log.Fatal(err)
return nil
return nil, fmt.Errorf("Could not get key signatures: %v", err)
}

sshexec.clientConfig = &ssh.ClientConfig{
User: user,
Auth: []ssh.AuthMethod{ssh.PublicKeys(signers...)},
}

return sshexec
return sshexec, nil
}

func NewSshExecWithKeyFile(logger *utils.Logger, user string, file string) *SshExec {
func NewSshExecWithKeyFile(user string, file string) (SshExecutor, error) {

var key ssh.Signer
var err error

sshexec := &SshExec{}
sshexec.logger = logger

// Now in the main function DO:
if key, err = getKeyFile(file); err != nil {
fmt.Println("Unable to get keyfile")
return nil
return nil, fmt.Errorf("Unable to get keyfile")
}
// Define the Client Config as :
sshexec.clientConfig = &ssh.ClientConfig{
Expand All @@ -102,19 +94,72 @@ func NewSshExecWithKeyFile(logger *utils.Logger, user string, file string) *SshE
},
}

return sshexec
return sshexec, nil
}

func (s *SshExec) Copy(size int64,
mode os.FileMode,
fileName string,
contents io.Reader,
host, destinationPath string) error {

// Create a connection to the server
client, err := ssh.Dial("tcp", host, s.clientConfig)
if err != nil {
return err
}
defer client.Close()

// Create a session
session, err := client.NewSession()
if err != nil {
return err
}
defer session.Close()

// Copy Data
err = scp.Copy(size, mode, fileName, contents, destinationPath, session)
if err != nil {
return err
}

return nil
}

func (s *SshExec) CopyPath(sourcePath, host, destinationPath string) error {

// Create a connection to the server
client, err := ssh.Dial("tcp", host, s.clientConfig)
if err != nil {
return err
}
defer client.Close()

// Create a session
session, err := client.NewSession()
if err != nil {
return err
}
defer session.Close()

// Copy Data
err = scp.CopyPath(sourcePath, destinationPath, session)
if err != nil {
return err
}

return nil
}

// This function was based from https://github.com/coreos/etcd-manager/blob/master/main.go
func (s *SshExec) ConnectAndExec(host string, commands []string, timeoutMinutes int, useSudo bool) ([]string, error) {
func (s *SshExec) Exec(host string, commands []string, timeoutMinutes int, useSudo bool) ([]string, error) {

buffers := make([]string, len(commands))

// :TODO: Will need a timeout here in case the server does not respond
client, err := ssh.Dial("tcp", host, s.clientConfig)
if err != nil {
s.logger.Warning("Failed to create SSH connection to %v: %v", host, err)
return nil, err
return nil, fmt.Errorf("Failed to create SSH connection to %v: %v", host, err)
}
defer client.Close()

Expand All @@ -123,8 +168,7 @@ func (s *SshExec) ConnectAndExec(host string, commands []string, timeoutMinutes

session, err := client.NewSession()
if err != nil {
s.logger.LogError("Unable to create SSH session: %v", err)
return nil, err
return nil, fmt.Errorf("Unable to create SSH session: %v", err)
}
defer session.Close()

Expand Down Expand Up @@ -156,22 +200,20 @@ func (s *SshExec) ConnectAndExec(host string, commands []string, timeoutMinutes
select {
case err := <-errch:
if err != nil {
s.logger.LogError("Failed to run command [%v] on %v: Err[%v]: Stdout [%v]: Stderr [%v]",
return nil, fmt.Errorf("Failed to run command [%v] on %v: Err[%v]: Stdout [%v]: Stderr [%v]",
command, host, err, b.String(), berr.String())
return nil, fmt.Errorf("%s", berr.String())
}
s.logger.Debug("Host: %v Command: %v\nResult: %v", host, command, b.String())
//LOG("Host: %v Command: %v\nResult: %v", host, command, b.String())
buffers[index] = b.String()

case <-timeout:
s.logger.LogError("Timeout on command [%v] on %v: Err[%v]: Stdout [%v]: Stderr [%v]",
command, host, err, b.String(), berr.String())
err := session.Signal(ssh.SIGKILL)
if err != nil {
s.logger.LogError("Unable to send kill signal to command [%v] on host [%v]: %v",
return nil, fmt.Errorf("Command timed out and unable to send kill signal to command [%v] on host [%v]: %v",
command, host, err)
}
return nil, errors.New("SSH command timeout")
return nil, fmt.Errorf("Timeout on command [%v] on %v: Err[%v]: Stdout [%v]: Stderr [%v]",
command, host, err, b.String(), berr.String())
}
}

Expand Down
28 changes: 28 additions & 0 deletions ssh/sshexecutor.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//
// Copyright (c) 2017 The heketi Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

package ssh

import (
"io"
"os"
)

type SshExecutor interface {
Copy(size int64, mode os.FileMode, fileName string, contents io.Reader, host, destinationPath string) error
CopyPath(sourcePath, host, destinationPath string) error
Exec(host string, commands []string, timeoutMinutes int, useSudo bool) ([]string, error)
}

0 comments on commit 1a408a6

Please sign in to comment.