Skip to content

Commit

Permalink
feat(server): create ssh client keypair
Browse files Browse the repository at this point in the history
  • Loading branch information
aymanbagabas committed May 2, 2023
1 parent 106c049 commit 468a99f
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 139 deletions.
121 changes: 8 additions & 113 deletions cmd/soft/migrate_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ var (
migrateConfig = &cobra.Command{
Use: "migrate-config",
Short: "Migrate config to new format",
RunE: func(cmd *cobra.Command, args []string) error {
RunE: func(_ *cobra.Command, _ []string) error {
keyPath := os.Getenv("SOFT_SERVE_KEY_PATH")
reposPath := os.Getenv("SOFT_SERVE_REPO_PATH")
bindAddr := os.Getenv("SOFT_SERVE_BIND_ADDRESS")
cfg := config.DefaultConfig()
sb, err := sqlite.NewSqliteBackend(cfg.DataPath)
sb, err := sqlite.NewSqliteBackend(cfg)
if err != nil {
return fmt.Errorf("failed to create sqlite backend: %w", err)
}
Expand Down Expand Up @@ -72,7 +72,7 @@ var (
return fmt.Errorf("failed to get tree: %w", err)
}

isJson := false
isJson := false // nolint: revive
te, err := tree.TreeEntry("config.yaml")
if err != nil {
te, err = tree.TreeEntry("config.json")
Expand Down Expand Up @@ -236,7 +236,7 @@ func isGitDir(path string) bool {
return true
}

// copyFile copies a single file from src to dst
// copyFile copies a single file from src to dst.
func copyFile(src, dst string) error {
var err error
var srcfd *os.File
Expand All @@ -246,12 +246,12 @@ func copyFile(src, dst string) error {
if srcfd, err = os.Open(src); err != nil {
return err
}
defer srcfd.Close()
defer srcfd.Close() // nolint: errcheck

if dstfd, err = os.Create(dst); err != nil {
return err
}
defer dstfd.Close()
defer dstfd.Close() // nolint: errcheck

if _, err = io.Copy(dstfd, srcfd); err != nil {
return err
Expand All @@ -262,7 +262,7 @@ func copyFile(src, dst string) error {
return os.Chmod(dst, srcinfo.Mode())
}

// copyDir copies a whole directory recursively
// copyDir copies a whole directory recursively.
func copyDir(src string, dst string) error {
var err error
var fds []os.DirEntry
Expand Down Expand Up @@ -296,112 +296,7 @@ func copyDir(src string, dst string) error {
return nil
}

// func copyDir(src, dst string) error {
// entries, err := os.ReadDir(src)
// if err != nil {
// return err
// }
// for _, entry := range entries {
// sourcePath := filepath.Join(src, entry.Name())
// destPath := filepath.Join(dst, entry.Name())
//
// fileInfo, err := os.Stat(sourcePath)
// if err != nil {
// return err
// }
//
// stat, ok := fileInfo.Sys().(*syscall.Stat_t)
// if !ok {
// return fmt.Errorf("failed to get raw syscall.Stat_t data for '%s'", sourcePath)
// }
//
// switch fileInfo.Mode() & os.ModeType {
// case os.ModeDir:
// if err := createIfNotExists(destPath, 0755); err != nil {
// return err
// }
// if err := copyDir(sourcePath, destPath); err != nil {
// return err
// }
// case os.ModeSymlink:
// if err := copySymLink(sourcePath, destPath); err != nil {
// return err
// }
// default:
// if err := copyFile(sourcePath, destPath); err != nil {
// return err
// }
// }
//
// if err := os.Lchown(destPath, int(stat.Uid), int(stat.Gid)); err != nil {
// return err
// }
//
// fInfo, err := entry.Info()
// if err != nil {
// return err
// }
//
// isSymlink := fInfo.Mode()&os.ModeSymlink != 0
// if !isSymlink {
// if err := os.Chmod(destPath, fInfo.Mode()); err != nil {
// return err
// }
// }
// }
// return nil
// }
//
// func copyFile(srcFile, dstFile string) error {
// out, err := os.Create(dstFile)
// if err != nil {
// return err
// }
//
// defer out.Close()
//
// in, err := os.Open(srcFile)
// defer in.Close()
// if err != nil {
// return err
// }
//
// _, err = io.Copy(out, in)
// if err != nil {
// return err
// }
//
// return nil
// }

func exists(filePath string) bool {
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return false
}

return true
}

func createIfNotExists(dir string, perm os.FileMode) error {
if exists(dir) {
return nil
}

if err := os.MkdirAll(dir, perm); err != nil {
return fmt.Errorf("failed to create directory: '%s', error: '%s'", dir, err.Error())
}

return nil
}

func copySymLink(source, dest string) error {
link, err := os.Readlink(source)
if err != nil {
return err
}
return os.Symlink(link, dest)
}

// Config is the configuration for the server.
type Config struct {
Name string `yaml:"name" json:"name"`
Host string `yaml:"host" json:"host"`
Expand Down
24 changes: 7 additions & 17 deletions server/backend/sqlite/sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ import (
"strings"
"text/template"

"github.com/charmbracelet/keygen"
"github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/git"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/utils"
"github.com/jmoiron/sqlx"
_ "modernc.org/sqlite"
Expand All @@ -26,9 +26,9 @@ var (
// SqliteBackend is a backend that uses a SQLite database as a Soft Serve
// backend.
type SqliteBackend struct {
cfg *config.Config
dp string
db *sqlx.DB
ckp string
AdditionalAdmins []string
}

Expand All @@ -39,32 +39,22 @@ func (d *SqliteBackend) reposPath() string {
}

// NewSqliteBackend creates a new SqliteBackend.
func NewSqliteBackend(dataPath string) (*SqliteBackend, error) {
func NewSqliteBackend(cfg *config.Config) (*SqliteBackend, error) {
dataPath := cfg.DataPath
if err := os.MkdirAll(dataPath, 0755); err != nil {
return nil, err
}

ckp := filepath.Join(dataPath, "ssh", "soft_serve_client")
_, err := keygen.NewWithWrite(ckp, nil, keygen.Ed25519)
if err != nil {
return nil, err
}

ckp, err = filepath.Abs(ckp)
if err != nil {
return nil, err
}

db, err := sqlx.Connect("sqlite", filepath.Join(dataPath, "soft-serve.db"+
"?_pragma=busy_timeout(5000)&_pragma=foreign_keys(1)"))
if err != nil {
return nil, err
}

d := &SqliteBackend{
cfg: cfg,
dp: dataPath,
db: db,
ckp: ckp,
}

if err := d.init(); err != nil {
Expand Down Expand Up @@ -186,8 +176,8 @@ func (d *SqliteBackend) ImportRepository(name string, remote string, opts backen
CommandOptions: git.CommandOptions{
Envs: []string{
fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
filepath.Join(filepath.Dir(d.ckp), "known_hosts"),
d.ckp,
filepath.Join(d.cfg.DataPath, "ssh", "known_hosts"),
filepath.Join(d.cfg.DataPath, d.cfg.SSH.ClientKeyPath),
),
},
},
Expand Down
9 changes: 9 additions & 0 deletions server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ type SSHConfig struct {
// KeyPath is the path to the SSH server's private key.
KeyPath string `env:"KEY_PATH" yaml:"key_path"`

// ClientKeyPath is the path to the SSH server's client private key.
ClientKeyPath string `env:"CLIENT_KEY_PATH" yaml:"client_key_path"`

// InternalKeyPath is the path to the SSH server's internal private key.
InternalKeyPath string `env:"INTERNAL_KEY_PATH" yaml:"internal_key_path"`

Expand Down Expand Up @@ -122,13 +125,19 @@ func DefaultConfig() *Config {
dataPath = "data"
}

dp, _ := filepath.Abs(dataPath)
if dp != "" {
dataPath = dp
}

cfg := &Config{
Name: "Soft Serve",
DataPath: dataPath,
SSH: SSHConfig{
ListenAddr: ":23231",
PublicURL: "ssh://localhost:23231",
KeyPath: filepath.Join("ssh", "soft_serve_host"),
ClientKeyPath: filepath.Join("ssh", "soft_serve_client"),
InternalKeyPath: filepath.Join("ssh", "soft_serve_internal"),
MaxTimeout: 0,
IdleTimeout: 120,
Expand Down
5 changes: 5 additions & 0 deletions server/config/file.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@ ssh:
# The relative path to the SSH server's private key.
key_path: "{{ .SSH.KeyPath }}"
# The relative path to the SSH server's client private key.
# This key will be used to authenticate the server to make git requests to
# ssh remotes.
client_key_path: "{{ .SSH.ClientKeyPath }}"
# The relative path to the SSH server's internal api private key.
internal_key_path: "{{ .SSH.InternalKeyPath }}"
Expand Down
7 changes: 4 additions & 3 deletions server/daemon_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,16 @@ func TestMain(m *testing.M) {
os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
fb, err := sqlite.NewSqliteBackend(tmp)
cfg := config.DefaultConfig()
d, err := NewGitDaemon(cfg)
if err != nil {
log.Fatal(err)
}
cfg := config.DefaultConfig().WithBackend(fb)
d, err := NewGitDaemon(cfg)
fb, err := sqlite.NewSqliteBackend(cfg)
if err != nil {
log.Fatal(err)
}
cfg = cfg.WithBackend(fb)
testDaemon = d
go func() {
if err := d.Start(); err != ErrServerClosed {
Expand Down
14 changes: 12 additions & 2 deletions server/jobs.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
package server

import (
"fmt"
"path/filepath"

"github.com/charmbracelet/soft-serve/git"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
)

var (
Expand All @@ -12,7 +15,8 @@ var (
)

// mirrorJob runs the (pull) mirror job task.
func mirrorJob(b backend.Backend) func() {
func mirrorJob(cfg *config.Config) func() {
b := cfg.Backend
logger := logger.WithPrefix("server.mirrorJob")
return func() {
repos, err := b.Repositories()
Expand All @@ -31,6 +35,12 @@ func mirrorJob(b backend.Backend) func() {
}

cmd := git.NewCommand("remote", "update", "--prune")
cmd.AddEnvs(
fmt.Sprintf(`GIT_SSH_COMMAND=ssh -o UserKnownHostsFile="%s" -o StrictHostKeyChecking=no -i "%s"`,
filepath.Join(cfg.DataPath, "ssh", "known_hosts"),
filepath.Join(cfg.DataPath, cfg.SSH.ClientKeyPath),
),
)
if _, err := cmd.RunInDir(r.Path); err != nil {
logger.Error("error running git remote update", "repo", repo.Name(), "err", err)
}
Expand Down
14 changes: 12 additions & 2 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Server struct {
func NewServer(cfg *config.Config) (*Server, error) {
var err error
if cfg.Backend == nil {
sb, err := sqlite.NewSqliteBackend(cfg.DataPath)
sb, err := sqlite.NewSqliteBackend(cfg)
if err != nil {
logger.Fatal(err)
}
Expand All @@ -57,6 +57,16 @@ func NewServer(cfg *config.Config) (*Server, error) {
if err != nil {
return nil, err
}

// Create client key.
_, err = keygen.NewWithWrite(
filepath.Join(cfg.DataPath, cfg.SSH.ClientKeyPath),
nil,
keygen.Ed25519,
)
if err != nil {
return nil, err
}
}

srv := &Server{
Expand All @@ -66,7 +76,7 @@ func NewServer(cfg *config.Config) (*Server, error) {
}

// Add cron jobs.
srv.Cron.AddFunc(jobSpecs["mirror"], mirrorJob(cfg.Backend))
srv.Cron.AddFunc(jobSpecs["mirror"], mirrorJob(cfg))

srv.SSHServer, err = NewSSHServer(cfg, srv)
if err != nil {
Expand Down
5 changes: 3 additions & 2 deletions server/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ func setup(tb testing.TB) *gossh.Session {
is.NoErr(os.Unsetenv("SOFT_SERVE_SSH_LISTEN_ADDR"))
is.NoErr(os.RemoveAll(dp))
})
fb, err := sqlite.NewSqliteBackend(dp)
cfg := config.DefaultConfig()
fb, err := sqlite.NewSqliteBackend(cfg)
if err != nil {
log.Fatal(err)
}
cfg := config.DefaultConfig().WithBackend(fb)
cfg = cfg.WithBackend(fb)
return testsession.New(tb, &ssh.Server{
Handler: bm.MiddlewareWithProgramHandler(SessionHandler(cfg), termenv.ANSI256)(func(s ssh.Session) {
_, _, active := s.Pty()
Expand Down

0 comments on commit 468a99f

Please sign in to comment.