diff --git a/go.mod b/go.mod index eefcace..b2cedda 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/docker/oci go 1.25.0 require ( + github.com/klauspost/compress v1.18.6 github.com/opencontainers/go-digest v1.0.0 github.com/opencontainers/image-spec v1.1.1 github.com/rogpeppe/go-internal v1.14.1 diff --git a/go.sum b/go.sum index be5ce8d..fb6b14f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/klauspost/compress v1.18.6 h1:2jupLlAwFm95+YDR+NwD2MEfFO9d4z4Prjl1XXDjuao= +github.com/klauspost/compress v1.18.6/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= diff --git a/ocifs/README.md b/ocifs/README.md new file mode 100644 index 0000000..99623f3 --- /dev/null +++ b/ocifs/README.md @@ -0,0 +1,99 @@ +# ocifs + +Package `ocifs` provides an [`io/fs.FS`](https://pkg.go.dev/io/fs#FS) backed by an OCI image. It downloads each layer once, builds compressed-stream and tar indices, and then serves any file in the merged image as a random-access read — without ever unpacking layers to disk. + +## How it works + +An OCI image is a stack of compressed tar layers. `ocifs` composes four sub-packages into a single overlay filesystem: + +``` +Registry + └─ blobra — range-request io.ReaderAt over each compressed layer blob + └─ gzipr / zstdr — checkpoint index → random-access decompressed io.ReaderAt + └─ tarfs — tar entry index → io/fs.FS over the decompressed stream + └─ ocifs (overlay) — merges all layers with OCI whiteout semantics +``` + +On the first call to `New`, every layer blob is streamed once: the decompressor builds its checkpoint index while the tar scanner builds its entry index. Subsequent `Open` / `ReadAt` calls fetch only the compressed bytes covering the requested file — no full-layer downloads. + +## Quick start + +```go +import ( + "context" + "io/fs" + + "github.com/docker/oci" + "github.com/docker/oci/ocifs" +) + +reg := oci.New(/* ... */) +ociFS, err := ocifs.New(ctx, reg, "library/alpine", "latest") +if err != nil { + return err +} +defer ociFS.Close() + +// Walk the entire filesystem (root is always ".", never "/"). +fs.WalkDir(ociFS, ".", func(path string, d fs.DirEntry, err error) error { + fmt.Println(path) + return err +}) + +// Read a single file. +data, err := ociFS.ReadFile("etc/os-release") +``` + +## Persisting the index + +Scanning large layers on every startup is expensive. Save the index after the first `New` call and reuse it with `NewWithIndex`: + +```go +// First run: build and persist. +ociFS, _ := ocifs.New(ctx, reg, repo, ref) +idx := ociFS.ImageIndex() +f, _ := os.Create("index.json") +idx.Encode(f) + +// Subsequent runs: restore from disk. +f, _ := os.Open("index.json") +idx, _ := ocifs.DecodeImageIndex(f) +ociFS, err := ocifs.NewWithIndex(ctx, reg, repo, ref, idx) +if errors.Is(err, ocifs.ErrIndexStale) { + // Image was re-pushed; fall back to full scan. + ociFS, err = ocifs.New(ctx, reg, repo, ref) +} +``` + +`NewWithIndex` re-fetches the manifest to verify that layer digests still match the persisted index before accepting it. If the tag was re-pushed or a layer changed, `ErrIndexStale` is returned. + +## Key types + +| Type | Purpose | +|------|---------| +| `FS` | The `io/fs.FS` implementation. Also implements `ReadDirFS`, `StatFS`, `ReadFileFS`, and `Lstat`. | +| `ImageIndex` | Serializable bundle of per-layer checkpoint and tar indices. | +| `LayerIndex` | Per-layer index: digest, media type, compressed size, decompressor index, tar entry list. | + +## Overlay semantics + +Layers are merged bottom-to-top following the [OCI image layer spec](https://github.com/opencontainers/image-spec/blob/main/layer.md): + +- **Whiteout** (`.wh.`): deletes `` from lower layers. +- **Opaque whiteout** (`.wh..wh..opq`): hides the entire directory contents from lower layers, keeping only what the current layer adds. +- **Hardlinks**: resolved to the target entry's content at open time. +- **Symlinks**: followed up to 255 hops; circular chains return `ErrSymlinkLoop`. + +Whiteout markers are excluded from `ReadDir` results. Whiteout targets that are `"."`, `".."`, or contain `/` are silently ignored. + +## Supported layer types + +| Media type | Decompressor | +|-----------|-------------| +| `application/vnd.oci.image.layer.v1.tar+gzip` | `gzipr` | +| `application/vnd.docker.image.rootfs.diff.tar.gzip` | `gzipr` | +| `application/vnd.oci.image.layer.v1.tar+zstd` | `zstdr` | + +## Concurrency + +`FS` is safe for concurrent use after construction. Each `Open` call may issue range requests to the registry; the context passed to `New` governs all such requests. Call `Close` to release decompressor pools when the `FS` is no longer needed. diff --git a/ocifs/blobra/README.md b/ocifs/blobra/README.md new file mode 100644 index 0000000..f49b1bf --- /dev/null +++ b/ocifs/blobra/README.md @@ -0,0 +1,50 @@ +# blobra + +Package `blobra` adapts an OCI registry's range-request API to [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt), allowing decompressors (`gzipr`, `zstdr`) to fetch arbitrary byte ranges of a compressed layer blob without downloading the entire blob. + +## How it works + +Each `ReadAt(p, off)` call translates to a single `GetBlobRange` registry request for the half-open byte range `[off, off+len(p))`. The response body is read in full into `p`. No buffering or caching is performed — callers are responsible for requesting only the ranges they need. + +## Usage + +```go +import ( + "context" + "github.com/docker/oci" + "github.com/docker/oci/ocifs/blobra" +) + +// desc is an oci.Descriptor with Digest, MediaType, and Size set. +ra := blobra.New(ctx, registry, "library/alpine", desc) + +// ra now satisfies io.ReaderAt over the compressed blob. +buf := make([]byte, 512) +n, err := ra.ReadAt(buf, 1024) // fetches bytes [1024, 1536) from the registry +``` + +`blobra.New` performs no I/O. Registry requests are only issued by `ReadAt`. + +## Key types + +### `BlobRanger` + +The narrow interface `blobra` requires from the registry. Any `oci.Interface` implementation satisfies it: + +```go +type BlobRanger interface { + GetBlobRange(ctx context.Context, repo string, digest oci.Digest, offset0, offset1 int64) (oci.BlobReader, error) +} +``` + +Using this interface rather than the full `oci.Interface` makes `blobra.Reader` easy to test with a small fake. + +### `Reader` + +`*Reader` implements `io.ReaderAt`. It also exposes `Size() int64` (from the descriptor) and `Descriptor() oci.Descriptor`. + +## `io.ReaderAt` contract notes + +- `ReadAt(p, off)` where `off >= desc.Size` returns `(0, io.EOF)` without issuing a request. +- A read that is clamped by end-of-blob (i.e. `off + len(p) > desc.Size`) returns `(n, io.EOF)` where `n < len(p)`. This signals a legitimate short read at end-of-blob. +- If the registry returns fewer bytes than the requested range, the error is surfaced as `io.ErrUnexpectedEOF` (never `io.EOF`), distinguishing a protocol violation from a legitimate end-of-blob. diff --git a/ocifs/blobra/blobra.go b/ocifs/blobra/blobra.go new file mode 100644 index 0000000..12ae2ba --- /dev/null +++ b/ocifs/blobra/blobra.go @@ -0,0 +1,106 @@ +// Package blobra adapts a range-capable OCI blob source to [io.ReaderAt]. +// +// The package depends on [oci.Digest] and [oci.BlobReader] from the parent +// oci package, but intentionally does not depend on the sealed +// [oci.Interface] type. Instead it consumes the narrow [BlobRanger] +// interface, which any [oci.Interface] implementation satisfies and which +// test fakes can implement directly. +package blobra + +import ( + "context" + "fmt" + "io" + + "github.com/docker/oci" +) + +// Compile-time interface check. +var _ io.ReaderAt = (*Reader)(nil) + +// BlobRanger is the subset of [oci.Interface] that [Reader] requires. +// Any [oci.Interface] implementation satisfies it. +// +// The range is half-open: [offset0, offset1). The response body contains +// exactly offset1-offset0 bytes when both endpoints are within the blob. +type BlobRanger interface { + GetBlobRange(ctx context.Context, repo string, digest oci.Digest, offset0, offset1 int64) (oci.BlobReader, error) +} + +// Reader serves [io.ReaderAt] calls against a single OCI blob using +// range requests against an underlying [BlobRanger]. +type Reader struct { + ctx context.Context + ranger BlobRanger + repo string + desc oci.Descriptor +} + +// New returns a Reader that serves ReadAt calls via range requests on +// ranger. desc.Size is the blob size in bytes; it is used to detect +// end-of-blob without issuing a probe request. No I/O is performed by +// this constructor. +func New(ctx context.Context, ranger BlobRanger, repo string, desc oci.Descriptor) *Reader { + return &Reader{ + ctx: ctx, + ranger: ranger, + repo: repo, + desc: desc, + } +} + +// Descriptor returns the descriptor that the Reader was constructed with. +func (r *Reader) Descriptor() oci.Descriptor { + return r.desc +} + +// Size returns the size of the underlying blob in bytes. +func (r *Reader) Size() int64 { + return r.desc.Size +} + +// ReadAt implements [io.ReaderAt]. See the package documentation for the +// full contract; in particular, a server-side truncation of a clamped +// range request surfaces as [io.ErrUnexpectedEOF] rather than [io.EOF] +// so that downstream consumers (gzipr, zstdr) cannot mistake a protocol +// violation for legitimate end-of-blob. +func (r *Reader) ReadAt(p []byte, off int64) (int, error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 { + return 0, fmt.Errorf("blobra: negative offset %d", off) + } + if off >= r.desc.Size { + return 0, io.EOF + } + + n := int64(len(p)) + if remaining := r.desc.Size - off; n > remaining { + n = remaining + } + + br, err := r.ranger.GetBlobRange(r.ctx, r.repo, r.desc.Digest, off, off+n) + if err != nil { + return 0, err + } + defer br.Close() + + if _, err := io.ReadFull(br, p[:n]); err != nil { + // io.ReadFull maps a fully empty stream to io.EOF and a partial + // stream to io.ErrUnexpectedEOF. Both indicate the registry + // returned fewer bytes than the clamped range demanded — a + // protocol violation. Surface it uniformly as + // io.ErrUnexpectedEOF; never propagate io.EOF here, which is + // reserved for off >= desc.Size. + if err == io.EOF || err == io.ErrUnexpectedEOF { + return 0, io.ErrUnexpectedEOF + } + return 0, err + } + + if int64(len(p)) > n { + return int(n), io.EOF + } + return int(n), nil +} diff --git a/ocifs/blobra/blobra_test.go b/ocifs/blobra/blobra_test.go new file mode 100644 index 0000000..0e7958a --- /dev/null +++ b/ocifs/blobra/blobra_test.go @@ -0,0 +1,298 @@ +package blobra_test + +import ( + "bytes" + "context" + "errors" + "io" + "testing" + + "github.com/opencontainers/go-digest" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/oci" + "github.com/docker/oci/ocifs/blobra" + "github.com/docker/oci/ocimem" +) + +const ( + testRepo = "example/blob" + testMediaType = "application/octet-stream" +) + +func pushBlob(t *testing.T, reg *ocimem.Registry, data []byte) oci.Descriptor { + t.Helper() + desc := oci.Descriptor{ + MediaType: testMediaType, + Size: int64(len(data)), + Digest: digest.FromBytes(data), + } + out, err := reg.PushBlob(context.Background(), testRepo, desc, bytes.NewReader(data)) + require.NoError(t, err) + return out +} + +func TestNewAndAccessors(t *testing.T) { + reg := ocimem.New() + data := []byte("hello, world!") + desc := pushBlob(t, reg, data) + + r := blobra.New(context.Background(), reg, testRepo, desc) + assert.Equal(t, desc, r.Descriptor()) + assert.Equal(t, int64(len(data)), r.Size()) +} + +func TestReadAtAgainstOCIMem(t *testing.T) { + reg := ocimem.New() + data := []byte("0123456789ABCDEF") + desc := pushBlob(t, reg, data) + r := blobra.New(context.Background(), reg, testRepo, desc) + + tests := []struct { + name string + bufLen int + off int64 + wantN int + wantErr error + wantBuf []byte + }{ + {name: "EmptyBuffer", bufLen: 0, off: 0, wantN: 0, wantErr: nil, wantBuf: []byte{}}, + {name: "EmptyBufferAtEnd", bufLen: 0, off: int64(len(data)), wantN: 0, wantErr: nil, wantBuf: []byte{}}, + {name: "FullReadFromStart", bufLen: len(data), off: 0, wantN: len(data), wantErr: nil, wantBuf: data}, + {name: "PartialFromStart", bufLen: 4, off: 0, wantN: 4, wantErr: nil, wantBuf: data[:4]}, + {name: "InteriorRead", bufLen: 5, off: 3, wantN: 5, wantErr: nil, wantBuf: data[3:8]}, + {name: "ReadEndingExactlyAtBoundary", bufLen: 6, off: 10, wantN: 6, wantErr: nil, wantBuf: data[10:]}, + {name: "ReadPastEndShortFill", bufLen: 8, off: 10, wantN: 6, wantErr: io.EOF, wantBuf: data[10:]}, + {name: "OffsetAtSize", bufLen: 4, off: int64(len(data)), wantN: 0, wantErr: io.EOF, wantBuf: []byte{}}, + {name: "OffsetBeyondSize", bufLen: 4, off: int64(len(data) + 100), wantN: 0, wantErr: io.EOF, wantBuf: []byte{}}, + {name: "SingleByteAtEnd", bufLen: 1, off: int64(len(data) - 1), wantN: 1, wantErr: nil, wantBuf: data[len(data)-1:]}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + buf := make([]byte, tc.bufLen) + n, err := r.ReadAt(buf, tc.off) + assert.Equal(t, tc.wantN, n) + assert.Equal(t, tc.wantErr, err) + assert.Equal(t, tc.wantBuf, buf[:n]) + }) + } +} + +func TestReadAtNegativeOffset(t *testing.T) { + data := []byte("abcd") + fr := newFake(data) + r := blobra.New(context.Background(), fr, testRepo, fr.desc) + + n, err := r.ReadAt(make([]byte, 1), -1) + assert.Equal(t, 0, n) + assert.ErrorContains(t, err, "negative offset") + assert.Equal(t, 0, fr.calls, "negative offsets should not issue range requests") +} + +// fakeRanger is a minimal BlobRanger used to exercise paths that ocimem +// cannot produce, such as a truncated or empty range response. +type fakeRanger struct { + data []byte + desc oci.Descriptor + + err error + truncateTo int // if > 0 and < requested, return only this many bytes + emptyResponse bool // return zero-byte body regardless of range + gotOff0 int64 // recorded for assertions + gotOff1 int64 + gotRepo string + gotDigest oci.Digest + calls int +} + +func (f *fakeRanger) GetBlobRange(ctx context.Context, repo string, dig oci.Digest, o0, o1 int64) (oci.BlobReader, error) { + f.calls++ + f.gotOff0 = o0 + f.gotOff1 = o1 + f.gotRepo = repo + f.gotDigest = dig + if f.err != nil { + return nil, f.err + } + var body []byte + switch { + case f.emptyResponse: + body = nil + case f.truncateTo > 0 && f.truncateTo < int(o1-o0): + body = f.data[o0 : o0+int64(f.truncateTo)] + default: + body = f.data[o0:o1] + } + return ocimem.NewBytesReader(body, f.desc), nil +} + +func newFake(data []byte) *fakeRanger { + return &fakeRanger{ + data: data, + desc: oci.Descriptor{ + MediaType: testMediaType, + Size: int64(len(data)), + Digest: digest.FromBytes(data), + }, + } +} + +func TestReadAtProtocolViolations(t *testing.T) { + data := []byte("abcdefghijklmnop") + + t.Run("EmptyResponseBecomesUnexpectedEOF", func(t *testing.T) { + fr := newFake(data) + fr.emptyResponse = true + r := blobra.New(context.Background(), fr, testRepo, fr.desc) + + buf := make([]byte, 4) + n, err := r.ReadAt(buf, 0) + assert.Equal(t, 0, n) + assert.Equal(t, io.ErrUnexpectedEOF, err) + assert.Equal(t, 1, fr.calls) + }) + + t.Run("TruncatedResponseBecomesUnexpectedEOF", func(t *testing.T) { + fr := newFake(data) + fr.truncateTo = 2 + r := blobra.New(context.Background(), fr, testRepo, fr.desc) + + buf := make([]byte, 8) + n, err := r.ReadAt(buf, 0) + assert.Equal(t, 0, n) + assert.Equal(t, io.ErrUnexpectedEOF, err) + }) + + t.Run("TruncatedNearEndStillUnexpectedEOF", func(t *testing.T) { + // Even when the request would be clamped at end-of-blob, a short + // response from the server must surface as ErrUnexpectedEOF, not EOF. + fr := newFake(data) + fr.truncateTo = 1 + r := blobra.New(context.Background(), fr, testRepo, fr.desc) + + buf := make([]byte, 8) + n, err := r.ReadAt(buf, int64(len(data)-4)) + assert.Equal(t, 0, n) + assert.Equal(t, io.ErrUnexpectedEOF, err) + }) +} + +func TestReadAtRangerErrorPropagated(t *testing.T) { + data := []byte("abcdefgh") + sentinel := errors.New("network down") + + fr := newFake(data) + fr.err = sentinel + r := blobra.New(context.Background(), fr, testRepo, fr.desc) + + buf := make([]byte, 4) + n, err := r.ReadAt(buf, 0) + assert.Equal(t, 0, n) + assert.ErrorIs(t, err, sentinel) +} + +func TestReadAtContextCanceled(t *testing.T) { + // A non-EOF, non-ErrUnexpectedEOF error from the body reader path can + // only arrive in ReadFull. We synthesise this with a reader whose Read + // returns context.Canceled. + data := []byte("abcdefgh") + fr := &cancelingRanger{ + desc: oci.Descriptor{ + MediaType: testMediaType, + Size: int64(len(data)), + Digest: digest.FromBytes(data), + }, + } + r := blobra.New(context.Background(), fr, testRepo, fr.desc) + buf := make([]byte, 4) + n, err := r.ReadAt(buf, 0) + assert.Equal(t, 0, n) + assert.ErrorIs(t, err, context.Canceled) + assert.True(t, fr.closed, "underlying BlobReader must be closed on error") +} + +func TestReadAtDoesNotCallRangerForTrivialCases(t *testing.T) { + data := []byte("abcd") + fr := newFake(data) + r := blobra.New(context.Background(), fr, testRepo, fr.desc) + + // len(p) == 0 must not issue a range request. + n, err := r.ReadAt(nil, 0) + require.NoError(t, err) + assert.Equal(t, 0, n) + + // off >= size must not issue a range request. + n, err = r.ReadAt(make([]byte, 4), int64(len(data))) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) + + assert.Equal(t, 0, fr.calls, "no range request should be issued for trivial cases") +} + +func TestReadAtClampsRangeUpperBound(t *testing.T) { + // When the read crosses the end-of-blob, the ranger must be called + // with offset1 == desc.Size, never beyond it. + data := []byte("abcdefghij") + fr := newFake(data) + r := blobra.New(context.Background(), fr, testRepo, fr.desc) + + buf := make([]byte, 100) + n, err := r.ReadAt(buf, 7) + assert.Equal(t, 3, n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, []byte("hij"), buf[:n]) + assert.Equal(t, int64(7), fr.gotOff0) + assert.Equal(t, int64(10), fr.gotOff1) +} + +func TestReadAtPassesCorrectArgsToRanger(t *testing.T) { + data := []byte("abcdefghij") + fr := newFake(data) + r := blobra.New(context.Background(), fr, "some/repo", fr.desc) + + buf := make([]byte, 4) + _, err := r.ReadAt(buf, 2) + require.NoError(t, err) + assert.Equal(t, "some/repo", fr.gotRepo) + assert.Equal(t, fr.desc.Digest, fr.gotDigest) + assert.Equal(t, int64(2), fr.gotOff0) + assert.Equal(t, int64(6), fr.gotOff1) +} + +// cancelingRanger returns a BlobReader whose Read returns context.Canceled +// on first call. Used to exercise the "other error" branch in ReadAt. +type cancelingRanger struct { + desc oci.Descriptor + closed bool +} + +func (c *cancelingRanger) GetBlobRange(ctx context.Context, repo string, dig oci.Digest, o0, o1 int64) (oci.BlobReader, error) { + return &cancelingReader{parent: c, desc: c.desc}, nil +} + +type cancelingReader struct { + parent *cancelingRanger + desc oci.Descriptor +} + +func (c *cancelingReader) Read(p []byte) (int, error) { + return 0, context.Canceled +} + +func (c *cancelingReader) Close() error { + c.parent.closed = true + return nil +} + +func (c *cancelingReader) Descriptor() oci.Descriptor { + return c.desc +} + +// Compile-time checks: *Reader is an io.ReaderAt, and oci.Interface +// satisfies blobra.BlobRanger directly (no adapter needed). +var ( + _ io.ReaderAt = (*blobra.Reader)(nil) + _ blobra.BlobRanger = (oci.Interface)(nil) +) diff --git a/ocifs/gzipr/README.md b/ocifs/gzipr/README.md new file mode 100644 index 0000000..1b258ff --- /dev/null +++ b/ocifs/gzipr/README.md @@ -0,0 +1,69 @@ +# gzipr + +Package `gzipr` provides sequential gzip scanning and random-access decompressed reading via DEFLATE block-boundary checkpoints. + +## The problem it solves + +A gzip stream must normally be decompressed from the beginning to reach any given byte. For large OCI layers this makes random file access prohibitively expensive. `gzipr` solves this by recording the full DEFLATE decompressor state (bit buffer + 32 KB sliding-window history) at periodic block boundaries during an initial scan. Later, any byte range of the decompressed stream can be reached by seeking to the nearest checkpoint in the compressed stream and decompressing only forward from there. + +## Two-phase usage + +### Phase 1 — Scan (once per blob) + +`Scan` makes a single sequential pass, decompressing to an `io.Writer` and emitting checkpoints at a configurable decompressed-byte interval (default: 1 MiB). + +```go +import "github.com/docker/oci/ocifs/gzipr" + +idx, err := gzipr.Scan(compressedReader, io.Discard) +// idx.Checkpoints holds the checkpoint sequence. +// idx.Size is the total decompressed length. +``` + +The `Index` is JSON-serializable; persist it alongside the blob to avoid re-scanning on every start. + +### Phase 2 — Random-access reads + +```go +// Build from a fresh scan: +reader, err := gzipr.NewReader(blobReaderAt, blobSize) + +// Or rebuild from a persisted index (no I/O at construction time): +reader, err := gzipr.NewReaderWithIndex(blobReaderAt, idx, blobSize) +defer reader.Close() + +buf := make([]byte, 4096) +n, err := reader.ReadAt(buf, decompressedOffset) +``` + +`ReadAt` finds the highest checkpoint at or before `decompressedOffset`, fetches only the compressed bytes between that checkpoint and the next, resumes the DEFLATE decompressor from the saved state, and discards forward to the exact byte. + +## Key types + +| Type | Purpose | +|------|---------| +| `Index` | Checkpoint sequence + total decompressed size. JSON-serializable. | +| `Checkpoint` | DEFLATE decompressor state at a block boundary: compressed offset (`In`), decompressed offset (`Out`), bit buffer (`B`, `NB`), sliding-window history (`Hist`). | +| `Reader` | `io.ReaderAt` + `io.Closer` over the decompressed stream. | + +## Options + +```go +gzipr.WithSpan(2 << 20) // checkpoint every 2 MiB (default: 1 MiB) +gzipr.WithMaxReaders(4) // allow up to 4 concurrent ReadAt calls (default: 8) +``` + +Smaller spans improve seek latency at the cost of more checkpoint memory (~32 KiB of sliding-window history per checkpoint). + +## Concurrency + +`Reader.ReadAt` is safe for concurrent use. Concurrent decompressions are bounded by `WithMaxReaders`; callers that exceed the cap block until a slot is available. `Close` unblocks any waiting callers with `ErrClosed`. + +## Internal DEFLATE fork + +`gzipr/internal/flate` is a fork of the Go standard library's `compress/flate` package extended with two entry points: + +- `NewReaderCallback` — invokes a callback at each DEFLATE block boundary during `Scan`. +- `NewReaderResume` — resumes decompression from a saved `(Hist, B, NB)` state for `ReadAt`. + +These additions are the only changes from the upstream stdlib code. diff --git a/ocifs/gzipr/gzipr.go b/ocifs/gzipr/gzipr.go new file mode 100644 index 0000000..3917a49 --- /dev/null +++ b/ocifs/gzipr/gzipr.go @@ -0,0 +1,96 @@ +// Package gzipr provides sequential gzip scanning and random-access +// decompressed reading via DEFLATE block-boundary checkpoints. +// +// A checkpoint captures the full DEFLATE decompressor state (bit buffer +// plus 32 KB sliding-window history) at a block boundary, which allows +// resuming decompression at that exact point. Scan walks a gzip stream +// once, decompressing to an [io.Writer] and emitting checkpoints at a +// configurable decompressed-byte interval. The resulting [Index] is +// then used by [NewReaderWithIndex] to back an [io.ReaderAt] whose +// ReadAt fetches only the compressed span needed for the requested +// decompressed range. +package gzipr + +import ( + "errors" +) + +// ErrInvalidFormat is returned by [Scan], [NewReader], and +// [Reader.ReadAt] when the underlying gzip or DEFLATE stream cannot be +// parsed. +var ErrInvalidFormat = errors.New("gzipr: invalid compressed format") + +// ErrClosed is returned by [Reader.ReadAt] after [Reader.Close] has +// been called. +var ErrClosed = errors.New("gzipr: reader has been closed") + +// Checkpoint records the DEFLATE decompressor state at a block boundary. +// +// In is the compressed-byte offset to seek to for resume. Out is the +// total decompressed bytes produced before the captured boundary. +// B and NB hold the bit buffer (up to 32 bits, NB ≤ 32). Hist contains +// the in-order sliding-window history (at most 32 KB) that primes the +// LZ77 dictionary on resume. +type Checkpoint struct { + In int64 `json:"in"` + Out int64 `json:"out"` + B uint32 `json:"b"` + NB uint `json:"nb"` + Hist []byte `json:"hist"` +} + +// Index holds the checkpoint sequence for one gzip stream plus the +// total decompressed size. +// +// Checkpoints is sorted by Out ascending. It may be empty (non-nil +// zero-length slice) when the stream decompresses to fewer bytes than +// the configured checkpoint span. Size is the total decompressed +// length of the stream. +type Index struct { + Checkpoints []*Checkpoint `json:"checkpoints"` + Size int64 `json:"size"` +} + +// Option configures [Scan], [NewReader], and [NewReaderWithIndex]. +type Option func(*config) + +type config struct { + span int64 + maxReaders int +} + +func defaultConfig() *config { + return &config{ + span: 1 << 20, + maxReaders: 8, + } +} + +func applyOpts(opts []Option) *config { + c := defaultConfig() + for _, o := range opts { + o(c) + } + if c.span <= 0 { + c.span = 1 << 20 + } + if c.maxReaders <= 0 { + c.maxReaders = 8 + } + return c +} + +// WithSpan sets the minimum decompressed-byte interval between +// checkpoints emitted by [Scan] and [NewReader]. The default is 1 MiB. +// Smaller spans yield finer-grained random access at the cost of +// proportionally more checkpoint memory (~32 KiB per checkpoint). +func WithSpan(decompressedBytes int64) Option { + return func(c *config) { c.span = decompressedBytes } +} + +// WithMaxReaders sets the maximum number of concurrent decompression +// streams a [Reader] will run for [Reader.ReadAt]. Calls beyond the cap +// block until a slot is returned. The default is 8. +func WithMaxReaders(n int) Option { + return func(c *config) { c.maxReaders = n } +} diff --git a/ocifs/gzipr/gzipr_test.go b/ocifs/gzipr/gzipr_test.go new file mode 100644 index 0000000..0e0f271 --- /dev/null +++ b/ocifs/gzipr/gzipr_test.go @@ -0,0 +1,528 @@ +package gzipr + +import ( + "bytes" + "compress/gzip" + "crypto/rand" + "errors" + "io" + mrand "math/rand" + "sync" + "sync/atomic" + "testing" +) + +func makeGzip(t *testing.T, payload []byte) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write(payload); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + return buf.Bytes() +} + +// makeGzipFlush writes payload in chunks separated by Flush calls so +// the stream contains many DEFLATE block boundaries — an exercise of +// the checkpoint emission path. +func makeGzipFlush(t *testing.T, payload []byte, chunk int) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + for off := 0; off < len(payload); off += chunk { + end := off + chunk + if end > len(payload) { + end = len(payload) + } + if _, err := gw.Write(payload[off:end]); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gw.Flush(); err != nil { + t.Fatalf("gzip flush: %v", err) + } + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + return buf.Bytes() +} + +func randPayload(n int) []byte { + b := make([]byte, n) + _, _ = rand.Read(b) + return b +} + +func TestScanRoundTrip(t *testing.T) { + payload := randPayload(64 * 1024) + gz := makeGzip(t, payload) + + var out bytes.Buffer + idx, err := Scan(bytes.NewReader(gz), &out) + if err != nil { + t.Fatalf("Scan: %v", err) + } + if !bytes.Equal(out.Bytes(), payload) { + t.Fatal("decompressed output mismatch") + } + if idx.Size != int64(len(payload)) { + t.Fatalf("idx.Size = %d; want %d", idx.Size, len(payload)) + } + if idx.Checkpoints == nil { + t.Fatal("Checkpoints must be non-nil even when empty") + } +} + +func TestScanShortStreamHasEmptyCheckpoints(t *testing.T) { + payload := []byte("hello world") + gz := makeGzip(t, payload) + idx, err := Scan(bytes.NewReader(gz), io.Discard) + if err != nil { + t.Fatalf("Scan: %v", err) + } + if idx.Checkpoints == nil { + t.Fatal("Checkpoints must be non-nil empty slice") + } + if len(idx.Checkpoints) != 0 { + t.Fatalf("len(Checkpoints) = %d; want 0", len(idx.Checkpoints)) + } +} + +func TestScanCheckpointsEmittedAtSpan(t *testing.T) { + payload := randPayload(2 * 1024 * 1024) + gz := makeGzipFlush(t, payload, 64*1024) + idx, err := Scan(bytes.NewReader(gz), io.Discard, WithSpan(256*1024)) + if err != nil { + t.Fatalf("Scan: %v", err) + } + if len(idx.Checkpoints) == 0 { + t.Fatal("expected checkpoints with frequent flush + small span") + } + for i := 1; i < len(idx.Checkpoints); i++ { + if idx.Checkpoints[i].Out <= idx.Checkpoints[i-1].Out { + t.Fatalf("checkpoints not Out-sorted at %d", i) + } + if idx.Checkpoints[i].Out-idx.Checkpoints[i-1].Out < 256*1024 { + t.Fatalf("checkpoint span violation at %d: %d -> %d", + i, idx.Checkpoints[i-1].Out, idx.Checkpoints[i].Out) + } + } +} + +func TestScanWriteErrorUnwrapped(t *testing.T) { + payload := randPayload(8 * 1024) + gz := makeGzip(t, payload) + + sentinel := errors.New("boom") + w := &errWriter{err: sentinel, after: 10} + _, err := Scan(bytes.NewReader(gz), w) + if !errors.Is(err, sentinel) { + t.Fatalf("expected sentinel write error, got %v", err) + } + if errors.Is(err, ErrInvalidFormat) { + t.Fatalf("write error must not be wrapped as ErrInvalidFormat: %v", err) + } +} + +func TestScanInvalidFormat(t *testing.T) { + bad := []byte("not gzip at all") + _, err := Scan(bytes.NewReader(bad), io.Discard) + if !errors.Is(err, ErrInvalidFormat) { + t.Fatalf("expected ErrInvalidFormat, got %v", err) + } +} + +func TestScanClosedPipe(t *testing.T) { + payload := randPayload(8 * 1024) + gz := makeGzip(t, payload) + pr, pw := io.Pipe() + pr.CloseWithError(io.ErrClosedPipe) + _, err := Scan(bytes.NewReader(gz), pw) + if !errors.Is(err, io.ErrClosedPipe) { + t.Fatalf("expected io.ErrClosedPipe, got %v", err) + } +} + +type errWriter struct { + w bytes.Buffer + err error + after int +} + +func (e *errWriter) Write(p []byte) (int, error) { + if e.w.Len()+len(p) > e.after { + _, _ = e.w.Write(p) + return len(p), e.err + } + return e.w.Write(p) +} + +func TestReaderRandomAccess(t *testing.T) { + payload := randPayload(2 * 1024 * 1024) + gz := makeGzipFlush(t, payload, 32*1024) + + idx, err := Scan(bytes.NewReader(gz), io.Discard, WithSpan(64*1024)) + if err != nil { + t.Fatalf("Scan: %v", err) + } + if len(idx.Checkpoints) < 8 { + t.Fatalf("expected several checkpoints, got %d", len(idx.Checkpoints)) + } + + r := NewReaderWithIndex(bytes.NewReader(gz), idx, int64(len(gz))) + defer r.Close() + + if r.Size() != int64(len(payload)) { + t.Fatalf("Size() = %d; want %d", r.Size(), len(payload)) + } + + rng := mrand.New(mrand.NewSource(0xC0FFEE)) + for i := 0; i < 100; i++ { + off := rng.Int63n(int64(len(payload))) + max := int64(len(payload)) - off + if max > 64*1024 { + max = 64 * 1024 + } + n := rng.Int63n(max) + 1 + got := make([]byte, n) + nn, err := r.ReadAt(got, off) + if err != nil && err != io.EOF { + t.Fatalf("ReadAt(%d, %d): %v", n, off, err) + } + if int64(nn) != n { + t.Fatalf("ReadAt(%d, %d) returned %d", n, off, nn) + } + want := payload[off : off+n] + if !bytes.Equal(got, want) { + t.Fatalf("ReadAt(%d, %d) mismatch", n, off) + } + } +} + +func TestReaderClampedReadEOF(t *testing.T) { + payload := randPayload(100 * 1024) + gz := makeGzipFlush(t, payload, 16*1024) + idx, err := Scan(bytes.NewReader(gz), io.Discard, WithSpan(16*1024)) + if err != nil { + t.Fatalf("Scan: %v", err) + } + r := NewReaderWithIndex(bytes.NewReader(gz), idx, int64(len(gz))) + defer r.Close() + + want := payload[len(payload)-100:] + got := make([]byte, 200) + n, err := r.ReadAt(got, int64(len(payload)-100)) + if err != io.EOF { + t.Fatalf("expected io.EOF, got %v", err) + } + if n != 100 { + t.Fatalf("n = %d; want 100", n) + } + if !bytes.Equal(got[:n], want) { + t.Fatal("clamped tail data mismatch") + } +} + +func TestReaderExactBoundaryFullRead(t *testing.T) { + payload := randPayload(100 * 1024) + gz := makeGzipFlush(t, payload, 16*1024) + idx, err := Scan(bytes.NewReader(gz), io.Discard, WithSpan(16*1024)) + if err != nil { + t.Fatalf("Scan: %v", err) + } + r := NewReaderWithIndex(bytes.NewReader(gz), idx, int64(len(gz))) + defer r.Close() + + got := make([]byte, 100) + n, err := r.ReadAt(got, int64(len(payload)-100)) + if err != nil { + t.Fatalf("expected nil err, got %v", err) + } + if n != 100 { + t.Fatalf("n = %d; want 100", n) + } +} + +func TestReaderEmpty(t *testing.T) { + payload := randPayload(8 * 1024) + gz := makeGzip(t, payload) + idx, err := Scan(bytes.NewReader(gz), io.Discard) + if err != nil { + t.Fatalf("Scan: %v", err) + } + r := NewReaderWithIndex(bytes.NewReader(gz), idx, int64(len(gz))) + defer r.Close() + + n, err := r.ReadAt(nil, 0) + if err != nil || n != 0 { + t.Fatalf("ReadAt(empty) = (%d, %v); want (0, nil)", n, err) + } + n, err = r.ReadAt(make([]byte, 1), int64(len(payload))) + if err != io.EOF || n != 0 { + t.Fatalf("ReadAt past end = (%d, %v); want (0, io.EOF)", n, err) + } +} + +func TestReaderClosed(t *testing.T) { + payload := randPayload(8 * 1024) + gz := makeGzip(t, payload) + idx, err := Scan(bytes.NewReader(gz), io.Discard) + if err != nil { + t.Fatalf("Scan: %v", err) + } + r := NewReaderWithIndex(bytes.NewReader(gz), idx, int64(len(gz))) + r.Close() + + _, err = r.ReadAt(make([]byte, 8), 0) + if !errors.Is(err, ErrClosed) { + t.Fatalf("expected ErrClosed, got %v", err) + } + // Idempotent close + if err := r.Close(); err != nil { + t.Fatalf("second Close: %v", err) + } +} + +func TestNewReader(t *testing.T) { + payload := randPayload(512 * 1024) + gz := makeGzipFlush(t, payload, 16*1024) + r, err := NewReader(bytes.NewReader(gz), int64(len(gz)), WithSpan(64*1024)) + if err != nil { + t.Fatalf("NewReader: %v", err) + } + defer r.Close() + + if r.Size() != int64(len(payload)) { + t.Fatalf("Size = %d; want %d", r.Size(), len(payload)) + } + got := make([]byte, len(payload)) + n, err := r.ReadAt(got, 0) + if err != nil { + t.Fatalf("ReadAt: %v", err) + } + if n != len(payload) || !bytes.Equal(got, payload) { + t.Fatal("data mismatch") + } +} + +func TestIndexEncodeDecode(t *testing.T) { + payload := randPayload(256 * 1024) + gz := makeGzipFlush(t, payload, 16*1024) + idx, err := Scan(bytes.NewReader(gz), io.Discard, WithSpan(64*1024)) + if err != nil { + t.Fatalf("Scan: %v", err) + } + var buf bytes.Buffer + if err := idx.Encode(&buf); err != nil { + t.Fatalf("Encode: %v", err) + } + got, err := DecodeIndex(&buf) + if err != nil { + t.Fatalf("DecodeIndex: %v", err) + } + if got.Size != idx.Size { + t.Fatalf("Size mismatch") + } + if len(got.Checkpoints) != len(idx.Checkpoints) { + t.Fatalf("Checkpoints length mismatch") + } + for i, c := range got.Checkpoints { + o := idx.Checkpoints[i] + if c.In != o.In || c.Out != o.Out || c.B != o.B || c.NB != o.NB || !bytes.Equal(c.Hist, o.Hist) { + t.Fatalf("Checkpoint[%d] mismatch", i) + } + } + + r := NewReaderWithIndex(bytes.NewReader(gz), got, int64(len(gz))) + defer r.Close() + out := make([]byte, 1024) + n, err := r.ReadAt(out, 100*1024) + if err != nil || n != 1024 { + t.Fatalf("ReadAt after decode = (%d, %v)", n, err) + } + if !bytes.Equal(out, payload[100*1024:100*1024+1024]) { + t.Fatal("data mismatch after Encode/Decode round-trip") + } +} + +func TestReaderConcurrent(t *testing.T) { + payload := randPayload(1024 * 1024) + gz := makeGzipFlush(t, payload, 16*1024) + idx, err := Scan(bytes.NewReader(gz), io.Discard, WithSpan(32*1024)) + if err != nil { + t.Fatalf("Scan: %v", err) + } + r := NewReaderWithIndex(bytes.NewReader(gz), idx, int64(len(gz))) + defer r.Close() + + const goroutines = 16 + var wg sync.WaitGroup + wg.Add(goroutines) + errCh := make(chan error, goroutines) + for g := 0; g < goroutines; g++ { + seed := int64(g) + go func() { + defer wg.Done() + rng := mrand.New(mrand.NewSource(seed)) + for i := 0; i < 50; i++ { + off := rng.Int63n(int64(len(payload))) + max := int64(len(payload)) - off + if max > 4096 { + max = 4096 + } + n := rng.Int63n(max) + 1 + got := make([]byte, n) + _, err := r.ReadAt(got, off) + if err != nil && err != io.EOF { + errCh <- err + return + } + if !bytes.Equal(got, payload[off:off+n]) { + errCh <- errors.New("data mismatch in goroutine") + return + } + } + }() + } + wg.Wait() + close(errCh) + for e := range errCh { + t.Fatal(e) + } +} + +// countingReaderAt records the byte ranges fetched. +type countingReaderAt struct { + ra io.ReaderAt + mu sync.Mutex + calls int32 + ranges [][2]int64 + bytes int64 +} + +func (c *countingReaderAt) ReadAt(p []byte, off int64) (int, error) { + atomic.AddInt32(&c.calls, 1) + c.mu.Lock() + c.ranges = append(c.ranges, [2]int64{off, int64(len(p))}) + c.bytes += int64(len(p)) + c.mu.Unlock() + return c.ra.ReadAt(p, off) +} + +func TestReaderRandomAccessNoFlush(t *testing.T) { + // Use default gzip (no Flush) with a large repeating-pattern payload + // so the encoder is forced to split into multiple non-byte-aligned + // DEFLATE blocks, exercising the (nb > 0) checkpoint path. + const N = 4 * 1024 * 1024 + payload := make([]byte, N) + for i := range payload { + payload[i] = byte(i*17 + 13) + } + // Make one large random burst in the middle to defeat compression + // homogeneity but keep stream non-trivially compressible. + rng := mrand.New(mrand.NewSource(42)) + for i := N / 4; i < N/2; i++ { + payload[i] = byte(rng.Intn(256)) + } + + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write(payload); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + gz := buf.Bytes() + + idx, err := Scan(bytes.NewReader(gz), io.Discard, WithSpan(64*1024)) + if err != nil { + t.Fatalf("Scan: %v", err) + } + t.Logf("emitted %d checkpoints over %d byte payload", len(idx.Checkpoints), N) + + r := NewReaderWithIndex(bytes.NewReader(gz), idx, int64(len(gz))) + defer r.Close() + + rng2 := mrand.New(mrand.NewSource(7)) + for i := 0; i < 100; i++ { + off := rng2.Int63n(int64(N)) + max := int64(N) - off + if max > 64*1024 { + max = 64 * 1024 + } + n := rng2.Int63n(max) + 1 + got := make([]byte, n) + nn, err := r.ReadAt(got, off) + if err != nil && err != io.EOF { + t.Fatalf("ReadAt(%d, %d): %v", n, off, err) + } + if int64(nn) != n { + t.Fatalf("ReadAt(%d, %d) returned %d", n, off, nn) + } + want := payload[off : off+n] + if !bytes.Equal(got, want) { + // Find first diff for a useful error message. + for k := range got { + if got[k] != want[k] { + t.Fatalf("ReadAt(%d, %d) mismatch at +%d: got %02x want %02x", + n, off, k, got[k], want[k]) + } + } + } + } +} + +func TestScanWithFNAMEHeader(t *testing.T) { + payload := randPayload(64 * 1024) + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + gw.Name = "data.bin" + gw.Comment = "a comment" + if _, err := gw.Write(payload); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + gz := buf.Bytes() + + var out bytes.Buffer + idx, err := Scan(bytes.NewReader(gz), &out) + if err != nil { + t.Fatalf("Scan: %v", err) + } + if !bytes.Equal(out.Bytes(), payload) { + t.Fatal("data mismatch with FNAME/FCOMMENT") + } + if idx.Size != int64(len(payload)) { + t.Fatal("Size mismatch") + } +} + +func TestReaderFetchesOnlyNeededSpan(t *testing.T) { + payload := randPayload(512 * 1024) + gz := makeGzipFlush(t, payload, 16*1024) + idx, err := Scan(bytes.NewReader(gz), io.Discard, WithSpan(32*1024)) + if err != nil { + t.Fatalf("Scan: %v", err) + } + if len(idx.Checkpoints) < 4 { + t.Skipf("not enough checkpoints (%d) to verify range narrowing", len(idx.Checkpoints)) + } + cra := &countingReaderAt{ra: bytes.NewReader(gz)} + r := NewReaderWithIndex(cra, idx, int64(len(gz))) + defer r.Close() + + c := idx.Checkpoints[1] + got := make([]byte, 64) + if _, err := r.ReadAt(got, c.Out); err != nil { + t.Fatalf("ReadAt: %v", err) + } + if cra.bytes >= int64(len(gz)) { + t.Fatalf("fetched %d bytes; whole blob is %d — range not narrowed", cra.bytes, len(gz)) + } +} diff --git a/ocifs/gzipr/index.go b/ocifs/gzipr/index.go new file mode 100644 index 0000000..0a6254a --- /dev/null +++ b/ocifs/gzipr/index.go @@ -0,0 +1,20 @@ +package gzipr + +import ( + "encoding/json" + "io" +) + +// Encode serializes the [Index] as JSON to w. +func (idx *Index) Encode(w io.Writer) error { + return json.NewEncoder(w).Encode(idx) +} + +// DecodeIndex deserializes a JSON-encoded [Index] from r. +func DecodeIndex(r io.Reader) (*Index, error) { + var idx Index + if err := json.NewDecoder(r).Decode(&idx); err != nil { + return nil, err + } + return &idx, nil +} diff --git a/ocifs/gzipr/internal/flate/deflate.go b/ocifs/gzipr/internal/flate/deflate.go new file mode 100644 index 0000000..6697f3a --- /dev/null +++ b/ocifs/gzipr/internal/flate/deflate.go @@ -0,0 +1,743 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "errors" + "fmt" + "io" + "math" +) + +const ( + NoCompression = 0 + BestSpeed = 1 + BestCompression = 9 + DefaultCompression = -1 + + // HuffmanOnly disables Lempel-Ziv match searching and only performs Huffman + // entropy encoding. This mode is useful in compressing data that has + // already been compressed with an LZ style algorithm (e.g. Snappy or LZ4) + // that lacks an entropy encoder. Compression gains are achieved when + // certain bytes in the input stream occur more frequently than others. + // + // Note that HuffmanOnly produces a compressed output that is + // RFC 1951 compliant. That is, any valid DEFLATE decompressor will + // continue to be able to decompress this output. + HuffmanOnly = -2 +) + +const ( + logWindowSize = 15 + windowSize = 1 << logWindowSize + windowMask = windowSize - 1 + + // The LZ77 step produces a sequence of literal tokens and + // pair tokens. The offset is also known as distance. The underlying wire + // format limits the range of lengths and offsets. For example, there are + // 256 legitimate lengths: those in the range [3, 258]. This package's + // compressor uses a higher minimum match length, enabling optimizations + // such as finding matches via 32-bit loads and compares. + baseMatchLength = 3 // The smallest match length per the RFC section 3.2.5 + minMatchLength = 4 // The smallest match length that the compressor actually emits + maxMatchLength = 258 // The largest match length + baseMatchOffset = 1 // The smallest match offset + maxMatchOffset = 1 << 15 // The largest match offset + + // The maximum number of tokens we put into a single flate block, just to + // stop things from getting too large. + maxFlateBlockTokens = 1 << 14 + maxStoreBlockSize = 65535 + hashBits = 17 // After 17 performance degrades + hashSize = 1 << hashBits + hashMask = (1 << hashBits) - 1 + maxHashOffset = 1 << 24 + + skipNever = math.MaxInt32 +) + +type compressionLevel struct { + level, good, lazy, nice, chain, fastSkipHashing int +} + +var levels = []compressionLevel{ + {0, 0, 0, 0, 0, 0}, // NoCompression. + {1, 0, 0, 0, 0, 0}, // BestSpeed uses a custom algorithm; see deflatefast.go. + // For levels 2-3 we don't bother trying with lazy matches. + {2, 4, 0, 16, 8, 5}, + {3, 4, 0, 32, 32, 6}, + // Levels 4-9 use increasingly more lazy matching + // and increasingly stringent conditions for "good enough". + {4, 4, 4, 16, 16, skipNever}, + {5, 8, 16, 32, 32, skipNever}, + {6, 8, 16, 128, 128, skipNever}, + {7, 8, 32, 128, 256, skipNever}, + {8, 32, 128, 258, 1024, skipNever}, + {9, 32, 258, 258, 4096, skipNever}, +} + +type compressor struct { + compressionLevel + + w *huffmanBitWriter + bulkHasher func([]byte, []uint32) + + // compression algorithm + fill func(*compressor, []byte) int // copy data to window + step func(*compressor) // process window + bestSpeed *deflateFast // Encoder for BestSpeed + + // Input hash chains + // hashHead[hashValue] contains the largest inputIndex with the specified hash value + // If hashHead[hashValue] is within the current window, then + // hashPrev[hashHead[hashValue] & windowMask] contains the previous index + // with the same hash value. + chainHead int + hashHead [hashSize]uint32 + hashPrev [windowSize]uint32 + hashOffset int + + // input window: unprocessed data is window[index:windowEnd] + index int + window []byte + windowEnd int + blockStart int // window index where current tokens start + byteAvailable bool // if true, still need to process window[index-1]. + + sync bool // requesting flush + + // queued output tokens + tokens []token + + // deflate state + length int + offset int + maxInsertIndex int + err error + + // hashMatch must be able to contain hashes for the maximum match length. + hashMatch [maxMatchLength - 1]uint32 +} + +func (d *compressor) fillDeflate(b []byte) int { + if d.index >= 2*windowSize-(minMatchLength+maxMatchLength) { + // shift the window by windowSize + copy(d.window, d.window[windowSize:2*windowSize]) + d.index -= windowSize + d.windowEnd -= windowSize + if d.blockStart >= windowSize { + d.blockStart -= windowSize + } else { + d.blockStart = math.MaxInt32 + } + d.hashOffset += windowSize + if d.hashOffset > maxHashOffset { + delta := d.hashOffset - 1 + d.hashOffset -= delta + d.chainHead -= delta + + // Iterate over slices instead of arrays to avoid copying + // the entire table onto the stack (Issue #18625). + for i, v := range d.hashPrev[:] { + if int(v) > delta { + d.hashPrev[i] = uint32(int(v) - delta) + } else { + d.hashPrev[i] = 0 + } + } + for i, v := range d.hashHead[:] { + if int(v) > delta { + d.hashHead[i] = uint32(int(v) - delta) + } else { + d.hashHead[i] = 0 + } + } + } + } + n := copy(d.window[d.windowEnd:], b) + d.windowEnd += n + return n +} + +func (d *compressor) writeBlock(tokens []token, index int) error { + if index > 0 { + var window []byte + if d.blockStart <= index { + window = d.window[d.blockStart:index] + } + d.blockStart = index + d.w.writeBlock(tokens, false, window) + return d.w.err + } + return nil +} + +// fillWindow will fill the current window with the supplied +// dictionary and calculate all hashes. +// This is much faster than doing a full encode. +// Should only be used after a reset. +func (d *compressor) fillWindow(b []byte) { + // Do not fill window if we are in store-only mode. + if d.compressionLevel.level < 2 { + return + } + if d.index != 0 || d.windowEnd != 0 { + panic("internal error: fillWindow called with stale data") + } + + // If we are given too much, cut it. + if len(b) > windowSize { + b = b[len(b)-windowSize:] + } + // Add all to window. + n := copy(d.window, b) + + // Calculate 256 hashes at the time (more L1 cache hits) + loops := (n + 256 - minMatchLength) / 256 + for j := 0; j < loops; j++ { + index := j * 256 + end := index + 256 + minMatchLength - 1 + if end > n { + end = n + } + toCheck := d.window[index:end] + dstSize := len(toCheck) - minMatchLength + 1 + + if dstSize <= 0 { + continue + } + + dst := d.hashMatch[:dstSize] + d.bulkHasher(toCheck, dst) + for i, val := range dst { + di := i + index + hh := &d.hashHead[val&hashMask] + // Get previous value with the same hash. + // Our chain should point to the previous value. + d.hashPrev[di&windowMask] = *hh + // Set the head of the hash chain to us. + *hh = uint32(di + d.hashOffset) + } + } + // Update window information. + d.windowEnd = n + d.index = n +} + +// Try to find a match starting at index whose length is greater than prevSize. +// We only look at chainCount possibilities before giving up. +func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead int) (length, offset int, ok bool) { + minMatchLook := maxMatchLength + if lookahead < minMatchLook { + minMatchLook = lookahead + } + + win := d.window[0 : pos+minMatchLook] + + // We quit when we get a match that's at least nice long + nice := len(win) - pos + if d.nice < nice { + nice = d.nice + } + + // If we've got a match that's good enough, only look in 1/4 the chain. + tries := d.chain + length = prevLength + if length >= d.good { + tries >>= 2 + } + + wEnd := win[pos+length] + wPos := win[pos:] + minIndex := pos - windowSize + + for i := prevHead; tries > 0; tries-- { + if wEnd == win[i+length] { + n := matchLen(win[i:], wPos, minMatchLook) + + if n > length && (n > minMatchLength || pos-i <= 4096) { + length = n + offset = pos - i + ok = true + if n >= nice { + // The match is good enough that we don't try to find a better one. + break + } + wEnd = win[pos+n] + } + } + if i == minIndex { + // hashPrev[i & windowMask] has already been overwritten, so stop now. + break + } + i = int(d.hashPrev[i&windowMask]) - d.hashOffset + if i < minIndex || i < 0 { + break + } + } + return +} + +func (d *compressor) writeStoredBlock(buf []byte) error { + if d.w.writeStoredHeader(len(buf), false); d.w.err != nil { + return d.w.err + } + d.w.writeBytes(buf) + return d.w.err +} + +const hashmul = 0x1e35a7bd + +// hash4 returns a hash representation of the first 4 bytes +// of the supplied slice. +// The caller must ensure that len(b) >= 4. +func hash4(b []byte) uint32 { + return ((uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24) * hashmul) >> (32 - hashBits) +} + +// bulkHash4 will compute hashes using the same +// algorithm as hash4. +func bulkHash4(b []byte, dst []uint32) { + if len(b) < minMatchLength { + return + } + hb := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 + dst[0] = (hb * hashmul) >> (32 - hashBits) + end := len(b) - minMatchLength + 1 + for i := 1; i < end; i++ { + hb = (hb << 8) | uint32(b[i+3]) + dst[i] = (hb * hashmul) >> (32 - hashBits) + } +} + +// matchLen returns the number of matching bytes in a and b +// up to length 'max'. Both slices must be at least 'max' +// bytes in size. +func matchLen(a, b []byte, max int) int { + a = a[:max] + b = b[:len(a)] + for i, av := range a { + if b[i] != av { + return i + } + } + return max +} + +// encSpeed will compress and store the currently added data, +// if enough has been accumulated or we at the end of the stream. +// Any error that occurred will be in d.err +func (d *compressor) encSpeed() { + // We only compress if we have maxStoreBlockSize. + if d.windowEnd < maxStoreBlockSize { + if !d.sync { + return + } + + // Handle small sizes. + if d.windowEnd < 128 { + switch { + case d.windowEnd == 0: + return + case d.windowEnd <= 16: + d.err = d.writeStoredBlock(d.window[:d.windowEnd]) + default: + d.w.writeBlockHuff(false, d.window[:d.windowEnd]) + d.err = d.w.err + } + d.windowEnd = 0 + d.bestSpeed.reset() + return + } + + } + // Encode the block. + d.tokens = d.bestSpeed.encode(d.tokens[:0], d.window[:d.windowEnd]) + + // If we removed less than 1/16th, Huffman compress the block. + if len(d.tokens) > d.windowEnd-(d.windowEnd>>4) { + d.w.writeBlockHuff(false, d.window[:d.windowEnd]) + } else { + d.w.writeBlockDynamic(d.tokens, false, d.window[:d.windowEnd]) + } + d.err = d.w.err + d.windowEnd = 0 +} + +func (d *compressor) initDeflate() { + d.window = make([]byte, 2*windowSize) + d.hashOffset = 1 + d.tokens = make([]token, 0, maxFlateBlockTokens+1) + d.length = minMatchLength - 1 + d.offset = 0 + d.byteAvailable = false + d.index = 0 + d.chainHead = -1 + d.bulkHasher = bulkHash4 +} + +func (d *compressor) deflate() { + if d.windowEnd-d.index < minMatchLength+maxMatchLength && !d.sync { + return + } + + d.maxInsertIndex = d.windowEnd - (minMatchLength - 1) + +Loop: + for { + if d.index > d.windowEnd { + panic("index > windowEnd") + } + lookahead := d.windowEnd - d.index + if lookahead < minMatchLength+maxMatchLength { + if !d.sync { + break Loop + } + if d.index > d.windowEnd { + panic("index > windowEnd") + } + if lookahead == 0 { + // Flush current output block if any. + if d.byteAvailable { + // There is still one pending token that needs to be flushed + d.tokens = append(d.tokens, literalToken(uint32(d.window[d.index-1]))) + d.byteAvailable = false + } + if len(d.tokens) > 0 { + if d.err = d.writeBlock(d.tokens, d.index); d.err != nil { + return + } + d.tokens = d.tokens[:0] + } + break Loop + } + } + if d.index < d.maxInsertIndex { + // Update the hash + hash := hash4(d.window[d.index : d.index+minMatchLength]) + hh := &d.hashHead[hash&hashMask] + d.chainHead = int(*hh) + d.hashPrev[d.index&windowMask] = uint32(d.chainHead) + *hh = uint32(d.index + d.hashOffset) + } + prevLength := d.length + prevOffset := d.offset + d.length = minMatchLength - 1 + d.offset = 0 + minIndex := d.index - windowSize + if minIndex < 0 { + minIndex = 0 + } + + if d.chainHead-d.hashOffset >= minIndex && + (d.fastSkipHashing != skipNever && lookahead > minMatchLength-1 || + d.fastSkipHashing == skipNever && lookahead > prevLength && prevLength < d.lazy) { + if newLength, newOffset, ok := d.findMatch(d.index, d.chainHead-d.hashOffset, minMatchLength-1, lookahead); ok { + d.length = newLength + d.offset = newOffset + } + } + if d.fastSkipHashing != skipNever && d.length >= minMatchLength || + d.fastSkipHashing == skipNever && prevLength >= minMatchLength && d.length <= prevLength { + // There was a match at the previous step, and the current match is + // not better. Output the previous match. + if d.fastSkipHashing != skipNever { + d.tokens = append(d.tokens, matchToken(uint32(d.length-baseMatchLength), uint32(d.offset-baseMatchOffset))) + } else { + d.tokens = append(d.tokens, matchToken(uint32(prevLength-baseMatchLength), uint32(prevOffset-baseMatchOffset))) + } + // Insert in the hash table all strings up to the end of the match. + // index and index-1 are already inserted. If there is not enough + // lookahead, the last two strings are not inserted into the hash + // table. + if d.length <= d.fastSkipHashing { + var newIndex int + if d.fastSkipHashing != skipNever { + newIndex = d.index + d.length + } else { + newIndex = d.index + prevLength - 1 + } + index := d.index + for index++; index < newIndex; index++ { + if index < d.maxInsertIndex { + hash := hash4(d.window[index : index+minMatchLength]) + // Get previous value with the same hash. + // Our chain should point to the previous value. + hh := &d.hashHead[hash&hashMask] + d.hashPrev[index&windowMask] = *hh + // Set the head of the hash chain to us. + *hh = uint32(index + d.hashOffset) + } + } + d.index = index + + if d.fastSkipHashing == skipNever { + d.byteAvailable = false + d.length = minMatchLength - 1 + } + } else { + // For matches this long, we don't bother inserting each individual + // item into the table. + d.index += d.length + } + if len(d.tokens) == maxFlateBlockTokens { + // The block includes the current character + if d.err = d.writeBlock(d.tokens, d.index); d.err != nil { + return + } + d.tokens = d.tokens[:0] + } + } else { + if d.fastSkipHashing != skipNever || d.byteAvailable { + i := d.index - 1 + if d.fastSkipHashing != skipNever { + i = d.index + } + d.tokens = append(d.tokens, literalToken(uint32(d.window[i]))) + if len(d.tokens) == maxFlateBlockTokens { + if d.err = d.writeBlock(d.tokens, i+1); d.err != nil { + return + } + d.tokens = d.tokens[:0] + } + } + d.index++ + if d.fastSkipHashing == skipNever { + d.byteAvailable = true + } + } + } +} + +func (d *compressor) fillStore(b []byte) int { + n := copy(d.window[d.windowEnd:], b) + d.windowEnd += n + return n +} + +func (d *compressor) store() { + if d.windowEnd > 0 && (d.windowEnd == maxStoreBlockSize || d.sync) { + d.err = d.writeStoredBlock(d.window[:d.windowEnd]) + d.windowEnd = 0 + } +} + +// storeHuff compresses and stores the currently added data +// when the d.window is full or we are at the end of the stream. +// Any error that occurred will be in d.err +func (d *compressor) storeHuff() { + if d.windowEnd < len(d.window) && !d.sync || d.windowEnd == 0 { + return + } + d.w.writeBlockHuff(false, d.window[:d.windowEnd]) + d.err = d.w.err + d.windowEnd = 0 +} + +func (d *compressor) write(b []byte) (n int, err error) { + if d.err != nil { + return 0, d.err + } + n = len(b) + for len(b) > 0 { + d.step(d) + b = b[d.fill(d, b):] + if d.err != nil { + return 0, d.err + } + } + return n, nil +} + +func (d *compressor) syncFlush() error { + if d.err != nil { + return d.err + } + d.sync = true + d.step(d) + if d.err == nil { + d.w.writeStoredHeader(0, false) + d.w.flush() + d.err = d.w.err + } + d.sync = false + return d.err +} + +func (d *compressor) init(w io.Writer, level int) (err error) { + d.w = newHuffmanBitWriter(w) + + switch { + case level == NoCompression: + d.window = make([]byte, maxStoreBlockSize) + d.fill = (*compressor).fillStore + d.step = (*compressor).store + case level == HuffmanOnly: + d.window = make([]byte, maxStoreBlockSize) + d.fill = (*compressor).fillStore + d.step = (*compressor).storeHuff + case level == BestSpeed: + d.compressionLevel = levels[level] + d.window = make([]byte, maxStoreBlockSize) + d.fill = (*compressor).fillStore + d.step = (*compressor).encSpeed + d.bestSpeed = newDeflateFast() + d.tokens = make([]token, maxStoreBlockSize) + case level == DefaultCompression: + level = 6 + fallthrough + case 2 <= level && level <= 9: + d.compressionLevel = levels[level] + d.initDeflate() + d.fill = (*compressor).fillDeflate + d.step = (*compressor).deflate + default: + return fmt.Errorf("flate: invalid compression level %d: want value in range [-2, 9]", level) + } + return nil +} + +func (d *compressor) reset(w io.Writer) { + d.w.reset(w) + d.sync = false + d.err = nil + switch d.compressionLevel.level { + case NoCompression: + d.windowEnd = 0 + case BestSpeed: + d.windowEnd = 0 + d.tokens = d.tokens[:0] + d.bestSpeed.reset() + default: + d.chainHead = -1 + clear(d.hashHead[:]) + clear(d.hashPrev[:]) + d.hashOffset = 1 + d.index, d.windowEnd = 0, 0 + d.blockStart, d.byteAvailable = 0, false + d.tokens = d.tokens[:0] + d.length = minMatchLength - 1 + d.offset = 0 + d.maxInsertIndex = 0 + } +} + +func (d *compressor) close() error { + if d.err == errWriterClosed { + return nil + } + if d.err != nil { + return d.err + } + d.sync = true + d.step(d) + if d.err != nil { + return d.err + } + if d.w.writeStoredHeader(0, true); d.w.err != nil { + return d.w.err + } + d.w.flush() + if d.w.err != nil { + return d.w.err + } + d.err = errWriterClosed + return nil +} + +// NewWriter returns a new [Writer] compressing data at the given level. +// Following zlib, levels range from 1 ([BestSpeed]) to 9 ([BestCompression]); +// higher levels typically run slower but compress more. Level 0 +// ([NoCompression]) does not attempt any compression; it only adds the +// necessary DEFLATE framing. +// Level -1 ([DefaultCompression]) uses the default compression level. +// Level -2 ([HuffmanOnly]) will use Huffman compression only, giving +// a very fast compression for all types of input, but sacrificing considerable +// compression efficiency. +// +// If level is in the range [-2, 9] then the error returned will be nil. +// Otherwise the error returned will be non-nil. +func NewWriter(w io.Writer, level int) (*Writer, error) { + var dw Writer + if err := dw.d.init(w, level); err != nil { + return nil, err + } + return &dw, nil +} + +// NewWriterDict is like [NewWriter] but initializes the new +// [Writer] with a preset dictionary. The returned [Writer] behaves +// as if the dictionary had been written to it without producing +// any compressed output. The compressed data written to w +// can only be decompressed by a reader initialized with the +// same dictionary (see [NewReaderDict]). +func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) { + dw := &dictWriter{w} + zw, err := NewWriter(dw, level) + if err != nil { + return nil, err + } + zw.d.fillWindow(dict) + zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method. + return zw, nil +} + +type dictWriter struct { + w io.Writer +} + +func (w *dictWriter) Write(b []byte) (n int, err error) { + return w.w.Write(b) +} + +var errWriterClosed = errors.New("flate: closed writer") + +// A Writer takes data written to it and writes the compressed +// form of that data to an underlying writer (see [NewWriter]). +type Writer struct { + d compressor + dict []byte +} + +// Write writes data to w, which will eventually write the +// compressed form of data to its underlying writer. +func (w *Writer) Write(data []byte) (n int, err error) { + return w.d.write(data) +} + +// Flush flushes any pending data to the underlying writer. +// It is useful mainly in compressed network protocols, to ensure that +// a remote reader has enough data to reconstruct a packet. +// Flush does not return until the data has been written. +// Calling Flush when there is no pending data still causes the [Writer] +// to emit a sync marker of at least 4 bytes. +// If the underlying writer returns an error, Flush returns that error. +// +// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH. +func (w *Writer) Flush() error { + // For more about flushing: + // https://www.bolet.org/~pornin/deflate-flush.html + return w.d.syncFlush() +} + +// Close flushes and closes the writer. +func (w *Writer) Close() error { + return w.d.close() +} + +// Reset discards the writer's state and makes it equivalent to +// the result of [NewWriter] or [NewWriterDict] called with dst +// and w's level and dictionary. +func (w *Writer) Reset(dst io.Writer) { + if dw, ok := w.d.w.writer.(*dictWriter); ok { + // w was created with NewWriterDict + dw.w = dst + w.d.reset(dw) + w.d.fillWindow(w.dict) + } else { + // w was created with NewWriter + w.d.reset(dst) + } +} diff --git a/ocifs/gzipr/internal/flate/deflatefast.go b/ocifs/gzipr/internal/flate/deflatefast.go new file mode 100644 index 0000000..e5554d6 --- /dev/null +++ b/ocifs/gzipr/internal/flate/deflatefast.go @@ -0,0 +1,307 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import "math" + +// This encoding algorithm, which prioritizes speed over output size, is +// based on Snappy's LZ77-style encoder: github.com/golang/snappy + +const ( + tableBits = 14 // Bits used in the table. + tableSize = 1 << tableBits // Size of the table. + tableMask = tableSize - 1 // Mask for table indices. Redundant, but can eliminate bounds checks. + tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32. + + // Reset the buffer offset when reaching this. + // Offsets are stored between blocks as int32 values. + // Since the offset we are checking against is at the beginning + // of the buffer, we need to subtract the current and input + // buffer to not risk overflowing the int32. + bufferReset = math.MaxInt32 - maxStoreBlockSize*2 +) + +func load32(b []byte, i int32) uint32 { + b = b[i : i+4 : len(b)] // Help the compiler eliminate bounds checks on the next line. + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func load64(b []byte, i int32) uint64 { + b = b[i : i+8 : len(b)] // Help the compiler eliminate bounds checks on the next line. + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | + uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 +} + +func hash(u uint32) uint32 { + return (u * 0x1e35a7bd) >> tableShift +} + +// These constants are defined by the Snappy implementation so that its +// assembly implementation can fast-path some 16-bytes-at-a-time copies. They +// aren't necessary in the pure Go implementation, as we don't use those same +// optimizations, but using the same thresholds doesn't really hurt. +const ( + inputMargin = 16 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin +) + +type tableEntry struct { + val uint32 // Value at destination + offset int32 +} + +// deflateFast maintains the table for matches, +// and the previous byte block for cross block matching. +type deflateFast struct { + table [tableSize]tableEntry + prev []byte // Previous block, zero length if unknown. + cur int32 // Current match offset. +} + +func newDeflateFast() *deflateFast { + return &deflateFast{cur: maxStoreBlockSize, prev: make([]byte, 0, maxStoreBlockSize)} +} + +// encode encodes a block given in src and appends tokens +// to dst and returns the result. +func (e *deflateFast) encode(dst []token, src []byte) []token { + // Ensure that e.cur doesn't wrap. + if e.cur >= bufferReset { + e.shiftOffsets() + } + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + e.cur += maxStoreBlockSize + e.prev = e.prev[:0] + return emitLiteral(dst, src) + } + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + nextEmit := int32(0) + s := int32(0) + cv := load32(src, s) + nextHash := hash(cv) + + for { + // Copied from the C++ snappy implementation: + // + // Heuristic match skipping: If 32 bytes are scanned with no matches + // found, start looking only at every other byte. If 32 more bytes are + // scanned (or skipped), look at every third byte, etc.. When a match + // is found, immediately go back to looking at every byte. This is a + // small loss (~5% performance, ~0.1% density) for compressible data + // due to more bookkeeping, but for non-compressible data (such as + // JPEG) it's a huge win since the compressor quickly "realizes" the + // data is incompressible and doesn't bother looking for matches + // everywhere. + // + // The "skip" variable keeps track of how many bytes there are since + // the last match; dividing it by 32 (ie. right-shifting by five) gives + // the number of bytes to move ahead for each iteration. + skip := int32(32) + + nextS := s + var candidate tableEntry + for { + s = nextS + bytesBetweenHashLookups := skip >> 5 + nextS = s + bytesBetweenHashLookups + skip += bytesBetweenHashLookups + if nextS > sLimit { + goto emitRemainder + } + candidate = e.table[nextHash&tableMask] + now := load32(src, nextS) + e.table[nextHash&tableMask] = tableEntry{offset: s + e.cur, val: cv} + nextHash = hash(now) + + offset := s - (candidate.offset - e.cur) + if offset > maxMatchOffset || cv != candidate.val { + // Out of range or not matched. + cv = now + continue + } + break + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + dst = emitLiteral(dst, src[nextEmit:s]) + + // Call emitCopy, and then see if another emitCopy could be our next + // move. Repeat until we find no match for the input immediately after + // what was consumed by the last emitCopy call. + // + // If we exit this loop normally then we need to call emitLiteral next, + // though we don't yet know how big the literal will be. We handle that + // by proceeding to the next iteration of the main loop. We also can + // exit this loop via goto if we get close to exhausting the input. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + + // Extend the 4-byte match as long as possible. + // + s += 4 + t := candidate.offset - e.cur + 4 + l := e.matchLen(s, t, src) + + // matchToken is flate's equivalent of Snappy's emitCopy. (length,offset) + dst = append(dst, matchToken(uint32(l+4-baseMatchLength), uint32(s-t-baseMatchOffset))) + s += l + nextEmit = s + if s >= sLimit { + goto emitRemainder + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-1 and at s. If + // another emitCopy is not our next move, also calculate nextHash + // at s+1. At least on GOARCH=amd64, these three hash calculations + // are faster as one load64 call (with some shifts) instead of + // three load32 calls. + x := load64(src, s-1) + prevHash := hash(uint32(x)) + e.table[prevHash&tableMask] = tableEntry{offset: e.cur + s - 1, val: uint32(x)} + x >>= 8 + currHash := hash(uint32(x)) + candidate = e.table[currHash&tableMask] + e.table[currHash&tableMask] = tableEntry{offset: e.cur + s, val: uint32(x)} + + offset := s - (candidate.offset - e.cur) + if offset > maxMatchOffset || uint32(x) != candidate.val { + cv = uint32(x >> 8) + nextHash = hash(cv) + s++ + break + } + } + } + +emitRemainder: + if int(nextEmit) < len(src) { + dst = emitLiteral(dst, src[nextEmit:]) + } + e.cur += int32(len(src)) + e.prev = e.prev[:len(src)] + copy(e.prev, src) + return dst +} + +func emitLiteral(dst []token, lit []byte) []token { + for _, v := range lit { + dst = append(dst, literalToken(uint32(v))) + } + return dst +} + +// matchLen returns the match length between src[s:] and src[t:]. +// t can be negative to indicate the match is starting in e.prev. +// We assume that src[s-4:s] and src[t-4:t] already match. +func (e *deflateFast) matchLen(s, t int32, src []byte) int32 { + s1 := int(s) + maxMatchLength - 4 + if s1 > len(src) { + s1 = len(src) + } + + // If we are inside the current block + if t >= 0 { + b := src[t:] + a := src[s:s1] + b = b[:len(a)] + // Extend the match to be as long as possible. + for i := range a { + if a[i] != b[i] { + return int32(i) + } + } + return int32(len(a)) + } + + // We found a match in the previous block. + tp := int32(len(e.prev)) + t + if tp < 0 { + return 0 + } + + // Extend the match to be as long as possible. + a := src[s:s1] + b := e.prev[tp:] + if len(b) > len(a) { + b = b[:len(a)] + } + a = a[:len(b)] + for i := range b { + if a[i] != b[i] { + return int32(i) + } + } + + // If we reached our limit, we matched everything we are + // allowed to in the previous block and we return. + n := int32(len(b)) + if int(s+n) == s1 { + return n + } + + // Continue looking for more matches in the current block. + a = src[s+n : s1] + b = src[:len(a)] + for i := range a { + if a[i] != b[i] { + return int32(i) + n + } + } + return int32(len(a)) + n +} + +// Reset resets the encoding history. +// This ensures that no matches are made to the previous block. +func (e *deflateFast) reset() { + e.prev = e.prev[:0] + // Bump the offset, so all matches will fail distance check. + // Nothing should be >= e.cur in the table. + e.cur += maxMatchOffset + + // Protect against e.cur wraparound. + if e.cur >= bufferReset { + e.shiftOffsets() + } +} + +// shiftOffsets will shift down all match offset. +// This is only called in rare situations to prevent integer overflow. +// +// See https://golang.org/issue/18636 and https://github.com/golang/go/issues/34121. +func (e *deflateFast) shiftOffsets() { + if len(e.prev) == 0 { + // We have no history; just clear the table. + clear(e.table[:]) + e.cur = maxMatchOffset + 1 + return + } + + // Shift down everything in the table that isn't already too far away. + for i := range e.table[:] { + v := e.table[i].offset - e.cur + maxMatchOffset + 1 + if v < 0 { + // We want to reset e.cur to maxMatchOffset + 1, so we need to shift + // all table entries down by (e.cur - (maxMatchOffset + 1)). + // Because we ignore matches > maxMatchOffset, we can cap + // any negative offsets at 0. + v = 0 + } + e.table[i].offset = v + } + e.cur = maxMatchOffset + 1 +} diff --git a/ocifs/gzipr/internal/flate/dict_decoder.go b/ocifs/gzipr/internal/flate/dict_decoder.go new file mode 100644 index 0000000..d2c1904 --- /dev/null +++ b/ocifs/gzipr/internal/flate/dict_decoder.go @@ -0,0 +1,182 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +// dictDecoder implements the LZ77 sliding dictionary as used in decompression. +// LZ77 decompresses data through sequences of two forms of commands: +// +// - Literal insertions: Runs of one or more symbols are inserted into the data +// stream as is. This is accomplished through the writeByte method for a +// single symbol, or combinations of writeSlice/writeMark for multiple symbols. +// Any valid stream must start with a literal insertion if no preset dictionary +// is used. +// +// - Backward copies: Runs of one or more symbols are copied from previously +// emitted data. Backward copies come as the tuple (dist, length) where dist +// determines how far back in the stream to copy from and length determines how +// many bytes to copy. Note that it is valid for the length to be greater than +// the distance. Since LZ77 uses forward copies, that situation is used to +// perform a form of run-length encoding on repeated runs of symbols. +// The writeCopy and tryWriteCopy are used to implement this command. +// +// For performance reasons, this implementation performs little to no sanity +// checks about the arguments. As such, the invariants documented for each +// method call must be respected. +type dictDecoder struct { + hist []byte // Sliding window history + + // Invariant: 0 <= rdPos <= wrPos <= len(hist) + wrPos int // Current output position in buffer + rdPos int // Have emitted hist[:rdPos] already + full bool // Has a full window length been written yet? +} + +// init initializes dictDecoder to have a sliding window dictionary of the given +// size. If a preset dict is provided, it will initialize the dictionary with +// the contents of dict. +func (dd *dictDecoder) init(size int, dict []byte) { + *dd = dictDecoder{hist: dd.hist} + + if cap(dd.hist) < size { + dd.hist = make([]byte, size) + } + dd.hist = dd.hist[:size] + + if len(dict) > len(dd.hist) { + dict = dict[len(dict)-len(dd.hist):] + } + dd.wrPos = copy(dd.hist, dict) + if dd.wrPos == len(dd.hist) { + dd.wrPos = 0 + dd.full = true + } + dd.rdPos = dd.wrPos +} + +// histSize reports the total amount of historical data in the dictionary. +func (dd *dictDecoder) histSize() int { + if dd.full { + return len(dd.hist) + } + return dd.wrPos +} + +// availRead reports the number of bytes that can be flushed by readFlush. +func (dd *dictDecoder) availRead() int { + return dd.wrPos - dd.rdPos +} + +// availWrite reports the available amount of output buffer space. +func (dd *dictDecoder) availWrite() int { + return len(dd.hist) - dd.wrPos +} + +// writeSlice returns a slice of the available buffer to write data to. +// +// This invariant will be kept: len(s) <= availWrite() +func (dd *dictDecoder) writeSlice() []byte { + return dd.hist[dd.wrPos:] +} + +// writeMark advances the writer pointer by cnt. +// +// This invariant must be kept: 0 <= cnt <= availWrite() +func (dd *dictDecoder) writeMark(cnt int) { + dd.wrPos += cnt +} + +// writeByte writes a single byte to the dictionary. +// +// This invariant must be kept: 0 < availWrite() +func (dd *dictDecoder) writeByte(c byte) { + dd.hist[dd.wrPos] = c + dd.wrPos++ +} + +// writeCopy copies a string at a given (dist, length) to the output. +// This returns the number of bytes copied and may be less than the requested +// length if the available space in the output buffer is too small. +// +// This invariant must be kept: 0 < dist <= histSize() +func (dd *dictDecoder) writeCopy(dist, length int) int { + dstBase := dd.wrPos + dstPos := dstBase + srcPos := dstPos - dist + endPos := dstPos + length + if endPos > len(dd.hist) { + endPos = len(dd.hist) + } + + // Copy non-overlapping section after destination position. + // + // This section is non-overlapping in that the copy length for this section + // is always less than or equal to the backwards distance. This can occur + // if a distance refers to data that wraps-around in the buffer. + // Thus, a backwards copy is performed here; that is, the exact bytes in + // the source prior to the copy is placed in the destination. + if srcPos < 0 { + srcPos += len(dd.hist) + dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:]) + srcPos = 0 + } + + // Copy possibly overlapping section before destination position. + // + // This section can overlap if the copy length for this section is larger + // than the backwards distance. This is allowed by LZ77 so that repeated + // strings can be succinctly represented using (dist, length) pairs. + // Thus, a forwards copy is performed here; that is, the bytes copied is + // possibly dependent on the resulting bytes in the destination as the copy + // progresses along. This is functionally equivalent to the following: + // + // for i := 0; i < endPos-dstPos; i++ { + // dd.hist[dstPos+i] = dd.hist[srcPos+i] + // } + // dstPos = endPos + // + for dstPos < endPos { + dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos]) + } + + dd.wrPos = dstPos + return dstPos - dstBase +} + +// tryWriteCopy tries to copy a string at a given (distance, length) to the +// output. This specialized version is optimized for short distances. +// +// This method is designed to be inlined for performance reasons. +// +// This invariant must be kept: 0 < dist <= histSize() +func (dd *dictDecoder) tryWriteCopy(dist, length int) int { + dstPos := dd.wrPos + endPos := dstPos + length + if dstPos < dist || endPos > len(dd.hist) { + return 0 + } + dstBase := dstPos + srcPos := dstPos - dist + + // Copy possibly overlapping section before destination position. + for dstPos < endPos { + dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos]) + } + + dd.wrPos = dstPos + return dstPos - dstBase +} + +// readFlush returns a slice of the historical buffer that is ready to be +// emitted to the user. The data returned by readFlush must be fully consumed +// before calling any other dictDecoder methods. +func (dd *dictDecoder) readFlush() []byte { + toRead := dd.hist[dd.rdPos:dd.wrPos] + dd.rdPos = dd.wrPos + if dd.wrPos == len(dd.hist) { + dd.wrPos, dd.rdPos = 0, 0 + dd.full = true + } + return toRead +} diff --git a/ocifs/gzipr/internal/flate/huffman_bit_writer.go b/ocifs/gzipr/internal/flate/huffman_bit_writer.go new file mode 100644 index 0000000..d68c77f --- /dev/null +++ b/ocifs/gzipr/internal/flate/huffman_bit_writer.go @@ -0,0 +1,693 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "io" +) + +const ( + // The largest offset code. + offsetCodeCount = 30 + + // The special code used to mark the end of a block. + endBlockMarker = 256 + + // The first length code. + lengthCodesStart = 257 + + // The number of codegen codes. + codegenCodeCount = 19 + badCode = 255 + + // bufferFlushSize indicates the buffer size + // after which bytes are flushed to the writer. + // Should preferably be a multiple of 6, since + // we accumulate 6 bytes between writes to the buffer. + bufferFlushSize = 240 + + // bufferSize is the actual output byte buffer size. + // It must have additional headroom for a flush + // which can contain up to 8 bytes. + bufferSize = bufferFlushSize + 8 +) + +// The number of extra bits needed by length code X - LENGTH_CODES_START. +var lengthExtraBits = []int8{ + /* 257 */ 0, 0, 0, + /* 260 */ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, + /* 270 */ 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, + /* 280 */ 4, 5, 5, 5, 5, 0, +} + +// The length indicated by length code X - LENGTH_CODES_START. +var lengthBase = []uint32{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, + 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, + 64, 80, 96, 112, 128, 160, 192, 224, 255, +} + +// offset code word extra bits. +var offsetExtraBits = []int8{ + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, + 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, + 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, +} + +var offsetBase = []uint32{ + 0x000000, 0x000001, 0x000002, 0x000003, 0x000004, + 0x000006, 0x000008, 0x00000c, 0x000010, 0x000018, + 0x000020, 0x000030, 0x000040, 0x000060, 0x000080, + 0x0000c0, 0x000100, 0x000180, 0x000200, 0x000300, + 0x000400, 0x000600, 0x000800, 0x000c00, 0x001000, + 0x001800, 0x002000, 0x003000, 0x004000, 0x006000, +} + +// The odd order in which the codegen code sizes are written. +var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15} + +type huffmanBitWriter struct { + // writer is the underlying writer. + // Do not use it directly; use the write method, which ensures + // that Write errors are sticky. + writer io.Writer + + // Data waiting to be written is bytes[0:nbytes] + // and then the low nbits of bits. Data is always written + // sequentially into the bytes array. + bits uint64 + nbits uint + bytes [bufferSize]byte + codegenFreq [codegenCodeCount]int32 + nbytes int + literalFreq []int32 + offsetFreq []int32 + codegen []uint8 + literalEncoding *huffmanEncoder + offsetEncoding *huffmanEncoder + codegenEncoding *huffmanEncoder + err error +} + +func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { + return &huffmanBitWriter{ + writer: w, + literalFreq: make([]int32, maxNumLit), + offsetFreq: make([]int32, offsetCodeCount), + codegen: make([]uint8, maxNumLit+offsetCodeCount+1), + literalEncoding: newHuffmanEncoder(maxNumLit), + codegenEncoding: newHuffmanEncoder(codegenCodeCount), + offsetEncoding: newHuffmanEncoder(offsetCodeCount), + } +} + +func (w *huffmanBitWriter) reset(writer io.Writer) { + w.writer = writer + w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil +} + +func (w *huffmanBitWriter) flush() { + if w.err != nil { + w.nbits = 0 + return + } + n := w.nbytes + for w.nbits != 0 { + w.bytes[n] = byte(w.bits) + w.bits >>= 8 + if w.nbits > 8 { // Avoid underflow + w.nbits -= 8 + } else { + w.nbits = 0 + } + n++ + } + w.bits = 0 + w.write(w.bytes[:n]) + w.nbytes = 0 +} + +func (w *huffmanBitWriter) write(b []byte) { + if w.err != nil { + return + } + _, w.err = w.writer.Write(b) +} + +func (w *huffmanBitWriter) writeBits(b int32, nb uint) { + if w.err != nil { + return + } + w.bits |= uint64(b) << w.nbits + w.nbits += nb + if w.nbits >= 48 { + bits := w.bits + w.bits >>= 48 + w.nbits -= 48 + n := w.nbytes + bytes := w.bytes[n : n+6] + bytes[0] = byte(bits) + bytes[1] = byte(bits >> 8) + bytes[2] = byte(bits >> 16) + bytes[3] = byte(bits >> 24) + bytes[4] = byte(bits >> 32) + bytes[5] = byte(bits >> 40) + n += 6 + if n >= bufferFlushSize { + w.write(w.bytes[:n]) + n = 0 + } + w.nbytes = n + } +} + +func (w *huffmanBitWriter) writeBytes(bytes []byte) { + if w.err != nil { + return + } + n := w.nbytes + if w.nbits&7 != 0 { + w.err = InternalError("writeBytes with unfinished bits") + return + } + for w.nbits != 0 { + w.bytes[n] = byte(w.bits) + w.bits >>= 8 + w.nbits -= 8 + n++ + } + if n != 0 { + w.write(w.bytes[:n]) + } + w.nbytes = 0 + w.write(bytes) +} + +// RFC 1951 3.2.7 specifies a special run-length encoding for specifying +// the literal and offset lengths arrays (which are concatenated into a single +// array). This method generates that run-length encoding. +// +// The result is written into the codegen array, and the frequencies +// of each code is written into the codegenFreq array. +// Codes 0-15 are single byte codes. Codes 16-18 are followed by additional +// information. Code badCode is an end marker +// +// numLiterals The number of literals in literalEncoding +// numOffsets The number of offsets in offsetEncoding +// litenc, offenc The literal and offset encoder to use +func (w *huffmanBitWriter) generateCodegen(numLiterals int, numOffsets int, litEnc, offEnc *huffmanEncoder) { + clear(w.codegenFreq[:]) + // Note that we are using codegen both as a temporary variable for holding + // a copy of the frequencies, and as the place where we put the result. + // This is fine because the output is always shorter than the input used + // so far. + codegen := w.codegen // cache + // Copy the concatenated code sizes to codegen. Put a marker at the end. + cgnl := codegen[:numLiterals] + for i := range cgnl { + cgnl[i] = uint8(litEnc.codes[i].len) + } + + cgnl = codegen[numLiterals : numLiterals+numOffsets] + for i := range cgnl { + cgnl[i] = uint8(offEnc.codes[i].len) + } + codegen[numLiterals+numOffsets] = badCode + + size := codegen[0] + count := 1 + outIndex := 0 + for inIndex := 1; size != badCode; inIndex++ { + // INVARIANT: We have seen "count" copies of size that have not yet + // had output generated for them. + nextSize := codegen[inIndex] + if nextSize == size { + count++ + continue + } + // We need to generate codegen indicating "count" of size. + if size != 0 { + codegen[outIndex] = size + outIndex++ + w.codegenFreq[size]++ + count-- + for count >= 3 { + n := 6 + if n > count { + n = count + } + codegen[outIndex] = 16 + outIndex++ + codegen[outIndex] = uint8(n - 3) + outIndex++ + w.codegenFreq[16]++ + count -= n + } + } else { + for count >= 11 { + n := 138 + if n > count { + n = count + } + codegen[outIndex] = 18 + outIndex++ + codegen[outIndex] = uint8(n - 11) + outIndex++ + w.codegenFreq[18]++ + count -= n + } + if count >= 3 { + // count >= 3 && count <= 10 + codegen[outIndex] = 17 + outIndex++ + codegen[outIndex] = uint8(count - 3) + outIndex++ + w.codegenFreq[17]++ + count = 0 + } + } + count-- + for ; count >= 0; count-- { + codegen[outIndex] = size + outIndex++ + w.codegenFreq[size]++ + } + // Set up invariant for next time through the loop. + size = nextSize + count = 1 + } + // Marker indicating the end of the codegen. + codegen[outIndex] = badCode +} + +// dynamicSize returns the size of dynamically encoded data in bits. +func (w *huffmanBitWriter) dynamicSize(litEnc, offEnc *huffmanEncoder, extraBits int) (size, numCodegens int) { + numCodegens = len(w.codegenFreq) + for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 { + numCodegens-- + } + header := 3 + 5 + 5 + 4 + (3 * numCodegens) + + w.codegenEncoding.bitLength(w.codegenFreq[:]) + + int(w.codegenFreq[16])*2 + + int(w.codegenFreq[17])*3 + + int(w.codegenFreq[18])*7 + size = header + + litEnc.bitLength(w.literalFreq) + + offEnc.bitLength(w.offsetFreq) + + extraBits + + return size, numCodegens +} + +// fixedSize returns the size of dynamically encoded data in bits. +func (w *huffmanBitWriter) fixedSize(extraBits int) int { + return 3 + + fixedLiteralEncoding.bitLength(w.literalFreq) + + fixedOffsetEncoding.bitLength(w.offsetFreq) + + extraBits +} + +// storedSize calculates the stored size, including header. +// The function returns the size in bits and whether the block +// fits inside a single block. +func (w *huffmanBitWriter) storedSize(in []byte) (int, bool) { + if in == nil { + return 0, false + } + if len(in) <= maxStoreBlockSize { + return (len(in) + 5) * 8, true + } + return 0, false +} + +func (w *huffmanBitWriter) writeCode(c hcode) { + if w.err != nil { + return + } + w.bits |= uint64(c.code) << w.nbits + w.nbits += uint(c.len) + if w.nbits >= 48 { + bits := w.bits + w.bits >>= 48 + w.nbits -= 48 + n := w.nbytes + bytes := w.bytes[n : n+6] + bytes[0] = byte(bits) + bytes[1] = byte(bits >> 8) + bytes[2] = byte(bits >> 16) + bytes[3] = byte(bits >> 24) + bytes[4] = byte(bits >> 32) + bytes[5] = byte(bits >> 40) + n += 6 + if n >= bufferFlushSize { + w.write(w.bytes[:n]) + n = 0 + } + w.nbytes = n + } +} + +// Write the header of a dynamic Huffman block to the output stream. +// +// numLiterals The number of literals specified in codegen +// numOffsets The number of offsets specified in codegen +// numCodegens The number of codegens used in codegen +func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) { + if w.err != nil { + return + } + var firstBits int32 = 4 + if isEof { + firstBits = 5 + } + w.writeBits(firstBits, 3) + w.writeBits(int32(numLiterals-257), 5) + w.writeBits(int32(numOffsets-1), 5) + w.writeBits(int32(numCodegens-4), 4) + + for i := 0; i < numCodegens; i++ { + value := uint(w.codegenEncoding.codes[codegenOrder[i]].len) + w.writeBits(int32(value), 3) + } + + i := 0 + for { + var codeWord int = int(w.codegen[i]) + i++ + if codeWord == badCode { + break + } + w.writeCode(w.codegenEncoding.codes[uint32(codeWord)]) + + switch codeWord { + case 16: + w.writeBits(int32(w.codegen[i]), 2) + i++ + case 17: + w.writeBits(int32(w.codegen[i]), 3) + i++ + case 18: + w.writeBits(int32(w.codegen[i]), 7) + i++ + } + } +} + +func (w *huffmanBitWriter) writeStoredHeader(length int, isEof bool) { + if w.err != nil { + return + } + var flag int32 + if isEof { + flag = 1 + } + w.writeBits(flag, 3) + w.flush() + w.writeBits(int32(length), 16) + w.writeBits(int32(^uint16(length)), 16) +} + +func (w *huffmanBitWriter) writeFixedHeader(isEof bool) { + if w.err != nil { + return + } + // Indicate that we are a fixed Huffman block + var value int32 = 2 + if isEof { + value = 3 + } + w.writeBits(value, 3) +} + +// writeBlock will write a block of tokens with the smallest encoding. +// The original input can be supplied, and if the huffman encoded data +// is larger than the original bytes, the data will be written as a +// stored block. +// If the input is nil, the tokens will always be Huffman encoded. +func (w *huffmanBitWriter) writeBlock(tokens []token, eof bool, input []byte) { + if w.err != nil { + return + } + + tokens = append(tokens, endBlockMarker) + numLiterals, numOffsets := w.indexTokens(tokens) + + var extraBits int + storedSize, storable := w.storedSize(input) + if storable { + // We only bother calculating the costs of the extra bits required by + // the length of offset fields (which will be the same for both fixed + // and dynamic encoding), if we need to compare those two encodings + // against stored encoding. + for lengthCode := lengthCodesStart + 8; lengthCode < numLiterals; lengthCode++ { + // First eight length codes have extra size = 0. + extraBits += int(w.literalFreq[lengthCode]) * int(lengthExtraBits[lengthCode-lengthCodesStart]) + } + for offsetCode := 4; offsetCode < numOffsets; offsetCode++ { + // First four offset codes have extra size = 0. + extraBits += int(w.offsetFreq[offsetCode]) * int(offsetExtraBits[offsetCode]) + } + } + + // Figure out smallest code. + // Fixed Huffman baseline. + var literalEncoding = fixedLiteralEncoding + var offsetEncoding = fixedOffsetEncoding + var size = w.fixedSize(extraBits) + + // Dynamic Huffman? + var numCodegens int + + // Generate codegen and codegenFrequencies, which indicates how to encode + // the literalEncoding and the offsetEncoding. + w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding) + w.codegenEncoding.generate(w.codegenFreq[:], 7) + dynamicSize, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, extraBits) + + if dynamicSize < size { + size = dynamicSize + literalEncoding = w.literalEncoding + offsetEncoding = w.offsetEncoding + } + + // Stored bytes? + if storable && storedSize < size { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + return + } + + // Huffman. + if literalEncoding == fixedLiteralEncoding { + w.writeFixedHeader(eof) + } else { + w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) + } + + // Write the tokens. + w.writeTokens(tokens, literalEncoding.codes, offsetEncoding.codes) +} + +// writeBlockDynamic encodes a block using a dynamic Huffman table. +// This should be used if the symbols used have a disproportionate +// histogram distribution. +// If input is supplied and the compression savings are below 1/16th of the +// input size the block is stored. +func (w *huffmanBitWriter) writeBlockDynamic(tokens []token, eof bool, input []byte) { + if w.err != nil { + return + } + + tokens = append(tokens, endBlockMarker) + numLiterals, numOffsets := w.indexTokens(tokens) + + // Generate codegen and codegenFrequencies, which indicates how to encode + // the literalEncoding and the offsetEncoding. + w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding) + w.codegenEncoding.generate(w.codegenFreq[:], 7) + size, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, 0) + + // Store bytes, if we don't get a reasonable improvement. + if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + return + } + + // Write Huffman table. + w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) + + // Write the tokens. + w.writeTokens(tokens, w.literalEncoding.codes, w.offsetEncoding.codes) +} + +// indexTokens indexes a slice of tokens, and updates +// literalFreq and offsetFreq, and generates literalEncoding +// and offsetEncoding. +// The number of literal and offset tokens is returned. +func (w *huffmanBitWriter) indexTokens(tokens []token) (numLiterals, numOffsets int) { + clear(w.literalFreq) + clear(w.offsetFreq) + + for _, t := range tokens { + if t < matchType { + w.literalFreq[t.literal()]++ + continue + } + length := t.length() + offset := t.offset() + w.literalFreq[lengthCodesStart+lengthCode(length)]++ + w.offsetFreq[offsetCode(offset)]++ + } + + // get the number of literals + numLiterals = len(w.literalFreq) + for w.literalFreq[numLiterals-1] == 0 { + numLiterals-- + } + // get the number of offsets + numOffsets = len(w.offsetFreq) + for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 { + numOffsets-- + } + if numOffsets == 0 { + // We haven't found a single match. If we want to go with the dynamic encoding, + // we should count at least one offset to be sure that the offset huffman tree could be encoded. + w.offsetFreq[0] = 1 + numOffsets = 1 + } + w.literalEncoding.generate(w.literalFreq, 15) + w.offsetEncoding.generate(w.offsetFreq, 15) + return +} + +// writeTokens writes a slice of tokens to the output. +// codes for literal and offset encoding must be supplied. +func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) { + if w.err != nil { + return + } + for _, t := range tokens { + if t < matchType { + w.writeCode(leCodes[t.literal()]) + continue + } + // Write the length + length := t.length() + lengthCode := lengthCode(length) + w.writeCode(leCodes[lengthCode+lengthCodesStart]) + extraLengthBits := uint(lengthExtraBits[lengthCode]) + if extraLengthBits > 0 { + extraLength := int32(length - lengthBase[lengthCode]) + w.writeBits(extraLength, extraLengthBits) + } + // Write the offset + offset := t.offset() + offsetCode := offsetCode(offset) + w.writeCode(oeCodes[offsetCode]) + extraOffsetBits := uint(offsetExtraBits[offsetCode]) + if extraOffsetBits > 0 { + extraOffset := int32(offset - offsetBase[offsetCode]) + w.writeBits(extraOffset, extraOffsetBits) + } + } +} + +// huffOffset is a static offset encoder used for huffman only encoding. +// It can be reused since we will not be encoding offset values. +var huffOffset *huffmanEncoder + +func init() { + offsetFreq := make([]int32, offsetCodeCount) + offsetFreq[0] = 1 + huffOffset = newHuffmanEncoder(offsetCodeCount) + huffOffset.generate(offsetFreq, 15) +} + +// writeBlockHuff encodes a block of bytes as either +// Huffman encoded literals or uncompressed bytes if the +// results only gains very little from compression. +func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte) { + if w.err != nil { + return + } + + // Clear histogram + clear(w.literalFreq) + + // Add everything as literals + histogram(input, w.literalFreq) + + w.literalFreq[endBlockMarker] = 1 + + const numLiterals = endBlockMarker + 1 + w.offsetFreq[0] = 1 + const numOffsets = 1 + + w.literalEncoding.generate(w.literalFreq, 15) + + // Figure out smallest code. + // Always use dynamic Huffman or Store + var numCodegens int + + // Generate codegen and codegenFrequencies, which indicates how to encode + // the literalEncoding and the offsetEncoding. + w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, huffOffset) + w.codegenEncoding.generate(w.codegenFreq[:], 7) + size, numCodegens := w.dynamicSize(w.literalEncoding, huffOffset, 0) + + // Store bytes, if we don't get a reasonable improvement. + if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + return + } + + // Huffman. + w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) + encoding := w.literalEncoding.codes[:257] + n := w.nbytes + for _, t := range input { + // Bitwriting inlined, ~30% speedup + c := encoding[t] + w.bits |= uint64(c.code) << w.nbits + w.nbits += uint(c.len) + if w.nbits < 48 { + continue + } + // Store 6 bytes + bits := w.bits + w.bits >>= 48 + w.nbits -= 48 + bytes := w.bytes[n : n+6] + bytes[0] = byte(bits) + bytes[1] = byte(bits >> 8) + bytes[2] = byte(bits >> 16) + bytes[3] = byte(bits >> 24) + bytes[4] = byte(bits >> 32) + bytes[5] = byte(bits >> 40) + n += 6 + if n < bufferFlushSize { + continue + } + w.write(w.bytes[:n]) + if w.err != nil { + return // Return early in the event of write failures + } + n = 0 + } + w.nbytes = n + w.writeCode(encoding[endBlockMarker]) +} + +// histogram accumulates a histogram of b in h. +// +// len(h) must be >= 256, and h's elements must be all zeroes. +func histogram(b []byte, h []int32) { + h = h[:256] + for _, t := range b { + h[t]++ + } +} diff --git a/ocifs/gzipr/internal/flate/huffman_code.go b/ocifs/gzipr/internal/flate/huffman_code.go new file mode 100644 index 0000000..6f69cab --- /dev/null +++ b/ocifs/gzipr/internal/flate/huffman_code.go @@ -0,0 +1,345 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "math" + "math/bits" + "sort" +) + +// hcode is a huffman code with a bit code and bit length. +type hcode struct { + code, len uint16 +} + +type huffmanEncoder struct { + codes []hcode + freqcache []literalNode + bitCount [17]int32 + lns byLiteral // stored to avoid repeated allocation in generate + lfs byFreq // stored to avoid repeated allocation in generate +} + +type literalNode struct { + literal uint16 + freq int32 +} + +// A levelInfo describes the state of the constructed tree for a given depth. +type levelInfo struct { + // Our level. for better printing + level int32 + + // The frequency of the last node at this level + lastFreq int32 + + // The frequency of the next character to add to this level + nextCharFreq int32 + + // The frequency of the next pair (from level below) to add to this level. + // Only valid if the "needed" value of the next lower level is 0. + nextPairFreq int32 + + // The number of chains remaining to generate for this level before moving + // up to the next level + needed int32 +} + +// set sets the code and length of an hcode. +func (h *hcode) set(code uint16, length uint16) { + h.len = length + h.code = code +} + +func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxInt32} } + +func newHuffmanEncoder(size int) *huffmanEncoder { + return &huffmanEncoder{codes: make([]hcode, size)} +} + +// Generates a HuffmanCode corresponding to the fixed literal table. +func generateFixedLiteralEncoding() *huffmanEncoder { + h := newHuffmanEncoder(maxNumLit) + codes := h.codes + var ch uint16 + for ch = 0; ch < maxNumLit; ch++ { + var bits uint16 + var size uint16 + switch { + case ch < 144: + // size 8, 000110000 .. 10111111 + bits = ch + 48 + size = 8 + case ch < 256: + // size 9, 110010000 .. 111111111 + bits = ch + 400 - 144 + size = 9 + case ch < 280: + // size 7, 0000000 .. 0010111 + bits = ch - 256 + size = 7 + default: + // size 8, 11000000 .. 11000111 + bits = ch + 192 - 280 + size = 8 + } + codes[ch] = hcode{code: reverseBits(bits, byte(size)), len: size} + } + return h +} + +func generateFixedOffsetEncoding() *huffmanEncoder { + h := newHuffmanEncoder(30) + codes := h.codes + for ch := range codes { + codes[ch] = hcode{code: reverseBits(uint16(ch), 5), len: 5} + } + return h +} + +var fixedLiteralEncoding *huffmanEncoder = generateFixedLiteralEncoding() +var fixedOffsetEncoding *huffmanEncoder = generateFixedOffsetEncoding() + +func (h *huffmanEncoder) bitLength(freq []int32) int { + var total int + for i, f := range freq { + if f != 0 { + total += int(f) * int(h.codes[i].len) + } + } + return total +} + +const maxBitsLimit = 16 + +// bitCounts computes the number of literals assigned to each bit size in the Huffman encoding. +// It is only called when list.length >= 3. +// The cases of 0, 1, and 2 literals are handled by special case code. +// +// list is an array of the literals with non-zero frequencies +// and their associated frequencies. The array is in order of increasing +// frequency and has as its last element a special element with frequency +// MaxInt32. +// +// maxBits is the maximum number of bits that should be used to encode any literal. +// It must be less than 16. +// +// bitCounts returns an integer slice in which slice[i] indicates the number of literals +// that should be encoded in i bits. +func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 { + if maxBits >= maxBitsLimit { + panic("flate: maxBits too large") + } + n := int32(len(list)) + list = list[0 : n+1] + list[n] = maxNode() + + // The tree can't have greater depth than n - 1, no matter what. This + // saves a little bit of work in some small cases + if maxBits > n-1 { + maxBits = n - 1 + } + + // Create information about each of the levels. + // A bogus "Level 0" whose sole purpose is so that + // level1.prev.needed==0. This makes level1.nextPairFreq + // be a legitimate value that never gets chosen. + var levels [maxBitsLimit]levelInfo + // leafCounts[i] counts the number of literals at the left + // of ancestors of the rightmost node at level i. + // leafCounts[i][j] is the number of literals at the left + // of the level j ancestor. + var leafCounts [maxBitsLimit][maxBitsLimit]int32 + + for level := int32(1); level <= maxBits; level++ { + // For every level, the first two items are the first two characters. + // We initialize the levels as if we had already figured this out. + levels[level] = levelInfo{ + level: level, + lastFreq: list[1].freq, + nextCharFreq: list[2].freq, + nextPairFreq: list[0].freq + list[1].freq, + } + leafCounts[level][level] = 2 + if level == 1 { + levels[level].nextPairFreq = math.MaxInt32 + } + } + + // We need a total of 2*n - 2 items at top level and have already generated 2. + levels[maxBits].needed = 2*n - 4 + + level := maxBits + for { + l := &levels[level] + if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 { + // We've run out of both leaves and pairs. + // End all calculations for this level. + // To make sure we never come back to this level or any lower level, + // set nextPairFreq impossibly large. + l.needed = 0 + levels[level+1].nextPairFreq = math.MaxInt32 + level++ + continue + } + + prevFreq := l.lastFreq + if l.nextCharFreq < l.nextPairFreq { + // The next item on this row is a leaf node. + n := leafCounts[level][level] + 1 + l.lastFreq = l.nextCharFreq + // Lower leafCounts are the same of the previous node. + leafCounts[level][level] = n + l.nextCharFreq = list[n].freq + } else { + // The next item on this row is a pair from the previous row. + // nextPairFreq isn't valid until we generate two + // more values in the level below + l.lastFreq = l.nextPairFreq + // Take leaf counts from the lower level, except counts[level] remains the same. + copy(leafCounts[level][:level], leafCounts[level-1][:level]) + levels[l.level-1].needed = 2 + } + + if l.needed--; l.needed == 0 { + // We've done everything we need to do for this level. + // Continue calculating one level up. Fill in nextPairFreq + // of that level with the sum of the two nodes we've just calculated on + // this level. + if l.level == maxBits { + // All done! + break + } + levels[l.level+1].nextPairFreq = prevFreq + l.lastFreq + level++ + } else { + // If we stole from below, move down temporarily to replenish it. + for levels[level-1].needed > 0 { + level-- + } + } + } + + // Somethings is wrong if at the end, the top level is null or hasn't used + // all of the leaves. + if leafCounts[maxBits][maxBits] != n { + panic("leafCounts[maxBits][maxBits] != n") + } + + bitCount := h.bitCount[:maxBits+1] + bits := 1 + counts := &leafCounts[maxBits] + for level := maxBits; level > 0; level-- { + // chain.leafCount gives the number of literals requiring at least "bits" + // bits to encode. + bitCount[bits] = counts[level] - counts[level-1] + bits++ + } + return bitCount +} + +// Look at the leaves and assign them a bit count and an encoding as specified +// in RFC 1951 3.2.2 +func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalNode) { + code := uint16(0) + for n, bits := range bitCount { + code <<= 1 + if n == 0 || bits == 0 { + continue + } + // The literals list[len(list)-bits] .. list[len(list)-bits] + // are encoded using "bits" bits, and get the values + // code, code + 1, .... The code values are + // assigned in literal order (not frequency order). + chunk := list[len(list)-int(bits):] + + h.lns.sort(chunk) + for _, node := range chunk { + h.codes[node.literal] = hcode{code: reverseBits(code, uint8(n)), len: uint16(n)} + code++ + } + list = list[0 : len(list)-int(bits)] + } +} + +// Update this Huffman Code object to be the minimum code for the specified frequency count. +// +// freq is an array of frequencies, in which freq[i] gives the frequency of literal i. +// maxBits The maximum number of bits to use for any literal. +func (h *huffmanEncoder) generate(freq []int32, maxBits int32) { + if h.freqcache == nil { + // Allocate a reusable buffer with the longest possible frequency table. + // Possible lengths are codegenCodeCount, offsetCodeCount and maxNumLit. + // The largest of these is maxNumLit, so we allocate for that case. + h.freqcache = make([]literalNode, maxNumLit+1) + } + list := h.freqcache[:len(freq)+1] + // Number of non-zero literals + count := 0 + // Set list to be the set of all non-zero literals and their frequencies + for i, f := range freq { + if f != 0 { + list[count] = literalNode{uint16(i), f} + count++ + } else { + h.codes[i].len = 0 + } + } + + list = list[:count] + if count <= 2 { + // Handle the small cases here, because they are awkward for the general case code. With + // two or fewer literals, everything has bit length 1. + for i, node := range list { + // "list" is in order of increasing literal value. + h.codes[node.literal].set(uint16(i), 1) + } + return + } + h.lfs.sort(list) + + // Get the number of literals for each bit count + bitCount := h.bitCounts(list, maxBits) + // And do the assignment + h.assignEncodingAndSize(bitCount, list) +} + +type byLiteral []literalNode + +func (s *byLiteral) sort(a []literalNode) { + *s = byLiteral(a) + sort.Sort(s) +} + +func (s byLiteral) Len() int { return len(s) } + +func (s byLiteral) Less(i, j int) bool { + return s[i].literal < s[j].literal +} + +func (s byLiteral) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +type byFreq []literalNode + +func (s *byFreq) sort(a []literalNode) { + *s = byFreq(a) + sort.Sort(s) +} + +func (s byFreq) Len() int { return len(s) } + +func (s byFreq) Less(i, j int) bool { + if s[i].freq == s[j].freq { + return s[i].literal < s[j].literal + } + return s[i].freq < s[j].freq +} + +func (s byFreq) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func reverseBits(number uint16, bitLength byte) uint16 { + return bits.Reverse16(number << (16 - bitLength)) +} diff --git a/ocifs/gzipr/internal/flate/inflate.go b/ocifs/gzipr/internal/flate/inflate.go new file mode 100644 index 0000000..42c28e3 --- /dev/null +++ b/ocifs/gzipr/internal/flate/inflate.go @@ -0,0 +1,923 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file has been modified from the original Go standard library +// compress/flate. Modifications add DEFLATE block-boundary checkpoint +// support so callers can resume decompression at any block boundary. +// The checkpoint mechanism follows the design used by +// jonjohnsonjr/targz (MIT). + +// Package flate implements the DEFLATE compressed data format, described in +// RFC 1951. The [compress/gzip] and [compress/zlib] packages implement access +// to DEFLATE-based file formats. +package flate + +import ( + "bufio" + "io" + "math/bits" + "strconv" + "sync" +) + +const ( + maxCodeLen = 16 // max length of Huffman code + // The next three numbers come from the RFC section 3.2.7, with the + // additional proviso in section 3.2.5 which implies that distance codes + // 30 and 31 should never occur in compressed data. + maxNumLit = 286 + maxNumDist = 30 + numCodes = 19 // number of codes in Huffman meta-code +) + +// Initialize the fixedHuffmanDecoder only once upon first use. +var fixedOnce sync.Once +var fixedHuffmanDecoder huffmanDecoder + +// A CorruptInputError reports the presence of corrupt input at a given offset. +type CorruptInputError int64 + +func (e CorruptInputError) Error() string { + return "flate: corrupt input before offset " + strconv.FormatInt(int64(e), 10) +} + +// An InternalError reports an error in the flate code itself. +type InternalError string + +func (e InternalError) Error() string { return "flate: internal error: " + string(e) } + +// A ReadError reports an error encountered while reading input. +// +// Deprecated: No longer returned. +type ReadError struct { + Offset int64 // byte offset where error occurred + Err error // error returned by underlying Read +} + +func (e *ReadError) Error() string { + return "flate: read error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error() +} + +// A WriteError reports an error encountered while writing output. +// +// Deprecated: No longer returned. +type WriteError struct { + Offset int64 // byte offset where error occurred + Err error // error returned by underlying Write +} + +func (e *WriteError) Error() string { + return "flate: write error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error() +} + +// Resetter resets a ReadCloser returned by [NewReader] or [NewReaderDict] +// to switch to a new underlying [Reader]. This permits reusing a ReadCloser +// instead of allocating a new one. +type Resetter interface { + // Reset discards any buffered data and resets the Resetter as if it was + // newly initialized with the given reader. + Reset(r io.Reader, dict []byte) error +} + +// The data structure for decoding Huffman tables is based on that of +// zlib. There is a lookup table of a fixed bit width (huffmanChunkBits), +// For codes smaller than the table width, there are multiple entries +// (each combination of trailing bits has the same value). For codes +// larger than the table width, the table contains a link to an overflow +// table. The width of each entry in the link table is the maximum code +// size minus the chunk width. +// +// Note that you can do a lookup in the table even without all bits +// filled. Since the extra bits are zero, and the DEFLATE Huffman codes +// have the property that shorter codes come before longer ones, the +// bit length estimate in the result is a lower bound on the actual +// number of bits. +// +// See the following: +// https://github.com/madler/zlib/raw/master/doc/algorithm.txt + +// chunk & 15 is number of bits +// chunk >> 4 is value, including table link + +const ( + huffmanChunkBits = 9 + huffmanNumChunks = 1 << huffmanChunkBits + huffmanCountMask = 15 + huffmanValueShift = 4 +) + +type huffmanDecoder struct { + min int // the minimum code length + chunks [huffmanNumChunks]uint32 // chunks as described above + links [][]uint32 // overflow links + linkMask uint32 // mask the width of the link table +} + +// Initialize Huffman decoding tables from array of code lengths. +// Following this function, h is guaranteed to be initialized into a complete +// tree (i.e., neither over-subscribed nor under-subscribed). The exception is a +// degenerate case where the tree has only a single symbol with length 1. Empty +// trees are permitted. +func (h *huffmanDecoder) init(lengths []int) bool { + // Sanity enables additional runtime tests during Huffman + // table construction. It's intended to be used during + // development to supplement the currently ad-hoc unit tests. + const sanity = false + + if h.min != 0 { + *h = huffmanDecoder{} + } + + // Count number of codes of each length, + // compute min and max length. + var count [maxCodeLen]int + var min, max int + for _, n := range lengths { + if n == 0 { + continue + } + if min == 0 || n < min { + min = n + } + if n > max { + max = n + } + count[n]++ + } + + // Empty tree. The decompressor.huffSym function will fail later if the tree + // is used. Technically, an empty tree is only valid for the HDIST tree and + // not the HCLEN and HLIT tree. However, a stream with an empty HCLEN tree + // is guaranteed to fail since it will attempt to use the tree to decode the + // codes for the HLIT and HDIST trees. Similarly, an empty HLIT tree is + // guaranteed to fail later since the compressed data section must be + // composed of at least one symbol (the end-of-block marker). + if max == 0 { + return true + } + + code := 0 + var nextcode [maxCodeLen]int + for i := min; i <= max; i++ { + code <<= 1 + nextcode[i] = code + code += count[i] + } + + // Check that the coding is complete (i.e., that we've + // assigned all 2-to-the-max possible bit sequences). + // Exception: To be compatible with zlib, we also need to + // accept degenerate single-code codings. See also + // TestDegenerateHuffmanCoding. + if code != 1< huffmanChunkBits { + numLinks := 1 << (uint(max) - huffmanChunkBits) + h.linkMask = uint32(numLinks - 1) + + // create link tables + link := nextcode[huffmanChunkBits+1] >> 1 + h.links = make([][]uint32, huffmanNumChunks-link) + for j := uint(link); j < huffmanNumChunks; j++ { + reverse := int(bits.Reverse16(uint16(j))) + reverse >>= uint(16 - huffmanChunkBits) + off := j - uint(link) + if sanity && h.chunks[reverse] != 0 { + panic("impossible: overwriting existing chunk") + } + h.chunks[reverse] = uint32(off<>= uint(16 - n) + if n <= huffmanChunkBits { + for off := reverse; off < len(h.chunks); off += 1 << uint(n) { + // We should never need to overwrite + // an existing chunk. Also, 0 is + // never a valid chunk, because the + // lower 4 "count" bits should be + // between 1 and 15. + if sanity && h.chunks[off] != 0 { + panic("impossible: overwriting existing chunk") + } + h.chunks[off] = chunk + } + } else { + j := reverse & (huffmanNumChunks - 1) + if sanity && h.chunks[j]&huffmanCountMask != huffmanChunkBits+1 { + // Longer codes should have been + // associated with a link table above. + panic("impossible: not an indirect chunk") + } + value := h.chunks[j] >> huffmanValueShift + linktab := h.links[value] + reverse >>= huffmanChunkBits + for off := reverse; off < len(linktab); off += 1 << uint(n-huffmanChunkBits) { + if sanity && linktab[off] != 0 { + panic("impossible: overwriting existing chunk") + } + linktab[off] = chunk + } + } + } + + if sanity { + // Above we've sanity checked that we never overwrote + // an existing entry. Here we additionally check that + // we filled the tables completely. + for i, chunk := range h.chunks { + if chunk == 0 { + // As an exception, in the degenerate + // single-code case, we allow odd + // chunks to be missing. + if code == 1 && i%2 == 1 { + continue + } + panic("impossible: missing chunk") + } + } + for _, linktab := range h.links { + for _, chunk := range linktab { + if chunk == 0 { + panic("impossible: missing chunk") + } + } + } + } + + return true +} + +// The actual read interface needed by [NewReader]. +// If the passed in [io.Reader] does not also have ReadByte, +// the [NewReader] will introduce its own buffering. +type Reader interface { + io.Reader + io.ByteReader +} + +// Decompress state. +type decompressor struct { + // Input source. + r Reader + rBuf *bufio.Reader // created if provided io.Reader does not implement io.ByteReader + roffset int64 + + // Input bits, in top of b. + b uint32 + nb uint + + // Huffman decoders for literal/length, distance. + h1, h2 huffmanDecoder + + // Length arrays used to define Huffman codes. + bits *[maxNumLit + maxNumDist]int + codebits *[numCodes]int + + // Output history, buffer. + dict dictDecoder + + // Temporary buffer (avoids repeated allocation). + buf [4]byte + + // Next step in the decompression, + // and decompression state. + step func(*decompressor) + stepState int + final bool + err error + toRead []byte + hl, hd *huffmanDecoder + copyLen int + copyDist int + + // Checkpoint support (extension over stdlib flate). + // + // checkpoint, if non-nil, is invoked at every DEFLATE block boundary + // immediately before the next block's header bits are read. The + // arguments capture sufficient state to resume decompression from + // that exact boundary. + // + // out is the running total of decompressed bytes WRITTEN to the + // sliding-window dictionary so far. It includes bytes that have been + // produced by the decompressor but not yet delivered to the + // consumer; conceptually it is the consumer's read offset after a + // hypothetical full drain at this point. The counter is advanced + // synchronously with every dict write (writeByte, writeCopy, and + // writeMark) so that checkpoints record the correct decompressed + // offset for resume. + checkpoint func(in, out int64, b uint32, nb uint, hist []byte) + out int64 +} + +func (f *decompressor) nextBlock() { + if f.checkpoint != nil { + // in is the offset of the NEXT byte to read from the + // underlying reader. The bits already loaded into f.b come + // from earlier bytes; on resume the caller positions the + // reader at in and explicitly pre-loads f.b/f.nb, so the + // already-buffered bits are not re-read. + in := f.roffset + var histCopy []byte + if f.dict.full { + histCopy = make([]byte, len(f.dict.hist)) + copy(histCopy, f.dict.hist[f.dict.wrPos:]) + copy(histCopy[len(f.dict.hist)-f.dict.wrPos:], f.dict.hist[:f.dict.wrPos]) + } else { + histCopy = make([]byte, f.dict.wrPos) + copy(histCopy, f.dict.hist[:f.dict.wrPos]) + } + f.checkpoint(in, f.out, f.b, f.nb, histCopy) + } + for f.nb < 1+2 { + if f.err = f.moreBits(); f.err != nil { + return + } + } + f.final = f.b&1 == 1 + f.b >>= 1 + typ := f.b & 3 + f.b >>= 2 + f.nb -= 1 + 2 + switch typ { + case 0: + f.dataBlock() + case 1: + // compressed, fixed Huffman tables + f.hl = &fixedHuffmanDecoder + f.hd = nil + f.huffmanBlock() + case 2: + // compressed, dynamic Huffman tables + if f.err = f.readHuffman(); f.err != nil { + break + } + f.hl = &f.h1 + f.hd = &f.h2 + f.huffmanBlock() + default: + // 3 is reserved. + f.err = CorruptInputError(f.roffset) + } +} + +func (f *decompressor) Read(b []byte) (int, error) { + for { + if len(f.toRead) > 0 { + n := copy(b, f.toRead) + f.toRead = f.toRead[n:] + if len(f.toRead) == 0 { + return n, f.err + } + return n, nil + } + if f.err != nil { + return 0, f.err + } + f.step(f) + if f.err != nil && len(f.toRead) == 0 { + f.toRead = f.dict.readFlush() // Flush what's left in case of error + } + } +} + +func (f *decompressor) Close() error { + if f.err == io.EOF { + return nil + } + return f.err +} + +// RFC 1951 section 3.2.7. +// Compression with dynamic Huffman codes + +var codeOrder = [...]int{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15} + +func (f *decompressor) readHuffman() error { + // HLIT[5], HDIST[5], HCLEN[4]. + for f.nb < 5+5+4 { + if err := f.moreBits(); err != nil { + return err + } + } + nlit := int(f.b&0x1F) + 257 + if nlit > maxNumLit { + return CorruptInputError(f.roffset) + } + f.b >>= 5 + ndist := int(f.b&0x1F) + 1 + if ndist > maxNumDist { + return CorruptInputError(f.roffset) + } + f.b >>= 5 + nclen := int(f.b&0xF) + 4 + // numCodes is 19, so nclen is always valid. + f.b >>= 4 + f.nb -= 5 + 5 + 4 + + // (HCLEN+4)*3 bits: code lengths in the magic codeOrder order. + for i := 0; i < nclen; i++ { + for f.nb < 3 { + if err := f.moreBits(); err != nil { + return err + } + } + f.codebits[codeOrder[i]] = int(f.b & 0x7) + f.b >>= 3 + f.nb -= 3 + } + for i := nclen; i < len(codeOrder); i++ { + f.codebits[codeOrder[i]] = 0 + } + if !f.h1.init(f.codebits[0:]) { + return CorruptInputError(f.roffset) + } + + // HLIT + 257 code lengths, HDIST + 1 code lengths, + // using the code length Huffman code. + for i, n := 0, nlit+ndist; i < n; { + x, err := f.huffSym(&f.h1) + if err != nil { + return err + } + if x < 16 { + // Actual length. + f.bits[i] = x + i++ + continue + } + // Repeat previous length or zero. + var rep int + var nb uint + var b int + switch x { + default: + return InternalError("unexpected length code") + case 16: + rep = 3 + nb = 2 + if i == 0 { + return CorruptInputError(f.roffset) + } + b = f.bits[i-1] + case 17: + rep = 3 + nb = 3 + b = 0 + case 18: + rep = 11 + nb = 7 + b = 0 + } + for f.nb < nb { + if err := f.moreBits(); err != nil { + return err + } + } + rep += int(f.b & uint32(1<>= nb + f.nb -= nb + if i+rep > n { + return CorruptInputError(f.roffset) + } + for j := 0; j < rep; j++ { + f.bits[i] = b + i++ + } + } + + if !f.h1.init(f.bits[0:nlit]) || !f.h2.init(f.bits[nlit:nlit+ndist]) { + return CorruptInputError(f.roffset) + } + + // As an optimization, we can initialize the min bits to read at a time + // for the HLIT tree to the length of the EOB marker since we know that + // every block must terminate with one. This preserves the property that + // we never read any extra bytes after the end of the DEFLATE stream. + if f.h1.min < f.bits[endBlockMarker] { + f.h1.min = f.bits[endBlockMarker] + } + + return nil +} + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanBlock() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + v, err := f.huffSym(f.hl) + if err != nil { + f.err = err + return + } + var n uint // number of bits extra + var length int + switch { + case v < 256: + f.dict.writeByte(byte(v)) + f.out++ + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBlock + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + n = 0 + case v < 269: + length = v*2 - (265*2 - 11) + n = 1 + case v < 273: + length = v*4 - (269*4 - 19) + n = 2 + case v < 277: + length = v*8 - (273*8 - 35) + n = 3 + case v < 281: + length = v*16 - (277*16 - 67) + n = 4 + case v < 285: + length = v*32 - (281*32 - 131) + n = 5 + case v < maxNumLit: + length = 258 + n = 0 + default: + f.err = CorruptInputError(f.roffset) + return + } + if n > 0 { + for f.nb < n { + if err = f.moreBits(); err != nil { + f.err = err + return + } + } + length += int(f.b & uint32(1<>= n + f.nb -= n + } + + var dist int + if f.hd == nil { + for f.nb < 5 { + if err = f.moreBits(); err != nil { + f.err = err + return + } + } + dist = int(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + if dist, err = f.huffSym(f.hd); err != nil { + f.err = err + return + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << nb + for f.nb < nb { + if err = f.moreBits(); err != nil { + f.err = err + return + } + } + extra |= int(f.b & uint32(1<>= nb + f.nb -= nb + dist = 1<<(nb+1) + 1 + extra + default: + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > f.dict.histSize() { + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, dist + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.out += int64(cnt) + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBlock // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +// Copy a single uncompressed data block from input to output. +func (f *decompressor) dataBlock() { + // Uncompressed. + // Discard current half-byte. + f.nb = 0 + f.b = 0 + + // Length then ones-complement of length. + nr, err := io.ReadFull(f.r, f.buf[0:4]) + f.roffset += int64(nr) + if err != nil { + f.err = noEOF(err) + return + } + n := int(f.buf[0]) | int(f.buf[1])<<8 + nn := int(f.buf[2]) | int(f.buf[3])<<8 + if uint16(nn) != uint16(^n) { + f.err = CorruptInputError(f.roffset) + return + } + + if n == 0 { + f.toRead = f.dict.readFlush() + f.finishBlock() + return + } + + f.copyLen = n + f.copyData() +} + +// copyData copies f.copyLen bytes from the underlying reader into f.hist. +// It pauses for reads when f.hist is full. +func (f *decompressor) copyData() { + buf := f.dict.writeSlice() + if len(buf) > f.copyLen { + buf = buf[:f.copyLen] + } + + cnt, err := io.ReadFull(f.r, buf) + f.roffset += int64(cnt) + f.copyLen -= cnt + f.dict.writeMark(cnt) + f.out += int64(cnt) + if err != nil { + f.err = noEOF(err) + return + } + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).copyData + return + } + f.finishBlock() +} + +func (f *decompressor) finishBlock() { + if f.final { + if f.dict.availRead() > 0 { + f.toRead = f.dict.readFlush() + } + f.err = io.EOF + } + f.step = (*decompressor).nextBlock +} + +// noEOF returns err, unless err == io.EOF, in which case it returns io.ErrUnexpectedEOF. +func noEOF(e error) error { + if e == io.EOF { + return io.ErrUnexpectedEOF + } + return e +} + +func (f *decompressor) moreBits() error { + c, err := f.r.ReadByte() + if err != nil { + return noEOF(err) + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + return nil +} + +// Read the next Huffman-encoded symbol from f according to h. +func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) { + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(h.min) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := f.r.ReadByte() + if err != nil { + f.b = b + f.nb = nb + return 0, noEOF(err) + } + f.roffset++ + b |= uint32(c) << (nb & 31) + nb += 8 + } + chunk := h.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = h.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&h.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + f.err = CorruptInputError(f.roffset) + return 0, f.err + } + f.b = b >> (n & 31) + f.nb = nb - n + return int(chunk >> huffmanValueShift), nil + } + } +} + +func (f *decompressor) makeReader(r io.Reader) { + if rr, ok := r.(Reader); ok { + f.rBuf = nil + f.r = rr + return + } + // Reuse rBuf if possible. Invariant: rBuf is always created (and owned) by decompressor. + if f.rBuf != nil { + f.rBuf.Reset(r) + } else { + // bufio.NewReader will not return r, as r does not implement flate.Reader, so it is not bufio.Reader. + f.rBuf = bufio.NewReader(r) + } + f.r = f.rBuf +} + +func fixedHuffmanDecoderInit() { + fixedOnce.Do(func() { + // These come from the RFC section 3.2.6. + var bits [288]int + for i := 0; i < 144; i++ { + bits[i] = 8 + } + for i := 144; i < 256; i++ { + bits[i] = 9 + } + for i := 256; i < 280; i++ { + bits[i] = 7 + } + for i := 280; i < 288; i++ { + bits[i] = 8 + } + fixedHuffmanDecoder.init(bits[:]) + }) +} + +func (f *decompressor) Reset(r io.Reader, dict []byte) error { + *f = decompressor{ + rBuf: f.rBuf, + bits: f.bits, + codebits: f.codebits, + dict: f.dict, + step: (*decompressor).nextBlock, + } + f.makeReader(r) + f.dict.init(maxMatchOffset, dict) + return nil +} + +// NewReader returns a new ReadCloser that can be used +// to read the uncompressed version of r. +// If r does not also implement [io.ByteReader], +// the decompressor may read more data than necessary from r. +// The reader returns [io.EOF] after the final block in the DEFLATE stream has +// been encountered. Any trailing data after the final block is ignored. +// +// The [io.ReadCloser] returned by NewReader also implements [Resetter]. +func NewReader(r io.Reader) io.ReadCloser { + fixedHuffmanDecoderInit() + + var f decompressor + f.makeReader(r) + f.bits = new([maxNumLit + maxNumDist]int) + f.codebits = new([numCodes]int) + f.step = (*decompressor).nextBlock + f.dict.init(maxMatchOffset, nil) + return &f +} + +// NewReaderDict is like [NewReader] but initializes the reader +// with a preset dictionary. The returned reader behaves as if +// the uncompressed data stream started with the given dictionary, +// which has already been read. NewReaderDict is typically used +// to read data compressed by [NewWriterDict]. +// +// The ReadCloser returned by NewReaderDict also implements [Resetter]. +func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { + fixedHuffmanDecoderInit() + + var f decompressor + f.makeReader(r) + f.bits = new([maxNumLit + maxNumDist]int) + f.codebits = new([numCodes]int) + f.step = (*decompressor).nextBlock + f.dict.init(maxMatchOffset, dict) + return &f +} + +// NewReaderCallback returns a [io.ReadCloser] like [NewReaderDict] that +// invokes fn at every DEFLATE block boundary, immediately before the next +// block's header bits are read. The callback receives: +// +// - in: compressed byte offset to seek to for resume. +// - out: total decompressed bytes produced so far. +// - b: current bit-buffer contents. +// - nb: number of valid bits in b. +// - hist: a copy of the (logical, in-order) sliding window — at most 32 KB. +// +// fn is invoked at every block boundary; callers are responsible for +// filtering by interval. dict, when non-nil, primes the sliding window +// (used to resume decompression from a previously captured checkpoint). +// +// If fn is nil, NewReaderCallback behaves identically to [NewReaderDict]. +func NewReaderCallback(r io.Reader, dict []byte, fn func(in, out int64, b uint32, nb uint, hist []byte)) io.ReadCloser { + return NewReaderResume(r, dict, 0, 0, fn) +} + +// NewReaderResume is like [NewReaderCallback] but additionally pre-loads +// the bit buffer from a previously captured checkpoint. b and nb come from +// the [Checkpoint]'s B and NB fields. r should be positioned at the +// checkpoint's compressed offset (In). dict, when non-nil, is the +// captured sliding-window history. +// +// fn may be nil if the caller does not need further checkpoints during +// the resumed read. +func NewReaderResume(r io.Reader, dict []byte, b uint32, nb uint, fn func(in, out int64, b uint32, nb uint, hist []byte)) io.ReadCloser { + fixedHuffmanDecoderInit() + + var f decompressor + f.makeReader(r) + f.bits = new([maxNumLit + maxNumDist]int) + f.codebits = new([numCodes]int) + f.step = (*decompressor).nextBlock + f.dict.init(maxMatchOffset, dict) + f.b = b + f.nb = nb + f.checkpoint = fn + return &f +} diff --git a/ocifs/gzipr/internal/flate/token.go b/ocifs/gzipr/internal/flate/token.go new file mode 100644 index 0000000..fc0e494 --- /dev/null +++ b/ocifs/gzipr/internal/flate/token.go @@ -0,0 +1,97 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +const ( + // 2 bits: type 0 = literal 1=EOF 2=Match 3=Unused + // 8 bits: xlength = length - MIN_MATCH_LENGTH + // 22 bits xoffset = offset - MIN_OFFSET_SIZE, or literal + lengthShift = 22 + offsetMask = 1< pair into a match token. +func matchToken(xlength uint32, xoffset uint32) token { + return token(matchType + xlength<> lengthShift) } + +func lengthCode(len uint32) uint32 { return lengthCodes[len] } + +// Returns the offset code corresponding to a specific offset. +func offsetCode(off uint32) uint32 { + if off < uint32(len(offsetCodes)) { + return offsetCodes[off] + } + if off>>7 < uint32(len(offsetCodes)) { + return offsetCodes[off>>7] + 14 + } + return offsetCodes[off>>14] + 28 +} diff --git a/ocifs/gzipr/reader.go b/ocifs/gzipr/reader.go new file mode 100644 index 0000000..aa1a779 --- /dev/null +++ b/ocifs/gzipr/reader.go @@ -0,0 +1,299 @@ +package gzipr + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "sync" + "sync/atomic" + + "github.com/docker/oci/ocifs/gzipr/internal/flate" +) + +// Reader provides random-access reads over a gzip stream backed by an +// [io.ReaderAt] and a previously-built [Index]. ReadAt is safe for +// concurrent use; the number of in-flight decompressions is capped via +// [WithMaxReaders]. +type Reader struct { + ra io.ReaderAt + idx *Index + size int64 // compressed blob size + + pool chan struct{} // bounded slot pool; receive = acquire, send = release + done chan struct{} // closed by Close to unblock waiting acquires + closed atomic.Bool + + closeOnce sync.Once +} + +// NewReader scans the entire gzip blob at ra (using sequential +// random-access reads) to build a checkpoint [Index], then returns a +// [Reader] backed by the same ra. size is the compressed blob size. +func NewReader(ra io.ReaderAt, size int64, opts ...Option) (*Reader, error) { + idx, err := Scan(io.NewSectionReader(ra, 0, size), io.Discard, opts...) + if err != nil { + return nil, err + } + return newReaderFromIndex(ra, idx, size, opts...), nil +} + +// NewReaderWithIndex returns a [Reader] backed by ra and the +// pre-computed [Index] idx. size is the compressed blob size; idx.Size +// is the decompressed size and is what [Reader.Size] returns. No I/O +// is performed. +// +// Panics if idx is nil, idx.Size <= 0, or idx.Checkpoints is nil. +// Callers who construct an Index from a persisted source (rather than +// from [Scan] or [NewReader]) should validate it before invoking this +// constructor. +func NewReaderWithIndex(ra io.ReaderAt, idx *Index, size int64, opts ...Option) *Reader { + if idx == nil { + panic("gzipr: NewReaderWithIndex: idx is nil") + } + if idx.Size <= 0 { + panic("gzipr: NewReaderWithIndex: idx.Size must be > 0") + } + if idx.Checkpoints == nil { + panic("gzipr: NewReaderWithIndex: idx.Checkpoints must not be nil") + } + return newReaderFromIndex(ra, idx, size, opts...) +} + +func newReaderFromIndex(ra io.ReaderAt, idx *Index, size int64, opts ...Option) *Reader { + cfg := applyOpts(opts) + r := &Reader{ + ra: ra, + idx: idx, + size: size, + pool: make(chan struct{}, cfg.maxReaders), + done: make(chan struct{}), + } + for i := 0; i < cfg.maxReaders; i++ { + r.pool <- struct{}{} + } + return r +} + +// Size returns the total decompressed size of the gzip stream. +func (r *Reader) Size() int64 { + if r.idx == nil { + return 0 + } + return r.idx.Size +} + +// Index returns the underlying [Index]. Callers must not mutate it. +func (r *Reader) Index() *Index { + return r.idx +} + +// Close releases the bounded reader pool. After Close, [Reader.ReadAt] +// returns [ErrClosed]; callers blocked acquiring a pool slot also +// receive [ErrClosed]. Close is idempotent. +func (r *Reader) Close() error { + r.closeOnce.Do(func() { + r.closed.Store(true) + close(r.done) + }) + return nil +} + +// acquire blocks until a pool slot is available or the reader is +// closed. It returns ErrClosed if the reader was closed before a slot +// was obtained. +func (r *Reader) acquire() error { + if r.closed.Load() { + return ErrClosed + } + select { + case <-r.pool: + if r.closed.Load() { + r.release() + return ErrClosed + } + return nil + case <-r.done: + return ErrClosed + } +} + +func (r *Reader) release() { + select { + case r.pool <- struct{}{}: + default: + panic("gzipr: pool over-release: release called without a matching acquire") + } +} + +// findFirst returns the index of the highest-Out checkpoint with +// Out <= off, or -1 if none exists. +func (r *Reader) findFirst(off int64) int { + cps := r.idx.Checkpoints + lo, hi := 0, len(cps) + for lo < hi { + mid := (lo + hi) / 2 + if cps[mid].Out <= off { + lo = mid + 1 + } else { + hi = mid + } + } + return lo - 1 +} + +// ReadAt implements [io.ReaderAt]. See the package design notes for the +// full algorithm. +func (r *Reader) ReadAt(p []byte, off int64) (int, error) { + if r.closed.Load() { + return 0, ErrClosed + } + if len(p) == 0 { + return 0, nil + } + if off < 0 { + return 0, fmt.Errorf("gzipr: negative offset %d", off) + } + size := r.Size() + if off >= size { + return 0, io.EOF + } + + target := int64(len(p)) + if remaining := size - off; target > remaining { + target = remaining + } + + if err := r.acquire(); err != nil { + return 0, err + } + defer r.release() + + cps := r.idx.Checkpoints + firstIdx := r.findFirst(off) + lastIdx := r.findFirst(off + target - 1) + + var ( + firstIn int64 + firstOut int64 + firstB uint32 + firstNB uint + firstHist []byte + ) + if firstIdx >= 0 { + c := cps[firstIdx] + firstIn = c.In + firstOut = c.Out + firstB = c.B + firstNB = c.NB + firstHist = c.Hist + } + + var lastNextIn int64 + switch { + case lastIdx >= 0 && lastIdx+1 < len(cps): + lastNextIn = cps[lastIdx+1].In + case lastIdx >= 0 && lastIdx+1 == len(cps): + lastNextIn = r.size + case lastIdx < 0 && len(cps) > 0: + lastNextIn = cps[0].In + default: + lastNextIn = r.size + } + + if firstIn >= lastNextIn { + return 0, ErrInvalidFormat + } + + compressed := make([]byte, lastNextIn-firstIn) + n, err := r.ra.ReadAt(compressed, firstIn) + if n < len(compressed) { + // Underlying ReadAt returned fewer bytes than requested. Use the + // error it provided; if it returned nil despite the short read, + // synthesize one. io.EOF is a valid io.ReaderAt return for an + // exact-length read (contract §ReadAt), so we only treat it as + // an error when the read is short. + if err == nil { + err = io.ErrUnexpectedEOF + } + return 0, err + } + // n == len(compressed); io.EOF here is the valid "exact-length" form + // of the io.ReaderAt contract — swallow it. + + // If we're starting from the virtual start-of-stream, we must skip + // the gzip header before handing bytes to the DEFLATE decompressor. + br := bufio.NewReader(bytes.NewReader(compressed)) + if firstIdx < 0 { + if _, err := readGzipHeader(br); err != nil { + return 0, err + } + } + + zr := flate.NewReaderResume(br, firstHist, firstB, firstNB, nil) + defer zr.Close() + + discard := off - firstOut + if discard < 0 { + return 0, fmt.Errorf("%w: negative discard %d", ErrInvalidFormat, discard) + } + if err := discardBytes(zr, discard); err != nil { + return 0, err + } + + accumulated := int64(0) + dst := p[:target] + for accumulated < target { + n, err := zr.Read(dst[accumulated:]) + accumulated += int64(n) + if accumulated >= target { + break + } + switch { + case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF): + return 0, ErrInvalidFormat + case err != nil: + return 0, err + } + } + + if target < int64(len(p)) { + return int(target), io.EOF + } + return int(target), nil +} + +// Compile-time interface checks. +var ( + _ io.ReaderAt = (*Reader)(nil) + _ io.Closer = (*Reader)(nil) +) + +// discardBytes advances zr past n bytes, applying the discard-phase +// rules from the package design. +func discardBytes(zr io.Reader, n int64) error { + if n == 0 { + return nil + } + buf := make([]byte, 32*1024) + remaining := n + for remaining > 0 { + want := int64(len(buf)) + if want > remaining { + want = remaining + } + k, err := zr.Read(buf[:want]) + remaining -= int64(k) + if remaining == 0 { + return nil + } + switch { + case errors.Is(err, io.EOF), errors.Is(err, io.ErrUnexpectedEOF): + return ErrInvalidFormat + case err != nil: + return err + } + } + return nil +} diff --git a/ocifs/gzipr/scan.go b/ocifs/gzipr/scan.go new file mode 100644 index 0000000..1f1f19f --- /dev/null +++ b/ocifs/gzipr/scan.go @@ -0,0 +1,187 @@ +package gzipr + +import ( + "bufio" + "encoding/binary" + "errors" + "fmt" + "io" + + "github.com/docker/oci/ocifs/gzipr/internal/flate" +) + +// gzipMagic is the two-byte gzip member-start signature (RFC 1952 §2.3.1). +var gzipMagic = [2]byte{0x1f, 0x8b} + +const ( + flagHCRC = 1 << 1 + flagExtra = 1 << 2 + flagName = 1 << 3 + flagComment = 1 << 4 +) + +// readGzipHeader consumes a single gzip member header from br and +// returns the number of bytes consumed. The reader is positioned at +// the first DEFLATE byte on success. +func readGzipHeader(br *bufio.Reader) (int64, error) { + var fixed [10]byte + if _, err := io.ReadFull(br, fixed[:]); err != nil { + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) { + return 0, fmt.Errorf("%w: short gzip header", ErrInvalidFormat) + } + return 0, err + } + if fixed[0] != gzipMagic[0] || fixed[1] != gzipMagic[1] { + return 0, fmt.Errorf("%w: bad gzip magic", ErrInvalidFormat) + } + if fixed[2] != 8 { + return 0, fmt.Errorf("%w: unknown gzip method %d", ErrInvalidFormat, fixed[2]) + } + flag := fixed[3] + consumed := int64(10) + + if flag&flagExtra != 0 { + var x [2]byte + if _, err := io.ReadFull(br, x[:]); err != nil { + return consumed, fmt.Errorf("%w: truncated FEXTRA", ErrInvalidFormat) + } + consumed += 2 + xlen := int(binary.LittleEndian.Uint16(x[:])) + if _, err := io.CopyN(io.Discard, br, int64(xlen)); err != nil { + return consumed, fmt.Errorf("%w: truncated FEXTRA payload", ErrInvalidFormat) + } + consumed += int64(xlen) + } + if flag&flagName != 0 { + n, err := discardNullTerminated(br) + consumed += n + if err != nil { + return consumed, fmt.Errorf("%w: truncated FNAME", ErrInvalidFormat) + } + } + if flag&flagComment != 0 { + n, err := discardNullTerminated(br) + consumed += n + if err != nil { + return consumed, fmt.Errorf("%w: truncated FCOMMENT", ErrInvalidFormat) + } + } + if flag&flagHCRC != 0 { + var c [2]byte + if _, err := io.ReadFull(br, c[:]); err != nil { + return consumed, fmt.Errorf("%w: truncated FHCRC", ErrInvalidFormat) + } + consumed += 2 + } + return consumed, nil +} + +// maxNullTerminatedLen caps the length of null-terminated FNAME/FCOMMENT +// fields to prevent a crafted blob from looping indefinitely. +const maxNullTerminatedLen = 65536 + +func discardNullTerminated(br *bufio.Reader) (int64, error) { + var n int64 + for { + if n >= maxNullTerminatedLen { + return n, fmt.Errorf("%w: FNAME/FCOMMENT field exceeds %d bytes", ErrInvalidFormat, maxNullTerminatedLen) + } + b, err := br.ReadByte() + if err != nil { + return n, err + } + n++ + if b == 0 { + return n, nil + } + } +} + +// Scan performs a single sequential pass over the gzip stream r, +// writing the decompressed bytes to out and recording a [Checkpoint] +// at every DEFLATE block boundary whose decompressed offset is at +// least one configured span past the previous checkpoint. +// +// Write errors from out are returned unwrapped so callers can +// distinguish them (e.g. [io.ErrClosedPipe]) from format errors. +// Format errors wrap [ErrInvalidFormat]. +// +// When the stream decompresses to fewer bytes than one span, the +// returned Index has a non-nil empty Checkpoints slice. +func Scan(r io.Reader, out io.Writer, opts ...Option) (*Index, error) { + cfg := applyOpts(opts) + + br := bufio.NewReader(r) + + headerLen, err := readGzipHeader(br) + if err != nil { + return nil, err + } + // Compressed offset of the first DEFLATE byte. cr counts every byte + // consumed from r; bufio.Reader may have buffered ahead, so we use + // the explicit headerLen rather than cr.n here. + deflateStart := headerLen + + idx := &Index{Checkpoints: []*Checkpoint{}} + + var ( + writeErr error + emittedOut int64 = -1 // ensures first eligible boundary qualifies + ) + checkpoint := func(in, outOff int64, b uint32, nb uint, hist []byte) { + if writeErr != nil { + return + } + // Skip the start-of-stream boundary: ReadAt handles the + // [0, firstCheckpoint.Out) range implicitly, so emitting a + // checkpoint here would be redundant. + if outOff == 0 && in == 0 { + emittedOut = 0 + return + } + if outOff-emittedOut < cfg.span { + return + } + histCopy := make([]byte, len(hist)) + copy(histCopy, hist) + idx.Checkpoints = append(idx.Checkpoints, &Checkpoint{ + In: in + deflateStart, + Out: outOff, + B: b, + NB: nb, + Hist: histCopy, + }) + emittedOut = outOff + } + + zr := flate.NewReaderCallback(br, nil, checkpoint) + defer zr.Close() + + buf := make([]byte, 32*1024) + var total int64 + for { + n, rerr := zr.Read(buf) + if n > 0 { + if _, werr := out.Write(buf[:n]); werr != nil { + writeErr = werr + return nil, werr + } + total += int64(n) + } + if rerr == io.EOF { + break + } + if rerr != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidFormat, rerr) + } + } + + // We intentionally do not validate the gzip trailer (CRC32 + ISIZE). + // The DEFLATE stream may end mid-byte, leaving leftover bits in the + // flate reader's buffer; the underlying byte stream has already been + // advanced past the last fully-consumed DEFLATE byte. Trailer + // validation would require careful re-alignment that is not needed + // for our random-access use case. + idx.Size = total + return idx, nil +} diff --git a/ocifs/imageidx.go b/ocifs/imageidx.go new file mode 100644 index 0000000..08e3cc6 --- /dev/null +++ b/ocifs/imageidx.go @@ -0,0 +1,121 @@ +package ocifs + +import ( + "encoding/json" + "errors" + "io" + + "github.com/docker/oci" + "github.com/docker/oci/ocifs/gzipr" + "github.com/docker/oci/ocifs/tarfs" + "github.com/docker/oci/ocifs/zstdr" +) + +// ErrIndexStale is returned by [NewWithIndex] when the manifest currently +// served by the registry no longer matches the persisted index (e.g. +// because the tag was re-pushed or a layer's digest changed). Callers are +// expected to fall back to [New] to rebuild the index. +var ErrIndexStale = errors.New("ocifs: image has changed since index was built") + +// LayerIndex captures everything required to reconstruct one layer's +// random-access reader from the registry without re-scanning the +// compressed blob. Exactly one of GzipIndex or ZstdIndex is populated, +// matching the layer's MediaType. +type LayerIndex struct { + Digest oci.Digest `json:"digest"` + MediaType string `json:"mediaType"` + Size int64 `json:"size"` + GzipIndex *gzipr.Index `json:"gzip,omitempty"` + ZstdIndex *zstdr.Index `json:"zstd,omitempty"` + TarIndex []tarfs.Entry `json:"tar"` +} + +// ImageIndex bundles per-layer indices for a single OCI image. +// +// The index does NOT include the image config blob; that is re-fetched +// from the registry by [NewWithIndex]. +type ImageIndex struct { + Layers []LayerIndex `json:"layers"` +} + +// Clone returns a deep copy of idx. Mutating the result does not affect idx +// or any FS that produced idx via ImageIndex. +func (idx *ImageIndex) Clone() *ImageIndex { + if idx == nil { + return nil + } + out := &ImageIndex{Layers: make([]LayerIndex, len(idx.Layers))} + for i := range idx.Layers { + out.Layers[i] = cloneLayerIndex(idx.Layers[i]) + } + return out +} + +// Encode writes the JSON encoding of idx to w. +func (idx *ImageIndex) Encode(w io.Writer) error { + return json.NewEncoder(w).Encode(idx) +} + +// DecodeImageIndex reads a JSON-encoded [ImageIndex] from r. +func DecodeImageIndex(r io.Reader) (*ImageIndex, error) { + var idx ImageIndex + if err := json.NewDecoder(r).Decode(&idx); err != nil { + return nil, err + } + return &idx, nil +} + +func cloneLayerIndex(in LayerIndex) LayerIndex { + out := LayerIndex{ + Digest: in.Digest, + MediaType: in.MediaType, + Size: in.Size, + GzipIndex: cloneGzipIndex(in.GzipIndex), + ZstdIndex: cloneZstdIndex(in.ZstdIndex), + } + if in.TarIndex != nil { + out.TarIndex = make([]tarfs.Entry, len(in.TarIndex)) + copy(out.TarIndex, in.TarIndex) + } + return out +} + +func cloneGzipIndex(in *gzipr.Index) *gzipr.Index { + if in == nil { + return nil + } + out := &gzipr.Index{ + Size: in.Size, + Checkpoints: make([]*gzipr.Checkpoint, len(in.Checkpoints)), + } + for i, cp := range in.Checkpoints { + if cp == nil { + continue + } + cpCopy := *cp + if cp.Hist != nil { + cpCopy.Hist = make([]byte, len(cp.Hist)) + copy(cpCopy.Hist, cp.Hist) + } + out.Checkpoints[i] = &cpCopy + } + return out +} + +func cloneZstdIndex(in *zstdr.Index) *zstdr.Index { + if in == nil { + return nil + } + out := &zstdr.Index{ + Size: in.Size, + Frames: make([]*zstdr.FrameCheckpoint, len(in.Frames)), + } + for i, fr := range in.Frames { + if fr == nil { + continue + } + frCopy := *fr + out.Frames[i] = &frCopy + } + return out +} diff --git a/ocifs/mediatype.go b/ocifs/mediatype.go new file mode 100644 index 0000000..6a13a02 --- /dev/null +++ b/ocifs/mediatype.go @@ -0,0 +1,79 @@ +package ocifs + +import ( + "errors" + "fmt" + "strings" +) + +// Manifest media types recognised by ocifs. +const ( + // MediaTypeOCIManifest is the OCI image manifest media type. + MediaTypeOCIManifest = "application/vnd.oci.image.manifest.v1+json" + // MediaTypeOCIIndex is the OCI image index media type. ocifs returns + // ErrUnsupportedManifest when this is encountered (multi-platform + // resolution is not implemented in v1). + MediaTypeOCIIndex = "application/vnd.oci.image.index.v1+json" + // MediaTypeDockerManifest is the Docker v2 manifest media type. + MediaTypeDockerManifest = "application/vnd.docker.distribution.manifest.v2+json" + // MediaTypeDockerManifestList is the Docker v2 manifest list media type. + // ocifs returns ErrUnsupportedManifest when this is encountered. + MediaTypeDockerManifestList = "application/vnd.docker.distribution.manifest.list.v2+json" +) + +// Layer media types recognised by ocifs. +const ( + // MediaTypeLayerOCIGzip is the OCI gzipped layer media type. + MediaTypeLayerOCIGzip = "application/vnd.oci.image.layer.v1.tar+gzip" + // MediaTypeLayerDockerGzip is the Docker rootfs diff (gzipped) layer media type. + MediaTypeLayerDockerGzip = "application/vnd.docker.image.rootfs.diff.tar.gzip" + // MediaTypeLayerOCIZstd is the OCI zstd-compressed layer media type. + MediaTypeLayerOCIZstd = "application/vnd.oci.image.layer.v1.tar+zstd" +) + +// ErrUnsupportedMediaType is returned when a layer's MediaType is not one of +// the gzip/zstd variants supported by ocifs. Callers can test with +// [errors.Is]. +var ErrUnsupportedMediaType = errors.New("ocifs: unsupported layer media type") + +// ErrUnsupportedManifest is returned when the resolved manifest's MediaType +// is not a supported single-image manifest (e.g. an OCI image index). +var ErrUnsupportedManifest = errors.New("ocifs: unsupported manifest type") + +// layerCompression identifies the compression algorithm used by a layer. +type layerCompression int + +const ( + compressionNone layerCompression = iota + compressionGzip + compressionZstd +) + +// classifyLayerMediaType returns the compression algorithm associated with +// mt, or compressionNone (and a wrapped ErrUnsupportedMediaType) if mt is +// not recognized. +func classifyLayerMediaType(mt string) (layerCompression, error) { + switch mt { + case MediaTypeLayerOCIGzip, MediaTypeLayerDockerGzip: + return compressionGzip, nil + case MediaTypeLayerOCIZstd: + return compressionZstd, nil + } + return compressionNone, fmt.Errorf("%w: %s", ErrUnsupportedMediaType, mt) +} + +const ( + // whiteoutPrefix is the prefix attached to the deleted file's name in an + // OCI overlay whiteout entry. + whiteoutPrefix = ".wh." + // whiteoutOpaque is the special whiteout filename that hides every entry + // in the parent directory below this layer. + whiteoutOpaque = ".wh..wh..opq" +) + +// whiteoutTarget returns the name of the entry being deleted by a whiteout +// marker. It assumes base has the whiteoutPrefix ".wh." prefix and is not +// the opaque marker. +func whiteoutTarget(base string) string { + return strings.TrimPrefix(base, whiteoutPrefix) +} diff --git a/ocifs/ocifs.go b/ocifs/ocifs.go new file mode 100644 index 0000000..21a4fba --- /dev/null +++ b/ocifs/ocifs.go @@ -0,0 +1,459 @@ +// Package ocifs provides an io/fs.FS implementation backed by an OCI image. +// +// FS composes the per-layer io.ReaderAt readers built by the blobra, gzipr, +// zstdr, and tarfs sub-packages into a single overlay filesystem reflecting +// the merged contents of all image layers (with whiteout semantics). +package ocifs + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "sync/atomic" + + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/docker/oci" + "github.com/docker/oci/ocifs/blobra" + "github.com/docker/oci/ocifs/gzipr" + "github.com/docker/oci/ocifs/tarfs" + "github.com/docker/oci/ocifs/zstdr" + "github.com/docker/oci/ociref" +) + +// ErrFSClosed is returned (wrapped in *fs.PathError) by Open/ReadDir/Stat/ +// ReadFile after Close has been called on the FS. +var ErrFSClosed = errors.New("ocifs: filesystem has been closed") + +// ErrSymlinkLoop is returned (wrapped in *fs.PathError) when a path +// resolution exceeds the per-call symlink hop budget (255). +var ErrSymlinkLoop = errors.New("ocifs: too many levels of symbolic links") + +// LayerReader is the interface satisfied by both *gzipr.Reader and +// *zstdr.Reader. It is the only surface tarfs.FS and the overlay need +// from the decompressor. +type LayerReader interface { + io.ReaderAt + io.Closer + Size() int64 +} + +// layerState bundles the per-layer readers built during New. +type layerState struct { + desc oci.Descriptor + decompRA LayerReader + tarFS *tarfs.FS + gzipIndex *gzipr.Index + zstdIndex *zstdr.Index + tarIndex []tarfs.Entry +} + +// FS is an io/fs.FS over an OCI image's merged layer contents. It is +// constructed by [New] or [NewWithIndex] and is safe for concurrent use +// after construction. Call [FS.Close] to release decompressor pools. +type FS struct { + layers []layerState + index overlayIndex + dirs overlayDirs + config []byte + closed atomic.Bool +} + +// ConfigBlob returns the raw image config JSON. It is always available on +// a successfully constructed FS, including after Close. The returned slice +// is owned by the FS; callers must not modify it. +func (f *FS) ConfigBlob() []byte { return f.config } + +// ImageIndex returns a snapshot of the per-layer indices that can be +// persisted via [ImageIndex.Encode] and reused with [NewWithIndex]. +// The returned struct is a deep copy; callers may mutate it without +// affecting the FS. +func (f *FS) ImageIndex() *ImageIndex { + out := &ImageIndex{Layers: make([]LayerIndex, len(f.layers))} + for i, l := range f.layers { + out.Layers[i] = cloneLayerIndex(LayerIndex{ + Digest: l.desc.Digest, + MediaType: l.desc.MediaType, + Size: l.desc.Size, + GzipIndex: l.gzipIndex, + ZstdIndex: l.zstdIndex, + TarIndex: l.tarIndex, + }) + } + return out +} + +// Close releases decompressor reader pools. After Close, all Open/ReadDir/ +// Stat/ReadFile calls return *fs.PathError wrapping ErrFSClosed. Close is +// idempotent and always returns nil; ConfigBlob and ImageIndex remain safe +// to call. +func (f *FS) Close() error { + if !f.closed.CompareAndSwap(false, true) { + return nil + } + for i := range f.layers { + if f.layers[i].decompRA != nil { + _ = f.layers[i].decompRA.Close() + } + } + return nil +} + +// New resolves the image at name@ref, downloads each layer once to build +// its checkpoint and tar entry indices, and returns a ready-to-use FS. +// The context is captured by the underlying blob readers and governs all +// subsequent registry range requests issued by file Read/ReadAt calls on +// the returned FS; cancel it to abort in-flight reads. +func New(ctx context.Context, reg oci.Interface, name string, ref string) (*FS, error) { + manifestBytes, manifestMT, err := fetchManifest(ctx, reg, name, ref) + if err != nil { + return nil, err + } + manifest, err := decodeManifest(manifestBytes, manifestMT) + if err != nil { + return nil, err + } + configBytes, err := fetchConfig(ctx, reg, name, manifest.Config.Digest) + if err != nil { + return nil, err + } + + layers, err := buildLayers(ctx, reg, name, manifest.Layers) + if err != nil { + return nil, err + } + + return finalizeFS(configBytes, layers), nil +} + +// NewWithIndex re-uses a previously persisted ImageIndex to skip layer +// scanning. The manifest is still re-fetched so the layer digests can be +// verified; if the digests do not match the persisted index, ErrIndexStale +// is returned and the caller should fall back to [New]. +func NewWithIndex(ctx context.Context, reg oci.Interface, name string, ref string, idx *ImageIndex) (*FS, error) { + if idx == nil { + return nil, errors.New("ocifs: NewWithIndex: idx is nil") + } + idx = idx.Clone() + manifestBytes, manifestMT, err := fetchManifest(ctx, reg, name, ref) + if err != nil { + return nil, err + } + manifest, err := decodeManifest(manifestBytes, manifestMT) + if err != nil { + return nil, err + } + if len(manifest.Layers) != len(idx.Layers) { + return nil, ErrIndexStale + } + for i, ml := range manifest.Layers { + li := idx.Layers[i] + if ml.Digest != li.Digest { + return nil, ErrIndexStale + } + } + configBytes, err := fetchConfig(ctx, reg, name, manifest.Config.Digest) + if err != nil { + return nil, err + } + + layers := make([]layerState, len(manifest.Layers)) + for i, ml := range manifest.Layers { + li := idx.Layers[i] + comp, err := classifyLayerMediaType(li.MediaType) + if err != nil { + return nil, err + } + if li.Size != ml.Size { + return nil, ErrIndexStale + } + if li.MediaType != ml.MediaType { + return nil, ErrIndexStale + } + if li.TarIndex == nil { + return nil, ErrIndexStale + } + switch comp { + case compressionGzip: + if li.GzipIndex == nil || li.ZstdIndex != nil { + return nil, ErrIndexStale + } + if li.GzipIndex.Size <= 0 { + return nil, ErrIndexStale + } + if li.GzipIndex.Checkpoints == nil { + return nil, ErrIndexStale + } + // Validate individual checkpoint fields so a tampered index + // cannot trigger unbounded allocations or silent data corruption. + for _, cp := range li.GzipIndex.Checkpoints { + if cp == nil || cp.In < 0 || cp.In > ml.Size || cp.Out < 0 || cp.NB > 32 { + return nil, ErrIndexStale + } + } + case compressionZstd: + if li.ZstdIndex == nil || li.GzipIndex != nil { + return nil, ErrIndexStale + } + if li.ZstdIndex.Size <= 0 { + return nil, ErrIndexStale + } + if li.ZstdIndex.Frames == nil { + return nil, ErrIndexStale + } + for _, fr := range li.ZstdIndex.Frames { + if fr == nil || fr.In < 0 || fr.In > ml.Size || fr.Out < 0 { + return nil, ErrIndexStale + } + } + } + // Validate TarIndex entry fields to guard against tampered persisted data. + for _, e := range li.TarIndex { + if e.Offset < 0 || e.Header.Size < 0 { + return nil, ErrIndexStale + } + } + desc := oci.Descriptor{ + Digest: li.Digest, + MediaType: li.MediaType, + Size: ml.Size, + } + blobRA := blobra.New(ctx, reg, name, desc) + var decompRA LayerReader + switch comp { + case compressionGzip: + decompRA = gzipr.NewReaderWithIndex(blobRA, li.GzipIndex, ml.Size) + case compressionZstd: + decompRA = zstdr.NewReaderWithIndex(blobRA, li.ZstdIndex, ml.Size) + default: + // classifyLayerMediaType already returned an error for unknown + // types, so this branch is unreachable in correct usage. + return nil, fmt.Errorf("ocifs: layer %d: unsupported compression type", i) + } + tarFS, err := tarfs.NewFromEntries(decompRA, li.TarIndex) + if err != nil { + // Close the just-allocated decompressor and all prior ones + // before returning to avoid leaking readers. + if decompRA != nil { + _ = decompRA.Close() + } + for j := 0; j < i; j++ { + if layers[j].decompRA != nil { + _ = layers[j].decompRA.Close() + } + } + return nil, fmt.Errorf("ocifs: layer %d: %w", i, err) + } + layers[i] = layerState{ + desc: desc, + decompRA: decompRA, + tarFS: tarFS, + gzipIndex: li.GzipIndex, + zstdIndex: li.ZstdIndex, + tarIndex: li.TarIndex, + } + } + + return finalizeFS(configBytes, layers), nil +} + +// finalizeFS builds the overlay maps and assembles the final *FS. +func finalizeFS(configBytes []byte, layers []layerState) *FS { + entriesByLayer := make([][]tarfs.Entry, len(layers)) + for i, l := range layers { + entriesByLayer[i] = l.tarFS.Entries() + } + idx, dirs := buildOverlay(entriesByLayer) + return &FS{ + layers: layers, + index: idx, + dirs: dirs, + config: configBytes, + } +} + +// fetchManifest fetches the manifest by tag or digest and returns the raw +// JSON, the manifest's media type, and any error. +func fetchManifest(ctx context.Context, reg oci.Interface, name string, ref string) ([]byte, string, error) { + var br oci.BlobReader + var err error + if ociref.IsValidDigest(ref) { + br, err = reg.GetManifest(ctx, name, oci.Digest(ref)) + } else { + br, err = reg.GetTag(ctx, name, ref) + } + if err != nil { + return nil, "", err + } + defer br.Close() + body, err := io.ReadAll(br) + if err != nil { + return nil, "", err + } + mt := br.Descriptor().MediaType + return body, mt, nil +} + +// decodeManifest parses the manifest payload and verifies that it is a +// supported single-image manifest. +func decodeManifest(body []byte, mt string) (*ocispec.Manifest, error) { + // If the descriptor lacks a media type (e.g. some implementations don't + // set it on tag responses), peek at the JSON itself. + if mt == "" { + var probe struct { + MediaType string `json:"mediaType"` + } + _ = json.Unmarshal(body, &probe) + mt = probe.MediaType + } + switch mt { + case MediaTypeOCIIndex, MediaTypeDockerManifestList: + return nil, fmt.Errorf("%w: %s", ErrUnsupportedManifest, mt) + case MediaTypeOCIManifest, MediaTypeDockerManifest, "": + // Fall through: many registries omit the embedded mediaType field; + // allow empty if the descriptor was also empty. + default: + return nil, fmt.Errorf("%w: %s", ErrUnsupportedManifest, mt) + } + var m ocispec.Manifest + if err := json.Unmarshal(body, &m); err != nil { + return nil, fmt.Errorf("ocifs: decode manifest: %w", err) + } + if m.Config.Digest == "" { + return nil, fmt.Errorf("ocifs: manifest is missing config") + } + return &m, nil +} + +// fetchConfig retrieves and buffers the image config blob. +func fetchConfig(ctx context.Context, reg oci.Interface, name string, dgst oci.Digest) ([]byte, error) { + br, err := reg.GetBlob(ctx, name, dgst) + if err != nil { + return nil, err + } + defer br.Close() + return io.ReadAll(br) +} + +// buildLayers performs the single-pass layer construction loop. +func buildLayers(ctx context.Context, reg oci.Interface, name string, descs []oci.Descriptor) ([]layerState, error) { + out := make([]layerState, 0, len(descs)) + for i, desc := range descs { + layer, err := buildLayer(ctx, reg, name, desc) + if err != nil { + // Close any partial layers built so far. + for _, l := range out { + if l.decompRA != nil { + _ = l.decompRA.Close() + } + } + return nil, fmt.Errorf("ocifs: layer %d: %w", i, err) + } + out = append(out, layer) + } + return out, nil +} + +// buildLayer performs single-pass construction for a single layer +// descriptor. It downloads the compressed blob exactly once via GetBlob, +// and pipes the decompressed stream to tarfs.Index while gzipr.Scan or +// zstdr.Scan emits checkpoints. On success it returns a fully-wired +// layerState whose decompRA is backed by blobra (random-access). +func buildLayer(ctx context.Context, reg oci.Interface, name string, desc oci.Descriptor) (layerState, error) { + comp, err := classifyLayerMediaType(desc.MediaType) + if err != nil { + return layerState{}, err + } + + br, err := reg.GetBlob(ctx, name, desc.Digest) + if err != nil { + return layerState{}, err + } + defer br.Close() + + pr, pw := io.Pipe() + + scanErr := make(chan error, 1) + var ( + gzipIdx *gzipr.Index + zstdIdx *zstdr.Index + ) + go func() { + var serr error + switch comp { + case compressionGzip: + gzipIdx, serr = gzipr.Scan(br, pw) + case compressionZstd: + zstdIdx, serr = zstdr.Scan(br, pw) + } + pw.CloseWithError(serr) + scanErr <- serr + }() + + entries, indexErr := tarfs.Index(pr) + if indexErr != nil { + pr.CloseWithError(indexErr) + } else { + go func() { + // Drain the pipe so the scanner goroutine is never + // blocked. Errors are ignored: all outcomes are delivered + // via the scanErr channel below. + io.Copy(io.Discard, pr) //nolint:errcheck + pr.Close() + }() + } + scanErrVal := <-scanErr + + switch { + case indexErr == nil && scanErrVal == nil: + // success + case indexErr == nil: + return layerState{}, scanErrVal + case scanErrVal == nil: + return layerState{}, indexErr + case errors.Is(scanErrVal, indexErr): + return layerState{}, indexErr + default: + // Both sides failed independently. Prefer indexErr: the tar parse + // error is typically the root cause (the scan error is often a + // pipe-closed propagation of indexErr). + return layerState{}, indexErr + } + + // Build the random-access readers from the freshly-built indices. + blobRA := blobra.New(ctx, reg, name, desc) + var decompRA LayerReader + switch comp { + case compressionGzip: + // Guard against an invalid/corrupt gzip index that would + // otherwise panic inside NewReaderWithIndex. + if gzipIdx == nil || gzipIdx.Size <= 0 || gzipIdx.Checkpoints == nil { + return layerState{}, fmt.Errorf("ocifs: gzip layer has invalid decompression index") + } + decompRA = gzipr.NewReaderWithIndex(blobRA, gzipIdx, desc.Size) + case compressionZstd: + // Guard against an invalid/corrupt zstd index. + if zstdIdx == nil || zstdIdx.Size <= 0 || zstdIdx.Frames == nil { + return layerState{}, fmt.Errorf("ocifs: zstd layer has invalid decompression index") + } + decompRA = zstdr.NewReaderWithIndex(blobRA, zstdIdx, desc.Size) + default: + // Unreachable: classifyLayerMediaType already rejected unknown types. + return layerState{}, fmt.Errorf("ocifs: unsupported compression type %d", comp) + } + tarFS, err := tarfs.NewFromEntries(decompRA, entries) + if err != nil { + _ = decompRA.Close() + return layerState{}, err + } + return layerState{ + desc: desc, + decompRA: decompRA, + tarFS: tarFS, + gzipIndex: gzipIdx, + zstdIndex: zstdIdx, + tarIndex: entries, + }, nil +} diff --git a/ocifs/ocifs_test.go b/ocifs/ocifs_test.go new file mode 100644 index 0000000..9bbd51c --- /dev/null +++ b/ocifs/ocifs_test.go @@ -0,0 +1,707 @@ +package ocifs + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "encoding/json" + "errors" + "io/fs" + "strings" + "testing" + "testing/fstest" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/opencontainers/go-digest" + ocispec "github.com/opencontainers/image-spec/specs-go/v1" + + "github.com/docker/oci" + "github.com/docker/oci/ocimem" +) + +// tarFile describes a single entry to write into a synthetic test layer. +type tarFile struct { + Name string + Linkname string + Mode int64 + Type byte + Body string +} + +// buildLayerTar serializes the supplied entries into a tar bytestream. +func buildLayerTar(t *testing.T, files []tarFile) []byte { + t.Helper() + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + for _, f := range files { + mode := f.Mode + if mode == 0 { + switch f.Type { + case tar.TypeDir: + mode = 0o755 + case tar.TypeSymlink: + mode = 0o777 + default: + mode = 0o644 + } + } + hdr := &tar.Header{ + Name: f.Name, + Linkname: f.Linkname, + Mode: mode, + Typeflag: f.Type, + Size: int64(len(f.Body)), + ModTime: time.Unix(1700000000, 0).UTC(), + } + if f.Type == tar.TypeDir || f.Type == tar.TypeSymlink || f.Type == tar.TypeLink { + hdr.Size = 0 + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("tar header %q: %v", f.Name, err) + } + if hdr.Size > 0 { + if _, err := tw.Write([]byte(f.Body)); err != nil { + t.Fatalf("tar body %q: %v", f.Name, err) + } + } + } + if err := tw.Close(); err != nil { + t.Fatalf("tar close: %v", err) + } + return buf.Bytes() +} + +// gzipBytes gzips data with default compression. +func gzipBytes(t *testing.T, data []byte) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + if _, err := gw.Write(data); err != nil { + t.Fatalf("gzip write: %v", err) + } + if err := gw.Close(); err != nil { + t.Fatalf("gzip close: %v", err) + } + return buf.Bytes() +} + +// zstdBytes encodes data as a zstd stream. +func zstdBytes(t *testing.T, data []byte) []byte { + t.Helper() + var buf bytes.Buffer + zw, err := zstd.NewWriter(&buf) + if err != nil { + t.Fatalf("zstd writer: %v", err) + } + if _, err := zw.Write(data); err != nil { + t.Fatalf("zstd write: %v", err) + } + if err := zw.Close(); err != nil { + t.Fatalf("zstd close: %v", err) + } + return buf.Bytes() +} + +// pushBlob pushes data to the registry as a blob and returns the descriptor. +func pushBlob(t *testing.T, ctx context.Context, reg oci.Interface, repo, mediaType string, data []byte) oci.Descriptor { + t.Helper() + desc := oci.Descriptor{ + MediaType: mediaType, + Digest: digest.FromBytes(data), + Size: int64(len(data)), + } + out, err := reg.PushBlob(ctx, repo, desc, bytes.NewReader(data)) + if err != nil { + t.Fatalf("push blob: %v", err) + } + return out +} + +// pushImage assembles a multi-layer image, pushes the layers, config and +// manifest, and returns the manifest digest. +func pushImage(t *testing.T, ctx context.Context, reg oci.Interface, repo, tag string, layers [][]tarFile) oci.Digest { + t.Helper() + return pushImageWithMediaType(t, ctx, reg, repo, tag, MediaTypeLayerOCIGzip, layers) +} + +// pushImageWithMediaType is like pushImage but uses the supplied layer +// media type (and matching compressor) for every layer. +func pushImageWithMediaType(t *testing.T, ctx context.Context, reg oci.Interface, repo, tag, mediaType string, layers [][]tarFile) oci.Digest { + t.Helper() + layerDescs := make([]oci.Descriptor, len(layers)) + for i, files := range layers { + raw := buildLayerTar(t, files) + var compressed []byte + switch mediaType { + case MediaTypeLayerOCIGzip, MediaTypeLayerDockerGzip: + compressed = gzipBytes(t, raw) + case MediaTypeLayerOCIZstd: + compressed = zstdBytes(t, raw) + default: + compressed = raw + } + layerDescs[i] = pushBlob(t, ctx, reg, repo, mediaType, compressed) + } + + cfg := ocispec.Image{} + cfgBytes, err := json.Marshal(cfg) + if err != nil { + t.Fatalf("marshal config: %v", err) + } + cfgDesc := pushBlob(t, ctx, reg, repo, ocispec.MediaTypeImageConfig, cfgBytes) + + mf := ocispec.Manifest{ + MediaType: MediaTypeOCIManifest, + Config: cfgDesc, + Layers: layerDescs, + } + mf.SchemaVersion = 2 + mfBytes, err := json.Marshal(mf) + if err != nil { + t.Fatalf("marshal manifest: %v", err) + } + mDesc, err := reg.PushManifest(ctx, repo, mfBytes, MediaTypeOCIManifest, &oci.PushManifestParameters{Tags: []string{tag}}) + if err != nil { + t.Fatalf("push manifest: %v", err) + } + return mDesc.Digest +} + +func TestNew_SingleLayer(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + { + {Name: "hello.txt", Type: tar.TypeReg, Body: "hello world"}, + {Name: "etc/", Type: tar.TypeDir}, + {Name: "etc/version", Type: tar.TypeReg, Body: "1.0"}, + }, + }) + + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + + data, err := f.ReadFile("hello.txt") + if err != nil { + t.Fatalf("ReadFile hello.txt: %v", err) + } + if string(data) != "hello world" { + t.Errorf("hello.txt = %q, want %q", data, "hello world") + } + + data, err = f.ReadFile("etc/version") + if err != nil { + t.Fatalf("ReadFile etc/version: %v", err) + } + if string(data) != "1.0" { + t.Errorf("etc/version = %q, want %q", data, "1.0") + } +} + +func TestNew_MultiLayerOverwrite(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + { + {Name: "a.txt", Type: tar.TypeReg, Body: "lower"}, + {Name: "b.txt", Type: tar.TypeReg, Body: "lower-b"}, + }, + { + {Name: "a.txt", Type: tar.TypeReg, Body: "upper"}, + }, + }) + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + + got, _ := f.ReadFile("a.txt") + if string(got) != "upper" { + t.Errorf("a.txt = %q, want upper", got) + } + got, _ = f.ReadFile("b.txt") + if string(got) != "lower-b" { + t.Errorf("b.txt = %q, want lower-b", got) + } +} + +func TestNew_Whiteout(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + { + {Name: "delete-me.txt", Type: tar.TypeReg, Body: "old"}, + {Name: "keep.txt", Type: tar.TypeReg, Body: "kept"}, + }, + { + {Name: ".wh.delete-me.txt", Type: tar.TypeReg, Body: ""}, + }, + }) + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + + _, err = f.Open("delete-me.txt") + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("Open(delete-me.txt) err = %v, want ErrNotExist", err) + } + if _, err := f.ReadFile("keep.txt"); err != nil { + t.Errorf("ReadFile(keep.txt): %v", err) + } + + // ReadDir should not return the deleted file. + entries, err := f.ReadDir(".") + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + for _, e := range entries { + if e.Name() == "delete-me.txt" { + t.Errorf("delete-me.txt should not appear in root listing") + } + if strings.HasPrefix(e.Name(), ".wh.") { + t.Errorf("whiteout marker leaked into ReadDir: %q", e.Name()) + } + } +} + +func TestNew_OpaqueWhiteout(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + { + {Name: "etc/", Type: tar.TypeDir}, + {Name: "etc/old1", Type: tar.TypeReg, Body: "1"}, + {Name: "etc/old2", Type: tar.TypeReg, Body: "2"}, + }, + { + {Name: "etc/.wh..wh..opq", Type: tar.TypeReg}, + {Name: "etc/new", Type: tar.TypeReg, Body: "fresh"}, + }, + }) + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + + for _, p := range []string{"etc/old1", "etc/old2"} { + _, err := f.Open(p) + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("Open(%q) err = %v, want ErrNotExist", p, err) + } + } + got, _ := f.ReadFile("etc/new") + if string(got) != "fresh" { + t.Errorf("etc/new = %q, want fresh", got) + } +} + +func TestNew_Symlink(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + { + {Name: "usr/", Type: tar.TypeDir}, + {Name: "usr/lib/", Type: tar.TypeDir}, + {Name: "usr/lib/foo.so", Type: tar.TypeReg, Body: "binary"}, + {Name: "lib", Type: tar.TypeSymlink, Linkname: "usr/lib"}, + }, + }) + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + + got, err := f.ReadFile("lib/foo.so") + if err != nil { + t.Fatalf("ReadFile(lib/foo.so): %v", err) + } + if string(got) != "binary" { + t.Errorf("lib/foo.so = %q, want binary", got) + } + // Lstat on the symlink itself should preserve ModeSymlink. + fi, err := f.Lstat("lib") + if err != nil { + t.Fatalf("Lstat(lib): %v", err) + } + if fi.Mode()&fs.ModeSymlink == 0 { + t.Errorf("Lstat(lib).Mode = %v, want ModeSymlink set", fi.Mode()) + } +} + +func TestNew_FsTestPasses(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + { + {Name: "a/", Type: tar.TypeDir}, + {Name: "a/b/", Type: tar.TypeDir}, + {Name: "a/b/c.txt", Type: tar.TypeReg, Body: "deep"}, + {Name: "top.txt", Type: tar.TypeReg, Body: "top"}, + }, + { + {Name: "extra.txt", Type: tar.TypeReg, Body: "extra body"}, + }, + }) + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + + if err := fstest.TestFS(f, "a/b/c.txt", "top.txt", "extra.txt"); err != nil { + t.Fatalf("fstest.TestFS: %v", err) + } +} + +func TestNew_UnsupportedMediaType(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + repo := "test/img" + tag := "v1" + + // Push a layer with an unsupported media type. + raw := buildLayerTar(t, []tarFile{{Name: "x", Type: tar.TypeReg, Body: "x"}}) + layerDesc := pushBlob(t, ctx, reg, repo, "application/vnd.oci.image.layer.v1.tar", raw) + cfg := ocispec.Image{} + cfgBytes, _ := json.Marshal(cfg) + cfgDesc := pushBlob(t, ctx, reg, repo, ocispec.MediaTypeImageConfig, cfgBytes) + mf := ocispec.Manifest{ + MediaType: MediaTypeOCIManifest, + Config: cfgDesc, + Layers: []oci.Descriptor{layerDesc}, + } + mf.SchemaVersion = 2 + mfBytes, _ := json.Marshal(mf) + if _, err := reg.PushManifest(ctx, repo, mfBytes, MediaTypeOCIManifest, &oci.PushManifestParameters{Tags: []string{tag}}); err != nil { + t.Fatalf("push manifest: %v", err) + } + + if _, err := New(ctx, reg, repo, tag); !errors.Is(err, ErrUnsupportedMediaType) { + t.Errorf("New err = %v, want ErrUnsupportedMediaType", err) + } +} + +func TestNew_BadTag(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + {{Name: "x", Type: tar.TypeReg, Body: "x"}}, + }) + if _, err := New(ctx, reg, "test/img", "no-such-tag"); err == nil { + t.Errorf("New with bad tag should fail") + } +} + +func TestNew_ConfigBlob(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + {{Name: "x", Type: tar.TypeReg, Body: "x"}}, + }) + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + + cfg := f.ConfigBlob() + if len(cfg) == 0 { + t.Errorf("ConfigBlob is empty") + } + var m map[string]any + if err := json.Unmarshal(cfg, &m); err != nil { + t.Errorf("ConfigBlob is not valid JSON: %v", err) + } +} + +func TestNew_Close(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + {{Name: "x.txt", Type: tar.TypeReg, Body: "data"}}, + }) + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + if err := f.Close(); err != nil { + t.Fatalf("Close: %v", err) + } + // A second Close is a no-op. + if err := f.Close(); err != nil { + t.Errorf("second Close: %v", err) + } + + // Operations after Close return ErrFSClosed. + _, err = f.Open("x.txt") + if !errors.Is(err, ErrFSClosed) { + t.Errorf("Open after Close err = %v, want ErrFSClosed", err) + } + _, err = f.ReadDir(".") + if !errors.Is(err, ErrFSClosed) { + t.Errorf("ReadDir after Close err = %v, want ErrFSClosed", err) + } + _, err = f.Stat("x.txt") + if !errors.Is(err, ErrFSClosed) { + t.Errorf("Stat after Close err = %v, want ErrFSClosed", err) + } + _, err = f.ReadFile("x.txt") + if !errors.Is(err, ErrFSClosed) { + t.Errorf("ReadFile after Close err = %v, want ErrFSClosed", err) + } + // ConfigBlob still works. + if len(f.ConfigBlob()) == 0 { + t.Errorf("ConfigBlob after Close should still work") + } +} + +func TestNewWithIndex_RoundTrip(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + { + {Name: "a/", Type: tar.TypeDir}, + {Name: "a/b.txt", Type: tar.TypeReg, Body: "bee"}, + }, + { + {Name: "c.txt", Type: tar.TypeReg, Body: "cee"}, + }, + }) + + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + idx := f.ImageIndex() + if len(idx.Layers) != 2 { + t.Fatalf("idx.Layers = %d, want 2", len(idx.Layers)) + } + for i, li := range idx.Layers { + if li.GzipIndex == nil { + t.Errorf("layer %d GzipIndex is nil", i) + } + if li.TarIndex == nil { + t.Errorf("layer %d TarIndex is nil", i) + } + } + f.Close() + + // Encode/decode round-trip. + var buf bytes.Buffer + if err := idx.Encode(&buf); err != nil { + t.Fatalf("Encode: %v", err) + } + dec, err := DecodeImageIndex(&buf) + if err != nil { + t.Fatalf("DecodeImageIndex: %v", err) + } + + // Wrap registry to count GetBlob calls. + rec := &recordingReg{Interface: reg} + f2, err := NewWithIndex(ctx, rec, "test/img", "v1", dec) + if err != nil { + t.Fatalf("NewWithIndex: %v", err) + } + defer f2.Close() + + // NewWithIndex must NOT call GetBlob for layers (only for the config). + if rec.layerBlobReads != 0 { + t.Errorf("NewWithIndex made %d layer GetBlob calls, want 0", rec.layerBlobReads) + } + + got, err := f2.ReadFile("a/b.txt") + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "bee" { + t.Errorf("a/b.txt = %q, want bee", got) + } + got, err = f2.ReadFile("c.txt") + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "cee" { + t.Errorf("c.txt = %q, want cee", got) + } +} + +func TestImageIndex_DefensiveCopies(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + {{Name: "a.txt", Type: tar.TypeReg, Body: "unchanged"}}, + }) + + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + + idx := f.ImageIndex() + idx.Layers[0].TarIndex[0].Header.Size = 1 + idx.Layers[0].GzipIndex.Size = 1 + if len(idx.Layers[0].GzipIndex.Checkpoints) > 0 { + idx.Layers[0].GzipIndex.Checkpoints[0].Out = 12345 + } + + got, err := f.ReadFile("a.txt") + if err != nil { + t.Fatalf("ReadFile after mutating ImageIndex result: %v", err) + } + if string(got) != "unchanged" { + t.Fatalf("a.txt after mutating ImageIndex result = %q, want unchanged", got) + } + + fresh := f.ImageIndex() + if fresh.Layers[0].TarIndex[0].Header.Size != int64(len("unchanged")) { + t.Fatalf("fresh ImageIndex tar size = %d, want %d", fresh.Layers[0].TarIndex[0].Header.Size, len("unchanged")) + } + if fresh.Layers[0].GzipIndex.Size == 1 { + t.Fatalf("fresh ImageIndex reused mutated gzip index") + } +} + +func TestNewWithIndex_DefensivelyCopiesInput(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + {{Name: "a.txt", Type: tar.TypeReg, Body: "from-index"}}, + }) + + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + idx := f.ImageIndex() + f.Close() + + f2, err := NewWithIndex(ctx, reg, "test/img", "v1", idx) + if err != nil { + t.Fatalf("NewWithIndex: %v", err) + } + defer f2.Close() + + idx.Layers[0].TarIndex[0].Header.Size = 1 + idx.Layers[0].GzipIndex.Size = 1 + if len(idx.Layers[0].GzipIndex.Checkpoints) > 0 { + idx.Layers[0].GzipIndex.Checkpoints[0].In = 12345 + } + + got, err := f2.ReadFile("a.txt") + if err != nil { + t.Fatalf("ReadFile after mutating NewWithIndex input: %v", err) + } + if string(got) != "from-index" { + t.Fatalf("a.txt after mutating NewWithIndex input = %q, want from-index", got) + } +} + +func TestNewWithIndex_Stale(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + {{Name: "a.txt", Type: tar.TypeReg, Body: "first"}}, + }) + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + idx := f.ImageIndex() + f.Close() + + // Re-push with different content under the same tag. + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + {{Name: "a.txt", Type: tar.TypeReg, Body: "second"}}, + }) + + if _, err := NewWithIndex(ctx, reg, "test/img", "v1", idx); !errors.Is(err, ErrIndexStale) { + t.Errorf("NewWithIndex err = %v, want ErrIndexStale", err) + } +} + +// recordingReg wraps an oci.Interface and counts GetBlob calls used to +// fetch any layer (i.e. anything with a tar+gzip or tar+zstd media type). +type recordingReg struct { + oci.Interface + layerBlobReads int +} + +func (r *recordingReg) GetBlob(ctx context.Context, repo string, dgst oci.Digest) (oci.BlobReader, error) { + br, err := r.Interface.GetBlob(ctx, repo, dgst) + if err != nil { + return nil, err + } + mt := br.Descriptor().MediaType + switch mt { + case MediaTypeLayerOCIGzip, MediaTypeLayerDockerGzip, MediaTypeLayerOCIZstd: + r.layerBlobReads++ + } + return br, nil +} + +// TestNew_ZstdLayer pushes a single-layer image with a zstd-compressed +// layer and verifies that ocifs reads it correctly. +func TestNew_ZstdLayer(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImageWithMediaType(t, ctx, reg, "test/img", "v1", MediaTypeLayerOCIZstd, [][]tarFile{ + { + {Name: "hello.txt", Type: tar.TypeReg, Body: "hello zstd"}, + }, + }) + f, err := New(ctx, reg, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + got, err := f.ReadFile("hello.txt") + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "hello zstd" { + t.Errorf("hello.txt = %q, want %q", got, "hello zstd") + } + // ImageIndex should record a ZstdIndex, not a GzipIndex. + idx := f.ImageIndex() + if len(idx.Layers) != 1 { + t.Fatalf("layers = %d, want 1", len(idx.Layers)) + } + if idx.Layers[0].ZstdIndex == nil { + t.Errorf("ZstdIndex is nil") + } + if idx.Layers[0].GzipIndex != nil { + t.Errorf("GzipIndex should be nil for a zstd layer") + } +} + +// SinglePassConstruction verifies that New makes exactly one GetBlob call +// per layer (in addition to the config blob fetch). +func TestNew_SinglePassConstruction(t *testing.T) { + ctx := context.Background() + reg := ocimem.New() + pushImage(t, ctx, reg, "test/img", "v1", [][]tarFile{ + {{Name: "a", Type: tar.TypeReg, Body: "1"}}, + {{Name: "b", Type: tar.TypeReg, Body: "2"}}, + {{Name: "c", Type: tar.TypeReg, Body: "3"}}, + }) + rec := &recordingReg{Interface: reg} + f, err := New(ctx, rec, "test/img", "v1") + if err != nil { + t.Fatalf("New: %v", err) + } + defer f.Close() + if rec.layerBlobReads != 3 { + t.Errorf("layer GetBlob calls = %d, want 3", rec.layerBlobReads) + } +} diff --git a/ocifs/overlay.go b/ocifs/overlay.go new file mode 100644 index 0000000..d6b1342 --- /dev/null +++ b/ocifs/overlay.go @@ -0,0 +1,670 @@ +package ocifs + +import ( + "archive/tar" + "io" + "io/fs" + "log/slog" + "path" + "sort" + "strings" + "time" + + "github.com/docker/oci/ocifs/tarfs" +) + +// overlayEntry is the bottom-of-stack record stored in overlayIndex for a +// resolved path. layerIdx selects the layer whose decompRA backs file +// content; entry holds the resolved tarfs.Entry (for hardlinks the target +// entry, never the link entry itself). +type overlayEntry struct { + layerIdx int + entry tarfs.Entry +} + +// overlayIndex maps a normalized path (no leading "/" or "./") to its +// resolved overlay entry. The root directory is NEVER stored here — it is +// synthesized lazily by Open/Stat/ReadDir. +type overlayIndex map[string]*overlayEntry + +// overlayDirEntry is the concrete fs.DirEntry stored in overlayDirs. +type overlayDirEntry struct { + name string + entry tarfs.Entry +} + +func (d *overlayDirEntry) Name() string { return d.name } +func (d *overlayDirEntry) IsDir() bool { return d.entry.Header.Mode.IsDir() } +func (d *overlayDirEntry) Type() fs.FileMode { return d.entry.Header.Mode.Type() } +func (d *overlayDirEntry) Info() (fs.FileInfo, error) { + return &overlayFileInfo{name: d.name, header: d.entry.Header}, nil +} + +// overlayDirs maps a directory path to its sorted child fs.DirEntry slice. +type overlayDirs map[string][]*overlayDirEntry + +// overlayFileInfo is the fs.FileInfo built from an overlay entry's header. +type overlayFileInfo struct { + name string + header tarfs.Header +} + +func (fi *overlayFileInfo) Name() string { return fi.name } +func (fi *overlayFileInfo) Size() int64 { return fi.header.Size } +func (fi *overlayFileInfo) Mode() fs.FileMode { return fi.header.Mode } +func (fi *overlayFileInfo) ModTime() time.Time { return time.Unix(0, fi.header.ModTime).UTC() } +func (fi *overlayFileInfo) IsDir() bool { return fi.header.Mode.IsDir() } +func (fi *overlayFileInfo) Sys() any { return nil } + +// rootHeader is the synthesized header for the root directory. +func rootHeader() tarfs.Header { + return tarfs.Header{ + Typeflag: tar.TypeDir, + Mode: fs.ModeDir | 0o755, + } +} + +// rootFileInfo returns the fs.FileInfo for the root, named by the original +// argument the caller passed to Open/Stat/ReadDir. +func rootFileInfo(openArg string) fs.FileInfo { + return &overlayFileInfo{name: path.Base(openArg), header: rootHeader()} +} + +// splitPath strips any leading "/" then splits on "/", filtering empty +// components. "/" maps to an empty slice. +func splitPath(p string) []string { + for strings.HasPrefix(p, "/") { + p = p[1:] + } + if p == "" { + return nil + } + parts := strings.Split(p, "/") + out := parts[:0] + for _, c := range parts { + if c == "" { + continue + } + out = append(out, c) + } + return out +} + +// splitParentBase splits a normalized path into (parent, base) where +// parent is "" (NOT ".") for root-level entries. +func splitParentBase(p string) (parent, base string) { + i := strings.LastIndex(p, "/") + if i < 0 { + return "", p + } + return p[:i], p[i+1:] +} + +// buildOverlay processes layer tar entries bottom-to-top and produces the +// final overlayIndex/overlayDirs maps. layerEntries[i] must hold the +// entries from layer i (bottom = 0, top = len-1). +func buildOverlay(layerEntries [][]tarfs.Entry) (overlayIndex, overlayDirs) { + idx := make(overlayIndex) + dirs := make(overlayDirs) + + for layerIdx, entries := range layerEntries { + applyLayer(layerIdx, entries, idx, dirs) + } + + // Sort each directory listing in lexical order (deferred for O(F log F)). + for k, v := range dirs { + sort.Slice(v, func(i, j int) bool { return v[i].Name() < v[j].Name() }) + dirs[k] = v + } + return idx, dirs +} + +// applyLayer applies one layer's entries to the (idx, dirs) maps. +func applyLayer(layerIdx int, entries []tarfs.Entry, idx overlayIndex, dirs overlayDirs) { + for _, entry := range entries { + p := entry.Filename + if p == "" { + continue + } + parent, base := splitParentBase(p) + + // 1. Opaque whiteout + if base == whiteoutOpaque { + if parent == "" { + clearOverlay(idx, dirs) + } else { + removePrefix(idx, parent+"/") + removeDirsPrefix(dirs, parent+"/") + delete(dirs, parent) + } + continue + } + + // 2. Regular whiteout + if strings.HasPrefix(base, whiteoutPrefix) { + deletedName := whiteoutTarget(base) + // Reject empty, ".", "..", or names that contain a slash — + // any of these would allow path traversal outside the parent + // directory. + if deletedName == "" || deletedName == "." || deletedName == ".." || strings.Contains(deletedName, "/") { + continue + } + var target string + if parent == "" { + target = deletedName + } else { + target = parent + "/" + deletedName + } + wasDir := false + if cur, ok := idx[target]; ok && cur != nil { + wasDir = cur.entry.Header.Mode.IsDir() + } + delete(idx, target) + removeFromDirList(dirs, parent, deletedName) + if wasDir { + removePrefix(idx, target+"/") + removeDirsPrefix(dirs, target+"/") + delete(dirs, target) + } + continue + } + + // 3. Directory + if entry.Header.Typeflag == tar.TypeDir || entry.Header.Mode.IsDir() { + synthesizeParents(p, layerIdx, idx, dirs) + idx[p] = &overlayEntry{layerIdx: layerIdx, entry: entry} + if _, ok := dirs[p]; !ok { + dirs[p] = nil + } + if p != "" { + upsertDirEntry(dirs, parent, &overlayDirEntry{name: base, entry: entry}) + } + continue + } + + // 4. File, symlink, or hardlink + var resolved tarfs.Entry + var resolvedLayer int + if entry.Header.Typeflag == tar.TypeLink { + target := resolveHardlink(entry.Header.Linkname, idx) + if target == nil { + slog.Default().Warn("ocifs: hardlink target absent", + "path", p, + "linkname", entry.Header.Linkname, + ) + continue + } + resolved = target.entry + resolvedLayer = target.layerIdx + } else { + resolved = entry + resolvedLayer = layerIdx + } + + idx[p] = &overlayEntry{layerIdx: resolvedLayer, entry: resolved} + synthesizeParents(p, layerIdx, idx, dirs) + upsertDirEntry(dirs, parent, &overlayDirEntry{name: base, entry: resolved}) + } +} + +// synthesizeParents ensures that each strict ancestor of path is present in +// idx and its parent's dirs entry as a directory. Insert-if-absent only — +// it must never overwrite real metadata. The root ("") is never touched; +// it is synthesized at call time. +func synthesizeParents(p string, layerIdx int, idx overlayIndex, dirs overlayDirs) { + parts := strings.Split(p, "/") + if len(parts) <= 1 { + return + } + for i := 1; i < len(parts); i++ { + ancestor := strings.Join(parts[:i], "/") + if ancestor == "" { + continue + } + ancParent, ancBase := splitParentBase(ancestor) + if _, ok := idx[ancestor]; !ok { + synth := tarfs.Entry{ + Header: tarfs.Header{ + Typeflag: tar.TypeDir, + Mode: fs.ModeDir | 0o755, + }, + Filename: ancestor, + } + idx[ancestor] = &overlayEntry{layerIdx: layerIdx, entry: synth} + if _, ok := dirs[ancestor]; !ok { + dirs[ancestor] = nil + } + insertIfAbsentDirEntry(dirs, ancParent, &overlayDirEntry{name: ancBase, entry: synth}) + } else { + // Real entry exists; do not overwrite. Ensure its dirs slot + // exists (it should from a prior real upsert, but defensive). + if _, ok := dirs[ancestor]; !ok { + dirs[ancestor] = nil + } + } + } +} + +// resolveHardlink looks up linkname in the overlay index, normalizing +// "./"/"/" prefixes first. Returns nil if absent. +func resolveHardlink(linkname string, idx overlayIndex) *overlayEntry { + for strings.HasPrefix(linkname, "./") { + linkname = linkname[2:] + } + for strings.HasPrefix(linkname, "/") { + linkname = linkname[1:] + } + for strings.HasSuffix(linkname, "/") { + linkname = linkname[:len(linkname)-1] + } + if linkname == "" || linkname == "." { + return nil + } + return idx[linkname] +} + +// upsertDirEntry inserts e into dirs[parent], replacing any existing entry +// with the same Name(). +func upsertDirEntry(dirs overlayDirs, parent string, e *overlayDirEntry) { + cur := dirs[parent] + for i, existing := range cur { + if existing.Name() == e.Name() { + cur[i] = e + dirs[parent] = cur + return + } + } + dirs[parent] = append(cur, e) +} + +// insertIfAbsentDirEntry inserts e into dirs[parent] only if no entry with +// the same Name() is already present. +func insertIfAbsentDirEntry(dirs overlayDirs, parent string, e *overlayDirEntry) { + cur := dirs[parent] + for _, existing := range cur { + if existing.Name() == e.Name() { + return + } + } + dirs[parent] = append(cur, e) +} + +// removeFromDirList removes any entry with name == base from dirs[parent]. +func removeFromDirList(dirs overlayDirs, parent, base string) { + cur, ok := dirs[parent] + if !ok { + return + } + out := cur[:0] + for _, e := range cur { + if e.Name() == base { + continue + } + out = append(out, e) + } + dirs[parent] = out +} + +// clearOverlay erases every key from idx and dirs. +func clearOverlay(idx overlayIndex, dirs overlayDirs) { + for k := range idx { + delete(idx, k) + } + for k := range dirs { + delete(dirs, k) + } +} + +// removePrefix deletes every key in idx that starts with the given prefix. +func removePrefix(idx overlayIndex, prefix string) { + for k := range idx { + if strings.HasPrefix(k, prefix) { + delete(idx, k) + } + } +} + +// removeDirsPrefix deletes every key in dirs that starts with the given prefix. +func removeDirsPrefix(dirs overlayDirs, prefix string) { + for k := range dirs { + if strings.HasPrefix(k, prefix) { + delete(dirs, k) + } + } +} + +// resolvePathResult bundles the outcome of component-wise traversal. +type resolvePathResult struct { + resolvedPath string // canonical normalized path; "" for root + entry *overlayEntry // overlayIndex[resolvedPath] when non-root, may be nil for root + isRoot bool // true when resolvedPath == "" +} + +// resolvePath performs the component-wise overlay traversal described in +// the design (Open steps 0–3). When followFinalSymlink is false, a +// final-component symlink is returned without dereferencing — used by +// Lstat. Errors are returned as *fs.PathError with the supplied op. +func (f *FS) resolvePath(name, op string, followFinalSymlink bool) (resolvePathResult, error) { + if name == "." { + return resolvePathResult{isRoot: true}, nil + } + if !fs.ValidPath(name) { + return resolvePathResult{}, &fs.PathError{Op: op, Path: name, Err: fs.ErrInvalid} + } + + components := splitPath(name) + stack := make([]string, 0, len(components)) + hops := 256 + var current *overlayEntry + resolvedPath := "" + + for len(components) > 0 { + c := components[0] + components = components[1:] + + switch c { + case ".": + continue + case "..": + if len(stack) > 0 { + stack = stack[:len(stack)-1] + } + resolvedPath = strings.Join(stack, "/") + current = nil + continue + } + + stack = append(stack, c) + resolvedPath = strings.Join(stack, "/") + current = f.index[resolvedPath] + if current == nil { + return resolvePathResult{}, &fs.PathError{Op: op, Path: name, Err: fs.ErrNotExist} + } + + isSymlink := current.entry.Header.Mode&fs.ModeSymlink != 0 + isFinal := len(components) == 0 + if isSymlink && (!isFinal || followFinalSymlink) { + if current.entry.Header.Linkname == "" { + return resolvePathResult{}, &fs.PathError{Op: op, Path: name, Err: fs.ErrNotExist} + } + hops-- + if hops == 0 { + return resolvePathResult{}, &fs.PathError{Op: op, Path: name, Err: ErrSymlinkLoop} + } + target := current.entry.Header.Linkname + var rebuilt []string + if path.IsAbs(target) { + rebuilt = append(rebuilt, splitPath(target)...) + } else { + rebuilt = append(rebuilt, stack[:len(stack)-1]...) + rebuilt = append(rebuilt, splitPath(target)...) + } + rebuilt = append(rebuilt, components...) + components = rebuilt + stack = stack[:0] + current = nil + resolvedPath = "" + continue + } + + if current.entry.Header.Mode.IsDir() { + continue + } + + // Non-directory, non-symlink (or final symlink with !followFinalSymlink). + if !isFinal { + return resolvePathResult{}, &fs.PathError{Op: op, Path: name, Err: fs.ErrNotExist} + } + } + + if len(stack) == 0 { + return resolvePathResult{isRoot: true}, nil + } + if current == nil { + current = f.index[resolvedPath] + if current == nil { + return resolvePathResult{}, &fs.PathError{Op: op, Path: name, Err: fs.ErrNotExist} + } + } + return resolvePathResult{resolvedPath: resolvedPath, entry: current}, nil +} + +// Open implements fs.FS. See the design document for the full algorithm. +func (f *FS) Open(name string) (fs.File, error) { + if f.closed.Load() { + return nil, &fs.PathError{Op: "open", Path: name, Err: ErrFSClosed} + } + res, err := f.resolvePath(name, "open", true) + if err != nil { + return nil, err + } + if res.isRoot { + return f.openDir("", name), nil + } + if res.entry.entry.Header.Mode.IsDir() { + return f.openDir(res.resolvedPath, name), nil + } + return f.openFile(res.entry, name), nil +} + +func (f *FS) openFile(oe *overlayEntry, openArg string) fs.File { + layer := &f.layers[oe.layerIdx] + sr := io.NewSectionReader(layer.decompRA, oe.entry.Offset, oe.entry.Header.Size) + return &overlayFile{ + entry: oe.entry, + openArg: openArg, + sr: sr, + } +} + +func (f *FS) openDir(resolvedPath, openArg string) fs.File { + return &overlayDirFile{ + fs: f, + resolvedPath: resolvedPath, + openArg: openArg, + entries: f.dirs[resolvedPath], + } +} + +// ReadDir implements fs.ReadDirFS. +func (f *FS) ReadDir(name string) ([]fs.DirEntry, error) { + if f.closed.Load() { + return nil, &fs.PathError{Op: "readdir", Path: name, Err: ErrFSClosed} + } + res, err := f.resolvePath(name, "readdir", true) + if err != nil { + return nil, err + } + if !res.isRoot { + if res.entry == nil { + return nil, &fs.PathError{Op: "readdir", Path: name, Err: fs.ErrNotExist} + } + if !res.entry.entry.Header.Mode.IsDir() { + return nil, &fs.PathError{Op: "readdir", Path: name, Err: fs.ErrInvalid} + } + } + raw := f.dirs[res.resolvedPath] + out := make([]fs.DirEntry, len(raw)) + for i, e := range raw { + out[i] = e + } + return out, nil +} + +// Stat implements fs.StatFS. Symlinks at the final component are followed. +func (f *FS) Stat(name string) (fs.FileInfo, error) { + if f.closed.Load() { + return nil, &fs.PathError{Op: "stat", Path: name, Err: ErrFSClosed} + } + res, err := f.resolvePath(name, "stat", true) + if err != nil { + return nil, err + } + if res.isRoot { + return rootFileInfo(name), nil + } + return &overlayFileInfo{name: path.Base(name), header: res.entry.entry.Header}, nil +} + +// Lstat returns metadata for the named entry without following a final +// symlink. +func (f *FS) Lstat(name string) (fs.FileInfo, error) { + if f.closed.Load() { + return nil, &fs.PathError{Op: "lstat", Path: name, Err: ErrFSClosed} + } + res, err := f.resolvePath(name, "lstat", false) + if err != nil { + return nil, err + } + if res.isRoot { + return rootFileInfo(name), nil + } + return &overlayFileInfo{name: path.Base(name), header: res.entry.entry.Header}, nil +} + +// ReadFile opens name, reads its entire contents, and closes it. Returns +// *fs.PathError wrapping ErrFSClosed if the FS has been closed. +func (f *FS) ReadFile(name string) ([]byte, error) { + if f.closed.Load() { + return nil, &fs.PathError{Op: "readfile", Path: name, Err: ErrFSClosed} + } + file, err := f.Open(name) + if err != nil { + return nil, err + } + defer file.Close() + return io.ReadAll(file) +} + +// overlayFile is the fs.File returned for regular and hardlink-resolved entries. +type overlayFile struct { + entry tarfs.Entry + openArg string + sr *io.SectionReader +} + +func (f *overlayFile) Read(p []byte) (int, error) { return f.sr.Read(p) } +func (f *overlayFile) ReadAt(p []byte, off int64) (int, error) { return f.sr.ReadAt(p, off) } +func (f *overlayFile) Seek(off int64, whence int) (int64, error) { + return f.sr.Seek(off, whence) +} + +// WriteTo allows io.Copy to bypass its 32 KB staging buffer. We stream +// directly from the SectionReader in 64 KB chunks rather than relying on +// io.Copy's internal staging buffer. +func (f *overlayFile) WriteTo(w io.Writer) (int64, error) { + pos, err := f.sr.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + remaining := f.entry.Header.Size - pos + if remaining <= 0 { + return 0, nil + } + const chunk = 64 * 1024 + buf := make([]byte, chunk) + var written int64 + for remaining > 0 { + want := int64(len(buf)) + if want > remaining { + want = remaining + } + n, rerr := f.sr.Read(buf[:want]) + if n > 0 { + wn, werr := w.Write(buf[:n]) + written += int64(wn) + if werr != nil { + return written, werr + } + if wn != n { + return written, io.ErrShortWrite + } + remaining -= int64(n) + } + if rerr == io.EOF { + break + } + if rerr != nil { + return written, rerr + } + } + return written, nil +} + +func (f *overlayFile) Close() error { return nil } +func (f *overlayFile) Stat() (fs.FileInfo, error) { + return &overlayFileInfo{name: path.Base(f.openArg), header: f.entry.Header}, nil +} + +// overlayDirFile is the fs.ReadDirFile returned for directory entries. +type overlayDirFile struct { + fs *FS + resolvedPath string + openArg string + entries []*overlayDirEntry + cursor int +} + +func (d *overlayDirFile) Read([]byte) (int, error) { + pathForError := d.resolvedPath + if pathForError == "" { + pathForError = "." + } + return 0, &fs.PathError{Op: "read", Path: pathForError, Err: fs.ErrInvalid} +} + +func (d *overlayDirFile) Close() error { return nil } + +func (d *overlayDirFile) Stat() (fs.FileInfo, error) { + if d.resolvedPath == "" { + return rootFileInfo(d.openArg), nil + } + if oe := d.fs.index[d.resolvedPath]; oe != nil { + return &overlayFileInfo{name: path.Base(d.openArg), header: oe.entry.Header}, nil + } + return rootFileInfo(d.openArg), nil +} + +func (d *overlayDirFile) ReadDir(n int) ([]fs.DirEntry, error) { + remaining := len(d.entries) - d.cursor + if n <= 0 { + out := make([]fs.DirEntry, remaining) + for i, e := range d.entries[d.cursor:] { + out[i] = e + } + d.cursor = len(d.entries) + return out, nil + } + if remaining == 0 { + return nil, io.EOF + } + take := n + if take > remaining { + take = remaining + } + out := make([]fs.DirEntry, take) + for i, e := range d.entries[d.cursor : d.cursor+take] { + out[i] = e + } + d.cursor += take + if take < n { + return out, io.EOF + } + return out, nil +} + +// Compile-time interface checks. +var ( + _ fs.FS = (*FS)(nil) + _ fs.ReadDirFS = (*FS)(nil) + _ fs.StatFS = (*FS)(nil) + _ fs.ReadFileFS = (*FS)(nil) + + _ fs.File = (*overlayFile)(nil) + _ io.ReaderAt = (*overlayFile)(nil) + _ io.Seeker = (*overlayFile)(nil) + _ io.WriterTo = (*overlayFile)(nil) + _ fs.ReadDirFile = (*overlayDirFile)(nil) + _ fs.DirEntry = (*overlayDirEntry)(nil) + _ fs.FileInfo = (*overlayFileInfo)(nil) +) diff --git a/ocifs/overlay_test.go b/ocifs/overlay_test.go new file mode 100644 index 0000000..6384d0e --- /dev/null +++ b/ocifs/overlay_test.go @@ -0,0 +1,388 @@ +package ocifs + +import ( + "archive/tar" + "io/fs" + "reflect" + "sort" + "testing" + + "github.com/docker/oci/ocifs/tarfs" +) + +// makeFile constructs a regular-file tar entry for testing. +func makeFile(name, body string) tarfs.Entry { + return tarfs.Entry{ + Header: tarfs.Header{ + Typeflag: tar.TypeReg, + Name: name, + Size: int64(len(body)), + Mode: 0o644, + }, + Filename: name, + } +} + +// makeDir constructs a directory tar entry for testing. +func makeDir(name string) tarfs.Entry { + return tarfs.Entry{ + Header: tarfs.Header{ + Typeflag: tar.TypeDir, + Name: name, + Mode: fs.ModeDir | 0o755, + }, + Filename: name, + } +} + +// makeSymlink constructs a symlink tar entry for testing. +func makeSymlink(name, target string) tarfs.Entry { + return tarfs.Entry{ + Header: tarfs.Header{ + Typeflag: tar.TypeSymlink, + Name: name, + Linkname: target, + Mode: fs.ModeSymlink | 0o777, + }, + Filename: name, + } +} + +// makeWhiteout constructs a whiteout tar entry for testing. +func makeWhiteout(parentDir, base string) tarfs.Entry { + name := whiteoutPrefix + base + if parentDir != "" { + name = parentDir + "/" + name + } + return tarfs.Entry{ + Header: tarfs.Header{ + Typeflag: tar.TypeReg, + Name: name, + Mode: 0o644, + }, + Filename: name, + } +} + +// makeOpaqueWhiteout constructs an opaque whiteout marker. +func makeOpaqueWhiteout(parentDir string) tarfs.Entry { + name := whiteoutOpaque + if parentDir != "" { + name = parentDir + "/" + name + } + return tarfs.Entry{ + Header: tarfs.Header{ + Typeflag: tar.TypeReg, + Name: name, + Mode: 0o644, + }, + Filename: name, + } +} + +// makeHardlink constructs a hardlink entry pointing to target. +func makeHardlink(name, target string) tarfs.Entry { + return tarfs.Entry{ + Header: tarfs.Header{ + Typeflag: tar.TypeLink, + Name: name, + Linkname: target, + Mode: 0o644, + }, + Filename: name, + } +} + +func TestBuildOverlay_SimpleUpsert(t *testing.T) { + layers := [][]tarfs.Entry{ + {makeFile("a.txt", "lower")}, + {makeFile("a.txt", "upper")}, + } + idx, dirs := buildOverlay(layers) + + oe, ok := idx["a.txt"] + if !ok { + t.Fatalf("a.txt missing from overlayIndex") + } + if oe.layerIdx != 1 { + t.Errorf("a.txt layerIdx = %d, want 1", oe.layerIdx) + } + root := dirs[""] + if len(root) != 1 { + t.Fatalf("root has %d entries, want 1: %v", len(root), entryNames(root)) + } + if root[0].Name() != "a.txt" { + t.Errorf("root[0] = %q, want a.txt", root[0].Name()) + } +} + +func TestBuildOverlay_WhiteoutDeletion(t *testing.T) { + layers := [][]tarfs.Entry{ + { + makeFile("a.txt", "lower"), + makeFile("b.txt", "lower-b"), + }, + {makeWhiteout("", "a.txt")}, + } + idx, dirs := buildOverlay(layers) + + if _, ok := idx["a.txt"]; ok { + t.Errorf("a.txt should be removed from overlayIndex") + } + if _, ok := idx["b.txt"]; !ok { + t.Errorf("b.txt should still exist in overlayIndex") + } + root := dirs[""] + if len(root) != 1 { + t.Fatalf("root has %d entries, want 1: %v", len(root), entryNames(root)) + } + if root[0].Name() != "b.txt" { + t.Errorf("root[0] = %q, want b.txt", root[0].Name()) + } +} + +func TestBuildOverlay_OpaqueWhiteout(t *testing.T) { + layers := [][]tarfs.Entry{ + { + makeDir("etc"), + makeFile("etc/a.txt", "lower"), + makeFile("etc/b.txt", "lower-b"), + }, + { + makeOpaqueWhiteout("etc"), + makeFile("etc/c.txt", "upper-c"), + }, + } + idx, dirs := buildOverlay(layers) + + if _, ok := idx["etc/a.txt"]; ok { + t.Errorf("etc/a.txt should be removed") + } + if _, ok := idx["etc/b.txt"]; ok { + t.Errorf("etc/b.txt should be removed") + } + if _, ok := idx["etc/c.txt"]; !ok { + t.Errorf("etc/c.txt should be present") + } + if _, ok := idx["etc"]; !ok { + t.Errorf("etc directory should still exist") + } + etcDir := dirs["etc"] + if len(etcDir) != 1 { + t.Fatalf("etc has %d entries, want 1: %v", len(etcDir), entryNames(etcDir)) + } + if etcDir[0].Name() != "c.txt" { + t.Errorf("etc[0] = %q, want c.txt", etcDir[0].Name()) + } +} + +func TestBuildOverlay_RootOpaqueWhiteout(t *testing.T) { + layers := [][]tarfs.Entry{ + { + makeFile("a.txt", "lower"), + makeDir("usr"), + makeFile("usr/lib", "lib-content"), + }, + { + makeOpaqueWhiteout(""), + makeFile("new.txt", "after"), + }, + } + idx, dirs := buildOverlay(layers) + + if _, ok := idx["a.txt"]; ok { + t.Errorf("a.txt should be removed by root opaque whiteout") + } + if _, ok := idx["usr/lib"]; ok { + t.Errorf("usr/lib should be removed by root opaque whiteout") + } + if _, ok := idx["new.txt"]; !ok { + t.Errorf("new.txt should be present after root opaque whiteout") + } + root := dirs[""] + if len(root) != 1 || root[0].Name() != "new.txt" { + t.Errorf("root entries = %v, want [new.txt]", entryNames(root)) + } +} + +func TestBuildOverlay_SynthesizeParents(t *testing.T) { + layers := [][]tarfs.Entry{ + { + // No explicit "a" or "a/b" directory entries; they must be synthesized. + makeFile("a/b/c.txt", "deep"), + }, + } + idx, dirs := buildOverlay(layers) + + if oe, ok := idx["a"]; !ok || !oe.entry.Header.Mode.IsDir() { + t.Errorf("a should be synthesized as a directory") + } + if oe, ok := idx["a/b"]; !ok || !oe.entry.Header.Mode.IsDir() { + t.Errorf("a/b should be synthesized as a directory") + } + if _, ok := idx["a/b/c.txt"]; !ok { + t.Errorf("a/b/c.txt should be present") + } + + rootChildren := entryNames(dirs[""]) + if !reflect.DeepEqual(rootChildren, []string{"a"}) { + t.Errorf("root children = %v, want [a]", rootChildren) + } + aChildren := entryNames(dirs["a"]) + if !reflect.DeepEqual(aChildren, []string{"b"}) { + t.Errorf("a children = %v, want [b]", aChildren) + } + abChildren := entryNames(dirs["a/b"]) + if !reflect.DeepEqual(abChildren, []string{"c.txt"}) { + t.Errorf("a/b children = %v, want [c.txt]", abChildren) + } +} + +func TestBuildOverlay_MultiLayerMerge(t *testing.T) { + layers := [][]tarfs.Entry{ + { + makeDir("usr"), + makeFile("usr/a.txt", "L0-a"), + makeFile("usr/b.txt", "L0-b"), + }, + { + makeFile("usr/a.txt", "L1-a-overwritten"), + makeFile("usr/c.txt", "L1-c"), + }, + { + makeFile("usr/b.txt", "L2-b-overwritten"), + }, + } + idx, dirs := buildOverlay(layers) + + if oe := idx["usr/a.txt"]; oe == nil || oe.layerIdx != 1 { + t.Errorf("usr/a.txt layerIdx mismatch: %+v", oe) + } + if oe := idx["usr/b.txt"]; oe == nil || oe.layerIdx != 2 { + t.Errorf("usr/b.txt layerIdx mismatch: %+v", oe) + } + if oe := idx["usr/c.txt"]; oe == nil || oe.layerIdx != 1 { + t.Errorf("usr/c.txt layerIdx mismatch: %+v", oe) + } + usr := entryNames(dirs["usr"]) + if !reflect.DeepEqual(usr, []string{"a.txt", "b.txt", "c.txt"}) { + t.Errorf("usr children = %v, want [a.txt b.txt c.txt]", usr) + } +} + +func TestBuildOverlay_DirectoryWhiteoutRemovesSubtree(t *testing.T) { + layers := [][]tarfs.Entry{ + { + makeDir("etc"), + makeFile("etc/a.txt", "x"), + makeDir("etc/sub"), + makeFile("etc/sub/y.txt", "y"), + }, + { + // Whiteout the "etc" directory itself. + makeWhiteout("", "etc"), + }, + } + idx, dirs := buildOverlay(layers) + + for _, k := range []string{"etc", "etc/a.txt", "etc/sub", "etc/sub/y.txt"} { + if _, ok := idx[k]; ok { + t.Errorf("idx[%q] should be removed", k) + } + } + for _, k := range []string{"etc", "etc/sub"} { + if _, ok := dirs[k]; ok { + t.Errorf("dirs[%q] should be removed", k) + } + } + root := entryNames(dirs[""]) + if len(root) != 0 { + t.Errorf("root should be empty, got %v", root) + } +} + +func TestBuildOverlay_HardlinkResolvesToTarget(t *testing.T) { + body := "the body" + target := makeFile("usr/bin/ls", body) + link := makeHardlink("bin/myls", "usr/bin/ls") + layers := [][]tarfs.Entry{ + {target, link}, + } + idx, dirs := buildOverlay(layers) + + oe, ok := idx["bin/myls"] + if !ok { + t.Fatalf("bin/myls missing") + } + if oe.entry.Header.Size != int64(len(body)) { + t.Errorf("hardlink size = %d, want %d", oe.entry.Header.Size, len(body)) + } + // The DirEntry must use the link's own name, not the target's name. + binDir := entryNames(dirs["bin"]) + if !reflect.DeepEqual(binDir, []string{"myls"}) { + t.Errorf("bin children = %v, want [myls]", binDir) + } +} + +func TestBuildOverlay_HardlinkWhiteoutDeletedTarget(t *testing.T) { + layers := [][]tarfs.Entry{ + {makeFile("usr/bin/ls", "body")}, + {makeWhiteout("usr/bin", "ls")}, + {makeHardlink("bin/myls", "usr/bin/ls")}, + } + idx, _ := buildOverlay(layers) + + if _, ok := idx["bin/myls"]; ok { + t.Errorf("bin/myls should be skipped because target was whiteout-deleted") + } +} + +func TestBuildOverlay_CrossLayerSymlink(t *testing.T) { + layers := [][]tarfs.Entry{ + { + makeDir("usr"), + makeDir("usr/lib"), + makeFile("usr/lib/foo.so", "binary"), + }, + { + makeSymlink("lib", "usr/lib"), + }, + } + idx, _ := buildOverlay(layers) + + if oe := idx["lib"]; oe == nil || oe.entry.Header.Mode&fs.ModeSymlink == 0 { + t.Fatalf("lib should be a symlink: %+v", oe) + } + // Direct lookup of "lib/foo.so" is not in idx; resolution happens at Open time. + // Here we just confirm the components are present. + if _, ok := idx["usr/lib/foo.so"]; !ok { + t.Errorf("usr/lib/foo.so should be present") + } +} + +func TestBuildOverlay_DirectoriesSorted(t *testing.T) { + layers := [][]tarfs.Entry{ + { + makeFile("z.txt", "z"), + makeFile("a.txt", "a"), + makeFile("m.txt", "m"), + }, + } + _, dirs := buildOverlay(layers) + got := entryNames(dirs[""]) + want := []string{"a.txt", "m.txt", "z.txt"} + if !sort.IsSorted(sort.StringSlice(got)) { + t.Errorf("root not sorted: %v", got) + } + if !reflect.DeepEqual(got, want) { + t.Errorf("root children = %v, want %v", got, want) + } +} + +func entryNames(es []*overlayDirEntry) []string { + out := make([]string, len(es)) + for i, e := range es { + out[i] = e.Name() + } + return out +} diff --git a/ocifs/tarfs/README.md b/ocifs/tarfs/README.md new file mode 100644 index 0000000..2ac74d4 --- /dev/null +++ b/ocifs/tarfs/README.md @@ -0,0 +1,63 @@ +# tarfs + +Package `tarfs` builds a read-only [`io/fs.FS`](https://pkg.go.dev/io/fs#FS) over a tar archive accessed via [`io.ReaderAt`](https://pkg.go.dev/io#ReaderAt). The filesystem is immutable after construction and safe for concurrent use. + +## Two-phase usage + +### Phase 1 — Index (once per archive) + +`Index` makes a single sequential pass over the tar stream and records each entry's byte offset, normalized path, and header metadata. No file content is read. + +```go +import "github.com/docker/oci/ocifs/tarfs" + +entries, err := tarfs.Index(decompressedReader) +``` + +The `[]Entry` slice is JSON-serializable; persist it to avoid re-scanning on every start. + +### Phase 2 — Serve reads + +```go +// From a fresh scan (combines Index + NewFromEntries): +tfs, err := tarfs.New(decompressedReaderAt, -1) + +// From a persisted index (no I/O at construction time): +tfs, err := tarfs.NewFromEntries(decompressedReaderAt, entries) + +// Use as any io/fs.FS: +f, err := tfs.Open("etc/passwd") +data, err := fs.ReadFile(tfs, "usr/lib/libz.so.1") +entries, err := tfs.ReadDir("usr/bin") +``` + +File content is fetched on demand via `io.SectionReader` over the `io.ReaderAt`; the entry index records where in the decompressed stream each file's data begins. + +## Key types + +| Type | Purpose | +|------|---------| +| `FS` | The `io/fs.FS` implementation. Also implements `ReadDirFS`, `StatFS`, and `Lstat`. | +| `Entry` | One tar entry: normalized `Filename`, byte `Offset` in the decompressed stream, and a `Header` with metadata. | +| `Header` | Subset of `tar.Header` that is JSON-serializable with stable round-trip semantics. `ModTime` is stored as Unix nanoseconds. | + +## Symlinks and hardlinks + +- **Symlinks** are followed by `Open` and `Stat` up to 255 hops. Circular chains return `ErrSymlinkLoop`. `Lstat` does not follow a final-component symlink. +- **Hardlinks** are resolved to the target entry at construction time so that `DirEntry.Info` reports the real size and mode. A hardlink whose target is absent is excluded from directory listings. + +## Synthetic directories + +Tar archives do not always include explicit entries for every ancestor directory. `tarfs` synthesizes a minimal directory entry (`Mode: fs.ModeDir|0755`) for any ancestor implied by a file path but not present in the archive. Synthesized entries are excluded from the `Entries()` output used for index persistence. + +## Security + +Tar entries whose normalized path contains a `".."` component are silently dropped during both `Index` and `NewFromEntries`. This prevents [TarSlip/ZipSlip](https://security.snyk.io/research/zip-slip-vulnerability) path traversal: a crafted archive cannot inject `".."` entries into directory listings or cause reads outside the archive root. + +## OCI whiteout markers + +Entries whose base name is `.wh..wh..opq` (opaque whiteout) or starts with `.wh.` (per-file whiteout) are included in the path index so that `Open` succeeds on them, but they are excluded from `ReadDir` results. The higher-level `ocifs` overlay layer interprets these markers and removes the appropriate lower-layer entries. + +## `io/fs` path conventions + +Following the `io/fs` spec, all paths use `"."` as the root, never `"/"`. Pass `"."` to `Open`, `ReadDir`, or `fs.WalkDir` to address the root directory. diff --git a/ocifs/tarfs/file.go b/ocifs/tarfs/file.go new file mode 100644 index 0000000..9803400 --- /dev/null +++ b/ocifs/tarfs/file.go @@ -0,0 +1,133 @@ +package tarfs + +import ( + "io" + "io/fs" + "path" +) + +// Compile-time interface checks. +var ( + _ fs.File = (*file)(nil) + _ io.ReaderAt = (*file)(nil) + _ io.Seeker = (*file)(nil) + _ io.WriterTo = (*file)(nil) + _ fs.ReadDirFile = (*dirFile)(nil) +) + +// file is the fs.File returned for regular and hardlink-resolved entries. +// Implements io.WriterTo by forwarding to the underlying SectionReader so +// io.Copy avoids allocating a 32 KB buffer. +type file struct { + entry *Entry + openArg string + sr *io.SectionReader +} + +func (f *file) Read(p []byte) (int, error) { return f.sr.Read(p) } +func (f *file) ReadAt(p []byte, off int64) (int, error) { return f.sr.ReadAt(p, off) } +func (f *file) Seek(off int64, whence int) (int64, error) { + return f.sr.Seek(off, whence) +} + +// WriteTo allows io.Copy to bypass its 32 KB staging buffer. We stream +// directly from the backing io.ReaderAt into w in 64 KB chunks, advancing +// the SectionReader's cursor so subsequent Reads continue from the right +// position. +func (f *file) WriteTo(w io.Writer) (int64, error) { + pos, err := f.sr.Seek(0, io.SeekCurrent) + if err != nil { + return 0, err + } + remaining := f.entry.Header.Size - pos + if remaining <= 0 { + return 0, nil + } + const chunk = 64 * 1024 + buf := make([]byte, chunk) + var written int64 + for remaining > 0 { + want := int64(len(buf)) + if want > remaining { + want = remaining + } + n, rerr := f.sr.Read(buf[:want]) + if n > 0 { + wn, werr := w.Write(buf[:n]) + written += int64(wn) + if werr != nil { + return written, werr + } + if wn != n { + return written, io.ErrShortWrite + } + remaining -= int64(n) + } + if rerr == io.EOF { + break + } + if rerr != nil { + return written, rerr + } + } + return written, nil +} +func (f *file) Close() error { return nil } +func (f *file) Stat() (fs.FileInfo, error) { + return entryFileInfo(f.entry, path.Base(f.openArg)), nil +} + +// dirFile is the fs.ReadDirFile returned for directory entries. Its cursor +// makes it not safe for concurrent use; callers needing parallel iteration +// must Open separately. +type dirFile struct { + fs *FS + resolved string + openArg string + entries []fs.DirEntry + cursor int +} + +func (d *dirFile) Read(p []byte) (int, error) { + pathForError := d.resolved + if pathForError == "" { + pathForError = "." + } + return 0, &fs.PathError{Op: "read", Path: pathForError, Err: fs.ErrInvalid} +} + +func (d *dirFile) Close() error { return nil } + +func (d *dirFile) Stat() (fs.FileInfo, error) { + if d.resolved == "" { + return rootFileInfo(d.openArg), nil + } + if i, ok := d.fs.index[d.resolved]; ok { + return entryFileInfo(&d.fs.files[i], path.Base(d.openArg)), nil + } + return rootFileInfo(d.openArg), nil +} + +func (d *dirFile) ReadDir(n int) ([]fs.DirEntry, error) { + remaining := len(d.entries) - d.cursor + if n <= 0 { + out := make([]fs.DirEntry, remaining) + copy(out, d.entries[d.cursor:]) + d.cursor = len(d.entries) + return out, nil + } + if remaining == 0 { + return nil, io.EOF + } + take := n + if take > remaining { + take = remaining + } + out := make([]fs.DirEntry, take) + copy(out, d.entries[d.cursor:d.cursor+take]) + d.cursor += take + if take < n { + return out, io.EOF + } + return out, nil +} diff --git a/ocifs/tarfs/tarfs.go b/ocifs/tarfs/tarfs.go new file mode 100644 index 0000000..79c2248 --- /dev/null +++ b/ocifs/tarfs/tarfs.go @@ -0,0 +1,624 @@ +// Package tarfs builds an fs.FS over a decompressed tar stream backed by an +// io.ReaderAt. The filesystem is immutable after construction; all read paths +// are safe for concurrent use. +package tarfs + +import ( + "archive/tar" + "errors" + "fmt" + "io" + "io/fs" + "math" + "path" + "sort" + "strings" + "time" +) + +// ErrSymlinkLoop is returned (wrapped in *fs.PathError) when symlink resolution +// in Open exceeds the per-call hop budget of 255 hops. +var ErrSymlinkLoop = errors.New("tarfs: too many levels of symbolic links") + +const ( + maxSymlinkHops = 256 + maxHardlinkHops = 8 +) + +// Header holds the subset of tar entry metadata that tarfs requires. It is +// JSON-serializable with stable round-trip semantics; ModTime is stored as +// Unix nanoseconds rather than time.Time to avoid timezone drift. +type Header struct { + Typeflag byte `json:"typeflag"` + Name string `json:"name"` + Linkname string `json:"linkname,omitempty"` + Size int64 `json:"size"` + Mode fs.FileMode `json:"mode"` + Uid int `json:"uid,omitempty"` + Gid int `json:"gid,omitempty"` + Uname string `json:"uname,omitempty"` + Gname string `json:"gname,omitempty"` + ModTime int64 `json:"modtime,omitempty"` +} + +// Entry pairs a Header with its byte offset in the decompressed tar stream +// and a normalized Filename (no leading ./ or /, no trailing /). +type Entry struct { + Header Header `json:"header"` + Offset int64 `json:"offset"` + Filename string `json:"filename"` +} + +// FS is a read-only fs.FS over a tar archive accessed via io.ReaderAt. +type FS struct { + ra io.ReaderAt + files []Entry + index map[string]int + dirs map[string][]fs.DirEntry +} + +// New scans the tar archive in ra (length size; pass -1 if unknown) and +// returns a ready-to-use FS. Implicit ancestor directories are synthesized. +func New(ra io.ReaderAt, size int64) (*FS, error) { + if size < 0 { + size = math.MaxInt64 + } + sr := io.NewSectionReader(ra, 0, size) + entries, err := scan(sr) + if err != nil { + return nil, err + } + return NewFromEntries(ra, entries) +} + +// Index scans the tar stream r and returns the raw entry list. Implicit +// ancestor directories are NOT synthesized; that is the responsibility of +// NewFromEntries when constructing a live FS. +func Index(r io.Reader) ([]Entry, error) { + return scan(r) +} + +// NewFromEntries constructs a live FS from a pre-built entry slice and a +// random-access decompressed reader. Synthesizes implicit ancestor +// directories. Whiteout entries (.wh. prefix, .wh..wh..opq) remain in the +// path index (so Open succeeds) but are excluded from ReadDir results. +// Entries whose Filename contains a ".." component are silently dropped to +// prevent TarSlip/ZipSlip path traversal. +func NewFromEntries(ra io.ReaderAt, entries []Entry) (*FS, error) { + safe := entries[:0:0] + for _, e := range entries { + if !containsDotDot(e.Filename) { + safe = append(safe, e) + } + } + entries = safe + + files := make([]Entry, len(entries)) + copy(files, entries) + + idx := make(map[string]int, len(files)) + for i, e := range files { + idx[e.Filename] = i + } + + synth := synthesizeAncestors(files, idx) + if len(synth) > 0 { + base := len(files) + files = append(files, synth...) + for i, e := range synth { + idx[e.Filename] = base + i + } + } + + dirs := buildDirs(files, idx) + + return &FS{ra: ra, files: files, index: idx, dirs: dirs}, nil +} + +// Entries returns a copy of the underlying entry slice. Synthetic ancestor +// directory entries created during construction are excluded; whiteouts and +// hardlinks are included. +func (f *FS) Entries() []Entry { + out := make([]Entry, 0, len(f.files)) + for _, e := range f.files { + if e.synthetic() { + continue + } + out = append(out, e) + } + return out +} + +func (e Entry) synthetic() bool { + // Synthetic ancestor directories are minted with a sentinel header + // (Mode=fs.ModeDir|0755, Typeflag=tar.TypeDir, all other fields zero, + // Offset=0, Size=0). The cleanest way to tag them without bloating + // Entry is by Header.Name being empty: real entries always have a + // non-empty Header.Name, while synthesized ones leave it blank. + return e.Header.Name == "" && e.Header.Typeflag == tar.TypeDir +} + +func scan(r io.Reader) ([]Entry, error) { + cr := &countingReader{r: r} + tr := tar.NewReader(cr) + var entries []Entry + for { + th, err := tr.Next() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("tarfs: read tar header: %w", err) + } + offset := cr.n + mode := th.FileInfo().Mode() + filename := normalize(th.Name) + // Skip entries with ".." components to prevent TarSlip path traversal. + if containsDotDot(filename) { + continue + } + entries = append(entries, Entry{ + Header: Header{ + Typeflag: th.Typeflag, + Name: th.Name, + Linkname: th.Linkname, + Size: th.Size, + Mode: mode, + Uid: th.Uid, + Gid: th.Gid, + Uname: th.Uname, + Gname: th.Gname, + ModTime: th.ModTime.UnixNano(), + }, + Offset: offset, + Filename: filename, + }) + } + return entries, nil +} + +// containsDotDot reports whether any component of the normalized path is "..". +// Such entries are rejected to prevent TarSlip/ZipSlip path traversal attacks. +func containsDotDot(p string) bool { + for _, c := range strings.Split(p, "/") { + if c == ".." { + return true + } + } + return false +} + +// countingReader wraps an io.Reader to track total bytes consumed so the +// scan loop can record each entry's data offset. +type countingReader struct { + r io.Reader + n int64 +} + +func (c *countingReader) Read(p []byte) (int, error) { + n, err := c.r.Read(p) + c.n += int64(n) + return n, err +} + +// normalize converts a tar entry name to the canonical FS key form: strip +// leading "./" or "/", strip trailing "/". +func normalize(name string) string { + for strings.HasPrefix(name, "./") { + name = name[2:] + } + for strings.HasPrefix(name, "/") { + name = name[1:] + } + for strings.HasSuffix(name, "/") { + name = name[:len(name)-1] + } + if name == "." { + name = "" + } + return name +} + +// isWhiteout reports whether the path's base is an OCI whiteout marker. +func isWhiteout(filename string) bool { + base := path.Base(filename) + return base == ".wh..wh..opq" || strings.HasPrefix(base, ".wh.") +} + +// synthesizeAncestors returns synthetic directory entries for any ancestor +// path implied by entries' filenames but missing from idx. +func synthesizeAncestors(entries []Entry, idx map[string]int) []Entry { + seen := make(map[string]bool) + var out []Entry + for _, e := range entries { + if e.Filename == "" { + continue + } + parts := strings.Split(e.Filename, "/") + for i := 1; i < len(parts); i++ { + anc := strings.Join(parts[:i], "/") + if anc == "" { + continue + } + if _, ok := idx[anc]; ok { + continue + } + if seen[anc] { + continue + } + seen[anc] = true + out = append(out, Entry{ + Header: Header{ + Typeflag: tar.TypeDir, + Mode: fs.ModeDir | 0o755, + }, + Filename: anc, + }) + } + } + return out +} + +// buildDirs constructs the directory listing map. Whiteouts are excluded. +// Each slice is sorted by Name() ascending. Hardlink entries reference the +// resolved target so their DirEntry.Info reports the real Size/Mode. +func buildDirs(entries []Entry, idx map[string]int) map[string][]fs.DirEntry { + dirs := make(map[string][]fs.DirEntry) + dirs[""] = nil + for _, e := range entries { + if e.Header.Mode.IsDir() { + if _, ok := dirs[e.Filename]; !ok { + dirs[e.Filename] = nil + } + } + } + for i := range entries { + e := entries[i] + if e.Filename == "" { + continue + } + if isWhiteout(e.Filename) { + continue + } + parent, base := splitParent(e.Filename) + target := &entries[i] + if target.Header.Typeflag == tar.TypeLink { + if resolved, ok := resolveHardlink(entries, idx, target); ok { + target = resolved + } + } + dirs[parent] = append(dirs[parent], &dirEntry{ + name: base, + entry: target, + }) + } + for k, v := range dirs { + sort.Slice(v, func(i, j int) bool { return v[i].Name() < v[j].Name() }) + dirs[k] = v + } + return dirs +} + +func resolveHardlink(entries []Entry, idx map[string]int, e *Entry) (*Entry, bool) { + cur := e + for hop := 0; hop < maxHardlinkHops && cur.Header.Typeflag == tar.TypeLink; hop++ { + i, ok := idx[normalize(cur.Header.Linkname)] + if !ok { + return nil, false + } + cur = &entries[i] + } + if cur.Header.Typeflag == tar.TypeLink { + return nil, false + } + return cur, true +} + +func splitParent(p string) (parent, base string) { + i := strings.LastIndex(p, "/") + if i < 0 { + return "", p + } + return p[:i], p[i+1:] +} + +// Open implements fs.FS. Symlinks are followed up to maxSymlinkHops-1 +// hops; hardlinks up to maxHardlinkHops hops. +func (f *FS) Open(name string) (fs.File, error) { + if !fs.ValidPath(name) { + return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrInvalid} + } + if name == "." { + return f.openDir("", name), nil + } + + hops := maxSymlinkHops + components := splitPath(name) + stack := make([]string, 0, len(components)) + var current *Entry + + for len(components) > 0 { + c := components[0] + components = components[1:] + switch c { + case ".": + continue + case "..": + if len(stack) > 0 { + stack = stack[:len(stack)-1] + } + current = nil + continue + } + stack = append(stack, c) + resolved := strings.Join(stack, "/") + i, ok := f.index[resolved] + if !ok { + return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist} + } + current = &f.files[i] + + if current.Header.Mode&fs.ModeSymlink != 0 { + if current.Header.Linkname == "" { + return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist} + } + hops-- + if hops == 0 { + return nil, &fs.PathError{Op: "open", Path: name, Err: ErrSymlinkLoop} + } + target := current.Header.Linkname + var rebuilt []string + if path.IsAbs(target) { + rebuilt = append(rebuilt, splitPath(target)...) + } else { + rebuilt = append(rebuilt, stack[:len(stack)-1]...) + rebuilt = append(rebuilt, splitPath(target)...) + } + rebuilt = append(rebuilt, components...) + components = rebuilt + stack = stack[:0] + current = nil + continue + } + + if current.Header.Mode.IsDir() { + continue + } + + // non-directory, non-symlink — must be the final component + if len(components) > 0 { + return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist} + } + } + + if len(stack) == 0 { + return f.openDir("", name), nil + } + + resolved := strings.Join(stack, "/") + if current == nil { + i, ok := f.index[resolved] + if !ok { + return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist} + } + current = &f.files[i] + } + + if current.Header.Mode.IsDir() { + return f.openDir(resolved, name), nil + } + + target, ok := resolveHardlink(f.files, f.index, current) + if !ok { + return nil, &fs.PathError{Op: "open", Path: name, Err: fs.ErrNotExist} + } + return f.openFile(target, name), nil +} + +func (f *FS) openFile(e *Entry, openArg string) fs.File { + return &file{ + entry: e, + openArg: openArg, + sr: io.NewSectionReader(f.ra, e.Offset, e.Header.Size), + } +} + +func (f *FS) openDir(resolved, openArg string) fs.File { + return &dirFile{ + fs: f, + resolved: resolved, + openArg: openArg, + entries: f.dirs[resolved], + } +} + +// ReadDir implements fs.ReadDirFS. Calls Open and forwards to the directory +// handle's ReadDir(-1) so symlinks naming directories are followed. +func (f *FS) ReadDir(name string) ([]fs.DirEntry, error) { + file, err := f.Open(name) + if err != nil { + return nil, err + } + defer file.Close() + rdf, ok := file.(fs.ReadDirFile) + if !ok { + return nil, &fs.PathError{Op: "readdir", Path: name, Err: fs.ErrInvalid} + } + return rdf.ReadDir(-1) +} + +// Stat returns metadata for name with symlinks followed. +func (f *FS) Stat(name string) (fs.FileInfo, error) { + file, err := f.Open(name) + if err != nil { + var pe *fs.PathError + if errors.As(err, &pe) { + pe.Op = "stat" + } + return nil, err + } + defer file.Close() + return file.Stat() +} + +// Lstat returns metadata for name without following a final-component symlink. +// Intermediate components are still resolved through symlinks. +func (f *FS) Lstat(name string) (fs.FileInfo, error) { + if !fs.ValidPath(name) { + return nil, &fs.PathError{Op: "lstat", Path: name, Err: fs.ErrInvalid} + } + if name == "." { + return rootFileInfo(name), nil + } + + parent, base := splitParent(name) + var dirPath string + if parent == "" { + dirPath = "" + } else { + // Resolve the parent through Open's traversal so intermediate + // symlinks are followed; then assemble the final-component lookup. + fi, err := f.Stat(parent) + if err != nil { + pe := &fs.PathError{Op: "lstat", Path: name, Err: fs.ErrNotExist} + if errors.Is(err, fs.ErrInvalid) { + pe.Err = fs.ErrInvalid + } + return nil, pe + } + if !fi.IsDir() { + return nil, &fs.PathError{Op: "lstat", Path: name, Err: fs.ErrNotExist} + } + var rdErr error + dirPath, rdErr = f.resolveDir(parent) + if rdErr != nil { + return nil, &fs.PathError{Op: "lstat", Path: name, Err: rdErr} + } + } + + var lookup string + if dirPath == "" { + lookup = base + } else { + lookup = dirPath + "/" + base + } + i, ok := f.index[lookup] + if !ok { + return nil, &fs.PathError{Op: "lstat", Path: name, Err: fs.ErrNotExist} + } + return entryFileInfo(&f.files[i], base), nil +} + +// resolveDir runs Open's traversal but returns the canonical directory path +// instead of a file handle. Used by Lstat for parent resolution. Caller +// guarantees name is a valid directory. Returns ErrSymlinkLoop if the hop +// budget is exhausted while resolving intermediate symlinks. +func (f *FS) resolveDir(name string) (string, error) { + hops := maxSymlinkHops + components := splitPath(name) + stack := make([]string, 0, len(components)) + for len(components) > 0 { + c := components[0] + components = components[1:] + switch c { + case ".": + continue + case "..": + if len(stack) > 0 { + stack = stack[:len(stack)-1] + } + continue + } + stack = append(stack, c) + resolved := strings.Join(stack, "/") + i, ok := f.index[resolved] + if !ok { + return resolved, nil + } + cur := &f.files[i] + if cur.Header.Mode&fs.ModeSymlink != 0 && cur.Header.Linkname != "" { + hops-- + if hops == 0 { + return "", ErrSymlinkLoop + } + target := cur.Header.Linkname + var rebuilt []string + if path.IsAbs(target) { + rebuilt = append(rebuilt, splitPath(target)...) + } else { + rebuilt = append(rebuilt, stack[:len(stack)-1]...) + rebuilt = append(rebuilt, splitPath(target)...) + } + rebuilt = append(rebuilt, components...) + components = rebuilt + stack = stack[:0] + continue + } + } + return strings.Join(stack, "/"), nil +} + +// splitPath strips any leading "/" then splits on "/" filtering empty parts. +func splitPath(p string) []string { + for strings.HasPrefix(p, "/") { + p = p[1:] + } + if p == "" { + return nil + } + parts := strings.Split(p, "/") + out := parts[:0] + for _, c := range parts { + if c == "" { + continue + } + out = append(out, c) + } + return out +} + +// dirEntry implements fs.DirEntry backed by an Entry pointer. +type dirEntry struct { + name string + entry *Entry +} + +func (d *dirEntry) Name() string { return d.name } +func (d *dirEntry) IsDir() bool { return d.entry.Header.Mode.IsDir() } +func (d *dirEntry) Type() fs.FileMode { + return d.entry.Header.Mode.Type() +} +func (d *dirEntry) Info() (fs.FileInfo, error) { + return entryFileInfo(d.entry, d.name), nil +} + +// fileInfo implements fs.FileInfo built from a Header. +type fileInfo struct { + name string + h Header +} + +func (fi *fileInfo) Name() string { return fi.name } +func (fi *fileInfo) Size() int64 { return fi.h.Size } +func (fi *fileInfo) Mode() fs.FileMode { return fi.h.Mode } +func (fi *fileInfo) ModTime() time.Time { return time.Unix(0, fi.h.ModTime).UTC() } +func (fi *fileInfo) IsDir() bool { return fi.h.Mode.IsDir() } +func (fi *fileInfo) Sys() any { return nil } + +func entryFileInfo(e *Entry, name string) fs.FileInfo { + return &fileInfo{name: name, h: e.Header} +} + +func rootFileInfo(openArg string) fs.FileInfo { + return &fileInfo{ + name: path.Base(openArg), + h: Header{Typeflag: tar.TypeDir, Mode: fs.ModeDir | 0o755}, + } +} + +// Compile-time interface checks. +var ( + _ fs.FS = (*FS)(nil) + _ fs.ReadDirFS = (*FS)(nil) + _ fs.StatFS = (*FS)(nil) +) diff --git a/ocifs/tarfs/tarfs_test.go b/ocifs/tarfs/tarfs_test.go new file mode 100644 index 0000000..181df08 --- /dev/null +++ b/ocifs/tarfs/tarfs_test.go @@ -0,0 +1,632 @@ +package tarfs + +import ( + "archive/tar" + "bytes" + "errors" + "io" + "io/fs" + "path" + "reflect" + "sort" + "strings" + "testing" + "testing/fstest" + "time" +) + +// fixtureFile describes one entry to write into a synthetic tar archive. +type fixtureFile struct { + Name string + Linkname string + Mode int64 + Type byte + Body string +} + +// buildTarBytes writes the supplied fixtures into a tar archive and returns +// the raw bytes. +func buildTarBytes(t *testing.T, files []fixtureFile) []byte { + t.Helper() + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + for _, f := range files { + mode := f.Mode + if mode == 0 { + switch f.Type { + case tar.TypeDir: + mode = 0o755 + case tar.TypeSymlink: + mode = 0o777 + default: + mode = 0o644 + } + } + hdr := &tar.Header{ + Name: f.Name, + Linkname: f.Linkname, + Mode: mode, + Typeflag: f.Type, + Size: int64(len(f.Body)), + ModTime: time.Unix(1700000000, 0).UTC(), + } + if f.Type == tar.TypeDir || f.Type == tar.TypeSymlink || f.Type == tar.TypeLink { + hdr.Size = 0 + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("tar header %q: %v", f.Name, err) + } + if hdr.Size > 0 { + if _, err := tw.Write([]byte(f.Body)); err != nil { + t.Fatalf("tar body %q: %v", f.Name, err) + } + } + } + if err := tw.Close(); err != nil { + t.Fatalf("tar close: %v", err) + } + return buf.Bytes() +} + +// buildTar wraps buildTarBytes in a *bytes.Reader (which satisfies io.ReaderAt). +func buildTar(t *testing.T, files []fixtureFile) *bytes.Reader { + t.Helper() + return bytes.NewReader(buildTarBytes(t, files)) +} + +func basicFixture() []fixtureFile { + return []fixtureFile{ + {Name: "etc/", Type: tar.TypeDir}, + {Name: "etc/hostname", Type: tar.TypeReg, Body: "node\n"}, + {Name: "etc/hosts", Type: tar.TypeReg, Body: "127.0.0.1\n"}, + {Name: "bin/sh", Type: tar.TypeReg, Body: "#!/bin/sh\necho hi\n"}, + {Name: "bin/ash", Type: tar.TypeLink, Linkname: "bin/sh"}, + {Name: "bin/sh-link", Type: tar.TypeSymlink, Linkname: "sh"}, + {Name: "var/log/", Type: tar.TypeDir}, + {Name: "var/log/.wh.secret", Type: tar.TypeReg, Body: ""}, + } +} + +func TestNew_BasicFixture(t *testing.T) { + ra := buildTar(t, basicFixture()) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + + t.Run("Open regular", func(t *testing.T) { + f, err := tfs.Open("etc/hostname") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + got, err := io.ReadAll(f) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if string(got) != "node\n" { + t.Errorf("got %q want %q", got, "node\n") + } + }) + + t.Run("Stat name is base of openArg", func(t *testing.T) { + f, err := tfs.Open("etc/hostname") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + t.Fatalf("Stat: %v", err) + } + if fi.Name() != "hostname" { + t.Errorf("Name() = %q want %q", fi.Name(), "hostname") + } + }) + + t.Run("Open root", func(t *testing.T) { + f, err := tfs.Open(".") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + fi, err := f.Stat() + if err != nil { + t.Fatalf("Stat: %v", err) + } + if !fi.IsDir() { + t.Errorf("root should be dir") + } + if fi.Name() != "." { + t.Errorf("Name() = %q want %q", fi.Name(), ".") + } + rdf, ok := f.(fs.ReadDirFile) + if !ok { + t.Fatalf("root not ReadDirFile") + } + entries, err := rdf.ReadDir(-1) + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + names := make([]string, 0, len(entries)) + for _, e := range entries { + names = append(names, e.Name()) + } + sort.Strings(names) + want := []string{"bin", "etc", "var"} + if !reflect.DeepEqual(names, want) { + t.Errorf("root entries = %v want %v", names, want) + } + }) + + t.Run("Symlink follow", func(t *testing.T) { + f, err := tfs.Open("bin/sh-link") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + got, err := io.ReadAll(f) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !strings.Contains(string(got), "echo hi") { + t.Errorf("symlink content = %q", got) + } + fi, _ := f.Stat() + if fi.Name() != "sh-link" { + t.Errorf("Stat.Name() = %q want sh-link", fi.Name()) + } + }) + + t.Run("Hardlink resolved", func(t *testing.T) { + f, err := tfs.Open("bin/ash") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + got, err := io.ReadAll(f) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if !strings.Contains(string(got), "echo hi") { + t.Errorf("hardlink content = %q (expected sh body)", got) + } + }) + + t.Run("Whiteout invisible in ReadDir", func(t *testing.T) { + entries, err := tfs.ReadDir("var/log") + if err != nil { + t.Fatalf("ReadDir: %v", err) + } + for _, e := range entries { + if strings.HasPrefix(e.Name(), ".wh.") { + t.Errorf("whiteout %q leaked into ReadDir", e.Name()) + } + } + }) + + t.Run("Whiteout openable by Open", func(t *testing.T) { + f, err := tfs.Open("var/log/.wh.secret") + if err != nil { + t.Fatalf("Open whiteout: %v", err) + } + f.Close() + }) + + t.Run("Lstat returns symlink itself", func(t *testing.T) { + fi, err := tfs.Lstat("bin/sh-link") + if err != nil { + t.Fatalf("Lstat: %v", err) + } + if fi.Mode()&fs.ModeSymlink == 0 { + t.Errorf("Lstat mode = %v, want symlink bit set", fi.Mode()) + } + if fi.Name() != "sh-link" { + t.Errorf("Name() = %q", fi.Name()) + } + }) + + t.Run("Stat follows symlink", func(t *testing.T) { + fi, err := tfs.Stat("bin/sh-link") + if err != nil { + t.Fatalf("Stat: %v", err) + } + if fi.Mode()&fs.ModeSymlink != 0 { + t.Errorf("Stat mode should not include symlink bit, got %v", fi.Mode()) + } + }) + + t.Run("Missing path is *fs.PathError", func(t *testing.T) { + _, err := tfs.Open("does/not/exist") + var pe *fs.PathError + if !errors.As(err, &pe) { + t.Fatalf("want *fs.PathError, got %T %v", err, err) + } + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("want fs.ErrNotExist, got %v", err) + } + }) + + t.Run("Invalid path", func(t *testing.T) { + _, err := tfs.Open("/abs") + if !errors.Is(err, fs.ErrInvalid) { + t.Errorf("want fs.ErrInvalid, got %v", err) + } + }) +} + +func TestIndex_RoundTrip(t *testing.T) { + raw := buildTarBytes(t, basicFixture()) + ra := bytes.NewReader(raw) + + entries, err := Index(bytes.NewReader(raw)) + if err != nil { + t.Fatalf("Index: %v", err) + } + + tfs, err := NewFromEntries(ra, entries) + if err != nil { + t.Fatalf("NewFromEntries: %v", err) + } + + got, err := fs.ReadFile(tfs, "etc/hosts") + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "127.0.0.1\n" { + t.Errorf("got %q", got) + } +} + +func TestIndex_NoSyntheticDirs(t *testing.T) { + files := []fixtureFile{ + {Name: "deep/nested/path/file.txt", Type: tar.TypeReg, Body: "x"}, + } + raw := buildTarBytes(t, files) + entries, err := Index(bytes.NewReader(raw)) + if err != nil { + t.Fatalf("Index: %v", err) + } + if len(entries) != 1 { + t.Errorf("Index returned %d entries, want 1 (no synthesis)", len(entries)) + } +} + +func TestNewFromEntries_SynthesizesAncestors(t *testing.T) { + files := []fixtureFile{ + {Name: "deep/nested/path/file.txt", Type: tar.TypeReg, Body: "x"}, + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + for _, p := range []string{"deep", "deep/nested", "deep/nested/path"} { + fi, err := tfs.Stat(p) + if err != nil { + t.Errorf("Stat(%q): %v", p, err) + continue + } + if !fi.IsDir() { + t.Errorf("Stat(%q) not dir", p) + } + } +} + +func TestEntries_ExcludesSynthetic(t *testing.T) { + files := []fixtureFile{ + {Name: "a/b/c.txt", Type: tar.TypeReg, Body: "y"}, + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + got := tfs.Entries() + if len(got) != 1 { + t.Errorf("Entries() = %d items, want 1; got %+v", len(got), got) + } + if got[0].Filename != "a/b/c.txt" { + t.Errorf("Filename = %q", got[0].Filename) + } +} + +func TestEntries_IncludesWhiteouts(t *testing.T) { + ra := buildTar(t, basicFixture()) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + var found bool + for _, e := range tfs.Entries() { + if path.Base(e.Filename) == ".wh.secret" { + found = true + } + } + if !found { + t.Errorf("Entries() missing whiteout") + } +} + +func TestSymlinkLoop(t *testing.T) { + files := []fixtureFile{ + {Name: "a", Type: tar.TypeSymlink, Linkname: "b"}, + {Name: "b", Type: tar.TypeSymlink, Linkname: "a"}, + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + _, err = tfs.Open("a") + if !errors.Is(err, ErrSymlinkLoop) { + t.Fatalf("want ErrSymlinkLoop, got %v", err) + } + var pe *fs.PathError + if !errors.As(err, &pe) { + t.Errorf("want *fs.PathError") + } +} + +func TestSymlinkAbsolute(t *testing.T) { + files := []fixtureFile{ + {Name: "data/payload", Type: tar.TypeReg, Body: "PAYLOAD"}, + {Name: "ptr", Type: tar.TypeSymlink, Linkname: "/data/payload"}, + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + got, err := fs.ReadFile(tfs, "ptr") + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "PAYLOAD" { + t.Errorf("got %q", got) + } +} + +func TestSymlinkRelativeWithDotDot(t *testing.T) { + files := []fixtureFile{ + {Name: "etc/conf", Type: tar.TypeReg, Body: "K=V"}, + {Name: "links/conf", Type: tar.TypeSymlink, Linkname: "../etc/conf"}, + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + got, err := fs.ReadFile(tfs, "links/conf") + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "K=V" { + t.Errorf("got %q", got) + } +} + +func TestHardlinkChain(t *testing.T) { + files := []fixtureFile{ + {Name: "real", Type: tar.TypeReg, Body: "REAL"}, + {Name: "hl1", Type: tar.TypeLink, Linkname: "real"}, + {Name: "hl2", Type: tar.TypeLink, Linkname: "hl1"}, + {Name: "hl3", Type: tar.TypeLink, Linkname: "./hl2"}, + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + got, err := fs.ReadFile(tfs, "hl3") + if err != nil { + t.Fatalf("ReadFile: %v", err) + } + if string(got) != "REAL" { + t.Errorf("got %q", got) + } +} + +func TestHardlinkBroken(t *testing.T) { + files := []fixtureFile{ + {Name: "broken", Type: tar.TypeLink, Linkname: "missing"}, + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + _, err = tfs.Open("broken") + if !errors.Is(err, fs.ErrNotExist) { + t.Errorf("want fs.ErrNotExist, got %v", err) + } +} + +func TestReadDirCursorSemantics(t *testing.T) { + files := []fixtureFile{ + {Name: "d/a", Type: tar.TypeReg, Body: "a"}, + {Name: "d/b", Type: tar.TypeReg, Body: "b"}, + {Name: "d/c", Type: tar.TypeReg, Body: "c"}, + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + + t.Run("n<=0 returns all then empty nil", func(t *testing.T) { + f, err := tfs.Open("d") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + rdf := f.(fs.ReadDirFile) + got, err := rdf.ReadDir(-1) + if err != nil { + t.Fatalf("ReadDir(-1): %v", err) + } + if len(got) != 3 { + t.Errorf("first batch = %d", len(got)) + } + got2, err := rdf.ReadDir(-1) + if err != nil { + t.Errorf("second ReadDir(-1) err = %v, want nil", err) + } + if len(got2) != 0 { + t.Errorf("second batch len = %d, want 0", len(got2)) + } + }) + + t.Run("n>0 paginates with EOF on final batch", func(t *testing.T) { + f, err := tfs.Open("d") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + rdf := f.(fs.ReadDirFile) + got, err := rdf.ReadDir(2) + if err != nil { + t.Fatalf("ReadDir(2) #1: %v", err) + } + if len(got) != 2 { + t.Errorf("got %d", len(got)) + } + got, err = rdf.ReadDir(2) + if !errors.Is(err, io.EOF) { + t.Errorf("ReadDir(2) #2 err = %v, want io.EOF", err) + } + if len(got) != 1 { + t.Errorf("ReadDir(2) #2 got %d", len(got)) + } + got, err = rdf.ReadDir(2) + if !errors.Is(err, io.EOF) { + t.Errorf("ReadDir(2) #3 err = %v, want io.EOF", err) + } + if len(got) != 0 { + t.Errorf("ReadDir(2) #3 got %d", len(got)) + } + }) + + t.Run("n>0 exact match returns nil err", func(t *testing.T) { + f, err := tfs.Open("d") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + rdf := f.(fs.ReadDirFile) + got, err := rdf.ReadDir(3) + if err != nil { + t.Errorf("err = %v, want nil (exact-fit)", err) + } + if len(got) != 3 { + t.Errorf("got %d", len(got)) + } + _, err = rdf.ReadDir(1) + if !errors.Is(err, io.EOF) { + t.Errorf("subsequent ReadDir(1) err = %v, want io.EOF", err) + } + }) +} + +func TestDirectoryRead(t *testing.T) { + files := []fixtureFile{{Name: "d/", Type: tar.TypeDir}} + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + f, err := tfs.Open("d") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + _, err = f.Read(make([]byte, 1)) + if !errors.Is(err, fs.ErrInvalid) { + t.Errorf("Read on dir err = %v, want fs.ErrInvalid", err) + } +} + +// writerToProbe records whether WriteTo was invoked on the source. +type writerToProbe struct { + buf bytes.Buffer +} + +func (w *writerToProbe) Write(p []byte) (int, error) { return w.buf.Write(p) } + +func TestFile_ImplementsWriterTo(t *testing.T) { + files := []fixtureFile{{Name: "blob", Type: tar.TypeReg, Body: strings.Repeat("z", 1024)}} + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + f, err := tfs.Open("blob") + if err != nil { + t.Fatalf("Open: %v", err) + } + defer f.Close() + if _, ok := f.(io.WriterTo); !ok { + t.Fatalf("file does not implement io.WriterTo") + } + probe := &writerToProbe{} + n, err := io.Copy(probe, f) + if err != nil { + t.Fatalf("Copy: %v", err) + } + if n != 1024 { + t.Errorf("copied %d bytes, want 1024", n) + } + if probe.buf.Len() != 1024 { + t.Errorf("probe got %d bytes", probe.buf.Len()) + } +} + +func TestFstestTestFS(t *testing.T) { + files := []fixtureFile{ + {Name: "etc/", Type: tar.TypeDir}, + {Name: "etc/conf", Type: tar.TypeReg, Body: "k=v"}, + {Name: "bin/", Type: tar.TypeDir}, + {Name: "bin/run", Type: tar.TypeReg, Body: "RUN"}, + {Name: "bin/run-link", Type: tar.TypeSymlink, Linkname: "run"}, + {Name: "bin/run-hardlink", Type: tar.TypeLink, Linkname: "bin/run"}, + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + want := []string{ + "etc/conf", + "bin/run", + "bin/run-link", + "bin/run-hardlink", + } + if err := fstest.TestFS(tfs, want...); err != nil { + t.Fatalf("TestFS: %v", err) + } +} + +func TestSize_NegativeOne(t *testing.T) { + ra := buildTar(t, basicFixture()) + if _, err := New(ra, -1); err != nil { + t.Fatalf("New(size=-1): %v", err) + } +} + +func TestModeBitsPreserved(t *testing.T) { + files := []fixtureFile{ + {Name: "exe", Type: tar.TypeReg, Mode: 0o4755, Body: "x"}, // setuid + } + ra := buildTar(t, files) + tfs, err := New(ra, int64(ra.Len())) + if err != nil { + t.Fatalf("New: %v", err) + } + fi, err := tfs.Stat("exe") + if err != nil { + t.Fatalf("Stat: %v", err) + } + if fi.Mode()&fs.ModeSetuid == 0 { + t.Errorf("setuid bit dropped: mode = %v", fi.Mode()) + } +} diff --git a/ocifs/zstdr/README.md b/ocifs/zstdr/README.md new file mode 100644 index 0000000..3f63177 --- /dev/null +++ b/ocifs/zstdr/README.md @@ -0,0 +1,68 @@ +# zstdr + +Package `zstdr` provides sequential zstd scanning and random-access decompressed reading via frame-boundary checkpoints. + +## The problem it solves + +A zstd stream is a sequence of independently decompressible frames. Each frame can be decompressed from its magic bytes without any prior context, making random access straightforward: record the compressed and decompressed byte offset of every frame, and any byte range of the decompressed stream can be reached by seeking to the nearest frame start and decompressing only that frame. + +## Two-phase usage + +### Phase 1 — Scan (once per blob) + +`Scan` makes a single sequential pass over the zstd stream, decompressing to an `io.Writer` and recording one `FrameCheckpoint` per non-skippable frame. + +```go +import "github.com/docker/oci/ocifs/zstdr" + +idx, err := zstdr.Scan(compressedReader, io.Discard) +// idx.Frames holds one entry per frame. +// idx.Size is the total decompressed length. +``` + +Skippable frames (magic `0x184D2A5x`) are consumed and discarded; they do not appear in the index. The `Index` is JSON-serializable; persist it to avoid re-scanning. + +### Phase 2 — Random-access reads + +```go +// Build from a fresh scan: +reader, err := zstdr.NewReader(blobReaderAt, blobSize) + +// Or rebuild from a persisted index (no I/O at construction time): +reader := zstdr.NewReaderWithIndex(blobReaderAt, idx, blobSize) +defer reader.Close() + +buf := make([]byte, 4096) +n, err := reader.ReadAt(buf, decompressedOffset) +``` + +`ReadAt` finds the highest frame whose decompressed offset is at or before the requested position, fetches the compressed bytes for that frame, decompresses the whole frame into memory, then copies the requested slice out. + +## Key types + +| Type | Purpose | +|------|---------| +| `Index` | Frame checkpoint sequence + total decompressed size. JSON-serializable. | +| `FrameCheckpoint` | Compressed offset (`In`) and decompressed offset (`Out`) of a frame's first byte. | +| `Reader` | `io.ReaderAt` + `io.Closer` over the decompressed stream. | + +## Options + +```go +zstdr.WithMaxReaders(4) // allow up to 4 concurrent ReadAt calls (default: 8) +``` + +Unlike gzip checkpoints, zstd frame checkpoints carry no per-checkpoint memory overhead (no sliding-window history). The trade-off between seek granularity and memory is governed by how the zstd encoder chose frame sizes when the blob was created. + +## Concurrency + +`Reader.ReadAt` is safe for concurrent use. Each in-flight `ReadAt` acquires a slot from a bounded pool (size = `WithMaxReaders`) and obtains a `*zstd.Decoder` from a `sync.Pool`. Callers that exceed the concurrency cap block until a slot is returned. `Close` unblocks waiting callers with `ErrClosed`. + +## Comparison with gzipr + +| | `gzipr` | `zstdr` | +|--|---------|---------| +| Resume granularity | Configurable (default 1 MiB) | One frame (encoder-defined) | +| Per-checkpoint cost | ~32 KiB (sliding window) | 16 bytes | +| Decompression on ReadAt | Partial frame only | Whole frame | +| Standard | RFC 1952 + RFC 1951 | RFC 8878 | diff --git a/ocifs/zstdr/scan.go b/ocifs/zstdr/scan.go new file mode 100644 index 0000000..d4f6879 --- /dev/null +++ b/ocifs/zstdr/scan.go @@ -0,0 +1,280 @@ +package zstdr + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + + "github.com/klauspost/compress/zstd" +) + +const ( + // zstdFrameMagic is the little-endian uint32 form of the zstd + // standard frame magic (0x28 0xB5 0x2F 0xFD). + zstdFrameMagic = 0xFD2FB528 + + // skippableMagicMask and skippableMagicBase together identify + // skippable frames: (v & mask) == base. + skippableMagicMask = 0xFFFFFFF0 + skippableMagicBase = 0x184D2A50 +) + +// Scan performs a single sequential pass over the zstd stream r, +// writing all decompressed bytes to out and recording one +// FrameCheckpoint per non-skippable frame. Skippable frames are +// passed over (their bytes are read from r but not written to out and +// no checkpoint is emitted). +// +// opts are accepted for API symmetry with gzipr.Scan; no current +// option affects scanning behaviour. +// +// When the stream contains no non-skippable frames, the returned +// Index has a non-nil empty Frames slice (rather than nil), so JSON +// round-trip produces "frames":[] not "frames":null. +func Scan(r io.Reader, out io.Writer, opts ...Option) (*Index, error) { + // opts is accepted for API symmetry with gzipr; no option currently + // affects scanning behaviour. + + idx := &Index{ + Frames: []*FrameCheckpoint{}, + } + + // Reuse a single Decoder across frames. + dec, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(1)) + if err != nil { + return nil, fmt.Errorf("zstdr: %w", err) + } + defer dec.Close() + + // We read frames one at a time. For each frame we accumulate the + // raw frame bytes into frameBuf and then DecodeAll them. This + // keeps the implementation independent of the decoder's + // multi-frame stream handling. + var ( + inOff int64 + outOff int64 + frameBuf bytes.Buffer + decompBuf []byte + ) + + for { + // Peek at the first 4 bytes to identify the frame type. + var magicBuf [4]byte + _, err := io.ReadFull(r, magicBuf[:]) + if err == io.EOF { + // Clean end of stream at a frame boundary. + idx.Size = outOff + return idx, nil + } + if err == io.ErrUnexpectedEOF { + return nil, ErrInvalidFormat + } + if err != nil { + return nil, err + } + + magic := binary.LittleEndian.Uint32(magicBuf[:]) + + switch { + case (magic & skippableMagicMask) == skippableMagicBase: + // Skippable frame: 4-byte magic + 4-byte size + size bytes. + var sizeBuf [4]byte + if _, err := io.ReadFull(r, sizeBuf[:]); err != nil { + return nil, mapTrunc(err) + } + size := int64(binary.LittleEndian.Uint32(sizeBuf[:])) + if size > 0 { + if _, err := io.CopyN(io.Discard, r, size); err != nil { + return nil, mapTrunc(err) + } + } + inOff += 8 + size + + case magic == zstdFrameMagic: + // Standard frame: record checkpoint, then read the + // entire frame body using a frame walker. + idx.Frames = append(idx.Frames, &FrameCheckpoint{ + In: inOff, + Out: outOff, + }) + + frameBuf.Reset() + frameBuf.Write(magicBuf[:]) + frameLen, err := readFrameBody(r, &frameBuf) + if err != nil { + return nil, err + } + + // Decompress and stream out. + decompBuf, err = dec.DecodeAll(frameBuf.Bytes(), decompBuf[:0]) + if err != nil { + return nil, fmt.Errorf("%w: %v", ErrInvalidFormat, err) + } + if _, err := out.Write(decompBuf); err != nil { + return nil, err + } + + inOff += 4 + frameLen + outOff += int64(len(decompBuf)) + + default: + return nil, fmt.Errorf("%w: unrecognised frame magic 0x%08x at offset %d", ErrInvalidFormat, magic, inOff) + } + } +} + +// readFrameBody reads the bytes of a single zstd standard frame +// following the 4-byte magic that has already been consumed by the +// caller. The bytes are appended to dst. Returns the number of bytes +// appended (i.e. the size of the frame excluding its 4-byte magic). +// +// The walker parses the frame header (variable length per RFC 8878 +// §3.1.1.1) followed by a sequence of Block_Header (3 bytes) + +// block content tuples, terminating after the block whose Last_Block +// flag is set. If the frame's Content_Checksum_Flag is set, a 4-byte +// trailing checksum is also consumed. +func readFrameBody(r io.Reader, dst *bytes.Buffer) (int64, error) { + startLen := dst.Len() + + // Frame_Header_Descriptor (1 byte). + var fhd [1]byte + if _, err := io.ReadFull(r, fhd[:]); err != nil { + return 0, mapTrunc(err) + } + dst.Write(fhd[:]) + + dictIDFlag := fhd[0] & 0x03 + checksumFlag := (fhd[0] >> 2) & 0x01 + reservedBit := (fhd[0] >> 3) & 0x01 + singleSegment := (fhd[0]>>5)&0x01 != 0 + fcsFlag := fhd[0] >> 6 + + if reservedBit != 0 { + return 0, fmt.Errorf("%w: reserved bit set in frame header", ErrInvalidFormat) + } + + // Window_Descriptor (1 byte) — present iff Single_Segment_Flag is 0. + if !singleSegment { + var wd [1]byte + if _, err := io.ReadFull(r, wd[:]); err != nil { + return 0, mapTrunc(err) + } + dst.Write(wd[:]) + } + + // Dictionary_ID — 0/1/2/4 bytes per dictIDFlag. + dictIDSize := dictIDFieldSize(dictIDFlag) + if dictIDSize > 0 { + buf := make([]byte, dictIDSize) + if _, err := io.ReadFull(r, buf); err != nil { + return 0, mapTrunc(err) + } + dst.Write(buf) + } + + // Frame_Content_Size — variable per fcsFlag and singleSegment. + fcsSize := frameContentSizeFieldSize(fcsFlag, singleSegment) + if fcsSize > 0 { + buf := make([]byte, fcsSize) + if _, err := io.ReadFull(r, buf); err != nil { + return 0, mapTrunc(err) + } + dst.Write(buf) + } + + // Block sequence: read Block_Header (3 bytes) + content until + // Last_Block flag is set. + for { + var bh [3]byte + if _, err := io.ReadFull(r, bh[:]); err != nil { + return 0, mapTrunc(err) + } + dst.Write(bh[:]) + + // Block_Header is 24 bits, little-endian. + bhVal := uint32(bh[0]) | uint32(bh[1])<<8 | uint32(bh[2])<<16 + lastBlock := bhVal & 0x01 + blockType := (bhVal >> 1) & 0x03 + blockSize := bhVal >> 3 + + var contentLen int64 + switch blockType { + case 0: // Raw_Block + contentLen = int64(blockSize) + case 1: // RLE_Block — single byte repeated Block_Size times. + contentLen = 1 + case 2: // Compressed_Block + contentLen = int64(blockSize) + case 3: + return 0, fmt.Errorf("%w: reserved block type", ErrInvalidFormat) + } + + if contentLen > 0 { + if _, err := io.CopyN(dst, r, contentLen); err != nil { + return 0, mapTrunc(err) + } + } + + if lastBlock != 0 { + break + } + } + + // Optional 4-byte content checksum. + if checksumFlag != 0 { + var ck [4]byte + if _, err := io.ReadFull(r, ck[:]); err != nil { + return 0, mapTrunc(err) + } + dst.Write(ck[:]) + } + + return int64(dst.Len() - startLen), nil +} + +// dictIDFieldSize returns the size of the Dictionary_ID field in bytes +// for the two-bit Dictionary_ID_Flag. +func dictIDFieldSize(flag byte) int { + switch flag { + case 0: + return 0 + case 1: + return 1 + case 2: + return 2 + case 3: + return 4 + } + return 0 +} + +// frameContentSizeFieldSize returns the size of the Frame_Content_Size +// field in bytes given the FCS flag (top 2 bits of FHD) and the +// Single_Segment_Flag. +func frameContentSizeFieldSize(flag byte, singleSegment bool) int { + switch flag { + case 0: + if singleSegment { + return 1 + } + return 0 + case 1: + return 2 + case 2: + return 4 + case 3: + return 8 + } + return 0 +} + +// mapTrunc maps a truncation-style error from a Read into the Scan +// function's error contract: a partial frame surfaces as +// ErrInvalidFormat; other errors propagate. +func mapTrunc(err error) error { + if err == io.EOF || err == io.ErrUnexpectedEOF { + return ErrInvalidFormat + } + return err +} diff --git a/ocifs/zstdr/zstdr.go b/ocifs/zstdr/zstdr.go new file mode 100644 index 0000000..f0d2ae9 --- /dev/null +++ b/ocifs/zstdr/zstdr.go @@ -0,0 +1,395 @@ +// Package zstdr provides random-access reading of zstd-compressed blobs. +// +// The package mirrors gzipr but uses zstd frame boundaries (rather than +// DEFLATE block checkpoints) as resume points: each zstd frame is +// independently decompressible, so a frame index recording each frame's +// (compressed, decompressed) offset pair is sufficient to seek into the +// decompressed stream. +package zstdr + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "sort" + "sync" + + "github.com/klauspost/compress/zstd" +) + +// ErrInvalidFormat is returned when the compressed stream does not have +// the structure the index predicts (truncated frame, premature EOF, etc.). +var ErrInvalidFormat = errors.New("zstdr: invalid compressed format") + +// ErrClosed is returned by [Reader.ReadAt] after [Reader.Close] has been +// called. +var ErrClosed = errors.New("zstdr: reader has been closed") + +// FrameCheckpoint records the start of a non-skippable zstd frame. +// +// In is the byte offset within the compressed blob where the frame's +// magic bytes (0x28 0xB5 0x2F 0xFD) begin. Out is the byte offset within +// the concatenated decompressed stream where the frame's first +// decompressed byte will appear. +type FrameCheckpoint struct { + In int64 `json:"in"` + Out int64 `json:"out"` +} + +// Index is the persisted frame map for a zstd-compressed blob. +// +// Frames is sorted by ascending In (and therefore Out) and contains one +// entry per non-skippable frame. Skippable frames are not represented: +// the gap between consecutive Frames[i].In and Frames[i+1].In may be +// larger than the i-th frame's compressed length when a skippable frame +// follows. Size is the total decompressed length of the stream. +type Index struct { + Frames []*FrameCheckpoint `json:"frames"` + Size int64 `json:"size"` +} + +// Encode writes the JSON encoding of idx to w. +func (idx *Index) Encode(w io.Writer) error { + return json.NewEncoder(w).Encode(idx) +} + +// DecodeIndex reads a JSON-encoded Index from r. +func DecodeIndex(r io.Reader) (*Index, error) { + var idx Index + if err := json.NewDecoder(r).Decode(&idx); err != nil { + return nil, err + } + return &idx, nil +} + +// Option configures a Reader (or, for API symmetry with gzipr, a Scan +// call). Options are applied in order; later values override earlier +// ones for the same setting. +type Option func(*config) + +type config struct { + maxReaders int +} + +func defaultConfig() config { + return config{maxReaders: 8} +} + +// WithMaxReaders bounds the number of concurrent ReadAt operations the +// returned Reader will service in parallel. Calls beyond the cap block +// until a reader is returned to the pool. n must be >= 1; values < 1 +// are silently clamped to 1. +func WithMaxReaders(n int) Option { + return func(c *config) { + if n < 1 { + n = 1 + } + c.maxReaders = n + } +} + +// Reader serves io.ReaderAt over the decompressed bytes of a +// zstd-compressed blob, given an io.ReaderAt over the compressed blob +// and a frame Index. +type Reader struct { + ra io.ReaderAt + idx *Index + cmpSize int64 + + mu sync.Mutex + cond *sync.Cond + closed bool + inUse int + maxCap int + decoderPool sync.Pool // pool of *zstd.Decoder; keyed by Reset before use +} + +// NewReader scans the compressed blob via ra to build a frame Index and +// returns a Reader. size is the COMPRESSED blob size in bytes. +func NewReader(ra io.ReaderAt, size int64, opts ...Option) (*Reader, error) { + idx, err := Scan(io.NewSectionReader(ra, 0, size), io.Discard, opts...) + if err != nil { + return nil, err + } + return NewReaderWithIndex(ra, idx, size), nil +} + +// NewReaderWithIndex constructs a Reader from a pre-built Index without +// any blob scanning. size is the COMPRESSED blob size; idx.Size is the +// DECOMPRESSED size (returned by Reader.Size). +// +// Precondition: idx.Size > 0 and idx.Frames != nil. +func NewReaderWithIndex(ra io.ReaderAt, idx *Index, size int64, opts ...Option) *Reader { + if idx == nil { + panic("zstdr: NewReaderWithIndex: idx is nil") + } + if idx.Size <= 0 { + panic("zstdr: NewReaderWithIndex: idx.Size must be > 0") + } + if idx.Frames == nil { + panic("zstdr: NewReaderWithIndex: idx.Frames must not be nil") + } + cfg := defaultConfig() + for _, o := range opts { + o(&cfg) + } + r := &Reader{ + ra: ra, + idx: idx, + cmpSize: size, + maxCap: cfg.maxReaders, + } + r.cond = sync.NewCond(&r.mu) + r.decoderPool.New = func() any { + dec, err := zstd.NewReader(bytes.NewReader(nil), zstd.WithDecoderConcurrency(1)) + if err != nil { + panic("zstdr: failed to allocate decoder: " + err.Error()) + } + return dec + } + return r +} + +// Size returns the total decompressed length of the stream. +func (r *Reader) Size() int64 { + return r.idx.Size +} + +// Index returns the Reader's frame index. The returned pointer is the +// same value held by the Reader; callers must not mutate it. +func (r *Reader) Index() *Index { + return r.idx +} + +// Close drains the bounded reader pool. Subsequent ReadAt calls return +// ErrClosed; in-flight callers blocked on a pool slot also receive +// ErrClosed. +func (r *Reader) Close() error { + r.mu.Lock() + r.closed = true + r.cond.Broadcast() + for r.inUse > 0 { + r.cond.Wait() + } + r.mu.Unlock() + return nil +} + +// acquire reserves a pool slot. Returns ErrClosed if the Reader has +// been closed (either before the call or while waiting for a slot). +func (r *Reader) acquire() error { + r.mu.Lock() + defer r.mu.Unlock() + for { + if r.closed { + return ErrClosed + } + if r.inUse < r.maxCap { + r.inUse++ + return nil + } + r.cond.Wait() + } +} + +// release returns a pool slot. +func (r *Reader) release() { + r.mu.Lock() + r.inUse-- + r.cond.Broadcast() + r.mu.Unlock() +} + +// ReadAt implements io.ReaderAt over the decompressed stream. +// +// Returns (0, nil) for len(p) == 0 and (0, io.EOF) for off >= r.Size(). +// Otherwise returns (n, io.EOF) when n < len(p) (clamped read crossing +// end-of-stream) or (n, nil) when n == len(p) (full read, including a +// full read that ends exactly at end-of-stream). Format corruption +// surfaces as ErrInvalidFormat; underlying ReaderAt errors are returned +// as-is. +func (r *Reader) ReadAt(p []byte, off int64) (int, error) { + if len(p) == 0 { + return 0, nil + } + if off < 0 { + return 0, fmt.Errorf("zstdr: negative offset %d", off) + } + if off >= r.idx.Size { + return 0, io.EOF + } + if err := r.acquire(); err != nil { + return 0, err + } + defer r.release() + return r.readAtLocked(p, off) +} + +// readAtLocked performs the seeking algorithm described in the design +// document. The caller must hold a pool slot. +func (r *Reader) readAtLocked(p []byte, off int64) (int, error) { + target64 := int64(len(p)) + if remaining := r.idx.Size - off; target64 > remaining { + target64 = remaining + } + target := int(target64) + + first := r.findFirst(off) + lastReal, lastIdx := r.findLast(off + target64 - 1) + + // Determine the upper bound of the compressed range to fetch. + var nextIn int64 + switch { + case lastReal && lastIdx < len(r.idx.Frames)-1: + nextIn = r.idx.Frames[lastIdx+1].In + case lastReal && lastIdx == len(r.idx.Frames)-1: + nextIn = r.cmpSize + case !lastReal && len(r.idx.Frames) > 0: + // Virtual start-of-stream as F_last AND real frames exist: + // upper bound is the start of the first real frame. + nextIn = r.idx.Frames[0].In + default: + // Virtual start-of-stream AND no real frames. + nextIn = r.cmpSize + } + + if first.In >= nextIn { + return 0, ErrInvalidFormat + } + + // Fetch the compressed span. + cmpLen := nextIn - first.In + if cmpLen <= 0 { + return 0, ErrInvalidFormat + } + cmp := make([]byte, cmpLen) + n, err := r.ra.ReadAt(cmp, first.In) + if int64(n) < cmpLen { + // Underlying ReadAt returned fewer bytes than requested. Use the + // error it provided; synthesize one if it returned nil despite + // the short read. io.EOF is valid for an exact-length ReadAt + // (io.ReaderAt contract), so only treat it as an error here. + if err == nil { + err = io.ErrUnexpectedEOF + } + return 0, err + } + // n == cmpLen; swallow io.EOF (valid exact-length io.ReaderAt form). + + // Obtain a decoder from the pool and reset it to read from cmp. + dec := r.decoderPool.Get().(*zstd.Decoder) + if resetErr := dec.Reset(bytes.NewReader(cmp)); resetErr != nil { + r.decoderPool.Put(dec) + return 0, fmt.Errorf("zstdr: %w", resetErr) + } + defer r.decoderPool.Put(dec) + + // Discard phase: drop (off - first.Out) bytes from the decompressor. + discardRemaining := off - first.Out + if discardRemaining < 0 { + // Index is corrupt. + return 0, ErrInvalidFormat + } + if err := discardN(dec, discardRemaining); err != nil { + return 0, err + } + + // Read phase: pull `target` bytes into p, applying the design's + // 4-rule loop (post-update accumulation). + accumulated := 0 + for accumulated < target { + nr, rerr := dec.Read(p[accumulated:target]) + accumulated += nr + if accumulated >= target { + // Rule 1: success regardless of rerr. + break + } + switch rerr { + case nil: + // Rule 4: partial read, loop. + continue + case io.EOF, io.ErrUnexpectedEOF: + // Rule 2: stream ended before target. + return 0, ErrInvalidFormat + default: + // Rule 3: other non-nil error. + return 0, rerr + } + } + + if target < len(p) { + return target, io.EOF + } + return target, nil +} + +// findFirst returns the highest-Out frame with Out <= off, or a virtual +// start-of-stream frame {In: 0, Out: 0} when no real frame qualifies. +func (r *Reader) findFirst(off int64) *FrameCheckpoint { + frames := r.idx.Frames + // Binary search for the largest i with frames[i].Out <= off. + i := sort.Search(len(frames), func(i int) bool { + return frames[i].Out > off + }) - 1 + if i < 0 { + return &FrameCheckpoint{In: 0, Out: 0} + } + return frames[i] +} + +// findLast returns the index in r.idx.Frames of the highest-Out frame +// with Out <= bound, or (-1, false) when no real frame qualifies (the +// caller should treat this as the virtual start-of-stream frame). +func (r *Reader) findLast(bound int64) (bool, int) { + frames := r.idx.Frames + i := sort.Search(len(frames), func(i int) bool { + return frames[i].Out > bound + }) - 1 + if i < 0 { + return false, -1 + } + return true, i +} + +// Compile-time interface checks. +var ( + _ io.ReaderAt = (*Reader)(nil) + _ io.Closer = (*Reader)(nil) +) + +// discardN reads and discards n bytes from dec, applying the design's +// discard-phase rules: io.EOF or io.ErrUnexpectedEOF before n bytes +// have been consumed map to ErrInvalidFormat; other errors propagate. +func discardN(dec io.Reader, n int64) error { + if n == 0 { + return nil + } + const chunk = 32 * 1024 + buf := make([]byte, chunk) + for n > 0 { + want := int64(len(buf)) + if want > n { + want = n + } + got, err := dec.Read(buf[:want]) + n -= int64(got) + if n == 0 { + // Done. A trailing io.EOF is fine here — the read phase + // will surface ErrInvalidFormat if it cannot make further + // progress. + return nil + } + switch err { + case nil: + // Forward progress assumed (see design note); loop. + continue + case io.EOF, io.ErrUnexpectedEOF: + return ErrInvalidFormat + default: + return err + } + } + return nil +} diff --git a/ocifs/zstdr/zstdr_test.go b/ocifs/zstdr/zstdr_test.go new file mode 100644 index 0000000..957c6fe --- /dev/null +++ b/ocifs/zstdr/zstdr_test.go @@ -0,0 +1,664 @@ +package zstdr_test + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "math/rand" + "sync" + "sync/atomic" + "testing" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/docker/oci/ocifs/zstdr" +) + +// encodeFrames encodes the given byte slices as a sequence of +// concatenated zstd frames (one frame per slice). Useful to produce +// multi-frame fixtures with predictable boundaries. +func encodeFrames(t *testing.T, parts ...[]byte) []byte { + t.Helper() + var out bytes.Buffer + enc, err := zstd.NewWriter(&out) + require.NoError(t, err) + for i, p := range parts { + _, err := enc.Write(p) + require.NoError(t, err) + if i < len(parts)-1 { + require.NoError(t, enc.Close()) + enc.Reset(&out) + } + } + require.NoError(t, enc.Close()) + return out.Bytes() +} + +// encodeSingleFrame encodes data as a single zstd frame. +func encodeSingleFrame(t *testing.T, data []byte) []byte { + t.Helper() + var out bytes.Buffer + enc, err := zstd.NewWriter(&out) + require.NoError(t, err) + _, err = enc.Write(data) + require.NoError(t, err) + require.NoError(t, enc.Close()) + return out.Bytes() +} + +// makeSkippableFrame returns a skippable frame (magic + size + payload) +// with the given user ID (0..15). +func makeSkippableFrame(userID byte, payload []byte) []byte { + if userID > 15 { + panic("userID out of range") + } + magic := uint32(0x184D2A50) | uint32(userID) + hdr := make([]byte, 8) + binary.LittleEndian.PutUint32(hdr[0:4], magic) + binary.LittleEndian.PutUint32(hdr[4:8], uint32(len(payload))) + return append(hdr, payload...) +} + +func TestScanSingleFrameRoundTrip(t *testing.T) { + data := bytes.Repeat([]byte("hello, zstd! "), 4096) + compressed := encodeSingleFrame(t, data) + + var out bytes.Buffer + idx, err := zstdr.Scan(bytes.NewReader(compressed), &out) + require.NoError(t, err) + assert.Equal(t, data, out.Bytes()) + require.NotNil(t, idx.Frames) + require.Len(t, idx.Frames, 1) + assert.Equal(t, int64(0), idx.Frames[0].In) + assert.Equal(t, int64(0), idx.Frames[0].Out) + assert.Equal(t, int64(len(data)), idx.Size) +} + +func TestScanMultiFrame(t *testing.T) { + parts := [][]byte{ + bytes.Repeat([]byte("aaaa"), 1024), + bytes.Repeat([]byte("bbbb"), 2048), + bytes.Repeat([]byte("cccc"), 512), + } + compressed := encodeFrames(t, parts...) + + var out bytes.Buffer + idx, err := zstdr.Scan(bytes.NewReader(compressed), &out) + require.NoError(t, err) + + var expected []byte + for _, p := range parts { + expected = append(expected, p...) + } + assert.Equal(t, expected, out.Bytes()) + require.Len(t, idx.Frames, 3) + + // Frame Out offsets should partition the decompressed stream by + // part length. + var outOff int64 + for i, p := range parts { + assert.Equal(t, outOff, idx.Frames[i].Out, "frame %d", i) + outOff += int64(len(p)) + } + assert.Equal(t, int64(len(expected)), idx.Size) + + // In offsets must be strictly increasing and start at 0. + assert.Equal(t, int64(0), idx.Frames[0].In) + for i := 1; i < len(idx.Frames); i++ { + assert.Greater(t, idx.Frames[i].In, idx.Frames[i-1].In, "frame %d", i) + } +} + +func TestScanEmptyInput(t *testing.T) { + var out bytes.Buffer + idx, err := zstdr.Scan(bytes.NewReader(nil), &out) + require.NoError(t, err) + require.NotNil(t, idx.Frames) + assert.Len(t, idx.Frames, 0) + assert.Equal(t, int64(0), idx.Size) + assert.Equal(t, 0, out.Len()) +} + +func TestScanSkippableFramesArePassedThrough(t *testing.T) { + dataA := bytes.Repeat([]byte("AAAA"), 256) + dataB := bytes.Repeat([]byte("BBBB"), 256) + frameA := encodeSingleFrame(t, dataA) + frameB := encodeSingleFrame(t, dataB) + skip := makeSkippableFrame(3, []byte("hello-skippable-payload")) + skipEmpty := makeSkippableFrame(0, nil) + + // Stream layout: + // skipEmpty | frameA | skip | frameB + var compressed bytes.Buffer + compressed.Write(skipEmpty) + compressed.Write(frameA) + compressed.Write(skip) + compressed.Write(frameB) + + var out bytes.Buffer + idx, err := zstdr.Scan(bytes.NewReader(compressed.Bytes()), &out) + require.NoError(t, err) + assert.Equal(t, append(append([]byte{}, dataA...), dataB...), out.Bytes()) + require.Len(t, idx.Frames, 2) + assert.Equal(t, int64(len(skipEmpty)), idx.Frames[0].In) + assert.Equal(t, int64(len(skipEmpty)+len(frameA)+len(skip)), idx.Frames[1].In) + assert.Equal(t, int64(0), idx.Frames[0].Out) + assert.Equal(t, int64(len(dataA)), idx.Frames[1].Out) + assert.Equal(t, int64(len(dataA)+len(dataB)), idx.Size) +} + +func TestScanRejectsInvalidMagic(t *testing.T) { + // Four bytes that are neither zstd magic nor skippable. + var out bytes.Buffer + _, err := zstdr.Scan(bytes.NewReader([]byte{0x00, 0x01, 0x02, 0x03}), &out) + assert.ErrorIs(t, err, zstdr.ErrInvalidFormat) +} + +func TestScanRejectsTruncatedFrame(t *testing.T) { + data := bytes.Repeat([]byte("payload"), 100) + compressed := encodeSingleFrame(t, data) + + // Truncate halfway through the frame body. + truncated := compressed[:len(compressed)/2] + var out bytes.Buffer + _, err := zstdr.Scan(bytes.NewReader(truncated), &out) + assert.ErrorIs(t, err, zstdr.ErrInvalidFormat) +} + +func TestIndexEncodeDecodeRoundTrip(t *testing.T) { + idx := &zstdr.Index{ + Frames: []*zstdr.FrameCheckpoint{ + {In: 0, Out: 0}, + {In: 1024, Out: 4096}, + {In: 2048, Out: 8192}, + }, + Size: 16384, + } + var buf bytes.Buffer + require.NoError(t, idx.Encode(&buf)) + got, err := zstdr.DecodeIndex(&buf) + require.NoError(t, err) + assert.Equal(t, idx, got) +} + +func TestIndexEncodeEmptyFramesIsArrayNotNull(t *testing.T) { + idx := &zstdr.Index{Frames: []*zstdr.FrameCheckpoint{}, Size: 1} + var buf bytes.Buffer + require.NoError(t, idx.Encode(&buf)) + assert.Contains(t, buf.String(), `"frames":[]`) + assert.NotContains(t, buf.String(), `"frames":null`) +} + +// countingReaderAt records the byte ranges fetched, so tests can +// verify that ReadAt only fetches the expected compressed span. +type countingReaderAt struct { + data []byte + mu sync.Mutex + calls int + totalBytes int64 + ranges []offRange +} + +type offRange struct { + off int64 + n int +} + +func (c *countingReaderAt) ReadAt(p []byte, off int64) (int, error) { + c.mu.Lock() + c.calls++ + c.ranges = append(c.ranges, offRange{off: off, n: len(p)}) + c.totalBytes += int64(len(p)) + c.mu.Unlock() + if off < 0 || off >= int64(len(c.data)) { + return 0, io.EOF + } + n := copy(p, c.data[off:]) + if n < len(p) { + return n, io.EOF + } + return n, nil +} + +func TestNewReaderAndReadAt(t *testing.T) { + parts := [][]byte{ + bytes.Repeat([]byte("alpha-"), 1000), + bytes.Repeat([]byte("beta--"), 1500), + bytes.Repeat([]byte("gamma!"), 2000), + } + var expected []byte + for _, p := range parts { + expected = append(expected, p...) + } + compressed := encodeFrames(t, parts...) + + cra := &countingReaderAt{data: compressed} + r, err := zstdr.NewReader(cra, int64(len(compressed))) + require.NoError(t, err) + t.Cleanup(func() { r.Close() }) + + assert.Equal(t, int64(len(expected)), r.Size()) + require.Len(t, r.Index().Frames, 3) + + t.Run("FullReadFromZero", func(t *testing.T) { + buf := make([]byte, len(expected)) + n, err := r.ReadAt(buf, 0) + assert.NoError(t, err) + assert.Equal(t, len(expected), n) + assert.Equal(t, expected, buf) + }) + + t.Run("EmptyBuffer", func(t *testing.T) { + n, err := r.ReadAt(nil, 0) + assert.NoError(t, err) + assert.Equal(t, 0, n) + }) + + t.Run("OffsetAtEnd", func(t *testing.T) { + buf := make([]byte, 8) + n, err := r.ReadAt(buf, r.Size()) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) + }) + + t.Run("OffsetBeyondEnd", func(t *testing.T) { + buf := make([]byte, 8) + n, err := r.ReadAt(buf, r.Size()+100) + assert.Equal(t, 0, n) + assert.Equal(t, io.EOF, err) + }) + + t.Run("ClampedRead", func(t *testing.T) { + buf := make([]byte, 1024) + off := r.Size() - 100 + n, err := r.ReadAt(buf, off) + assert.Equal(t, 100, n) + assert.Equal(t, io.EOF, err) + assert.Equal(t, expected[off:], buf[:n]) + }) + + t.Run("FullReadEndingExactlyAtBoundary", func(t *testing.T) { + buf := make([]byte, 100) + off := r.Size() - 100 + n, err := r.ReadAt(buf, off) + assert.NoError(t, err) + assert.Equal(t, 100, n) + assert.Equal(t, expected[off:], buf) + }) +} + +func TestReadAtMatchesSequentialDecompress(t *testing.T) { + // Build a fixture with several frames. + rng := rand.New(rand.NewSource(42)) + var parts [][]byte + for i := 0; i < 5; i++ { + size := 1024 + rng.Intn(8192) + p := make([]byte, size) + // Mix of compressible and incompressible-ish data. + for j := range p { + p[j] = byte(rng.Intn(256)) + } + // Inject a repeating pattern in part of the buffer to make + // blocks compressible. + if i%2 == 0 { + pat := []byte("repeating-pattern-") + for j := 0; j+len(pat) <= len(p)/2; j += len(pat) { + copy(p[j:], pat) + } + } + parts = append(parts, p) + } + var expected []byte + for _, p := range parts { + expected = append(expected, p...) + } + compressed := encodeFrames(t, parts...) + + r, err := zstdr.NewReader(bytes.NewReader(compressed), int64(len(compressed))) + require.NoError(t, err) + t.Cleanup(func() { r.Close() }) + require.Equal(t, int64(len(expected)), r.Size()) + + // 100 random ReadAt calls. + for i := 0; i < 100; i++ { + off := int64(rng.Intn(len(expected) + 100)) + size := 1 + rng.Intn(len(expected)/4) + + buf := make([]byte, size) + n, err := r.ReadAt(buf, off) + + switch { + case off >= int64(len(expected)): + assert.Equal(t, 0, n, "i=%d off=%d", i, off) + assert.Equal(t, io.EOF, err, "i=%d off=%d", i, off) + case off+int64(size) > int64(len(expected)): + expN := int(int64(len(expected)) - off) + assert.Equal(t, expN, n, "i=%d off=%d", i, off) + assert.Equal(t, io.EOF, err, "i=%d off=%d", i, off) + assert.Equal(t, expected[off:off+int64(expN)], buf[:n], "i=%d off=%d", i, off) + default: + assert.NoError(t, err, "i=%d off=%d size=%d", i, off, size) + assert.Equal(t, size, n, "i=%d off=%d", i, off) + assert.Equal(t, expected[off:off+int64(size)], buf[:n], "i=%d off=%d", i, off) + } + } +} + +func TestReadAtRangeIsBoundedByFrameWindow(t *testing.T) { + // Use predictable per-frame data so the index has well-known + // frame boundaries; assert that reading from inside one frame + // only fetches that frame's compressed span (or a contiguous + // span ending at the next frame's start). + parts := [][]byte{ + bytes.Repeat([]byte("AAAAAAAA"), 4096), + bytes.Repeat([]byte("BBBBBBBB"), 4096), + bytes.Repeat([]byte("CCCCCCCC"), 4096), + } + compressed := encodeFrames(t, parts...) + cra := &countingReaderAt{data: compressed} + r, err := zstdr.NewReader(cra, int64(len(compressed))) + require.NoError(t, err) + t.Cleanup(func() { r.Close() }) + + idx := r.Index() + require.Len(t, idx.Frames, 3) + + // Reset call counter — NewReader scans which uses the + // underlying io.Reader, but our cra is io.ReaderAt; scan goes + // through io.NewSectionReader so it still hits ReadAt. + cra.mu.Lock() + cra.ranges = nil + cra.calls = 0 + cra.totalBytes = 0 + cra.mu.Unlock() + + // Read 8 bytes from inside the second frame. Only the second + // frame's compressed span should be fetched. + off := idx.Frames[1].Out + 16 + buf := make([]byte, 8) + n, err := r.ReadAt(buf, off) + require.NoError(t, err) + assert.Equal(t, 8, n) + + // Exactly one ReadAt should have hit the underlying ra, with + // off == idx.Frames[1].In and length == idx.Frames[2].In - idx.Frames[1].In. + cra.mu.Lock() + defer cra.mu.Unlock() + require.Len(t, cra.ranges, 1) + assert.Equal(t, idx.Frames[1].In, cra.ranges[0].off) + assert.Equal(t, int(idx.Frames[2].In-idx.Frames[1].In), cra.ranges[0].n) +} + +func TestNewReaderWithIndexRoundTrip(t *testing.T) { + parts := [][]byte{ + bytes.Repeat([]byte("xxx"), 1000), + bytes.Repeat([]byte("yyy"), 1000), + } + compressed := encodeFrames(t, parts...) + + // Build via Scan. + var sink bytes.Buffer + idx, err := zstdr.Scan(bytes.NewReader(compressed), &sink) + require.NoError(t, err) + + // Round-trip the index via JSON. + var jsonBuf bytes.Buffer + require.NoError(t, idx.Encode(&jsonBuf)) + idx2, err := zstdr.DecodeIndex(&jsonBuf) + require.NoError(t, err) + + r := zstdr.NewReaderWithIndex(bytes.NewReader(compressed), idx2, int64(len(compressed))) + t.Cleanup(func() { r.Close() }) + + expected := append(append([]byte{}, parts[0]...), parts[1]...) + buf := make([]byte, len(expected)) + n, err := r.ReadAt(buf, 0) + require.NoError(t, err) + assert.Equal(t, len(expected), n) + assert.Equal(t, expected, buf) +} + +func TestNewReaderWithIndexNilFramesPanics(t *testing.T) { + idx := &zstdr.Index{Size: 1} + assert.Panics(t, func() { + zstdr.NewReaderWithIndex(bytes.NewReader(nil), idx, 0) + }) +} + +func TestNewReaderWithIndexZeroSizePanics(t *testing.T) { + idx := &zstdr.Index{Frames: []*zstdr.FrameCheckpoint{}, Size: 0} + assert.Panics(t, func() { + zstdr.NewReaderWithIndex(bytes.NewReader(nil), idx, 0) + }) +} + +func TestCloseReturnsErrClosed(t *testing.T) { + parts := [][]byte{bytes.Repeat([]byte("z"), 1000)} + compressed := encodeFrames(t, parts...) + + r, err := zstdr.NewReader(bytes.NewReader(compressed), int64(len(compressed))) + require.NoError(t, err) + require.NoError(t, r.Close()) + + buf := make([]byte, 4) + n, err := r.ReadAt(buf, 0) + assert.Equal(t, 0, n) + assert.ErrorIs(t, err, zstdr.ErrClosed) +} + +func TestNewReaderRejectsCorrupt(t *testing.T) { + junk := []byte{0xDE, 0xAD, 0xBE, 0xEF, 0xDE, 0xAD, 0xBE, 0xEF} + _, err := zstdr.NewReader(bytes.NewReader(junk), int64(len(junk))) + assert.ErrorIs(t, err, zstdr.ErrInvalidFormat) +} + +func TestConcurrentReadAt(t *testing.T) { + parts := [][]byte{ + bytes.Repeat([]byte("AAAA"), 4096), + bytes.Repeat([]byte("BBBB"), 4096), + bytes.Repeat([]byte("CCCC"), 4096), + bytes.Repeat([]byte("DDDD"), 4096), + } + var expected []byte + for _, p := range parts { + expected = append(expected, p...) + } + compressed := encodeFrames(t, parts...) + + r, err := zstdr.NewReader(bytes.NewReader(compressed), int64(len(compressed)), zstdr.WithMaxReaders(4)) + require.NoError(t, err) + t.Cleanup(func() { r.Close() }) + + const goroutines = 16 + const reads = 50 + var failures int64 + + var wg sync.WaitGroup + wg.Add(goroutines) + for g := 0; g < goroutines; g++ { + go func(seed int) { + defer wg.Done() + rng := rand.New(rand.NewSource(int64(seed))) + for i := 0; i < reads; i++ { + size := 1 + rng.Intn(len(expected)/2) + off := int64(rng.Intn(len(expected) - size + 1)) + buf := make([]byte, size) + n, err := r.ReadAt(buf, off) + if err != nil || n != size || !bytes.Equal(buf, expected[off:off+int64(size)]) { + atomic.AddInt64(&failures, 1) + } + } + }(g) + } + wg.Wait() + assert.Equal(t, int64(0), failures) +} + +func TestReadAtClosedDuringWait(t *testing.T) { + // A maxReaders=1 reader holds its slot while we artificially + // block in another goroutine. + parts := [][]byte{bytes.Repeat([]byte("a"), 1000)} + compressed := encodeFrames(t, parts...) + r, err := zstdr.NewReader(bytes.NewReader(compressed), int64(len(compressed)), zstdr.WithMaxReaders(1)) + require.NoError(t, err) + + hold := &slowReaderAt{ReaderAt: bytes.NewReader(compressed)} + idx := r.Index() + r2 := zstdr.NewReaderWithIndex(hold, idx, int64(len(compressed)), zstdr.WithMaxReaders(1)) + t.Cleanup(func() { r2.Close() }) + + hold.gate = make(chan struct{}) + + // Goroutine 1 holds the only pool slot. + doneG1 := make(chan struct{}) + go func() { + defer close(doneG1) + buf := make([]byte, 16) + _, _ = r2.ReadAt(buf, 0) + }() + + // Wait for goroutine 1 to be inside ReadAt and blocked on the + // gate. + hold.waitEntered() + + // Goroutine 2 will block on the pool slot. + type res struct { + n int + err error + } + doneG2 := make(chan res, 1) + go func() { + buf := make([]byte, 16) + n, err := r2.ReadAt(buf, 0) + doneG2 <- res{n: n, err: err} + }() + + // Close should unblock goroutine 2 with ErrClosed while it waits + // for goroutine 1's in-flight call to drain. Run Close in the + // background: calling it synchronously before releasing hold.gate + // would deadlock by construction. + closeDone := make(chan error, 1) + go func() { + closeDone <- r2.Close() + }() + + got := <-doneG2 + assert.Equal(t, 0, got.n) + assert.ErrorIs(t, got.err, zstdr.ErrClosed) + + close(hold.gate) + require.NoError(t, <-closeDone) + <-doneG1 +} + +// slowReaderAt blocks the first ReadAt call on a gate channel so a +// second goroutine can be observed waiting on the bounded pool. +type slowReaderAt struct { + io.ReaderAt + gate chan struct{} + entered chan struct{} + once sync.Once +} + +func (s *slowReaderAt) ReadAt(p []byte, off int64) (int, error) { + s.once.Do(func() { + if s.entered == nil { + s.entered = make(chan struct{}) + } + close(s.entered) + <-s.gate + }) + return s.ReaderAt.ReadAt(p, off) +} + +func (s *slowReaderAt) waitEntered() { + if s.entered == nil { + s.entered = make(chan struct{}) + } + <-s.entered +} + +func TestScanFrameContentSizeFastPath(t *testing.T) { + // The klauspost zstd writer sets Frame_Content_Size when the + // stream is a single Write call (it knows the size up front). + // We exercise both paths by scanning a file with FCS set + // (single-Write) and one without FCS (multi-Write streaming). + // In both cases the resulting Index must be identical apart + // from compressed-frame layout details we can't fully control. + data := bytes.Repeat([]byte("data-with-fcs-"), 2048) + + // Write with FCS set: single Write before Close. + var withFCS bytes.Buffer + enc, err := zstd.NewWriter(&withFCS) + require.NoError(t, err) + _, err = enc.Write(data) + require.NoError(t, err) + require.NoError(t, enc.Close()) + + // Write without FCS: many small Writes; encoder may not know + // total size at frame-header time. + var withoutFCS bytes.Buffer + enc2, err := zstd.NewWriter(&withoutFCS) + require.NoError(t, err) + const chunk = 17 + for i := 0; i < len(data); i += chunk { + j := i + chunk + if j > len(data) { + j = len(data) + } + _, err := enc2.Write(data[i:j]) + require.NoError(t, err) + } + require.NoError(t, enc2.Close()) + + for _, tc := range []struct { + name string + buf []byte + }{ + {"WithFCS", withFCS.Bytes()}, + {"WithoutFCS", withoutFCS.Bytes()}, + } { + t.Run(tc.name, func(t *testing.T) { + var out bytes.Buffer + idx, err := zstdr.Scan(bytes.NewReader(tc.buf), &out) + require.NoError(t, err) + assert.Equal(t, data, out.Bytes()) + assert.Equal(t, int64(len(data)), idx.Size) + require.GreaterOrEqual(t, len(idx.Frames), 1) + assert.Equal(t, int64(0), idx.Frames[0].In) + assert.Equal(t, int64(0), idx.Frames[0].Out) + }) + } +} + +func TestReadAtRangerErrorPropagated(t *testing.T) { + parts := [][]byte{ + bytes.Repeat([]byte("xx"), 2048), + bytes.Repeat([]byte("yy"), 2048), + } + compressed := encodeFrames(t, parts...) + idx, err := zstdr.Scan(bytes.NewReader(compressed), io.Discard) + require.NoError(t, err) + + sentinel := errors.New("network down") + er := &erroringReaderAt{err: sentinel} + r := zstdr.NewReaderWithIndex(er, idx, int64(len(compressed))) + t.Cleanup(func() { r.Close() }) + + buf := make([]byte, 16) + n, err := r.ReadAt(buf, 0) + assert.Equal(t, 0, n) + assert.ErrorIs(t, err, sentinel) +} + +type erroringReaderAt struct{ err error } + +func (e *erroringReaderAt) ReadAt(p []byte, off int64) (int, error) { + return 0, e.err +} + +// Compile-time assertion: *Reader is an io.ReaderAt. +var _ io.ReaderAt = (*zstdr.Reader)(nil)