/
cachemanager.go
156 lines (141 loc) · 4.66 KB
/
cachemanager.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package cachemanager
import (
"errors"
"net/http"
"net/url"
"os"
"strconv"
"sync"
"time"
"github.com/mKaloer/tfservingcache/pkg/tfservingproxy"
log "github.com/sirupsen/logrus"
)
type Model struct {
Identifier ModelIdentifier
Path string
SizeOnDisk int64
}
type ModelIdentifier struct {
ModelName string
Version int64
}
type CacheManager struct {
RestProxy *tfservingproxy.RestProxy
localRestUrl url.URL
ModelProvider ModelProvider
LocalCache ModelCache
TFServingServerModelBasePath string
ServingController TFServingController
ModelFetchTimeout float32 // model fetch timeout in seconds
rwMux sync.RWMutex
}
func (handler *CacheManager) ServeRest() func(http.ResponseWriter, *http.Request) {
return handler.RestProxy.Serve()
}
func (cache *CacheManager) fetchModel(identifier ModelIdentifier) error {
_, isPresent := cache.tryGetModelFromCache(identifier)
if !isPresent {
// Model does not exist - get size, then put in cache
cache.rwMux.Lock()
defer cache.rwMux.Unlock()
modelSize, err := cache.ModelProvider.ModelSize(identifier.ModelName, identifier.Version)
if err != nil {
log.WithError(err).Error("Error while retrieving model size")
return err
}
cache.LocalCache.EnsureFreeBytes(modelSize)
model, err := cache.ModelProvider.LoadModel(identifier.ModelName, identifier.Version, cache.LocalCache.BaseDir())
if err != nil {
log.WithError(err).Error("Error while retrieving model")
return err
}
cache.LocalCache.Put(identifier, model)
err = cache.ServingController.ReloadConfig(cache.LocalCache.ListModels(), cache.TFServingServerModelBasePath)
if err != nil {
log.WithError(err).Error("Error while loading model")
return err
}
totalTime := float32(0.0)
for totalTime == 0 || totalTime < cache.ModelFetchTimeout {
status, err := cache.ServingController.GetModelStatus(model)
if err != nil {
log.WithError(err).Errorf("Error getting model status. Duration: %fs", totalTime)
} else if status == ModelVersionStatus_AVAILABLE {
log.Info("Model available")
break
} else {
log.Debugf("Model not yet available: %s. Duration: %fs", status.String(), totalTime)
}
totalTime += 0.5
time.Sleep(time.Millisecond * 500)
}
if totalTime >= cache.ModelFetchTimeout {
return errors.New("Timeout: Model did not load in time")
}
}
return nil
}
func (cache *CacheManager) tryGetModelFromCache(identifier ModelIdentifier) (Model, bool) {
cache.rwMux.RLock()
defer cache.rwMux.RUnlock()
model, isPresent := cache.LocalCache.Get(identifier)
hostModelPath := cache.LocalCache.ModelPath(model)
fileExists := isPresent && fileOrDirExists(hostModelPath)
if isPresent && !fileExists {
log.Warnf("Model in cache but not present on disk. Name: %s, Version: %d, path: %s",
identifier.ModelName, identifier.Version, hostModelPath)
}
return model, fileExists
}
func New(
modelProvider ModelProvider,
modelCache ModelCache,
tfServingServerBasePath string,
tfservingServerGRPCHost string,
tfservingServerRESTHost string,
modelFetchTimeout float32,
) *CacheManager {
restUrl, err := url.Parse(tfservingServerRESTHost)
if err != nil {
return nil
}
servingController := TFServingController{grpcHost: tfservingServerGRPCHost, restHost: tfservingServerRESTHost}
h := &CacheManager{
localRestUrl: *restUrl,
ModelProvider: modelProvider,
LocalCache: modelCache,
ServingController: servingController,
TFServingServerModelBasePath: tfServingServerBasePath,
ModelFetchTimeout: modelFetchTimeout,
}
director := func(req *http.Request, modelName string, version string) {
log.Infof("Fetching model...")
modelVersion, err := strconv.ParseInt(version, 10, 64)
if err != nil {
log.WithError(err).Errorf("Error handling request. Version must be valid integer: '%s'", version)
req.Response.StatusCode = 500
return
}
identifier := ModelIdentifier{ModelName: modelName, Version: modelVersion}
err = h.fetchModel(identifier)
if err != nil {
log.WithError(err).Errorf("Error handling request. Aborting: %s", req.URL.String())
req.Response.StatusCode = 500
return
}
localUrl := *restUrl
localUrl.Path = req.URL.Path
log.Infof("Forwarding to %s", localUrl.String())
req.URL = &localUrl
if _, ok := req.Header["User-Agent"]; !ok {
// explicitly disable User-Agent so it's not set to default value
req.Header.Set("User-Agent", "")
}
}
h.RestProxy = tfservingproxy.NewRestProxy(director)
return h
}
func fileOrDirExists(filename string) bool {
_, err := os.Stat(filename)
return !os.IsNotExist(err)
}