Skip to content
Merged
109 changes: 109 additions & 0 deletions examples/cmd/benchmark_experimental.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//nolint:forbidigo // We use Println here extensively because we are printing markdown.
package cmd

import (
"context"
"crypto/rand"
"fmt"
"sync"
"time"

"connectrpc.com/connect"
"github.com/opentdf/platform/lib/ocrypto"
kasp "github.com/opentdf/platform/protocol/go/kas"
"github.com/opentdf/platform/protocol/go/kas/kasconnect"
"github.com/opentdf/platform/protocol/go/policy"

"github.com/opentdf/platform/sdk/experimental/tdf"
"github.com/opentdf/platform/sdk/httputil"
"github.com/spf13/cobra"
)

var (
payloadSize int
segmentChunk int
testAttr = "https://example.com/attr/attr1/value/value1"
)

func init() {
benchmarkCmd := &cobra.Command{
Use: "benchmark-experimental-writer",
Short: "Benchmark experimental TDF writer speed",
Long: `Benchmark the experimental TDF writer with configurable payload size.`,
RunE: runExperimentalWriterBenchmark,
}
//nolint: mnd // no magic number, this is just default value for payload size
benchmarkCmd.Flags().IntVar(&payloadSize, "payload-size", 1024*1024, "Payload size in bytes") // Default 1MB
//nolint: mnd // same as above
benchmarkCmd.Flags().IntVar(&segmentChunk, "segment-chunks", 16*1024, "segment chunks ize") // Default 16 segments
ExamplesCmd.AddCommand(benchmarkCmd)
}

func runExperimentalWriterBenchmark(_ *cobra.Command, _ []string) error {
payload := make([]byte, payloadSize)
_, err := rand.Read(payload)
if err != nil {
return fmt.Errorf("failed to generate random payload: %w", err)
}

http := httputil.SafeHTTPClient()
fmt.Println("endpoint:", platformEndpoint)
serviceClient := kasconnect.NewAccessServiceClient(http, platformEndpoint)
resp, err := serviceClient.PublicKey(context.Background(), connect.NewRequest(&kasp.PublicKeyRequest{Algorithm: string(ocrypto.RSA2048Key)}))
if err != nil {
return fmt.Errorf("failed to get public key from KAS: %w", err)
}
var attrs []*policy.Value

simpleyKey := &policy.SimpleKasKey{
KasUri: platformEndpoint,
KasId: "id",
PublicKey: &policy.SimpleKasPublicKey{
Kid: resp.Msg.GetKid(),
Pem: resp.Msg.GetPublicKey(),
Algorithm: policy.Algorithm_ALGORITHM_RSA_2048,
},
}

attrs = append(attrs, &policy.Value{Fqn: testAttr, KasKeys: []*policy.SimpleKasKey{simpleyKey}, Attribute: &policy.Attribute{Namespace: &policy.Namespace{Name: "example.com"}, Fqn: testAttr}})
writer, err := tdf.NewWriter(context.Background(), tdf.WithDefaultKASForWriter(simpleyKey), tdf.WithInitialAttributes(attrs), tdf.WithSegmentIntegrityAlgorithm(tdf.HS256))
if err != nil {
return fmt.Errorf("failed to create writer: %w", err)
}
i := 0
wg := sync.WaitGroup{}
segs := len(payload) / segmentChunk
wg.Add(segs)
start := time.Now()
for i < segs {
segment := i
go func() {
start := i * segmentChunk
end := min(start+segmentChunk, len(payload))
_, err = writer.WriteSegment(context.Background(), segment, payload[start:end])
if err != nil {
fmt.Println(err)
panic(err)
}
wg.Done()
}()
i++
}
wg.Wait()

end := time.Now()
result, err := writer.Finalize(context.Background())
if err != nil {
return fmt.Errorf("failed to finalize writer: %w", err)
}
totalTime := end.Sub(start)

fmt.Printf("# Benchmark Experimental TDF Writer Results:\n")
fmt.Printf("| Metric | Value |\n")
fmt.Printf("|--------------------|--------------|\n")
fmt.Printf("| Payload Size (B) | %d |\n", payloadSize)
fmt.Printf("| Output Size (B) | %d |\n", len(result.Data))
fmt.Printf("| Total Time | %s |\n", totalTime)

return nil
}
4 changes: 2 additions & 2 deletions examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@ go 1.24.0
toolchain go1.24.9

