Skip to content

Commit

Permalink
Initial commit for OCI dist spec v1.1.0 agent support
Browse files Browse the repository at this point in the history
Partially addresses kubeflow/community#682

Signed-off-by: Ramkumar Chinchani <rchincha@cisco.com>
  • Loading branch information
rchincha committed Mar 22, 2024
1 parent 11e0ab2 commit 0dfe129
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 15 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ deploy-ci: manifests
deploy-helm: manifests
helm install kserve-crd charts/kserve-crd/ --wait --timeout 180s
helm install kserve charts/kserve-resources/ --wait --timeout 180s
# deploy a OCI dist spec v1.1.0 registry
helm repo add project-zot http://zotregistry.dev/helm-charts
helm install --set service.port=5000 zot project-zot/zot

undeploy:
kubectl delete -k config/default
Expand Down
107 changes: 107 additions & 0 deletions pkg/agent/storage/oci.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
Copyright 2021 The KServe Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package storage

import (
"fmt"
"io"
"net/http"
"net/url"
"path/filepath"
"strings"
)

type OCIProvider struct {
Client *http.Client
}

func (m *OCIProvider) DownloadModel(modelDir string, modelName string, storageUri string) error {
log.Info("Download model ", "modelName", modelName, "storageUri", storageUri, "modelDir", modelDir)
uri, err := url.Parse(storageUri)
if err != nil {
return fmt.Errorf("unable to parse storage uri: %w", err)
}
OCIDownloader := &OCIDownloader{
StorageUri: storageUri,
ModelDir: modelDir,
ModelName: modelName,
Uri: uri,
}
if err := OCIDownloader.Download(*m.Client); err != nil {
return err
}
return nil
}

type OCIDownloader struct {
StorageUri string
ModelDir string
ModelName string
Uri *url.URL
}

func (h *OCIDownloader) Download(client http.Client) error {
// Create request
req, err := http.NewRequest("GET", h.StorageUri, nil)
if err != nil {
return err
}

// Query request
resp, err := client.Do(req)
if err != nil {
return fmt.Errorf("failed to make a request: %w", err)
}

defer func(Body io.ReadCloser) {
closeErr := Body.Close()
if closeErr != nil {
log.Error(closeErr, "failed to close body")
}
}(resp.Body)
if resp.StatusCode != 200 {
return fmt.Errorf("URI: %s returned a %d response code", h.StorageUri, resp.StatusCode)
}

// Write content into file(s)
contentType := resp.Header.Get("Content-type")
fileDirectory := filepath.Join(h.ModelDir, h.ModelName)

if strings.Contains(contentType, "application/zip") {
if err := extractZipFiles(resp.Body, fileDirectory); err != nil {
return err
}
} else if strings.Contains(contentType, "application/x-tar") || strings.Contains(contentType, "application/x-gtar") ||
strings.Contains(contentType, "application/x-gzip") || strings.Contains(contentType, "application/gzip") {
if err := extractTarFiles(resp.Body, fileDirectory); err != nil {
return err
}
} else {
paths := strings.Split(h.Uri.Path, "/")
fileName := paths[len(paths)-1]
fileFullName := filepath.Join(fileDirectory, fileName)
file, err := createNewFile(fileFullName)
if err != nil {
return err
}
if _, err = io.Copy(file, resp.Body); err != nil {
return fmt.Errorf("unable to copy file content: %w", err)
}
}

return nil
}
4 changes: 3 additions & 1 deletion pkg/agent/storage/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@ const (
//File Protocol = "file://"
HTTPS Protocol = "https://"
HTTP Protocol = "http://"
// OCI dist spec v1.1.0
OCI Protocol = "oci://"
)

var SupportedProtocols = []Protocol{S3, GCS, HTTPS, HTTP}
var SupportedProtocols = []Protocol{S3, GCS, HTTPS, HTTP, OCI}

func GetAllProtocol() (protocols []string) {
for _, protocol := range SupportedProtocols {
Expand Down
5 changes: 5 additions & 0 deletions pkg/agent/storage/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,11 @@ func GetProvider(providers map[Protocol]Provider, protocol Protocol) (Provider,
providers[HTTP] = &HTTPSProvider{
Client: httpsClient,
}
case OCI:
httpsClient := &http.Client{}
providers[OCI] = &OCIProvider{
Client: httpsClient,
}
}

return providers[protocol], nil
Expand Down
14 changes: 0 additions & 14 deletions pkg/apis/serving/v1beta1/openapi_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pkg/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ func TestIsPrefixSupported(t *testing.T) {
"GCS://",
"HTTP://",
"HTTPS://",
"OCI://",
}
scenarios := map[string]struct {
input string
Expand Down

0 comments on commit 0dfe129

Please sign in to comment.