Skip to content

Commit

Permalink
feat(ssh): support create cluster with stored key
Browse files Browse the repository at this point in the history
The stored key will be copy to <cfgPath>/<cluster name>/id_rsa and
<cfgPath>/<cluster name>/pub.cert(if necessary) when creating cluster
with --ssh-key-name flag and the ssh_key_path will be set.

And add logs when trying to ssh to host.
  • Loading branch information
orangedeng committed Dec 19, 2022
1 parent 3c095bb commit 45b618b
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 40 deletions.
5 changes: 2 additions & 3 deletions cmd/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"strconv"

"github.com/cnrancher/autok3s/pkg/common"
"github.com/cnrancher/autok3s/pkg/utils"

// import custom provider
_ "github.com/cnrancher/autok3s/pkg/providers/alibaba"
Expand Down Expand Up @@ -66,8 +65,8 @@ func Command() *cobra.Command {
}

func initCfg() {
if err := utils.EnsureFolderExist(common.GetLogPath()); err != nil {
logrus.Fatalln(err)
if err := common.MoveLogs(); err != nil {
logrus.Errorf("failed to relocate cluster logs, %v", err)
}

kubeCfg := filepath.Join(common.CfgPath, common.KubeCfgFile)
Expand Down
2 changes: 1 addition & 1 deletion cmd/sshkey/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ func create(cmd *cobra.Command, args []string) error {
}

if sshKeyFlags.Generate {
if err := pathsNotExists(sshKeyFlags.OutputPath, privateKeyFilename, publicKeyFilename); err != nil {
if err := pathsNotExists(sshKeyFlags.OutputPath, pkgsshkey.PrivateKeyFilename, pkgsshkey.PublicKeyFilename); err != nil {
return err
}
infoMsg := fmt.Sprintf("generating RSA ssh key pair with %d bit size", sshKeyFlags.Bits)
Expand Down
19 changes: 7 additions & 12 deletions cmd/sshkey/export.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,12 @@ import (
"path/filepath"

"github.com/cnrancher/autok3s/pkg/common"
pkgsshkey "github.com/cnrancher/autok3s/pkg/sshkey"
"github.com/cnrancher/autok3s/pkg/utils"

"github.com/spf13/cobra"
)

const (
privateKeyFilename = "id_rsa"
publicKeyFilename = "id_rsa.pub"
certificateFilename = "pub.cert"
)

var exportCmd = &cobra.Command{
Use: "export <name>",
Args: cobra.ExactArgs(1),
Expand Down Expand Up @@ -46,9 +41,9 @@ func validateFiles(cmd *cobra.Command, args []string) error {
}
target := rtn[0]
checkmap := map[string]string{
privateKeyFilename: target.SSHKey,
publicKeyFilename: target.SSHPublicKey,
certificateFilename: target.SSHCert,
pkgsshkey.PrivateKeyFilename: target.SSHKey,
pkgsshkey.PublicKeyFilename: target.SSHPublicKey,
pkgsshkey.CertificateFilename: target.SSHCert,
}
for filename, toCheck := range checkmap {
if toCheck == "" {
Expand Down Expand Up @@ -89,9 +84,9 @@ func pathsNotExists(basePath string, paths ...string) error {

func exportKeyFiles(path string, target *common.SSHKey) error {
dataMap := map[string]string{
privateKeyFilename: target.SSHKey,
publicKeyFilename: target.SSHPublicKey,
certificateFilename: target.SSHCert,
pkgsshkey.PrivateKeyFilename: target.SSHKey,
pkgsshkey.PublicKeyFilename: target.SSHPublicKey,
pkgsshkey.CertificateFilename: target.SSHCert,
}
for filename, data := range dataMap {
if data == "" {
Expand Down
30 changes: 24 additions & 6 deletions pkg/cluster/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/cnrancher/autok3s/pkg/hosts"
"github.com/cnrancher/autok3s/pkg/providers"
putil "github.com/cnrancher/autok3s/pkg/providers/utils"
pkgsshkey "github.com/cnrancher/autok3s/pkg/sshkey"
"github.com/cnrancher/autok3s/pkg/types"
"github.com/cnrancher/autok3s/pkg/utils"

Expand Down Expand Up @@ -270,6 +271,12 @@ func (p *ProviderBase) GetSSHOptions() []types.Flag {
V: p.SSHAgentAuth,
Usage: "Enable ssh agent",
},
{
Name: "ssh-key-name",
P: &p.SSHKeyName,
V: p.SSHKeyName,
Usage: "Use the stored ssh key with name",
},
}
}

Expand Down Expand Up @@ -340,12 +347,23 @@ func (p *ProviderBase) InitCluster(options interface{}, deployPlugins func() []s
p.Logger.Infof("[%s] begin to create cluster %s...", p.Provider, p.Name)
c.Status.Status = common.StatusCreating
// save cluster.
err = common.DefaultDB.SaveCluster(c)
if err != nil {
if err = common.DefaultDB.SaveCluster(c); err != nil {
return err
}
// store ssh key
if newSSH, err := pkgsshkey.StoreClusterSSHKeys(p.ContextName, &c.SSH); err != nil {
return err
} else if newSSH != nil {
p.Logger.Infof("[%s] cluster's ssh keys saved", p.Name)
c.SSH = *newSSH
// update cluster with stored key
if err = common.DefaultDB.SaveCluster(c); err != nil {
return err
}
p.SSH = *newSSH
}

c, err = cloudInstanceFunc(&p.SSH)
c, err = cloudInstanceFunc(&c.SSH)
if err != nil {
return err
}
Expand Down Expand Up @@ -449,7 +467,7 @@ func (p *ProviderBase) JoinNodes(cloudInstanceFunc func(ssh *types.SSH) (*types.
return err
}

c, err := cloudInstanceFunc(&p.SSH)
c, err := cloudInstanceFunc(&state.SSH)
if err != nil {
p.Logger.Errorf("[%s] failed to prepare instance, got error %v", p.Provider, err)
return err
Expand Down Expand Up @@ -595,7 +613,7 @@ func (p *ProviderBase) DeleteCluster(force bool, delete func(f bool) (string, er
defer func() {
_ = logFile.Close()
// remove log file.
_ = os.Remove(filepath.Join(common.GetLogPath(), p.ContextName))
_ = os.RemoveAll(common.GetClusterContextPath(p.ContextName))
}()
state, err := common.DefaultDB.GetCluster(p.Name, p.Provider)
if err != nil && !force {
Expand Down Expand Up @@ -975,7 +993,7 @@ func (p *ProviderBase) ReleaseManifests() error {
masterIP := p.IP
for _, n := range p.Status.MasterNodes {
if n.InternalIPAddress[0] == masterIP {
dialer, err := hosts.NewSSHDialer(&n, true)
dialer, err := hosts.NewSSHDialer(&n, true, p.Logger)
if err != nil {
return err
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/cluster/cluster.go
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,7 @@ func (p *ProviderBase) execute(n *types.Node, cmds []string) (string, error) {
return "", nil
}

dialer, err := hosts.NewSSHDialer(n, true)
dialer, err := hosts.NewSSHDialer(n, true, p.Logger)
if err != nil {
return "", err
}
Expand All @@ -492,7 +492,7 @@ func (p *ProviderBase) execute(n *types.Node, cmds []string) (string, error) {
}

func terminal(n *types.Node) error {
dialer, err := hosts.NewSSHDialer(n, true)
dialer, err := hosts.NewSSHDialer(n, true, common.NewLogger(nil))
if err != nil {
return err
}
Expand Down Expand Up @@ -867,7 +867,7 @@ func nodeByInstanceID(nodes []types.Node) map[string]types.Node {
}

func (p *ProviderBase) scpFiles(clusterName string, pkg *common.Package, node *types.Node) error {
dialer, err := hosts.NewSSHDialer(node, true)
dialer, err := hosts.NewSSHDialer(node, true, p.Logger)
if err != nil {
return err
}
Expand Down
48 changes: 44 additions & 4 deletions pkg/common/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package common

import (
"io"
"io/fs"
"os"
"path/filepath"

Expand All @@ -22,14 +23,25 @@ func NewLogger(w *os.File) (logger *logrus.Logger) {
return
}

// GetLogPath returns log path.
func GetLogPath() string {
// GetOldLogPath returns old log path.
func GetOldLogPath() string {
return filepath.Join(CfgPath, "logs")
}

func GetLogFilePath(clusterName string) string {
return filepath.Join(GetClusterContextPath(clusterName), "log")
}

func GetClusterContextPath(clusterName string) string {
return filepath.Join(CfgPath, clusterName)
}

// GetLogFile open and return log file.
func GetLogFile(name string) (logFile *os.File, err error) {
logFilePath := filepath.Join(GetLogPath(), name)
func GetLogFile(clusterName string) (logFile *os.File, err error) {
logFilePath := GetLogFilePath(clusterName)
if err = os.MkdirAll(filepath.Dir(logFilePath), 0755); err != nil {
return nil, err
}
// check file exist
_, err = os.Stat(logFilePath)
if err != nil {
Expand All @@ -51,3 +63,31 @@ func InitLogger(logger *logrus.Logger) {
FullTimestamp: true,
})
}

func MoveLogs() error {
oldRoot := GetOldLogPath()
_, err := os.Stat(oldRoot)
if os.IsNotExist(err) {
return nil
}
newRoot := CfgPath

if err := filepath.Walk(oldRoot, func(path string, info fs.FileInfo, err error) error {
// skip all the dirs because we store all the logs with cluster context name and no dirs exists in logs dir
if info.IsDir() {
return nil
}
// assuming all the relative path should only be logs file
rel, _ := filepath.Rel(oldRoot, path)
if err := os.MkdirAll(filepath.Join(newRoot, rel), 0755); err != nil {
return err
}
if err := os.Rename(path, GetLogFilePath(rel)); err != nil {
return err
}
return nil
}); err != nil {
return err
}
return os.RemoveAll(oldRoot)
}
10 changes: 7 additions & 3 deletions pkg/hosts/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/cnrancher/autok3s/pkg/types"
"github.com/cnrancher/autok3s/pkg/utils"
"github.com/sirupsen/logrus"

"github.com/moby/term"
"golang.org/x/crypto/ssh"
Expand All @@ -20,7 +21,7 @@ import (
)

var defaultBackoff = wait.Backoff{
Duration: 30 * time.Second,
Duration: 15 * time.Second,
Factor: 1,
Steps: 5,
}
Expand Down Expand Up @@ -55,7 +56,7 @@ type SSHDialer struct {
}

// NewSSHDialer returns new ssh dialer.
func NewSSHDialer(n *types.Node, timeout bool) (*SSHDialer, error) {
func NewSSHDialer(n *types.Node, timeout bool, logger *logrus.Logger) (*SSHDialer, error) {
if len(n.PublicIPAddress) <= 0 && n.InstanceID == "" {
return nil, errors.New("[ssh-dialer] no node IP or node ID is specified")
}
Expand Down Expand Up @@ -91,7 +92,10 @@ func NewSSHDialer(n *types.Node, timeout bool) (*SSHDialer, error) {
}
}

try := 0
if err := wait.ExponentialBackoff(defaultBackoff, func() (bool, error) {
try++
logger.Infof("the %d/%d time tring to ssh to %s with user %s", try, defaultBackoff.Steps, d.sshAddress, d.username)
c, err := d.Dial(timeout)
if err != nil {
return false, nil
Expand All @@ -109,7 +113,7 @@ func NewSSHDialer(n *types.Node, timeout bool) (*SSHDialer, error) {

// Dial handshake with ssh address.
func (d *SSHDialer) Dial(t bool) (*ssh.Client, error) {
timeout := 30 * time.Second
timeout := defaultBackoff.Duration
if !t {
timeout = 0
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/alibaba/alibaba.go
Original file line number Diff line number Diff line change
Expand Up @@ -1514,7 +1514,7 @@ func (p *Alibaba) getSecurityGroup(id string) (*ecs.DescribeSecurityGroupAttribu
}

func (p *Alibaba) uploadKeyPair(node types.Node, publicKey string) error {
dialer, err := hosts.NewSSHDialer(&node, true)
dialer, err := hosts.NewSSHDialer(&node, true, p.Logger)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/native/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ func (p *Native) CreateK3sCluster() (err error) {
if p.SSHUser == "" {
p.SSHUser = defaultUser
}
if p.SSHPassword == "" && p.SSHKeyPath == "" {
if p.SSHKeyName == "" && p.SSHPassword == "" && p.SSHKeyPath == "" {
p.SSHKeyPath = defaultSSHKeyPath
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/providers/tencent/tencent.go
Original file line number Diff line number Diff line change
Expand Up @@ -1446,7 +1446,7 @@ func (p *Tencent) allocateEIPForInstance(num int, master bool) ([]uint64, error)
}

func (p *Tencent) uploadKeyPair(node types.Node, publicKey string) error {
dialer, err := hosts.NewSSHDialer(&node, true)
dialer, err := hosts.NewSSHDialer(&node, true, p.logger)
if err != nil {
return err
}
Expand Down
3 changes: 1 addition & 2 deletions pkg/server/store/websocket/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"net/http"
"os"
"path/filepath"

"github.com/cnrancher/autok3s/pkg/common"

Expand Down Expand Up @@ -83,7 +82,7 @@ func logHandler(apiOp *types.APIRequest) error {
w.Header().Set("Transfer-Encoding", "chunked")
w.Header().Set("Access-Control-Allow-Origin", "*")

logFilePath := filepath.Join(common.GetLogPath(), cluster)
logFilePath := common.GetLogFilePath(cluster)
state, err := common.DefaultDB.GetClusterByID(cluster)
if err != nil {
return err
Expand Down
2 changes: 1 addition & 1 deletion pkg/server/store/websocket/ssh/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ func newDialer(id, node string, conn *websocket.Conn) (*hosts.WebSocketDialer, e
wsDialer = hosts.NewWebSocketDialer(conn, dialer)
return wsDialer, nil
}
dialer, err := hosts.NewSSHDialer(&n, true)
dialer, err := hosts.NewSSHDialer(&n, true, common.NewLogger(nil))
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 45b618b

Please sign in to comment.