Skip to content

Commit

Permalink
refactor(server): move into modules
Browse files Browse the repository at this point in the history
  • Loading branch information
aymanbagabas committed May 2, 2023
1 parent 3d7eb7b commit 2d5089e
Show file tree
Hide file tree
Showing 11 changed files with 113 additions and 84 deletions.
45 changes: 27 additions & 18 deletions server/daemon.go → server/daemon/daemon.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package server
package daemon

import (
"bytes"
Expand All @@ -9,14 +9,20 @@ import (
"sync"
"time"

"github.com/charmbracelet/log"
"github.com/charmbracelet/soft-serve/server/backend"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/git"
"github.com/charmbracelet/soft-serve/server/utils"
"github.com/go-git/go-git/v5/plumbing/format/pktline"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
)

var (
logger = log.WithPrefix("server.daemon")
)

var (
uploadPackGitCounter = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "soft_serve",
Expand All @@ -33,8 +39,11 @@ var (
}, []string{"repo"})
)

// ErrServerClosed indicates that the server has been closed.
var ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed)
var (

// ErrServerClosed indicates that the server has been closed.
ErrServerClosed = fmt.Errorf("git: %w", net.ErrClosed)
)

// connections synchronizes access to to a net.Conn pool.
type connections struct {
Expand Down Expand Up @@ -133,7 +142,7 @@ func (d *GitDaemon) Start() error {
// Close connection if there are too many open connections.
if d.conns.Size()+1 >= d.cfg.Git.MaxConnections {
logger.Debugf("git: max connections reached, closing %s", conn.RemoteAddr())
fatal(conn, ErrMaxConnections)
fatal(conn, git.ErrMaxConnections)
continue
}

Expand All @@ -146,7 +155,7 @@ func (d *GitDaemon) Start() error {
}

func fatal(c net.Conn, err error) {
writePktline(c, err)
git.WritePktline(c, err)
if err := c.Close(); err != nil {
logger.Debugf("git: error closing connection: %v", err)
}
Expand Down Expand Up @@ -176,10 +185,10 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
if !s.Scan() {
if err := s.Err(); err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Timeout() {
fatal(c, ErrTimeout)
fatal(c, git.ErrTimeout)
} else {
logger.Debugf("git: error scanning pktline: %v", err)
fatal(c, ErrSystemMalfunction)
fatal(c, git.ErrSystemMalfunction)
}
}
return
Expand All @@ -197,32 +206,32 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
line := s.Bytes()
split := bytes.SplitN(line, []byte{' '}, 2)
if len(split) != 2 {
fatal(c, ErrInvalidRequest)
fatal(c, git.ErrInvalidRequest)
return
}

gitPack := uploadPack
gitPack := git.UploadPack
counter := uploadPackGitCounter
cmd := string(split[0])
switch cmd {
case uploadPackBin:
gitPack = uploadPack
case uploadArchiveBin:
gitPack = uploadArchive
case git.UploadPackBin:
gitPack = git.UploadPack
case git.UploadArchiveBin:
gitPack = git.UploadArchive
counter = uploadArchiveGitCounter
default:
fatal(c, ErrInvalidRequest)
fatal(c, git.ErrInvalidRequest)
return
}

opts := bytes.Split(split[1], []byte{'\x00'})
if len(opts) == 0 {
fatal(c, ErrInvalidRequest)
fatal(c, git.ErrInvalidRequest)
return
}

if !d.cfg.Backend.AllowKeyless() {
fatal(c, ErrNotAuthed)
fatal(c, git.ErrNotAuthed)
return
}

Expand All @@ -233,14 +242,14 @@ func (d *GitDaemon) handleClient(conn net.Conn) {
// https://git-scm.com/docs/gitrepository-layout
repo := name + ".git"
reposDir := filepath.Join(d.cfg.DataPath, "repos")
if err := ensureWithin(reposDir, repo); err != nil {
if err := git.EnsureWithin(reposDir, repo); err != nil {
fatal(c, err)
return
}

auth := d.cfg.Backend.AccessLevel(name, "")
if auth < backend.ReadOnlyAccess {
fatal(c, ErrNotAuthed)
fatal(c, git.ErrNotAuthed)
return
}

Expand Down
14 changes: 8 additions & 6 deletions server/daemon_test.go → server/daemon/daemon_test.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package server
package daemon

import (
"bytes"
Expand All @@ -14,6 +14,8 @@ import (

"github.com/charmbracelet/soft-serve/server/backend/sqlite"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/git"
"github.com/charmbracelet/soft-serve/server/test"
"github.com/go-git/go-git/v5/plumbing/format/pktline"
)

Expand All @@ -29,7 +31,7 @@ func TestMain(m *testing.M) {
os.Setenv("SOFT_SERVE_GIT_MAX_CONNECTIONS", "3")
os.Setenv("SOFT_SERVE_GIT_MAX_TIMEOUT", "100")
os.Setenv("SOFT_SERVE_GIT_IDLE_TIMEOUT", "1")
os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
os.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
cfg := config.DefaultConfig()
d, err := NewGitDaemon(cfg)
if err != nil {
Expand Down Expand Up @@ -67,8 +69,8 @@ func TestIdleTimeout(t *testing.T) {
if err != nil && !errors.Is(err, io.EOF) {
t.Fatalf("expected nil, got error: %v", err)
}
if out != ErrTimeout.Error() || out == "" {
t.Fatalf("expected %q error, got %q", ErrTimeout, out)
if out != git.ErrTimeout.Error() || out == "" {
t.Fatalf("expected %q error, got %q", git.ErrTimeout, out)
}
}

Expand All @@ -84,8 +86,8 @@ func TestInvalidRepo(t *testing.T) {
if err != nil {
t.Fatalf("expected nil, got error: %v", err)
}
if out != ErrInvalidRepo.Error() {
t.Fatalf("expected %q error, got %q", ErrInvalidRepo, out)
if out != git.ErrInvalidRepo.Error() {
t.Fatalf("expected %q error, got %q", git.ErrInvalidRepo, out)
}
}

Expand Down
44 changes: 22 additions & 22 deletions server/git.go → server/git/git.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package server
package git

import (
"errors"
Expand Down Expand Up @@ -36,45 +36,45 @@ var (

// Git protocol commands.
const (
receivePackBin = "git-receive-pack"
uploadPackBin = "git-upload-pack"
uploadArchiveBin = "git-upload-archive"
ReceivePackBin = "git-receive-pack"
UploadPackBin = "git-upload-pack"
UploadArchiveBin = "git-upload-archive"
)

// uploadPack runs the git upload-pack protocol against the provided repo.
func uploadPack(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
// UploadPack runs the git upload-pack protocol against the provided repo.
func UploadPack(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
exists, err := fileExists(repoDir)
if !exists {
return ErrInvalidRepo
}
if err != nil {
return err
}
return runGit(in, out, er, "", uploadPackBin[4:], repoDir)
return RunGit(in, out, er, "", UploadPackBin[4:], repoDir)
}

// uploadArchive runs the git upload-archive protocol against the provided repo.
func uploadArchive(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
// UploadArchive runs the git upload-archive protocol against the provided repo.
func UploadArchive(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
exists, err := fileExists(repoDir)
if !exists {
return ErrInvalidRepo
}
if err != nil {
return err
}
return runGit(in, out, er, "", uploadArchiveBin[4:], repoDir)
return RunGit(in, out, er, "", UploadArchiveBin[4:], repoDir)
}

// receivePack runs the git receive-pack protocol against the provided repo.
func receivePack(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
if err := runGit(in, out, er, "", receivePackBin[4:], repoDir); err != nil {
// ReceivePack runs the git receive-pack protocol against the provided repo.
func ReceivePack(in io.Reader, out io.Writer, er io.Writer, repoDir string) error {
if err := RunGit(in, out, er, "", ReceivePackBin[4:], repoDir); err != nil {
return err
}
return ensureDefaultBranch(in, out, er, repoDir)
return EnsureDefaultBranch(in, out, er, repoDir)
}

// runGit runs a git command in the given repo.
func runGit(in io.Reader, out io.Writer, err io.Writer, dir string, args ...string) error {
// RunGit runs a git command in the given repo.
func RunGit(in io.Reader, out io.Writer, err io.Writer, dir string, args ...string) error {
c := git.NewCommand(args...)
return c.RunInDirWithOptions(dir, git.RunInDirOptions{
Stdin: in,
Expand All @@ -83,8 +83,8 @@ func runGit(in io.Reader, out io.Writer, err io.Writer, dir string, args ...stri
})
}

// writePktline encodes and writes a pktline to the given writer.
func writePktline(w io.Writer, v ...interface{}) {
// WritePktline encodes and writes a pktline to the given writer.
func WritePktline(w io.Writer, v ...interface{}) {
msg := fmt.Sprintln(v...)
pkt := pktline.NewEncoder(w)
if err := pkt.EncodeString(msg); err != nil {
Expand All @@ -95,8 +95,8 @@ func writePktline(w io.Writer, v ...interface{}) {
}
}

// ensureWithin ensures the given repo is within the repos directory.
func ensureWithin(reposDir string, repo string) error {
// EnsureWithin ensures the given repo is within the repos directory.
func EnsureWithin(reposDir string, repo string) error {
repoDir := filepath.Join(reposDir, repo)
absRepos, err := filepath.Abs(reposDir)
if err != nil {
Expand Down Expand Up @@ -129,7 +129,7 @@ func fileExists(path string) (bool, error) {
return true, err
}

func ensureDefaultBranch(in io.Reader, out io.Writer, er io.Writer, repoPath string) error {
func EnsureDefaultBranch(in io.Reader, out io.Writer, er io.Writer, repoPath string) error {
r, err := git.Open(repoPath)
if err != nil {
return err
Expand All @@ -144,7 +144,7 @@ func ensureDefaultBranch(in io.Reader, out io.Writer, er io.Writer, repoPath str
// Rename the default branch to the first branch available
_, err = r.HEAD()
if err == git.ErrReferenceNotExist {
err = runGit(in, out, er, repoPath, "branch", "-M", brs[0])
err = RunGit(in, out, er, repoPath, "branch", "-M", brs[0])
if err != nil {
return err
}
Expand Down
22 changes: 13 additions & 9 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import (
"github.com/charmbracelet/soft-serve/server/backend/sqlite"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/cron"
"github.com/charmbracelet/soft-serve/server/daemon"
sshsrv "github.com/charmbracelet/soft-serve/server/ssh"
"github.com/charmbracelet/soft-serve/server/stats"
"github.com/charmbracelet/soft-serve/server/web"
"github.com/charmbracelet/ssh"
"golang.org/x/sync/errgroup"
)
Expand All @@ -23,10 +27,10 @@ var (

// Server is the Soft Serve server.
type Server struct {
SSHServer *SSHServer
GitDaemon *GitDaemon
HTTPServer *HTTPServer
StatsServer *StatsServer
SSHServer *sshsrv.SSHServer
GitDaemon *daemon.GitDaemon
HTTPServer *web.HTTPServer
StatsServer *stats.StatsServer
Cron *cron.CronScheduler
Config *config.Config
Backend backend.Backend
Expand Down Expand Up @@ -81,22 +85,22 @@ func NewServer(ctx context.Context, cfg *config.Config) (*Server, error) {
// Add cron jobs.
srv.Cron.AddFunc(jobSpecs["mirror"], mirrorJob(cfg))

srv.SSHServer, err = NewSSHServer(cfg, srv)
srv.SSHServer, err = sshsrv.NewSSHServer(cfg, srv)
if err != nil {
return nil, err
}

srv.GitDaemon, err = NewGitDaemon(cfg)
srv.GitDaemon, err = daemon.NewGitDaemon(cfg)
if err != nil {
return nil, err
}

srv.HTTPServer, err = NewHTTPServer(cfg)
srv.HTTPServer, err = web.NewHTTPServer(cfg)
if err != nil {
return nil, err
}

srv.StatsServer, err = NewStatsServer(cfg)
srv.StatsServer, err = stats.NewStatsServer(cfg)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -124,7 +128,7 @@ func (s *Server) Start() error {
errg, ctx := errgroup.WithContext(s.ctx)
errg.Go(func() error {
logger.Print("Starting Git daemon", "addr", s.Config.Git.ListenAddr)
if err := start(ctx, s.GitDaemon.Start); !errors.Is(err, ErrServerClosed) {
if err := start(ctx, s.GitDaemon.Start); !errors.Is(err, daemon.ErrServerClosed) {
return err
}
return nil
Expand Down
12 changes: 3 additions & 9 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,28 @@ package server
import (
"context"
"fmt"
"net"
"path/filepath"
"strings"
"testing"

"github.com/charmbracelet/keygen"
"github.com/charmbracelet/soft-serve/server/config"
"github.com/charmbracelet/soft-serve/server/test"
"github.com/charmbracelet/ssh"
"github.com/matryer/is"
gossh "golang.org/x/crypto/ssh"
)

func randomPort() int {
addr, _ := net.Listen("tcp", ":0") //nolint:gosec
_ = addr.Close()
return addr.Addr().(*net.TCPAddr).Port
}

func setupServer(tb testing.TB) (*Server, *config.Config, string) {
tb.Helper()
tb.Log("creating keypair")
pub, pkPath := createKeyPair(tb)
dp := tb.TempDir()
sshPort := fmt.Sprintf(":%d", randomPort())
sshPort := fmt.Sprintf(":%d", test.RandomPort())
tb.Setenv("SOFT_SERVE_DATA_PATH", dp)
tb.Setenv("SOFT_SERVE_INITIAL_ADMIN_KEY", authorizedKey(pub))
tb.Setenv("SOFT_SERVE_SSH_LISTEN_ADDR", sshPort)
tb.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", randomPort()))
tb.Setenv("SOFT_SERVE_GIT_LISTEN_ADDR", fmt.Sprintf(":%d", test.RandomPort()))
cfg := config.DefaultConfig()
tb.Log("configuring server")
ctx := context.TODO()
Expand Down
2 changes: 1 addition & 1 deletion server/session.go → server/ssh/session.go
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package server
package ssh

import (
"fmt"
Expand Down

0 comments on commit 2d5089e

Please sign in to comment.