require (
connectrpc.com/connect v1.18.1
github.com/opentdf/platform/lib/ocrypto v0.7.0
github.com/opentdf/platform/protocol/go v0.13.0
github.com/opentdf/platform/sdk v0.7.0
github.com/opentdf/platform/sdk v0.10.1
github.com/spf13/cobra v1.9.1
github.com/stretchr/testify v1.10.0
google.golang.org/grpc v1.73.0
Expand All @@ -16,7 +17,6 @@ require (

require (
buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go v1.36.6-20250603165357-b52ab10f4468.1 // indirect
connectrpc.com/connect v1.18.1 // indirect
github.com/Masterminds/semver/v3 v3.3.1 // indirect
github.com/cpuguy83/go-md2man/v2 v2.0.6 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
Expand Down
4 changes: 2 additions & 2 deletions examples/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ github.com/opentdf/platform/lib/ocrypto v0.7.0 h1:uBZJXisuXU3V8681aP8FVMJkyWBrwW
github.com/opentdf/platform/lib/ocrypto v0.7.0/go.mod h1:sYhoBL1bQYgQVSSNpxU13RsrE5JAk8BABT1hfr9L3j8=
github.com/opentdf/platform/protocol/go v0.13.0 h1:vrOOHyhYDPzJgNenz/1g0M5nWtkOYKkPggMNHKzeMcs=
github.com/opentdf/platform/protocol/go v0.13.0/go.mod h1:GRycoDGDxaz91sOvGZFWVEKJLluZFg2wM3NJmhucDHo=
github.com/opentdf/platform/sdk v0.7.0 h1:8hczDycXGY1ucdIXSrP17oW/Eyu3vsb4LEX4hc7tvVY=
github.com/opentdf/platform/sdk v0.7.0/go.mod h1:CTJR1NXeYe896M1/VN0h+1Ff54SdBtxv4z18BGTi8yk=
github.com/opentdf/platform/sdk v0.10.1 h1:kBrTK48xle7mdGc+atlr4kDh94f6kVj+0OB76K8rozI=
github.com/opentdf/platform/sdk v0.10.1/go.mod h1:+yaTi/c/GWHZPPmO27sq2s7Tcb2P/USkK8LuW1krhI8=
github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs=
github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
Expand Down
15 changes: 15 additions & 0 deletions lib/ocrypto/aes_gcm.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ func (aesGcm AesGcm) Encrypt(data []byte) ([]byte, error) {
return cipherText, nil
}

func (aesGcm AesGcm) EncryptInPlace(data []byte) ([]byte, []byte, error) {
nonce, err := RandomBytes(GcmStandardNonceSize)
if err != nil {
return nil, nil, err
}

gcm, err := cipher.NewGCMWithNonceSize(aesGcm.block, GcmStandardNonceSize)
if err != nil {
return nil, nil, fmt.Errorf("cipher.NewGCMWithNonceSize failed: %w", err)
}

cipherText := gcm.Seal(data[:0], nonce, data, nil)
return cipherText, nonce, nil
}

// EncryptWithIV encrypts data with symmetric key.
// NOTE: This method use default auth tag as aes block size(16 bytes)
// and expects iv of 16 bytes.
Expand Down
71 changes: 45 additions & 26 deletions sdk/experimental/tdf/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"errors"
"fmt"
"hash/crc32"
"io"
"log/slog"
"sort"
"sync"
Expand All @@ -34,12 +35,11 @@ const (

// SegmentResult contains the result of writing a segment
type SegmentResult struct {
Data []byte `json:"data"` // Encrypted segment bytes (for streaming)
Index int `json:"index"` // Segment index
Hash string `json:"hash"` // Base64-encoded integrity hash
PlaintextSize int64 `json:"plaintextSize"` // Original data size
EncryptedSize int64 `json:"encryptedSize"` // Encrypted data size
CRC32 uint32 `json:"crc32"` // CRC32 checksum
TDFData io.Reader // Reader for the full TDF segment (nonce + encrypted data + zip structures)
Index int `json:"index"` // Segment index
Hash string `json:"hash"` // Base64-encoded integrity hash
PlaintextSize int64 `json:"plaintextSize"` // Original data size
EncryptedSize int64 `json:"encryptedSize"` // Encrypted data size
}

// FinalizeResult contains the complete TDF creation result
Expand Down Expand Up @@ -110,7 +110,7 @@ type Writer struct {

// segments stores segment metadata using sparse map for memory efficiency
// Maps segment index to Segment metadata (hash, size information)
segments map[int]Segment
segments map[int]*Segment
// maxSegmentIndex tracks the highest segment index written
maxSegmentIndex int

Expand Down Expand Up @@ -183,7 +183,7 @@ func NewWriter(_ context.Context, opts ...Option[*WriterConfig]) (*Writer, error
WriterConfig: *config,
archiveWriter: archiveWriter,
dek: dek,
segments: make(map[int]Segment), // Initialize sparse storage
segments: make(map[int]*Segment), // Initialize sparse storage
block: block,
initialAttributes: config.initialAttributes,
initialDefaultKAS: config.initialDefaultKAS,
Expand Down Expand Up @@ -232,58 +232,76 @@ func NewWriter(_ context.Context, opts ...Option[*WriterConfig]) (*Writer, error
// uploadToS3(segment1, "part-001")
func (w *Writer) WriteSegment(ctx context.Context, index int, data []byte) (*SegmentResult, error) {
w.mutex.Lock()
defer w.mutex.Unlock()

if w.finalized {
w.mutex.Unlock()
return nil, ErrAlreadyFinalized
}

if index < 0 {
w.mutex.Unlock()
return nil, ErrInvalidSegmentIndex
}

// Check for duplicate segments using map lookup
if _, exists := w.segments[index]; exists {
w.mutex.Unlock()
return nil, ErrSegmentAlreadyWritten
}

if index > w.maxSegmentIndex {
w.maxSegmentIndex = index
}
seg := &Segment{
Size: -1, // indicates not filled yet
}
w.segments[index] = seg

// Calculate CRC32 before encryption for integrity tracking
crc32Checksum := crc32.ChecksumIEEE(data)
w.mutex.Unlock()

// Encrypt directly without unnecessary copying - the archive layer will handle copying if needed
segmentCipher, err := w.block.Encrypt(data)
segmentCipher, nonce, err := w.block.EncryptInPlace(data)
if err != nil {
return nil, err
}

segmentSig, err := calculateSignature(segmentCipher, w.dek, w.segmentIntegrityAlgorithm, false) // Don't ever hex encode new tdf's
if err != nil {
return nil, err
}

segmentHash := string(ocrypto.Base64Encode([]byte(segmentSig)))
w.segments[index] = Segment{
Hash: segmentHash,
Size: int64(len(data)), // Use original data length
EncryptedSize: int64(len(segmentCipher)),
}
w.mutex.Lock()
seg.Size = int64(len(data))
seg.EncryptedSize = int64(len(segmentCipher)) + int64(len(nonce))
seg.Hash = segmentHash
w.mutex.Unlock()

zipBytes, err := w.archiveWriter.WriteSegment(ctx, index, segmentCipher)
crc := crc32.NewIEEE()
_, err = crc.Write(nonce)
if err != nil {
return nil, err
}
_, err = crc.Write(segmentCipher)
if err != nil {
return nil, err
}
header, err := w.archiveWriter.WriteSegment(ctx, index, uint64(seg.EncryptedSize), crc.Sum32())
if err != nil {
return nil, err
}
var reader io.Reader
if len(header) == 0 {
reader = io.MultiReader(bytes.NewReader(nonce), bytes.NewReader(segmentCipher))
} else {
reader = io.MultiReader(bytes.NewReader(header), bytes.NewReader(nonce), bytes.NewReader(segmentCipher))
}

return &SegmentResult{
Data: zipBytes,
TDFData: reader,
Index: index,
Hash: segmentHash,
PlaintextSize: int64(len(data)),
EncryptedSize: int64(len(segmentCipher)),
CRC32: crc32Checksum,
Hash: seg.Hash,
PlaintextSize: seg.Size,
EncryptedSize: seg.EncryptedSize,
}, nil
}

Expand Down Expand Up @@ -505,7 +523,7 @@ func (w *Writer) getManifest(ctx context.Context, cfg *WriterFinalizeConfig) (*M
// Copy segments to manifest in finalize order (pack densely)
for i, idx := range order {
if segment, exists := w.segments[idx]; exists {
encryptInfo.Segments[i] = segment
encryptInfo.Segments[i] = *segment
}
}

Expand All @@ -524,7 +542,8 @@ func (w *Writer) getManifest(ctx context.Context, cfg *WriterFinalizeConfig) (*M
var totalPlaintextSize, totalEncryptedSize int64
for _, i := range order {
segment, exists := w.segments[i]
if !exists {
// if size is negative, segment was not written, finalized has been called too early
if !exists || w.segments[i].Size < 0 {
return nil, 0, 0, fmt.Errorf("segment %d not written; cannot finalize", i)
}
if segment.Hash != "" {
Expand Down
2 changes: 1 addition & 1 deletion sdk/experimental/tdf/writer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ func testErrorConditions(t *testing.T) {
require.NoError(t, err)

// Manually corrupt segment hash to test error handling
writer.segments[0] = Segment{Hash: "", Size: 10, EncryptedSize: 26}
writer.segments[0] = &Segment{Hash: "", Size: 10, EncryptedSize: 26}
writer.maxSegmentIndex = 0

attributes := []*policy.Value{
Expand Down
16 changes: 11 additions & 5 deletions sdk/internal/zipstream/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package zipstream

import (
"hash/crc32"
"testing"
)

Expand Down Expand Up @@ -55,7 +56,8 @@ func BenchmarkSegmentWriter_CRC32ContiguousProcessing(b *testing.B) {

// Write segments in specified order
for _, segIdx := range writeOrder {
_, err := writer.WriteSegment(ctx, segIdx, segmentData)
crc := crc32.ChecksumIEEE(segmentData)
_, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc)
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -99,7 +101,8 @@ func BenchmarkSegmentWriter_SparseIndices(b *testing.B) {
// Write sparse indices in order
for k := 0; k < n; k++ {
idx := k * stride
if _, err := w.WriteSegment(ctx, idx, data); err != nil {
crc := crc32.ChecksumIEEE(data)
if _, err := w.WriteSegment(ctx, idx, uint64(len(data)), crc); err != nil {
b.Fatal(err)
}
}
Expand Down Expand Up @@ -153,7 +156,8 @@ func BenchmarkSegmentWriter_VariableSegmentSizes(b *testing.B) {
segmentData[j] = byte((segIdx * j) % 256)
}

_, err := writer.WriteSegment(ctx, segIdx, segmentData)
crc := crc32.ChecksumIEEE(segmentData)
_, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc)
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -234,7 +238,8 @@ func BenchmarkSegmentWriter_MemoryPressure(b *testing.B) {
segmentData[j] = byte((orderIdx * j) % 256)
}

_, err := writer.WriteSegment(ctx, segIdx, segmentData)
crc := crc32.ChecksumIEEE(segmentData)
_, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc)
if err != nil {
b.Fatal(err)
}
Expand Down Expand Up @@ -305,7 +310,8 @@ func BenchmarkSegmentWriter_ZIPGeneration(b *testing.B) {

// Write all segments
for segIdx := 0; segIdx < tc.segmentCount; segIdx++ {
_, err := writer.WriteSegment(ctx, segIdx, segmentData)
crc := crc32.ChecksumIEEE(segmentData)
_, err := writer.WriteSegment(ctx, segIdx, uint64(len(segmentData)), crc)
if err != nil {
b.Fatal(err)
}
Expand Down
Loading
Loading