diff --git a/examples/cmd/benchmark_experimental.go b/examples/cmd/benchmark_experimental.go new file mode 100644 index 0000000000..f50c5cc348 --- /dev/null +++ b/examples/cmd/benchmark_experimental.go @@ -0,0 +1,109 @@ +//nolint:forbidigo // We use Println here extensively because we are printing markdown. +package cmd + +import ( + "context" + "crypto/rand" + "fmt" + "sync" + "time" + + "connectrpc.com/connect" + "github.com/opentdf/platform/lib/ocrypto" + kasp "github.com/opentdf/platform/protocol/go/kas" + "github.com/opentdf/platform/protocol/go/kas/kasconnect" + "github.com/opentdf/platform/protocol/go/policy" + + "github.com/opentdf/platform/sdk/experimental/tdf" + "github.com/opentdf/platform/sdk/httputil" + "github.com/spf13/cobra" +) + +var ( + payloadSize int + segmentChunk int + testAttr = "https://example.com/attr/attr1/value/value1" +) + +func init() { + benchmarkCmd := &cobra.Command{ + Use: "benchmark-experimental-writer", + Short: "Benchmark experimental TDF writer speed", + Long: `Benchmark the experimental TDF writer with configurable payload size.`, + RunE: runExperimentalWriterBenchmark, + } + //nolint: mnd // no magic number, this is just default value for payload size + benchmarkCmd.Flags().IntVar(&payloadSize, "payload-size", 1024*1024, "Payload size in bytes") // Default 1MB + //nolint: mnd // same as above + benchmarkCmd.Flags().IntVar(&segmentChunk, "segment-chunks", 16*1024, "segment chunks ize") // Default 16 segments + ExamplesCmd.AddCommand(benchmarkCmd) +} + +func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error { + payload := make([]byte, payloadSize) + _, err := rand.Read(payload) + if err != nil { + return fmt.Errorf("failed to generate random payload: %w", err) + } + + http := httputil.SafeHTTPClient() + fmt.Println("endpoint:", platformEndpoint) + serviceClient := kasconnect.NewAccessServiceClient(http, platformEndpoint) + resp, err := serviceClient.PublicKey(context.Background(), connect.NewRequest(&kasp.PublicKeyRequest{Algorithm: string(ocrypto.RSA2048Key)})) + if err != nil { + return fmt.Errorf("failed to get public key from KAS: %w", err) + } + var attrs []*policy.Value + + simpleyKey := &policy.SimpleKasKey{ + KasUri: platformEndpoint, + KasId: "id", + PublicKey: &policy.SimpleKasPublicKey{ + Kid: resp.Msg.GetKid(), + Pem: resp.Msg.GetPublicKey(), + Algorithm: policy.Algorithm_ALGORITHM_RSA_2048, + }, + } + + attrs = append(attrs, &policy.Value{Fqn: testAttr, KasKeys: []*policy.SimpleKasKey{simpleyKey}, Attribute: &policy.Attribute{Namespace: &policy.Namespace{Name: "example.com"}, Fqn: testAttr}}) + writer, err := tdf.NewWriter(context.Background(), tdf.WithDefaultKASForWriter(simpleyKey), tdf.WithInitialAttributes(attrs), tdf.WithSegmentIntegrityAlgorithm(tdf.HS256)) + if err != nil { + return fmt.Errorf("failed to create writer: %w", err) + } + i := 0 + wg := sync.WaitGroup{} + segs := len(payload) / segmentChunk + wg.Add(segs) + start := time.Now() + for i < segs { + segment := i + go func() { + start := i * segmentChunk + end := min(start+segmentChunk, len(payload)) + _, err = writer.WriteSegment(context.Background(), segment, payload[start:end]) + if err != nil { + fmt.Println(err) + panic(err) + } + wg.Done() + }() + i++ + } + wg.Wait() + + end := time.Now() + result, err := writer.Finalize(context.Background()) + if err != nil { + return fmt.Errorf("failed to finalize writer: %w", err) + } + totalTime := end.Sub(start) + + fmt.Printf("# Benchmark Experimental TDF Writer Results:\n") + fmt.Printf("| Metric | Value |\n") + fmt.Printf("|--------------------|--------------|\n") + fmt.Printf("| Payload Size (B) | %d |\n", payloadSize) + fmt.Printf("| Output Size (B) | %d |\n", len(result.Data)) + fmt.Printf("| Total Time | %s |\n", totalTime) + + return nil +} diff --git a/examples/go.mod b/examples/go.mod index dd48e2eaef..0adaedfae2 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -5,9 +5,10 @@ go 1.24.0 toolchain go1.24.9 require ( + connectrpc.com/connect v1.18.1 github.com/opentdf/platform/lib/ocrypto v0.7.0 github.com/opentdf/platform/protocol/go v0.13.0 - github.com/opentdf/platform/sdk v0.7.0 + github.com/opentdf/platform/sdk v0.10.1 github.com/spf13/cobra v1.9.1 github.com/stretchr/testify v1.10.0 google.golang.org/grpc v1.73.0 @@ -16,7 +17,6 @@ require ( require ( buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.6-20250603165357-b52ab10f4468.1 // indirect - connectrpc.com/connect v1.18.1 // indirect github.com/Masterminds/semver/v3 v3.3.1 // indirect github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect diff --git a/examples/go.sum b/examples/go.sum index 3b4cd41af2..b05ea7e203 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -120,8 +120,8 @@ github.com/opentdf/platform/lib/ocrypto v0.7.0 h1:uBZJXisuXU3V8681aP8FVMJkyWBrwW github.com/opentdf/platform/lib/ocrypto v0.7.0/go.mod h1:sYhoBL1bQYgQVSSNpxU13RsrE5JAk8BABT1hfr9L3j8= github.com/opentdf/platform/protocol/go v0.13.0 h1:vrOOHyhYDPzJgNenz/1g0M5nWtkOYKkPggMNHKzeMcs= github.com/opentdf/platform/protocol/go v0.13.0/go.mod h1:GRycoDGDxaz91sOvGZFWVEKJLluZFg2wM3NJmhucDHo= -github.com/opentdf/platform/sdk v0.7.0 h1:8hczDycXGY1ucdIXSrP17oW/Eyu3vsb4LEX4hc7tvVY= -github.com/opentdf/platform/sdk v0.7.0/go.mod h1:CTJR1NXeYe896M1/VN0h+1Ff54SdBtxv4z18BGTi8yk= +github.com/opentdf/platform/sdk v0.10.1 h1:kBrTK48xle7mdGc+atlr4kDh94f6kVj+0OB76K8rozI= +github.com/opentdf/platform/sdk v0.10.1/go.mod h1:+yaTi/c/GWHZPPmO27sq2s7Tcb2P/USkK8LuW1krhI8= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/lib/ocrypto/aes_gcm.go b/lib/ocrypto/aes_gcm.go index 77aaf647ea..6866587050 100644 --- a/lib/ocrypto/aes_gcm.go +++ b/lib/ocrypto/aes_gcm.go @@ -47,6 +47,21 @@ func (aesGcm AesGcm) Encrypt(data []byte) ([]byte, error) { return cipherText, nil } +func (aesGcm AesGcm) EncryptInPlace(data []byte) ([]byte, []byte, error) { + nonce, err := RandomBytes(GcmStandardNonceSize) + if err != nil { + return nil, nil, err + } + + gcm, err := cipher.NewGCMWithNonceSize(aesGcm.block, GcmStandardNonceSize) + if err != nil { + return nil, nil, fmt.Errorf("cipher.NewGCMWithNonceSize failed: %w", err) + } + + cipherText := gcm.Seal(data[:0], nonce, data, nil) + return cipherText, nonce, nil +} + // EncryptWithIV encrypts data with symmetric key. // NOTE: This method use default auth tag as aes block size(16 bytes) // and expects iv of 16 bytes. diff --git a/sdk/experimental/tdf/writer.go b/sdk/experimental/tdf/writer.go index ed7b12d97e..02d2f8af53 100644 --- a/sdk/experimental/tdf/writer.go +++ b/sdk/experimental/tdf/writer.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "hash/crc32" + "io" "log/slog" "sort" "sync" @@ -34,12 +35,11 @@ const ( // SegmentResult contains the result of writing a segment type SegmentResult struct { - Data []byte `json:"data"` // Encrypted segment bytes (for streaming) - Index int `json:"index"` // Segment index - Hash string `json:"hash"` // Base64-encoded integrity hash - PlaintextSize int64 `json:"plaintextSize"` // Original data size - EncryptedSize int64 `json:"encryptedSize"` // Encrypted data size - CRC32 uint32 `json:"crc32"` // CRC32 checksum + TDFData io.Reader // Reader for the full TDF segment (nonce + encrypted data + zip structures) + Index int `json:"index"` // Segment index + Hash string `json:"hash"` // Base64-encoded integrity hash + PlaintextSize int64 `json:"plaintextSize"` // Original data size + EncryptedSize int64 `json:"encryptedSize"` // Encrypted data size } // FinalizeResult contains the complete TDF creation result @@ -110,7 +110,7 @@ type Writer struct { // segments stores segment metadata using sparse map for memory efficiency // Maps segment index to Segment metadata (hash, size information) - segments map[int]Segment + segments map[int]*Segment // maxSegmentIndex tracks the highest segment index written maxSegmentIndex int @@ -183,7 +183,7 @@ func NewWriter(_ context.Context, opts ...Option[*WriterConfig]) (*Writer, error WriterConfig: *config, archiveWriter: archiveWriter, dek: dek, - segments: make(map[int]Segment), // Initialize sparse storage + segments: make(map[int]*Segment), // Initialize sparse storage block: block, initialAttributes: config.initialAttributes, initialDefaultKAS: config.initialDefaultKAS, @@ -232,58 +232,76 @@ func NewWriter(_ context.Context, opts ...Option[*WriterConfig]) (*Writer, error // uploadToS3(segment1, "part-001") func (w *Writer) WriteSegment(ctx context.Context, index int, data []byte) (*SegmentResult, error) { w.mutex.Lock() - defer w.mutex.Unlock() if w.finalized { + w.mutex.Unlock() return nil, ErrAlreadyFinalized } if index < 0 { + w.mutex.Unlock() return nil, ErrInvalidSegmentIndex } // Check for duplicate segments using map lookup if _, exists := w.segments[index]; exists { + w.mutex.Unlock() return nil, ErrSegmentAlreadyWritten } if index > w.maxSegmentIndex { w.maxSegmentIndex = index } + seg := &Segment{ + Size: -1, // indicates not filled yet + } + w.segments[index] = seg - // Calculate CRC32 before encryption for integrity tracking - crc32Checksum := crc32.ChecksumIEEE(data) + w.mutex.Unlock() // Encrypt directly without unnecessary copying - the archive layer will handle copying if needed - segmentCipher, err := w.block.Encrypt(data) + segmentCipher, nonce, err := w.block.EncryptInPlace(data) if err != nil { return nil, err } - segmentSig, err := calculateSignature(segmentCipher, w.dek, w.segmentIntegrityAlgorithm, false) // Don't ever hex encode new tdf's if err != nil { return nil, err } segmentHash := string(ocrypto.Base64Encode([]byte(segmentSig))) - w.segments[index] = Segment{ - Hash: segmentHash, - Size: int64(len(data)), // Use original data length - EncryptedSize: int64(len(segmentCipher)), - } + w.mutex.Lock() + seg.Size = int64(len(data)) + seg.EncryptedSize = int64(len(segmentCipher)) + int64(len(nonce)) + seg.Hash = segmentHash + w.mutex.Unlock() - zipBytes, err := w.archiveWriter.WriteSegment(ctx, index, segmentCipher) + crc := crc32.NewIEEE() + _, err = crc.Write(nonce) + if err != nil { + return nil, err + } + _, err = crc.Write(segmentCipher) + if err != nil { + return nil, err + } + header, err := w.archiveWriter.WriteSegment(ctx, index, uint64(seg.EncryptedSize), crc.Sum32()) if err != nil { return nil, err } + var reader io.Reader + if len(header) == 0 { + reader = io.MultiReader(bytes.NewReader(nonce), bytes.NewReader(segmentCipher)) + } else { + reader = io.MultiReader(bytes.NewReader(header), bytes.NewReader(nonce), bytes.NewReader(segmentCipher)) + } return &SegmentResult{ - Data: zipBytes, + TDFData: reader, Index: index, - Hash: segmentHash, - PlaintextSize: int64(len(data)), - EncryptedSize: int64(len(segmentCipher)), - CRC32: crc32Checksum, + Hash: seg.Hash, + PlaintextSize: seg.Size, + EncryptedSize: seg.EncryptedSize, }, nil } @@ -505,7 +523,7 @@ func (w *Writer) getManifest(ctx context.Context, cfg *WriterFinalizeConfig) (*M // Copy segments to manifest in finalize order (pack densely) for i, idx := range order { if segment, exists := w.segments[idx]; exists { - encryptInfo.Segments[i] = segment + encryptInfo.Segments[i] = *segment } } @@ -524,7 +542,8 @@ func (w *Writer) getManifest(ctx context.Context, cfg *WriterFinalizeConfig) (*M var totalPlaintextSize, totalEncryptedSize int64 for _, i := range order { segment, exists := w.segments[i] - if !exists { + // if size is negative, segment was not written, finalized has been called too early + if !exists || w.segments[i].Size < 0 { return nil, 0, 0, fmt.Errorf("segment %d not written; cannot finalize", i) } if segment.Hash != "" { diff --git a/sdk/experimental/tdf/writer_test.go b/sdk/experimental/tdf/writer_test.go index a4d8251af5..296c582041 100644 --- a/sdk/experimental/tdf/writer_test.go +++ b/sdk/experimental/tdf/writer_test.go @@ -669,7 +669,7 @@ func testErrorConditions(t *testing.T) { require.NoError(t, err) // Manually corrupt segment hash to test error handling - writer.segments[0] = Segment{Hash: "", Size: 10, EncryptedSize: 26} + writer.segments[0] = &Segment{Hash: "", Size: 10, EncryptedSize: 26} writer.maxSegmentIndex = 0 attributes := []*policy.Value{ diff --git a/sdk/internal/zipstream/benchmark_test.go b/sdk/internal/zipstream/benchmark_test.go index 4ef7e38d83..24263c1d3d 100644 --- a/sdk/internal/zipstream/benchmark_test.go +++ b/sdk/internal/zipstream/benchmark_test.go @@ -3,6 +3,7 @@ package zipstream import ( + "hash/crc32" "testing" ) @@ -55,7 +56,8 @@ func BenchmarkSegmentWriter_CRC32ContiguousProcessing(b *testing.B) { // Write segments in specified order for _, segIdx := range writeOrder { - _, err := writer.WriteSegment(ctx, segIdx, segmentData) + crc := crc32.ChecksumIEEE(segmentData) + _, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc) if err != nil { b.Fatal(err) } @@ -99,7 +101,8 @@ func BenchmarkSegmentWriter_SparseIndices(b *testing.B) { // Write sparse indices in order for k := 0; k < n; k++ { idx := k * stride - if _, err := w.WriteSegment(ctx, idx, data); err != nil { + crc := crc32.ChecksumIEEE(data) + if _, err := w.WriteSegment(ctx, idx, uint64(len(data)), crc); err != nil { b.Fatal(err) } } @@ -153,7 +156,8 @@ func BenchmarkSegmentWriter_VariableSegmentSizes(b *testing.B) { segmentData[j] = byte((segIdx * j) % 256) } - _, err := writer.WriteSegment(ctx, segIdx, segmentData) + crc := crc32.ChecksumIEEE(segmentData) + _, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc) if err != nil { b.Fatal(err) } @@ -234,7 +238,8 @@ func BenchmarkSegmentWriter_MemoryPressure(b *testing.B) { segmentData[j] = byte((orderIdx * j) % 256) } - _, err := writer.WriteSegment(ctx, segIdx, segmentData) + crc := crc32.ChecksumIEEE(segmentData) + _, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc) if err != nil { b.Fatal(err) } @@ -305,7 +310,8 @@ func BenchmarkSegmentWriter_ZIPGeneration(b *testing.B) { // Write all segments for segIdx := 0; segIdx < tc.segmentCount; segIdx++ { - _, err := writer.WriteSegment(ctx, segIdx, segmentData) + crc := crc32.ChecksumIEEE(segmentData) + _, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc) if err != nil { b.Fatal(err) } diff --git a/sdk/internal/zipstream/segment_writer.go b/sdk/internal/zipstream/segment_writer.go index 9f78d0ed89..2c2500be8d 100644 --- a/sdk/internal/zipstream/segment_writer.go +++ b/sdk/internal/zipstream/segment_writer.go @@ -49,7 +49,7 @@ func NewSegmentTDFWriter(expectedSegments int, opts ...Option) SegmentWriter { } // WriteSegment writes a segment with deterministic output based on segment index -func (sw *segmentWriter) WriteSegment(ctx context.Context, index int, data []byte) ([]byte, error) { +func (sw *segmentWriter) WriteSegment(ctx context.Context, index int, size uint64, crc32 uint32) ([]byte, error) { sw.mu.Lock() defer sw.mu.Unlock() @@ -79,9 +79,7 @@ func (sw *segmentWriter) WriteSegment(ctx context.Context, index int, data []byt default: } - // CRC32 over stored segment bytes (what goes into the ZIP entry) - originalCRC := crc32.ChecksumIEEE(data) - originalSize := uint64(len(data)) + originalSize := size // Create segment buffer for this segment's output buffer := &bytes.Buffer{} @@ -94,20 +92,15 @@ func (sw *segmentWriter) WriteSegment(ctx context.Context, index int, data []byt } } - // All segments: write the encrypted data - if _, err := buffer.Write(data); err != nil { - return nil, &Error{Op: "write-segment", Type: "segment", Err: err} - } - // Record segment metadata only (no payload retention). Payload bytes are returned // to the caller and may be uploaded; we keep only CRC and size for finalize. - if err := sw.metadata.AddSegment(index, data, originalSize, originalCRC); err != nil { + if err := sw.metadata.AddSegment(index, size, crc32); err != nil { return nil, &Error{Op: "write-segment", Type: "segment", Err: err} } // Update payload entry metadata sw.payloadEntry.Size += originalSize - sw.payloadEntry.CompressedSize += uint64(len(data)) // Encrypted size + sw.payloadEntry.CompressedSize += originalSize // Encrypted size // Return the bytes for this segment return buffer.Bytes(), nil diff --git a/sdk/internal/zipstream/segment_writer_test.go b/sdk/internal/zipstream/segment_writer_test.go index 130af5aa21..8636778fc5 100644 --- a/sdk/internal/zipstream/segment_writer_test.go +++ b/sdk/internal/zipstream/segment_writer_test.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "fmt" + "hash/crc32" "io" "testing" @@ -32,21 +33,22 @@ func TestSegmentWriter_SequentialOrder(t *testing.T) { // Write segments in sequential order for i, data := range testSegments { - segmentBytes, err := writer.WriteSegment(ctx, i, data) + crc := crc32.ChecksumIEEE(data) + segmentBytes, err := writer.WriteSegment(ctx, i, uint64(len(data)), crc) require.NoError(t, err, "Failed to write segment %d", i) - assert.NotEmpty(t, segmentBytes, "Segment %d should have bytes", i) t.Logf("Sequential segment %d: %d bytes", i, len(segmentBytes)) if i == 0 { // Segment 0 should be larger due to ZIP header - assert.Greater(t, len(segmentBytes), len(data), "Segment 0 should include ZIP header") + assert.NotEmpty(t, segmentBytes, "Segment 0 should include ZIP header") } else { // Other segments should be approximately the size of the data - assert.Len(t, data, len(segmentBytes), "Segment %d should be raw data", i) + assert.Empty(t, segmentBytes, "Segment %d should have no zip bytes", i) } allBytes = append(allBytes, segmentBytes...) + allBytes = append(allBytes, data...) } t.Logf("Sequential total payload bytes before finalization: %d", len(allBytes)) @@ -102,16 +104,21 @@ func TestSegmentWriter_OutOfOrder(t *testing.T) { for _, index := range writeOrder { data := testSegments[index] - bytes, err := writer.WriteSegment(ctx, index, data) + crc := crc32.ChecksumIEEE(data) + bytes, err := writer.WriteSegment(ctx, index, uint64(len(data)), crc) require.NoError(t, err, "Failed to write segment %d out of order", index) - assert.NotEmpty(t, bytes, "Segment %d should have bytes", index) if index == 0 { // Segment 0 should always include ZIP header, regardless of write order - assert.Greater(t, len(bytes), len(data), "Segment 0 should include ZIP header even when written out of order") + assert.NotEmpty(t, bytes, "Segment 0 should include ZIP header") + } else { + assert.Empty(t, bytes, "Segment %d should have no zip bytes", index) } - segmentBytes[index] = bytes + var allBytes []byte + allBytes = append(allBytes, bytes...) + allBytes = append(allBytes, data...) + segmentBytes[index] = allBytes } // Reassemble in logical order (as S3 would do) @@ -180,15 +187,18 @@ func TestSegmentWriter_SparseIndices_InOrder(t *testing.T) { segmentBytes := make(map[int][]byte) for _, index := range order { data := testSegments[index] - bytes, err := writer.WriteSegment(ctx, index, data) + crc := crc32.ChecksumIEEE(data) + bytes, err := writer.WriteSegment(ctx, index, uint64(len(data)), crc) require.NoError(t, err, "write segment %d failed", index) - assert.NotEmpty(t, bytes, "segment %d should yield bytes", index) if index == 0 { - assert.Greater(t, len(bytes), len(data), "segment 0 should include ZIP header") + assert.NotEmpty(t, bytes, "segment 0 should include ZIP header") } else { - assert.Len(t, bytes, len(data), "non-zero segments are raw payload bytes") + assert.Empty(t, bytes, "segment %d should have no zip bytes", index) } - segmentBytes[index] = bytes + var totalBytes []byte + totalBytes = append(totalBytes, bytes...) + totalBytes = append(totalBytes, data...) + segmentBytes[index] = totalBytes } // Assemble full file: concatenate segment bytes in ascending index order @@ -244,10 +254,13 @@ func TestSegmentWriter_SparseIndices_OutOfOrder(t *testing.T) { segmentBytes := make(map[int][]byte) for _, index := range writeOrder { data := testSegments[index] - bytes, err := writer.WriteSegment(ctx, index, data) + crc := crc32.ChecksumIEEE(data) + bytes, err := writer.WriteSegment(ctx, index, uint64(len(data)), crc) require.NoError(t, err, "write segment %d failed", index) - assert.NotEmpty(t, bytes, "segment %d should yield bytes", index) - segmentBytes[index] = bytes + var totalBytes []byte + totalBytes = append(totalBytes, bytes...) + totalBytes = append(totalBytes, data...) + segmentBytes[index] = totalBytes } // Assemble full file in final (ascending) order regardless of write order @@ -285,10 +298,10 @@ func TestSegmentWriter_DuplicateSegments(t *testing.T) { ctx := t.Context() // Write segment 1 twice - _, err := writer.WriteSegment(ctx, 1, []byte("first")) + _, err := writer.WriteSegment(ctx, 1, 10, crc32.ChecksumIEEE([]byte("first write"))) require.NoError(t, err, "First write of segment 1 should succeed") - _, err = writer.WriteSegment(ctx, 1, []byte("duplicate")) + _, err = writer.WriteSegment(ctx, 1, 10, crc32.ChecksumIEEE([]byte("second write"))) require.Error(t, err, "Duplicate segment should fail") assert.Contains(t, err.Error(), "duplicate", "Error should mention duplicate") @@ -310,7 +323,7 @@ func TestSegmentWriter_InvalidSegmentIndex(t *testing.T) { for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - _, err := writer.WriteSegment(ctx, tc.index, []byte("test")) + _, err := writer.WriteSegment(ctx, tc.index, 10, crc32.ChecksumIEEE([]byte("test"))) require.Error(t, err, "Negative segment index should fail") assert.Contains(t, err.Error(), "invalid", "Error should mention invalid") }) @@ -318,7 +331,7 @@ func TestSegmentWriter_InvalidSegmentIndex(t *testing.T) { // Test that large indices are actually allowed (dynamic expansion) t.Run("large_index_allowed", func(t *testing.T) { - _, err := writer.WriteSegment(ctx, 100, []byte("test")) + _, err := writer.WriteSegment(ctx, 100, 10, crc32.ChecksumIEEE([]byte("large index"))) require.NoError(t, err, "Large segment index should be allowed for dynamic expansion") }) @@ -331,13 +344,13 @@ func TestSegmentWriter_AllowsGapsOnFinalize(t *testing.T) { ctx := t.Context() // Write only segments 0, 1, 3 (2 is missing) - _, err := writer.WriteSegment(ctx, 0, []byte("first")) + _, err := writer.WriteSegment(ctx, 0, 5, crc32.ChecksumIEEE([]byte("first"))) require.NoError(t, err) - _, err = writer.WriteSegment(ctx, 1, []byte("second")) + _, err = writer.WriteSegment(ctx, 1, 6, crc32.ChecksumIEEE([]byte("second"))) require.NoError(t, err) - _, err = writer.WriteSegment(ctx, 3, []byte("fourth")) + _, err = writer.WriteSegment(ctx, 3, 5, crc32.ChecksumIEEE([]byte("fourth"))) require.NoError(t, err) // Finalize should succeed (auto-dense behavior) @@ -356,7 +369,7 @@ func TestSegmentWriter_CleanupSegment(t *testing.T) { testData := []byte("test data for cleanup") // Write a segment - _, err := writer.WriteSegment(ctx, 1, testData) + _, err := writer.WriteSegment(ctx, 1, uint64(len(testData)), crc32.ChecksumIEEE(testData)) require.NoError(t, err) // Verify segment exists before cleanup @@ -385,7 +398,7 @@ func TestSegmentWriter_ContextCancellation(t *testing.T) { cancel() // Try to write segment with cancelled context - _, err := writer.WriteSegment(ctx, 0, []byte("test")) + _, err := writer.WriteSegment(ctx, 0, 10, crc32.ChecksumIEEE([]byte("data"))) require.Error(t, err, "Should fail with cancelled context") assert.Contains(t, err.Error(), "context", "Error should mention context") @@ -408,7 +421,7 @@ func TestSegmentWriter_LargeNumberOfSegments(t *testing.T) { // Write all segments in reverse order for i := segmentCount - 1; i >= 0; i-- { - bytes, err := writer.WriteSegment(ctx, i, testSegments[i]) + bytes, err := writer.WriteSegment(ctx, i, uint64(len(testSegments[i])), crc32.ChecksumIEEE(testSegments[i])) require.NoError(t, err, "Failed to write segment %d", i) // Store in logical order for final assembly @@ -439,13 +452,13 @@ func TestSegmentWriter_EmptySegments(t *testing.T) { ctx := t.Context() // Write segments with empty data - _, err := writer.WriteSegment(ctx, 0, []byte("")) + _, err := writer.WriteSegment(ctx, 0, 0, 0) require.NoError(t, err, "Should handle empty segment 0") - _, err = writer.WriteSegment(ctx, 1, []byte("non-empty")) + _, err = writer.WriteSegment(ctx, 1, 10, crc32.ChecksumIEEE([]byte("not empty"))) require.NoError(t, err, "Should handle non-empty segment") - _, err = writer.WriteSegment(ctx, 2, []byte("")) + _, err = writer.WriteSegment(ctx, 2, 0, 0) require.NoError(t, err, "Should handle empty segment 2") // Finalize @@ -500,7 +513,9 @@ func benchmarkSegmentWriter(b *testing.B, name string, writeOrder []int) { // Write segments in specified order for _, index := range writeOrder { - _, err := writer.WriteSegment(ctx, index, testSegments[index]) + data := testSegments[index] + crc := crc32.ChecksumIEEE(data) + _, err := writer.WriteSegment(ctx, index, uint64(len(data)), crc) if err != nil { b.Fatal(err) } diff --git a/sdk/internal/zipstream/writer.go b/sdk/internal/zipstream/writer.go index 6a2f5c8106..bebd50fcc8 100644 --- a/sdk/internal/zipstream/writer.go +++ b/sdk/internal/zipstream/writer.go @@ -22,7 +22,7 @@ type Writer interface { // SegmentWriter handles out-of-order segments with deterministic output type SegmentWriter interface { Writer - WriteSegment(ctx context.Context, index int, data []byte) ([]byte, error) + WriteSegment(ctx context.Context, index int, size uint64, crc32 uint32) ([]byte, error) Finalize(ctx context.Context, manifest []byte) ([]byte, error) // CleanupSegment removes the presence marker for a segment index. // Calling this before Finalize will cause IsComplete() to fail for that index. diff --git a/sdk/internal/zipstream/zip64_mode_test.go b/sdk/internal/zipstream/zip64_mode_test.go index 8c57613a50..b986c35846 100644 --- a/sdk/internal/zipstream/zip64_mode_test.go +++ b/sdk/internal/zipstream/zip64_mode_test.go @@ -5,6 +5,7 @@ package zipstream import ( "archive/zip" "bytes" + "hash/crc32" "io" "testing" ) @@ -24,16 +25,18 @@ func TestZip64Mode_Auto_Small_UsesZip32(t *testing.T) { w := NewSegmentTDFWriter(2, WithZip64Mode(Zip64Auto)) var parts [][]byte - p0, err := w.WriteSegment(t.Context(), 0, []byte("hello ")) + p0, err := w.WriteSegment(t.Context(), 0, 5, crc32.ChecksumIEEE([]byte("hello"))) if err != nil { t.Fatal(err) } parts = append(parts, p0) - p1, err := w.WriteSegment(t.Context(), 1, []byte("world")) + parts = append(parts, []byte("hello")) + p1, err := w.WriteSegment(t.Context(), 1, 5, crc32.ChecksumIEEE([]byte("world"))) if err != nil { t.Fatal(err) } parts = append(parts, p1) + parts = append(parts, []byte("world")) fin, err := w.Finalize(t.Context(), []byte(`{"m":1}`)) if err != nil { @@ -75,7 +78,7 @@ func TestZip64Mode_Auto_Small_UsesZip32(t *testing.T) { func TestZip64Mode_Always_Small_UsesZip64(t *testing.T) { w := NewSegmentTDFWriter(1, WithZip64Mode(Zip64Always)) - seg, err := w.WriteSegment(t.Context(), 0, []byte("data")) + seg, err := w.WriteSegment(t.Context(), 0, 4, crc32.ChecksumIEEE([]byte("data"))) if err != nil { t.Fatal(err) } @@ -85,7 +88,7 @@ func TestZip64Mode_Always_Small_UsesZip64(t *testing.T) { } w.Close() - data := buildZip(t, [][]byte{seg}, fin) + data := buildZip(t, [][]byte{seg, []byte("data")}, fin) // Basic open check; many readers accept ZIP64 regardless of size. if _, err := zip.NewReader(bytes.NewReader(data), int64(len(data))); err != nil { t.Fatalf("zip open failed (zip64 always): %v", err) @@ -101,7 +104,7 @@ func TestZip64Mode_Never_Overflow_Fails(t *testing.T) { t.Fatal("writer type assertion failed") } // Write minimal segment to initialize structures - if _, err := w.WriteSegment(t.Context(), 0, []byte("x")); err != nil { + if _, err := w.WriteSegment(t.Context(), 0, 1, crc32.ChecksumIEEE([]byte("x"))); err != nil { t.Fatal(err) } sw.payloadEntry.Size = uint64(^uint32(0)) + 1 // exceed 32-bit diff --git a/sdk/internal/zipstream/zip_primitives.go b/sdk/internal/zipstream/zip_primitives.go index 148cf35618..b695889625 100644 --- a/sdk/internal/zipstream/zip_primitives.go +++ b/sdk/internal/zipstream/zip_primitives.go @@ -55,7 +55,7 @@ func NewSegmentMetadata(expectedCount int) *SegmentMetadata { } // AddSegment records metadata for a segment (size + CRC) without retaining payload bytes. -func (sm *SegmentMetadata) AddSegment(index int, _ []byte, originalSize uint64, originalCRC32 uint32) error { +func (sm *SegmentMetadata) AddSegment(index int, originalSize uint64, originalCRC32 uint32) error { if index < 0 { return ErrInvalidSegment }