-
Notifications
You must be signed in to change notification settings - Fork 2
/
rest.go
115 lines (96 loc) · 3.31 KB
/
rest.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
package rest
import (
"fmt"
"net/http"
"github.com/go-chi/chi"
"github.com/grupawp/tensorflow-deploy/app"
"github.com/grupawp/tensorflow-deploy/exterr"
"github.com/grupawp/tensorflow-deploy/lock"
"github.com/grupawp/tensorflow-deploy/logging"
)
var (
logInvalidChecksumErrorCode = 1001
errorInvalidChecksum = exterr.NewErrorWithMessage("invalid checksum").WithComponent(app.ComponentRest).WithCode(logInvalidChecksumErrorCode)
)
// REST represents restful API methods
type REST struct {
modelsService ModelsService
modulesService ModulesService
uploadFileName string
uploadFileChecksum string
listenPort string
version string
lock *lock.Lock
}
// NewREST returns new instance of REST struct
func NewREST(modelsSrv ModelsService, modulesSrv ModulesService, listenPort, version string) *REST {
l := lock.New()
return &REST{
modelsService: modelsSrv,
modulesService: modulesSrv,
uploadFileName: "archive_data",
uploadFileChecksum: "archive_hash",
listenPort: listenPort,
version: version,
lock: l,
}
}
// Mount mounts each restful endpoints into router
func (rest *REST) Mount() {
r := chi.NewRouter()
// logging middlewares
r.Use(logging.HTTPCtxValuesMiddleware)
r.Use(logging.HTTPRequestMiddleware())
// common
r.Get("/ping", rest.pingHandler)
// v3: model
r.Route("/v1/models", func(r chi.Router) {
r.Get("/list", rest.listModelsHandler)
})
r.Route("/v1/models/{team}/{project}", func(r chi.Router) {
r.Get("/config", rest.configFileHandler)
r.Get("/list", rest.listModelsByProjectHandler)
r.Post("/reload", rest.reloadHandler)
})
r.Route("/v1/models/{team}/{project}/names/{name}", func(r chi.Router) {
r.Post("/", rest.uploadModelHandler)
r.Get("/list", rest.listModelsByNameHandler)
r.Put("/revert", rest.revertModelHandler)
})
r.Route("/v1/models/{team}/{project}/names/{name}/labels/{label}", func(r chi.Router) {
r.Get("/", rest.downloadModelByLabelHandler)
r.Delete("/", rest.deleteModelLabelHandler)
r.Post("/", rest.uploadModelWithLabelHandler)
r.Delete("/remove_version", rest.deleteModelByLabelHandler)
})
r.Route("/v1/models/{team}/{project}/names/{name}/versions/{version}", func(r chi.Router) {
r.Get("/", rest.downloadModelByVersionHandler)
r.Delete("/", rest.deleteModelByVersionHandler)
r.Put("/labels/stable", rest.setModelLabelToStableHandler)
r.Put("/labels/{label}", rest.setModelLabelHandler)
})
// v3: module
r.Route("/v1/modules", func(r chi.Router) {
r.Get("/list", rest.listModulesHandler)
})
r.Route("/v1/modules/{team}/{project}", func(r chi.Router) {
r.Get("/list", rest.listModulesByProjectHandler)
})
r.Route("/v1/modules/{team}/{project}/names/{name}", func(r chi.Router) {
r.Post("/", rest.uploadModuleHandler)
r.Get("/list", rest.listModulesByNameHandler)
r.Get("/versions/{version}", rest.downloadModuleByVersionHandler)
r.Delete("/versions/{version}", rest.deleteModuleHandler)
})
server := &http.Server{Addr: rest.listenPort, Handler: r}
server.ListenAndServe()
}
func (rest *REST) pingHandler(w http.ResponseWriter, r *http.Request) {
writeJSONSuccessResponse(w, r, http.StatusOK, struct {
Name string `json:"name"`
Version string `json:"version"`
}{
Name: fmt.Sprintf("%s:%s", "tensorflow-deploy", rest.version),
Version: rest.version,
})
}