/
server.go
116 lines (95 loc) · 2.19 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package server
import (
"context"
"net"
"os"
"path/filepath"
"strconv"
"github.com/michaeladler/shellsnoop/internal/log"
"github.com/michaeladler/shellsnoop/internal/storage"
"golang.org/x/sync/errgroup"
)
func StartServer(ctx context.Context, socketFile string, uid *int) error {
logger := log.Logger
dir := filepath.Dir(socketFile)
logger.Debug("Creating directory", "dir", dir)
if err := os.Mkdir(dir, os.FileMode(0755)); err != nil && !os.IsExist(err) {
logger.Warn("Failed to create directory", "err", err)
}
listener, err := net.Listen("unix", socketFile)
if err != nil {
return err
}
if uid != nil {
logger.Debug("Changing owner", uid, *uid)
_ = os.Chown(socketFile, *uid, 0)
_ = os.Chmod(socketFile, 0600)
}
defer listener.Close()
logger.Info("Listening on Unix domain socket", "socketFile", socketFile)
g, ctx := errgroup.WithContext(ctx)
running := true
g.Go(func() error {
// Accept and handle incoming connections.
for running {
conn, err := listener.Accept()
if err != nil {
if err != net.ErrClosed {
logger.Error("Failed to accept connection", "err", err)
}
continue
}
defer conn.Close()
// Handle the connection in a new goroutine.
g.Go(func() error {
err := handleConnection(conn)
if err != nil {
logger.Error("Something went wrong with the client", "err", err)
}
return nil
})
}
return nil
})
doneChan := make(chan any)
go func() {
_ = g.Wait()
doneChan <- nil
}()
// wait for quit event
select {
case <-doneChan:
running = false
case <-ctx.Done():
running = false
}
listener.Close()
_ = g.Wait()
return nil
}
func handleConnection(conn net.Conn) error {
// Read incoming data.
buffer := make([]byte, 4096)
n, err := conn.Read(buffer)
if err != nil {
return err
}
data := string(buffer[:n])
pid, err := strconv.Atoi(data)
if err != nil {
return err
}
logger := log.Logger
cmdline := storage.Get(pid)
if cmdline == "" {
logger.Warn("Last command not available", "pid", pid)
_, _ = conn.Write([]byte("\x00"))
return nil
}
logger.Info("Providing last command to client", "pid", pid)
_, err = conn.Write([]byte(cmdline))
if err != nil {
return err
}
return nil
}