-
Notifications
You must be signed in to change notification settings - Fork 3
/
daemon.go
134 lines (111 loc) · 3.34 KB
/
daemon.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
package cmd
// This file contains all the daemon-related commands when starting `cedana daemon ...`
import (
"context"
"fmt"
"io"
"net/http"
"os"
"github.com/cedana/cedana/api"
"github.com/cedana/cedana/utils"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
"github.com/spf13/viper"
)
const (
gpuControllerBinaryName = "gpucontroller"
gpuControllerBinaryPath = "/usr/local/bin/gpu-controller"
gpuSharedLibName = "libcedana"
gpuSharedLibPath = "/usr/local/lib/libcedana-gpu.so"
)
var daemonCmd = &cobra.Command{
Use: "daemon",
Short: "Start daemon for cedana client. Must be run as root, needed for all other cedana functionality.",
}
var startDaemonCmd = &cobra.Command{
Use: "start",
Short: "Starts the rpc server. To run as a daemon, use the provided script (systemd) or use systemd/sysv/upstart.",
Run: func(cmd *cobra.Command, args []string) {
ctx := cmd.Context()
logger := ctx.Value("logger").(*zerolog.Logger)
if os.Getuid() != 0 {
logger.Error().Msg("daemon must be run as root")
return
}
stopOtel, err := utils.InitOtel(cmd.Context(), cmd.Parent().Version)
if err != nil {
logger.Warn().Err(err).Msg("Failed to initialize otel")
}
defer stopOtel(ctx)
if viper.GetBool("profiling_enabled") {
go startProfiler()
}
if viper.GetBool("gpu_enabled") {
err := pullGPUBinary(ctx, gpuControllerBinaryName, gpuControllerBinaryPath)
if err != nil {
logger.Error().Err(err).Msg("could not pull gpu controller")
return
}
err = pullGPUBinary(ctx, gpuSharedLibName, gpuSharedLibPath)
if err != nil {
logger.Error().Err(err).Msg("could not pull libcedana")
return
}
}
logger.Info().Msgf("starting daemon version %s", rootCmd.Version)
err = api.StartServer(ctx)
if err != nil {
logger.Error().Err(err).Msgf("stopping daemon")
}
},
}
// Used for debugging and profiling only!
func startProfiler() {
utils.StartPprofServer()
}
func init() {
rootCmd.AddCommand(daemonCmd)
daemonCmd.AddCommand(startDaemonCmd)
}
func pullGPUBinary(ctx context.Context, binary string, filePath string) error {
logger := ctx.Value("logger").(*zerolog.Logger)
_, err := os.Stat(filePath)
if err == nil {
logger.Debug().Msgf("binary exists at %s, doing nothing", filePath)
// file exists, do nothing.
// TODO NR - check version of binary
return nil
}
url := "https://" + viper.GetString("connection.cedana_url") + "/checkpoint/gpu/" + binary
logger.Debug().Msgf("pulling %s from %s", binary, url)
httpClient := &http.Client{}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
var resp *http.Response
if err != nil {
logger.Err(err).Msg("could not create request")
return err
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", viper.GetString("connection.cedana_auth_token")))
resp, err = httpClient.Do(req)
if err != nil || resp.StatusCode != http.StatusOK {
logger.Err(err).Msg("gpu binary get request failed")
return err
}
defer resp.Body.Close()
file, err := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY, 0755)
if err == nil {
err = os.Chmod(filePath, 0755)
}
if err != nil {
logger.Err(err).Msg("could not create file")
return err
}
defer file.Close()
_, err = io.Copy(file, resp.Body)
if err != nil {
logger.Err(err).Msg("could not read file from response")
return err
}
logger.Debug().Msgf("%s downloaded", binary)
return err
}