diff --git a/cmd/gogit/daemon.go b/cmd/gogit/daemon.go new file mode 100644 index 0000000..c42317d --- /dev/null +++ b/cmd/gogit/daemon.go @@ -0,0 +1,102 @@ +package main + +import ( + "log" + "net" + "path/filepath" + "strconv" + + gitserver "github.com/go-git/cli/server/git" + "github.com/go-git/go-billy/v6" + "github.com/go-git/go-billy/v6/osfs" + gitbackend "github.com/go-git/go-git/v6/backend/git" + "github.com/go-git/go-git/v6/plumbing/transport" + "github.com/go-git/go-git/v6/storage" + "github.com/spf13/cobra" +) + +var ( + daemonExportAll bool + daemonPort int + daemonListen string +) + +func init() { + daemonCmd.Flags().BoolVarP(&daemonExportAll, "export-all", "", false, "Export all repositories") + daemonCmd.Flags().IntVarP(&daemonPort, "port", "", 9418, "Port to run the Git daemon on") + daemonCmd.Flags().StringVarP(&daemonListen, "listen", "", "", "Address to listen on (default: all interfaces)") + + rootCmd.AddCommand(daemonCmd) +} + +var daemonCmd = &cobra.Command{ + Use: "daemon [] [...]", + Short: "Start a Git daemon server", + RunE: func(cmd *cobra.Command, args []string) error { + var dirs []string + if len(args) == 0 { + dirs = append(dirs, ".") + } + + loader := NewDirsLoader(dirs, false, daemonExportAll) + addr := net.JoinHostPort(daemonListen, strconv.Itoa(daemonPort)) + be := gitbackend.NewBackend(loader) + srv := &gitserver.Server{ + Addr: addr, + Handler: gitserver.LoggingMiddleware(log.Default(), be), + ErrorLog: log.Default(), + } + + log.Printf("Starting Git daemon on %q", addr) + return srv.ListenAndServe() + }, +} + +type dirsLoader struct { + loaders []transport.Loader + fss []billy.Filesystem + exportAll bool +} + +var _ transport.Loader = (*dirsLoader)(nil) + +// NewDirsLoader creates a new dirsLoader with the given directories. +func NewDirsLoader(dirs []string, strict, exportAll bool) *dirsLoader { + var loaders []transport.Loader + var fss []billy.Filesystem + for _, dir := range dirs { + abs, err := filepath.Abs(dir) + if err != nil { + continue + } + fs := osfs.New(abs, osfs.WithBoundOS()) + fss = append(fss, fs) + loaders = append(loaders, transport.NewFilesystemLoader(fs, strict)) + } + return &dirsLoader{loaders: loaders, fss: fss, exportAll: exportAll} +} + +// Load implements transport.Loader. +func (d *dirsLoader) Load(ep *transport.Endpoint) (storage.Storer, error) { + for i, loader := range d.loaders { + storer, err := loader.Load(ep) + if err == nil { + if !d.exportAll { + // We need to check if git-daemon-export-ok + // file exists and if it does not, we skip this + // repository. + dfs := d.fss[i] + okFile := filepath.Join(ep.Path, "git-daemon-export-ok") + stat, err := dfs.Lstat(okFile) + if err != nil || (stat != nil && !stat.Mode().IsRegular()) { + // If the file does not exist or is a directory, + // we skip this repository. + continue + } + + } + return storer, nil + } + } + return nil, transport.ErrRepositoryNotFound +} diff --git a/cmd/gogit/main.go b/cmd/gogit/main.go index ad4c30a..c3523e9 100644 --- a/cmd/gogit/main.go +++ b/cmd/gogit/main.go @@ -1,6 +1,7 @@ package main import ( + "errors" "fmt" "os" "strconv" @@ -43,7 +44,12 @@ func init() { func main() { if err := rootCmd.Execute(); err != nil { - fmt.Fprintln(os.Stderr, err) + var rerr *transport.RemoteError + if errors.As(err, &rerr) { + fmt.Fprintln(os.Stderr, rerr) + } else { + fmt.Fprintln(os.Stderr, err) + } os.Exit(1) } } diff --git a/server/git/logging.go b/server/git/logging.go new file mode 100644 index 0000000..4ee5ce0 --- /dev/null +++ b/server/git/logging.go @@ -0,0 +1,24 @@ +package git + +import ( + "context" + "io" + "time" + + "github.com/go-git/go-git/v6/plumbing/protocol/packp" +) + +type Logger interface { + Printf(format string, v ...interface{}) +} + +func LoggingMiddleware(logger Logger, next Handler) HandlerFunc { + return func(ctx context.Context, c io.ReadWriteCloser, r *packp.GitProtoRequest) { + now := time.Now() + next.ServeTCP(ctx, c, r) + elapsedTime := time.Since(now) + if logger != nil { + logger.Printf("%s %s %s %v %v", r.Host, r.RequestCommand, r.Pathname, r.ExtraParams, elapsedTime) + } + } +} diff --git a/server/git/server.go b/server/git/server.go new file mode 100644 index 0000000..8da794c --- /dev/null +++ b/server/git/server.go @@ -0,0 +1,381 @@ +package git + +import ( + "context" + "errors" + "fmt" + "io" + "log" + "math/rand" + "net" + "sync" + "sync/atomic" + "time" + + gitbackend "github.com/go-git/go-git/v6/backend/git" + "github.com/go-git/go-git/v6/plumbing/format/pktline" + "github.com/go-git/go-git/v6/plumbing/protocol/packp" + "github.com/go-git/go-git/v6/plumbing/transport" + "github.com/go-git/go-git/v6/utils/ioutil" +) + +// DefaultAddr is the default address to listen on for Git protocol server. +const DefaultAddr = ":9418" + +// ErrServerClosed indicates that the server has been closed. +var ErrServerClosed = errors.New("server closed") + +// DefaultBackend is the default global Git transport server handler. +var DefaultBackend = gitbackend.NewBackend(nil) + +// ServerContextKey is the context key used to store the server in the context. +var ServerContextKey = &contextKey{"git-server"} + +// Handler is the interface that handles TCP requests for the Git protocol. +type Handler interface { + // ServeTCP handles a TCP connection for the Git protocol. + ServeTCP(ctx context.Context, c io.ReadWriteCloser, req *packp.GitProtoRequest) +} + +// HandlerFunc is a function that implements the Handler interface. +type HandlerFunc func(ctx context.Context, c io.ReadWriteCloser, req *packp.GitProtoRequest) + +// ServeTCP implements the Handler interface. +func (f HandlerFunc) ServeTCP(ctx context.Context, c io.ReadWriteCloser, req *packp.GitProtoRequest) { + f(ctx, c, req) +} + +// Server is a TCP server that handles Git protocol requests. +type Server struct { + // Addr is the address to listen on. If empty, it defaults to ":9418". + Addr string + + // Handler is the handler for Git protocol requests. It uses + // [DefaultHandler] when nil. + Handler Handler + + // ErrorLog is the logger used to log errors. When nil, it won't log + // errors. + ErrorLog *log.Logger + + // BaseContext optionally specifies a function to create a base context for + // the server listeners. If nil, [context.Background] will be used. + // The provided listener is the specific listener that is about to start + // accepting connections. + BaseContext func(net.Listener) context.Context + + // ConnContext optionally specifies a function to create a context for each + // connection. If nil, the context will be derived from the server's base + // context. + ConnContext func(context.Context, net.Conn) context.Context + + inShutdown atomic.Bool // true when server is in shutdown + mu sync.Mutex + listeners map[*net.Listener]struct{} + listenerGroup sync.WaitGroup + activeConn map[*conn]struct{} // active connections being served +} + +// shutdownPollIntervalMax is the maximum interval for polling +// idle connections during shutdown. +const shutdownPollIntervalMax = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server, waiting for all active +// connections to finish. +func (s *Server) Shutdown(ctx context.Context) error { + s.inShutdown.Store(true) + + s.mu.Lock() + lnerr := s.closeListenersLocked() + s.mu.Unlock() + s.listenerGroup.Wait() + + pollIntervalBase := time.Millisecond + nextPollInterval := func() time.Duration { + // Add 10% jitter. + interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10))) + // Double and clamp for next time. + pollIntervalBase *= 2 + if pollIntervalBase > shutdownPollIntervalMax { + pollIntervalBase = shutdownPollIntervalMax + } + return interval + } + + timer := time.NewTimer(nextPollInterval()) + for { + if s.closeIdleConns() { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + timer.Reset(nextPollInterval()) + } + } +} + +// Close immediately closes the server and all active connections. It returns +// any error returned from closing the underlying listeners. +func (s *Server) Close() error { + s.inShutdown.Store(true) + + s.mu.Lock() + defer s.mu.Unlock() + err := s.closeListenersLocked() + + // We need to unlock the mutex while waiting for listenersGroup. + s.mu.Unlock() + s.listenerGroup.Wait() + s.mu.Lock() + + for c := range s.activeConn { + c.Close() //nolint:errcheck + delete(s.activeConn, c) + } + return err +} + +// ListenAndServe listens on the TCP network address and serves Git +// protocol requests using the provided handler. +func (s *Server) ListenAndServe() error { + if s.shuttingDown() { + return ErrServerClosed + } + addr := s.Addr + if addr == "" { + addr = DefaultAddr // Default Git protocol port + } + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + return s.Serve(ln) +} + +// Serve starts the server and listens for incoming connections on the given +// listener. +func (s *Server) Serve(ln net.Listener) error { + origLn := ln + l := &onceCloseListener{Listener: ln} + defer l.Close() //nolint:errcheck + + if !s.trackListener(&l.Listener, true) { + return ErrServerClosed + } + defer s.trackListener(&l.Listener, false) + + baseCtx := context.Background() + if s.BaseContext != nil { + baseCtx = s.BaseContext(origLn) + if baseCtx == nil { + panic("git: BaseContext returned nil context") + } + } + + var tempDelay time.Duration // how long to sleep on accept failure + ctx := context.WithValue(baseCtx, ServerContextKey, s) + for { + rw, err := l.Accept() + if err != nil { + if s.shuttingDown() { + return ErrServerClosed + } + if ne, ok := err.(net.Error); ok && ne.Temporary() { + if tempDelay == 0 { + tempDelay = 5 * time.Millisecond + } else { + tempDelay *= 2 + } + if max := 1 * time.Second; tempDelay > max { + tempDelay = max + } + s.logf("git: Accept error: %v; retrying in %v", err, tempDelay) + time.Sleep(tempDelay) + continue + } + return err + } + connCtx := ctx + if cc := s.ConnContext; cc != nil { + connCtx = cc(ctx, rw) + if connCtx == nil { + panic("git: ConnContext returned nil context") + } + } + tempDelay = 0 + c := s.newConn(rw) + s.trackConn(c, true) + go c.serve(connCtx) //nolint:errcheck + } +} + +func (s *Server) shuttingDown() bool { + return s.inShutdown.Load() +} + +func (s *Server) closeListenersLocked() error { + var err error + for ln := range s.listeners { + if cerr := (*ln).Close(); cerr != nil && err == nil { + err = cerr + } + } + return err +} + +// handler delegates to either the server's Handler or the DefaultBackend. +func (s *Server) handler(ctx context.Context, c net.Conn, req *packp.GitProtoRequest) { + if s.Handler != nil { + s.Handler.ServeTCP(ctx, c, req) + } else { + DefaultBackend.ServeTCP(ctx, c, req) + } +} + +// trackListener adds or removes a net.Listener to the set of tracked +// listeners. +// +// We store a pointer to interface in the map set, in case the +// net.Listener is not comparable. This is safe because we only call +// trackListener via Serve and can track+defer untrack the same +// pointer to local variable there. We never need to compare a +// Listener from another caller. +// +// It reports whether the server is still up (not Shutdown or Closed). +func (s *Server) trackListener(ln *net.Listener, add bool) bool { + s.mu.Lock() + defer s.mu.Unlock() + if s.listeners == nil { + s.listeners = make(map[*net.Listener]struct{}) + } + if add { + if s.shuttingDown() { + return false + } + s.listeners[ln] = struct{}{} + s.listenerGroup.Add(1) + } else { + delete(s.listeners, ln) + s.listenerGroup.Done() + } + return true +} + +// closeIdleConns closes all idle connections. It returns true only if no new +// connection was found. +func (s *Server) closeIdleConns() bool { + idle := true + for c := range s.activeConn { + unixSec := c.unixSec.Load() + if unixSec == 0 { + // New connection, skip it. + idle = false + continue + } + c.Close() //nolint:errcheck + delete(s.activeConn, c) + } + return idle +} + +func (s *Server) logf(format string, args ...interface{}) { + if s.ErrorLog != nil { + s.ErrorLog.Printf(format, args...) + } +} + +func (s *Server) trackConn(c *conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + c.unixSec.Store(uint64(time.Now().Unix())) + if s.activeConn == nil { + s.activeConn = make(map[*conn]struct{}) + } + if add { + s.activeConn[c] = struct{}{} + } else { + delete(s.activeConn, c) + } +} + +// conn represents a server connection that is being handled. +type conn struct { + // Conn is the underlying net.Conn that is being used to read and write Git + // protocol messages. + net.Conn + // unix timestamp in seconds when the connection was established + unixSec atomic.Uint64 + // s the server that is handling this connection. + s *Server +} + +// newConn creates a new conn instance with the given net.Conn. +func (s *Server) newConn(rwc net.Conn) *conn { + return &conn{ + s: s, + Conn: rwc, + } +} + +// logf logs a message using the server's ErrorLog, if set. +func (c *conn) logf(format string, args ...interface{}) { + if c.s.ErrorLog != nil { + c.s.logf(format, args...) + } +} + +// serve serves a new connection. +func (c *conn) serve(ctx context.Context) { + defer func() { + if err := recover(); err != nil { + c.s.logf("git: panic serving connection: %v", err) + if cerr := c.Conn.Close(); cerr != nil { + c.s.logf("git: error closing connection: %v", cerr) + } + } + }() + + r := ioutil.NewContextReadCloser(ctx, c) + + var req packp.GitProtoRequest + if err := req.Decode(r); err != nil { + c.s.logf("git: error decoding request: %v", err) + if rErr := renderError(c, fmt.Errorf("error decoding request: %s", transport.ErrInvalidRequest)); rErr != nil { + c.s.logf("git: error writing error response: %v", rErr) + } + return + } + + c.s.handler(ctx, c.Conn, &req) +} + +// onceCloseListener wraps a net.Listener, protecting it from +// multiple Close calls. +type onceCloseListener struct { + net.Listener + once sync.Once + closeErr error +} + +func (oc *onceCloseListener) Close() error { + oc.once.Do(oc.close) + return oc.closeErr +} + +func (oc *onceCloseListener) close() { oc.closeErr = oc.Listener.Close() } + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func renderError(rw io.WriteCloser, err error) error { + if _, err := pktline.WriteError(rw, err); err != nil { + rw.Close() //nolint:errcheck + return err + } + return rw.Close() +}