diff --git a/connection.go b/connection.go index 5b8a91c2..21217b8b 100644 --- a/connection.go +++ b/connection.go @@ -515,6 +515,8 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver req.Parameters = parameters } + req.EnforceEmbeddedSchemaCorrectness = c.cfg.EnforceEmbeddedSchemaCorrectness + // Add per-statement query tags if provided via context if queryTags := driverctx.QueryTagsFromContext(ctx); len(queryTags) > 0 { serialized := SerializeQueryTags(queryTags) diff --git a/connector.go b/connector.go index 9b1e0872..7834fb75 100644 --- a/connector.go +++ b/connector.go @@ -329,6 +329,15 @@ func WithEnableMetricViewMetadata(enable bool) ConnOption { } } +// WithEnforceEmbeddedSchemaCorrectness enables enforcement of embedded schema correctness +// in query execution. When set to true, the server will enforce embedded schema correctness. +// Default is false. +func WithEnforceEmbeddedSchemaCorrectness(enforce bool) ConnOption { + return func(c *config.Config) { + c.EnforceEmbeddedSchemaCorrectness = enforce + } +} + // Setup of Oauth M2m authentication func WithClientCredentials(clientID, clientSecret string) ConnOption { return func(c *config.Config) { diff --git a/internal/cli_service/cli_service.go b/internal/cli_service/cli_service.go index 71952c69..a43bcc2d 100644 --- a/internal/cli_service/cli_service.go +++ b/internal/cli_service/cli_service.go @@ -11776,6 +11776,7 @@ func (p *TSparkArrowTypes) Validate() error { // - Parameters // - MaxBytesPerBatch // - StatementConf +// - EnforceEmbeddedSchemaCorrectness type TExecuteStatementReq struct { SessionHandle *TSessionHandle `thrift:"sessionHandle,1,required" db:"sessionHandle" json:"sessionHandle"` Statement string `thrift:"statement,2,required" db:"statement" json:"statement"` @@ -11794,6 +11795,8 @@ type TExecuteStatementReq struct { MaxBytesPerBatch *int64 `thrift:"maxBytesPerBatch,1289" db:"maxBytesPerBatch" json:"maxBytesPerBatch,omitempty"` // unused fields # 1290 to 1295 StatementConf *TStatementConf `thrift:"statementConf,1296" db:"statementConf" json:"statementConf,omitempty"` + // unused fields # 1297 to 3352 + EnforceEmbeddedSchemaCorrectness bool `thrift:"enforceEmbeddedSchemaCorrectness,3353" db:"enforceEmbeddedSchemaCorrectness" json:"enforceEmbeddedSchemaCorrectness"` } func NewTExecuteStatementReq() *TExecuteStatementReq { @@ -11894,6 +11897,11 @@ func (p *TExecuteStatementReq) GetStatementConf() *TStatementConf { } return p.StatementConf } +var TExecuteStatementReq_EnforceEmbeddedSchemaCorrectness_DEFAULT bool = false + +func (p *TExecuteStatementReq) GetEnforceEmbeddedSchemaCorrectness() bool { + return p.EnforceEmbeddedSchemaCorrectness +} func (p *TExecuteStatementReq) IsSetSessionHandle() bool { return p.SessionHandle != nil } @@ -11950,6 +11958,10 @@ func (p *TExecuteStatementReq) IsSetStatementConf() bool { return p.StatementConf != nil } +func (p *TExecuteStatementReq) IsSetEnforceEmbeddedSchemaCorrectness() bool { + return p.EnforceEmbeddedSchemaCorrectness != TExecuteStatementReq_EnforceEmbeddedSchemaCorrectness_DEFAULT +} + func (p *TExecuteStatementReq) Read(ctx context.Context, iprot thrift.TProtocol) error { if _, err := iprot.ReadStructBegin(ctx); err != nil { return thrift.PrependError(fmt.Sprintf("%T read error: ", p), err) @@ -12117,6 +12129,16 @@ func (p *TExecuteStatementReq) Read(ctx context.Context, iprot thrift.TProtocol) return err } } + case 3353: + if fieldTypeId == thrift.BOOL { + if err := p.ReadField3353(ctx, iprot); err != nil { + return err + } + } else { + if err := iprot.Skip(ctx, fieldTypeId); err != nil { + return err + } + } default: if err := iprot.Skip(ctx, fieldTypeId); err != nil { return err @@ -12299,6 +12321,15 @@ func (p *TExecuteStatementReq) ReadField1296(ctx context.Context, iprot thrift. return nil } +func (p *TExecuteStatementReq) ReadField3353(ctx context.Context, iprot thrift.TProtocol) error { + if v, err := iprot.ReadBool(ctx); err != nil { + return thrift.PrependError("error reading field 3353: ", err) +} else { + p.EnforceEmbeddedSchemaCorrectness = v +} + return nil +} + func (p *TExecuteStatementReq) Write(ctx context.Context, oprot thrift.TProtocol) error { if err := oprot.WriteStructBegin(ctx, "TExecuteStatementReq"); err != nil { return thrift.PrependError(fmt.Sprintf("%T write struct begin error: ", p), err) } @@ -12318,6 +12349,7 @@ func (p *TExecuteStatementReq) Write(ctx context.Context, oprot thrift.TProtocol if err := p.writeField1288(ctx, oprot); err != nil { return err } if err := p.writeField1289(ctx, oprot); err != nil { return err } if err := p.writeField1296(ctx, oprot); err != nil { return err } + if err := p.writeField3353(ctx, oprot); err != nil { return err } } if err := oprot.WriteFieldStop(ctx); err != nil { return thrift.PrependError("write field stop error: ", err) } @@ -12525,6 +12557,18 @@ func (p *TExecuteStatementReq) writeField1296(ctx context.Context, oprot thrift. return err } +func (p *TExecuteStatementReq) writeField3353(ctx context.Context, oprot thrift.TProtocol) (err error) { + if p.IsSetEnforceEmbeddedSchemaCorrectness() { + if err := oprot.WriteFieldBegin(ctx, "enforceEmbeddedSchemaCorrectness", thrift.BOOL, 3353); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field begin error 3353:enforceEmbeddedSchemaCorrectness: ", p), err) } + if err := oprot.WriteBool(ctx, bool(p.EnforceEmbeddedSchemaCorrectness)); err != nil { + return thrift.PrependError(fmt.Sprintf("%T.enforceEmbeddedSchemaCorrectness (3353) field write error: ", p), err) } + if err := oprot.WriteFieldEnd(ctx); err != nil { + return thrift.PrependError(fmt.Sprintf("%T write field end error 3353:enforceEmbeddedSchemaCorrectness: ", p), err) } + } + return err +} + func (p *TExecuteStatementReq) Equals(other *TExecuteStatementReq) bool { if p == other { return true @@ -12584,6 +12628,7 @@ func (p *TExecuteStatementReq) Equals(other *TExecuteStatementReq) bool { if (*p.MaxBytesPerBatch) != (*other.MaxBytesPerBatch) { return false } } if !p.StatementConf.Equals(other.StatementConf) { return false } + if p.EnforceEmbeddedSchemaCorrectness != other.EnforceEmbeddedSchemaCorrectness { return false } return true } diff --git a/internal/cli_service/thrift_field_id_test.go b/internal/cli_service/thrift_field_id_test.go index f70546aa..00416517 100644 --- a/internal/cli_service/thrift_field_id_test.go +++ b/internal/cli_service/thrift_field_id_test.go @@ -20,6 +20,10 @@ import ( func TestThriftFieldIdsAreWithinAllowedRange(t *testing.T) { const maxAllowedFieldID = 3329 + allowedExceptions := map[int]bool{ + 3353: true, + } + // Get the directory of this test file _, filename, _, ok := runtime.Caller(0) if !ok { @@ -30,7 +34,7 @@ func TestThriftFieldIdsAreWithinAllowedRange(t *testing.T) { testDir := filepath.Dir(filename) cliServicePath := filepath.Join(testDir, "cli_service.go") - violations, err := validateThriftFieldIDs(cliServicePath, maxAllowedFieldID) + violations, err := validateThriftFieldIDs(cliServicePath, maxAllowedFieldID, allowedExceptions) if err != nil { t.Fatalf("Failed to validate thrift field IDs: %v", err) } @@ -51,8 +55,9 @@ func TestThriftFieldIdsAreWithinAllowedRange(t *testing.T) { } // validateThriftFieldIDs parses the cli_service.go file and extracts all thrift field IDs -// to validate they are within the allowed range. -func validateThriftFieldIDs(filePath string, maxAllowedFieldID int) ([]string, error) { +// to validate they are within the allowed range. Field IDs listed in allowedExceptions +// are permitted even if they exceed the maximum. +func validateThriftFieldIDs(filePath string, maxAllowedFieldID int, allowedExceptions map[int]bool) ([]string, error) { file, err := os.Open(filePath) //nolint:gosec // G304: path is a test fixture, not user-controlled if err != nil { return nil, fmt.Errorf("failed to open file %s: %w", filePath, err) @@ -84,7 +89,7 @@ func validateThriftFieldIDs(filePath string, maxAllowedFieldID int) ([]string, e continue } - if fieldID >= maxAllowedFieldID { + if fieldID >= maxAllowedFieldID && !allowedExceptions[fieldID] { // Extract struct/field context from the line context := extractFieldContext(line) violation := fmt.Sprintf( diff --git a/internal/config/config.go b/internal/config/config.go index b8be59cb..bdb1a17c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -101,14 +101,15 @@ type UserConfig struct { // Telemetry configuration // Uses config overlay pattern: client > server > default. // Unset = check server feature flag; explicitly true/false overrides the server. - EnableTelemetry ConfigValue[bool] - TelemetryBatchSize int // 0 = use default (100) - TelemetryFlushInterval time.Duration // 0 = use default (5s) - TelemetryRetryCount int // -1 = use default (3); 0 = disable retries; set via telemetry_retry_count - TelemetryRetryDelay time.Duration // 0 = use default (100ms); set via telemetry_retry_delay - Transport http.RoundTripper - UseLz4Compression bool - EnableMetricViewMetadata bool + EnableTelemetry ConfigValue[bool] + TelemetryBatchSize int // 0 = use default (100) + TelemetryFlushInterval time.Duration // 0 = use default (5s) + TelemetryRetryCount int // -1 = use default (3); 0 = disable retries; set via telemetry_retry_count + TelemetryRetryDelay time.Duration // 0 = use default (100ms); set via telemetry_retry_delay + Transport http.RoundTripper + UseLz4Compression bool + EnableMetricViewMetadata bool + EnforceEmbeddedSchemaCorrectness bool CloudFetchConfig } @@ -302,6 +303,13 @@ func ParseDSN(dsn string) (UserConfig, error) { ucfg.EnableMetricViewMetadata = enableMetricViewMetadata } + if enforceEmbeddedSchemaCorrectness, ok, err := params.extractAsBool("enforceEmbeddedSchemaCorrectness"); ok { + if err != nil { + return UserConfig{}, err + } + ucfg.EnforceEmbeddedSchemaCorrectness = enforceEmbeddedSchemaCorrectness + } + // Telemetry parameters if enableTelemetry, ok, err := params.extractAsBool("enableTelemetry"); ok { if err != nil {