diff --git a/pkg/distribution/internal/bundle/parse.go b/pkg/distribution/internal/bundle/parse.go index cf769cea7..b2b640f1f 100644 --- a/pkg/distribution/internal/bundle/parse.go +++ b/pkg/distribution/internal/bundle/parse.go @@ -101,13 +101,25 @@ func parseRuntimeConfig(rootDir string) (types.ModelConfig, error) { // top-level match in modelDir and falls back to a recursive search when needed. // Hidden files are ignored. func findModelFile(modelDir, ext string) (string, error) { + return findModelFileExcluding(modelDir, ext, nil) +} + +// findModelFileExcluding finds a supported model file by extension, skipping +// any file for which the exclude function returns true. It prefers a top-level +// match in modelDir and falls back to a recursive search when needed. Hidden +// files are ignored. +func findModelFileExcluding(modelDir, ext string, exclude func(string) bool) (string, error) { pattern := filepath.Join(modelDir, "[^.]*"+ext) paths, err := filepath.Glob(pattern) if err != nil { return "", fmt.Errorf("find %s files: %w", ext, err) } - if len(paths) > 0 { - return filepath.Base(paths[0]), nil + for _, p := range paths { + name := filepath.Base(p) + if exclude != nil && exclude(name) { + continue + } + return name, nil } var firstFound string @@ -126,6 +138,9 @@ func findModelFile(modelDir, ext string) (string, error) { strings.HasPrefix(info.Name(), ".") { return nil } + if exclude != nil && exclude(info.Name()) { + return nil + } rel, relErr := filepath.Rel(modelDir, path) if relErr != nil { @@ -146,8 +161,10 @@ func findModelFile(modelDir, ext string) (string, error) { } func findGGUFFile(modelDir string) (string, error) { - // GGUF files are optional. - return findModelFile(modelDir, ".gguf") + // GGUF files are optional. Use findModelFileExcluding to skip mmproj + // files that also carry a .gguf extension (common in CNCF ModelPack + // format where mmproj files are named e.g. "mmproj-BF16.gguf"). + return findModelFileExcluding(modelDir, ".gguf", isMMProjFilePath) } func findSafetensorsFile(modelDir string) (string, error) { @@ -161,17 +178,39 @@ func findDDUFFile(modelDir string) (string, error) { } func findMultiModalProjectorFile(modelDir string) (string, error) { + // First, look for files with the traditional .mmproj extension. mmprojPaths, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.mmproj")) if err != nil { return "", err } - if len(mmprojPaths) == 0 { - return "", nil - } if len(mmprojPaths) > 1 { return "", fmt.Errorf("found multiple .mmproj files, but only 1 is supported") } - return filepath.Base(mmprojPaths[0]), nil + if len(mmprojPaths) == 1 { + return filepath.Base(mmprojPaths[0]), nil + } + + // Fall back to detecting mmproj files with a .gguf extension. + // CNCF ModelPack format packages mmproj files as generic weight layers, + // preserving their original filename (e.g., "mmproj-BF16.gguf"). + ggufPaths, err := filepath.Glob(filepath.Join(modelDir, "[^.]*.gguf")) + if err != nil { + return "", err + } + var mmprojGGUF []string + for _, p := range ggufPaths { + if isMMProjFilePath(p) { + mmprojGGUF = append(mmprojGGUF, p) + } + } + if len(mmprojGGUF) > 1 { + return "", fmt.Errorf("found multiple mmproj .gguf files, but only 1 is supported") + } + if len(mmprojGGUF) == 1 { + return filepath.Base(mmprojGGUF[0]), nil + } + + return "", nil } func findChatTemplateFile(modelDir string) (string, error) { diff --git a/pkg/distribution/internal/bundle/parse_test.go b/pkg/distribution/internal/bundle/parse_test.go index eb601e3d3..b6a334115 100644 --- a/pkg/distribution/internal/bundle/parse_test.go +++ b/pkg/distribution/internal/bundle/parse_test.go @@ -362,6 +362,95 @@ func TestParse_WithNestedDDUF(t *testing.T) { } } +func TestParse_WithCNCFMMProjGGUF(t *testing.T) { + // Simulate a cached CNCF ModelPack bundle where the mmproj file has a + // .gguf extension (e.g., "mmproj-BF16.gguf") instead of .mmproj. + tempDir := t.TempDir() + modelDir := filepath.Join(tempDir, ModelSubdir) + if err := os.MkdirAll(modelDir, 0755); err != nil { + t.Fatalf("Failed to create model directory: %v", err) + } + + // Create the main GGUF weight and the mmproj .gguf file + if err := os.WriteFile(filepath.Join(modelDir, "gemma-4-E2B-it-UD-Q4_K_XL.gguf"), []byte("main model"), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(modelDir, "mmproj-BF16.gguf"), []byte("mmproj"), 0644); err != nil { + t.Fatal(err) + } + + // Create a valid config.json + cfg := types.Config{Format: types.FormatGGUF} + cfgBytes, marshalErr := json.Marshal(cfg) + if marshalErr != nil { + t.Fatal(marshalErr) + } + if err := os.WriteFile(filepath.Join(tempDir, "config.json"), cfgBytes, 0644); err != nil { + t.Fatal(err) + } + + bundle, err := Parse(tempDir) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + // The main GGUF file should NOT be the mmproj + if bundle.ggufFile == "" { + t.Fatal("Expected ggufFile to be set") + } + if isMMProjFilePath(bundle.ggufFile) { + t.Errorf("ggufFile should not be an mmproj file, got: %s", bundle.ggufFile) + } + + // The mmproj should be detected from mmproj-BF16.gguf + if bundle.mmprojPath == "" { + t.Fatal("Expected mmprojPath to be set for CNCF mmproj .gguf file") + } + if bundle.mmprojPath != "mmproj-BF16.gguf" { + t.Errorf("mmprojPath = %q, want %q", bundle.mmprojPath, "mmproj-BF16.gguf") + } + if bundle.MMPROJPath() == "" { + t.Fatal("Expected MMPROJPath() to return non-empty path") + } +} + +func TestParse_WithTraditionalMMProj(t *testing.T) { + // Ensure traditional .mmproj files still work + tempDir := t.TempDir() + modelDir := filepath.Join(tempDir, ModelSubdir) + if err := os.MkdirAll(modelDir, 0755); err != nil { + t.Fatal(err) + } + + if err := os.WriteFile(filepath.Join(modelDir, "model.gguf"), []byte("main"), 0644); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join(modelDir, "model.mmproj"), []byte("mmproj"), 0644); err != nil { + t.Fatal(err) + } + + cfg := types.Config{Format: types.FormatGGUF} + cfgBytes, marshalErr := json.Marshal(cfg) + if marshalErr != nil { + t.Fatal(marshalErr) + } + if err := os.WriteFile(filepath.Join(tempDir, "config.json"), cfgBytes, 0644); err != nil { + t.Fatal(err) + } + + bundle, err := Parse(tempDir) + if err != nil { + t.Fatalf("Parse failed: %v", err) + } + + if bundle.mmprojPath != "model.mmproj" { + t.Errorf("mmprojPath = %q, want %q", bundle.mmprojPath, "model.mmproj") + } + if bundle.ggufFile != "model.gguf" { + t.Errorf("ggufFile = %q, want %q", bundle.ggufFile, "model.gguf") + } +} + func TestParse_WithBothFormats(t *testing.T) { // Create a temporary directory for the test bundle tempDir := t.TempDir() diff --git a/pkg/distribution/internal/bundle/unpack.go b/pkg/distribution/internal/bundle/unpack.go index f1bec8e61..00efee6ff 100644 --- a/pkg/distribution/internal/bundle/unpack.go +++ b/pkg/distribution/internal/bundle/unpack.go @@ -40,6 +40,8 @@ func isV02Model(model types.ModelArtifact) bool { } // isCNCFModel checks if the model was packaged using the CNCF ModelPack format. +// Detection uses the manifest's artifactType field, which is required by the +// CNCF model-spec ("application/vnd.cncf.model.manifest.v1+json"). // CNCF ModelPack uses a layer-per-file approach with filepath annotations, // similar to V0.2, so it can be unpacked using UnpackFromLayers. func isCNCFModel(model types.ModelArtifact) bool { @@ -47,7 +49,7 @@ func isCNCFModel(model types.ModelArtifact) bool { if err != nil { return false } - return manifest.Config.MediaType == modelpack.MediaTypeModelConfigV1 + return manifest.ArtifactType == modelpack.ArtifactTypeModelManifest } // unpackLegacy is the original V0.1 unpacking logic that uses model.GGUFPaths(), model.SafetensorsPaths(), etc. @@ -859,6 +861,17 @@ func updateBundleFieldsFromLayer(bundle *Bundle, mediaType oci.MediaType, relPat default: // Handle format-agnostic CNCF weight types (e.g., .raw) by checking the model config format. if modelpack.IsModelPackGenericWeightMediaType(string(mediaType)) { + // Detect mmproj files by filepath annotation before treating as + // regular weight. In CNCF format, mmproj files share the same + // generic weight media type as model weights and can only be + // distinguished by their filename (same heuristic used by + // huggingface.isMMProjFile). + if isMMProjFilePath(relPath) { + if bundle.mmprojPath == "" { + bundle.mmprojPath = relPath + } + return + } switch types.Format(modelFormat) { case types.FormatGGUF: if bundle.ggufFile == "" { @@ -877,6 +890,13 @@ func updateBundleFieldsFromLayer(bundle *Bundle, mediaType oci.MediaType, relPat } } +// isMMProjFilePath checks if a filepath refers to a multimodal projector file +// by looking for "mmproj" in the filename (case-insensitive). This is the same +// heuristic used by huggingface.isMMProjFile. +func isMMProjFilePath(path string) bool { + return strings.Contains(strings.ToLower(filepath.Base(path)), "mmproj") +} + // unpackGenericFileLayers unpacks layers with MediaTypeModelFile using their filepath annotation. // This supports the new format where each config file is packaged as an individual layer // with its relative path preserved in the annotation. diff --git a/pkg/distribution/internal/bundle/unpack_test.go b/pkg/distribution/internal/bundle/unpack_test.go index 61c49d32e..1bdd2e7d9 100644 --- a/pkg/distribution/internal/bundle/unpack_test.go +++ b/pkg/distribution/internal/bundle/unpack_test.go @@ -231,15 +231,102 @@ func TestUpdateBundleFieldsFromLayer_CNCFMediaTypes(t *testing.T) { } } +func TestUpdateBundleFieldsFromLayer_CNCFMMProj(t *testing.T) { + tests := []struct { + name string + mediaType oci.MediaType + relPath string + modelFormat string + expectMMProj string + expectGGUF string + }{ + { + name: "CNCF generic weight raw with mmproj filename", + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + relPath: "mmproj-BF16.gguf", + modelFormat: string(types.FormatGGUF), + expectMMProj: "mmproj-BF16.gguf", + expectGGUF: "", + }, + { + name: "CNCF generic weight raw with mmproj in path (case-insensitive)", + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + relPath: "MMProj-model-f16.gguf", + modelFormat: string(types.FormatGGUF), + expectMMProj: "MMProj-model-f16.gguf", + expectGGUF: "", + }, + { + name: "CNCF generic weight raw with regular GGUF (not mmproj)", + mediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + relPath: "model-Q4_K_XL.gguf", + modelFormat: string(types.FormatGGUF), + expectMMProj: "", + expectGGUF: "model-Q4_K_XL.gguf", + }, + { + name: "Docker mmproj media type still works", + mediaType: types.MediaTypeMultimodalProjector, + relPath: "model.mmproj", + modelFormat: "", + expectMMProj: "model.mmproj", + expectGGUF: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bundle := &Bundle{} + updateBundleFieldsFromLayer(bundle, tt.mediaType, tt.relPath, tt.modelFormat) + + if bundle.mmprojPath != tt.expectMMProj { + t.Errorf("mmprojPath = %q, want %q", bundle.mmprojPath, tt.expectMMProj) + } + if bundle.ggufFile != tt.expectGGUF { + t.Errorf("ggufFile = %q, want %q", bundle.ggufFile, tt.expectGGUF) + } + }) + } +} + +func TestIsMMProjFilePath(t *testing.T) { + tests := []struct { + path string + expected bool + }{ + {"mmproj-BF16.gguf", true}, + {"mmproj-model-f16.gguf", true}, + {"mmproj-model-f32.gguf", true}, + {"MMProj-model.gguf", true}, + {"MMPROJ-model.gguf", true}, + {"some/path/mmproj-BF16.gguf", true}, + {"model-Q4_K_XL.gguf", false}, + {"model.gguf", false}, + {"model.safetensors", false}, + {"", false}, + } + + for _, tt := range tests { + t.Run(tt.path, func(t *testing.T) { + got := isMMProjFilePath(tt.path) + if got != tt.expected { + t.Errorf("isMMProjFilePath(%q) = %v, want %v", tt.path, got, tt.expected) + } + }) + } +} + func TestIsCNCFModel(t *testing.T) { tests := []struct { name string configMediaType oci.MediaType + artifactType string expected bool }{ { name: "CNCF ModelPack config V1", configMediaType: modelpack.MediaTypeModelConfigV1, + artifactType: modelpack.ArtifactTypeModelManifest, expected: true, }, { @@ -256,9 +343,10 @@ func TestIsCNCFModel(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Create a minimal artifact with the given config media type + // Create a minimal artifact with the given config media type and artifact type artifact := &testArtifactWithConfigMediaType{ configMediaType: tt.configMediaType, + artifactType: tt.artifactType, } result := isCNCFModel(artifact) if result != tt.expected { @@ -271,10 +359,12 @@ func TestIsCNCFModel(t *testing.T) { // testArtifactWithConfigMediaType is a minimal ModelArtifact for testing isCNCFModel/isV02Model. type testArtifactWithConfigMediaType struct { configMediaType oci.MediaType + artifactType string } func (a *testArtifactWithConfigMediaType) Manifest() (*oci.Manifest, error) { return &oci.Manifest{ + ArtifactType: a.artifactType, Config: oci.Descriptor{ MediaType: a.configMediaType, }, @@ -581,6 +671,54 @@ func TestUnpackFromLayers_PathSanitizationRejectsCollapsedPath(t *testing.T) { } } +func TestUnpackFromLayers_CNCFModelWithMMProj(t *testing.T) { + // Simulate the exact scenario from the bug: a CNCF ModelPack model with + // two layers using MediaTypeWeightRaw — one is the main GGUF weight and + // the other is the mmproj file. Both share the same media type; mmproj + // detection relies on the filepath annotation containing "mmproj". + artifact := testutil.NewModelPackArtifact( + t, + modelpack.Model{ + Config: modelpack.ModelConfig{Format: string(types.FormatGGUF)}, + }, + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.gguf"), + RelativePath: "gemma-4-E2B-it-UD-Q4_K_XL.gguf", + MediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + }, + testutil.LayerSpec{ + Path: filepath.Join("..", "..", "assets", "dummy.mmproj"), + RelativePath: "mmproj-BF16.gguf", + MediaType: oci.MediaType(modelpack.MediaTypeWeightRaw), + }, + ) + + bundleRoot := t.TempDir() + bundle, err := UnpackFromLayers(bundleRoot, artifact) + if err != nil { + t.Fatalf("UnpackFromLayers failed: %v", err) + } + + // The main weight file should be tracked as ggufFile. + if bundle.ggufFile != "gemma-4-E2B-it-UD-Q4_K_XL.gguf" { + t.Errorf("ggufFile = %q, want %q", bundle.ggufFile, "gemma-4-E2B-it-UD-Q4_K_XL.gguf") + } + if _, err := os.Stat(bundle.GGUFPath()); err != nil { + t.Fatalf("Expected GGUF file to exist at %s, got: %v", bundle.GGUFPath(), err) + } + + // The mmproj file should be tracked as mmprojPath. + if bundle.mmprojPath != "mmproj-BF16.gguf" { + t.Errorf("mmprojPath = %q, want %q", bundle.mmprojPath, "mmproj-BF16.gguf") + } + if bundle.MMPROJPath() == "" { + t.Fatal("Expected MMPROJPath() to return non-empty path") + } + if _, err := os.Stat(bundle.MMPROJPath()); err != nil { + t.Fatalf("Expected mmproj file to exist at %s, got: %v", bundle.MMPROJPath(), err) + } +} + func TestValidatePathWithinDirectory_RealFilesystem(t *testing.T) { // Create a temporary directory structure baseDir := t.TempDir() diff --git a/pkg/distribution/oci/remote/extract_diffids_test.go b/pkg/distribution/oci/remote/extract_diffids_test.go new file mode 100644 index 000000000..bffc88d4b --- /dev/null +++ b/pkg/distribution/oci/remote/extract_diffids_test.go @@ -0,0 +1,195 @@ +package remote + +import ( + "encoding/json" + "testing" + + "github.com/docker/model-runner/pkg/distribution/oci" +) + +// Valid 64-char hex strings for SHA256 test hashes. +const ( + hexA = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + hexB = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb" + hexC = "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc" + hex1 = "1111111111111111111111111111111111111111111111111111111111111111" + hex2 = "2222222222222222222222222222222222222222222222222222222222222222" +) + +func TestExtractDiffIDs_DockerFormat(t *testing.T) { + config := map[string]interface{}{ + "rootfs": map[string]interface{}{ + "type": "rootfs", + "diff_ids": []string{"sha256:" + hexA, "sha256:" + hexB, "sha256:" + hexC}, + }, + } + raw, err := json.Marshal(config) + if err != nil { + t.Fatalf("marshal config: %v", err) + } + + tests := []struct { + name string + index int + wantHex string + wantOk bool + }{ + {"first layer", 0, hexA, true}, + {"second layer", 1, hexB, true}, + {"last layer", 2, hexC, true}, + {"index out of bounds", 3, "", false}, + {"negative index", -1, "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, err := extractDiffIDs(raw, tt.index) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.wantOk { + if h == (oci.Hash{}) { + t.Fatal("expected non-zero hash, got zero") + } + if h.Hex != tt.wantHex { + t.Errorf("expected hex %q, got %q", tt.wantHex, h.Hex) + } + } else { + if h != (oci.Hash{}) { + t.Errorf("expected zero hash, got %v", h) + } + } + }) + } +} + +func TestExtractDiffIDs_CNCFModelPackFormat(t *testing.T) { + config := map[string]interface{}{ + "modelfs": map[string]interface{}{ + "type": "layers", + "diffIds": []string{"sha256:" + hex1, "sha256:" + hex2}, + }, + } + raw, err := json.Marshal(config) + if err != nil { + t.Fatalf("marshal config: %v", err) + } + + tests := []struct { + name string + index int + wantHex string + wantOk bool + }{ + {"first layer", 0, hex1, true}, + {"second layer", 1, hex2, true}, + {"index out of bounds", 2, "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h, err := extractDiffIDs(raw, tt.index) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if tt.wantOk { + if h == (oci.Hash{}) { + t.Fatal("expected non-zero hash, got zero") + } + if h.Hex != tt.wantHex { + t.Errorf("expected hex %q, got %q", tt.wantHex, h.Hex) + } + } else { + if h != (oci.Hash{}) { + t.Errorf("expected zero hash, got %v", h) + } + } + }) + } +} + +func TestExtractDiffIDs_DockerTakesPrecedence(t *testing.T) { + // When both rootfs and modelfs are present, Docker format should win. + config := map[string]interface{}{ + "rootfs": map[string]interface{}{ + "type": "rootfs", + "diff_ids": []string{"sha256:" + hexA}, + }, + "modelfs": map[string]interface{}{ + "type": "layers", + "diffIds": []string{"sha256:" + hex1}, + }, + } + raw, err := json.Marshal(config) + if err != nil { + t.Fatalf("marshal config: %v", err) + } + + h, err := extractDiffIDs(raw, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h.Hex != hexA { + t.Errorf("expected Docker format to take precedence (hex %q), got %q", hexA, h.Hex) + } +} + +func TestExtractDiffIDs_EmptyConfig(t *testing.T) { + raw := []byte(`{}`) + h, err := extractDiffIDs(raw, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h != (oci.Hash{}) { + t.Errorf("expected zero hash for empty config, got %v", h) + } +} + +func TestExtractDiffIDs_InvalidJSON(t *testing.T) { + raw := []byte(`not valid json`) + _, err := extractDiffIDs(raw, 0) + if err == nil { + t.Fatal("expected error for invalid JSON, got nil") + } +} + +func TestExtractDiffIDs_MalformedRootFS(t *testing.T) { + // rootfs exists but is not an object — should fall through gracefully. + config := map[string]interface{}{ + "rootfs": "not an object", + } + raw, err := json.Marshal(config) + if err != nil { + t.Fatalf("marshal config: %v", err) + } + + h, err := extractDiffIDs(raw, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h != (oci.Hash{}) { + t.Errorf("expected zero hash for malformed rootfs, got %v", h) + } +} + +func TestExtractDiffIDs_MalformedModelFS(t *testing.T) { + // modelfs exists but diffIds contains invalid hashes (not valid SHA256). + config := map[string]interface{}{ + "modelfs": map[string]interface{}{ + "type": "layers", + "diffIds": []string{"not-a-valid-hash"}, + }, + } + raw, err := json.Marshal(config) + if err != nil { + t.Fatalf("marshal config: %v", err) + } + + h, err := extractDiffIDs(raw, 0) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if h != (oci.Hash{}) { + t.Errorf("expected zero hash for malformed modelfs hash, got %v", h) + } +} diff --git a/pkg/distribution/oci/remote/remote.go b/pkg/distribution/oci/remote/remote.go index 987412071..a908aa483 100644 --- a/pkg/distribution/oci/remote/remote.go +++ b/pkg/distribution/oci/remote/remote.go @@ -652,19 +652,62 @@ func (l *remoteLayer) Digest() (oci.Hash, error) { // DiffID returns the uncompressed layer digest. // For remote layers, we look up the diff ID from the image config. +// Supports both Docker format (rootfs.diff_ids) and CNCF ModelPack format +// (modelfs.diffIds). func (l *remoteLayer) DiffID() (oci.Hash, error) { - // Get the config file to look up the diff ID - config, err := l.image.ConfigFile() + raw, err := l.image.RawConfigFile() if err != nil { - return oci.Hash{}, fmt.Errorf("getting config file for diff ID lookup: %w", err) + return oci.Hash{}, fmt.Errorf("getting raw config for diff ID lookup: %w", err) } - // Check if the layer index is within bounds of the diff IDs - if l.index < 0 || l.index >= len(config.RootFS.DiffIDs) { - return l.desc.Digest, nil // Fallback to digest if diff ID not available + // Try to extract diffIds from the raw config generically, so we support + // both Docker format (rootfs.diff_ids) and CNCF ModelPack (modelfs.diffIds). + diffIDs, err := extractDiffIDs(raw, l.index) + if err != nil || diffIDs == (oci.Hash{}) { + // Fall back to the descriptor digest (works for uncompressed layers). + return l.desc.Digest, nil + } + return diffIDs, nil +} + +// extractDiffIDs parses a raw config blob and returns the DiffID at the given +// layer index. It tries Docker format (rootfs.diff_ids) first, then CNCF +// ModelPack format (modelfs.diffIds). +func extractDiffIDs(raw []byte, index int) (oci.Hash, error) { + // Parse as a generic map to support both config formats. + var parsed map[string]json.RawMessage + if err := json.Unmarshal(raw, &parsed); err != nil { + return oci.Hash{}, err + } + + // Try Docker format: rootfs.diff_ids + if rootfsRaw, ok := parsed["rootfs"]; ok { + var rootfs struct { + DiffIDs []oci.Hash `json:"diff_ids"` + } + if err := json.Unmarshal(rootfsRaw, &rootfs); err == nil { + if index >= 0 && index < len(rootfs.DiffIDs) { + return rootfs.DiffIDs[index], nil + } + } } - return config.RootFS.DiffIDs[l.index], nil + // Try CNCF ModelPack format: modelfs.diffIds + if modelfsRaw, ok := parsed["modelfs"]; ok { + var modelfs struct { + DiffIDs []string `json:"diffIds"` + } + if err := json.Unmarshal(modelfsRaw, &modelfs); err == nil { + if index >= 0 && index < len(modelfs.DiffIDs) { + h, err := oci.NewHash(modelfs.DiffIDs[index]) + if err == nil { + return h, nil + } + } + } + } + + return oci.Hash{}, nil } // Compressed returns the compressed layer contents. @@ -880,8 +923,15 @@ func Write(ref reference.Reference, img oci.Image, w io.Writer, opts ...Option) return fmt.Errorf("getting config name: %w", err) } + // Use the config media type from the manifest rather than a hardcoded value, + // so that both Docker-format and CNCF ModelPack artifacts are pushed + // with the correct media type. + pushManifest, err := img.Manifest() + if err != nil { + return fmt.Errorf("getting manifest for config media type: %w", err) + } configDesc := v1.Descriptor{ - MediaType: "application/vnd.docker.container.image.v1+json", + MediaType: string(pushManifest.Config.MediaType), Digest: godigest.Digest(configName.String()), Size: int64(len(rawConfig)), }