diff --git a/driver.go b/driver.go index 3f8fba51..aae2501c 100644 --- a/driver.go +++ b/driver.go @@ -689,7 +689,7 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti return err } for { - err = f(ctx, tx) + err = protected(ctx, tx, f) errDuringCommit := false if err == nil { err = tx.Commit() @@ -742,6 +742,15 @@ func runTransactionWithOptions(ctx context.Context, db *sql.DB, opts *sql.TxOpti } } } + +} +func protected(ctx context.Context, tx *sql.Tx, f func(ctx context.Context, tx *sql.Tx) error) (err error) { + defer func() { + if x := recover(); x != nil { + err = spanner.ToSpannerError(status.Errorf(codes.Unknown, "transaction function panic: %v", x)) + } + }() + return f(ctx, tx) } func resetTransactionForRetry(ctx context.Context, conn *sql.Conn, errDuringCommit bool) error { diff --git a/driver_with_mockserver_test.go b/driver_with_mockserver_test.go index 54219b83..af66e19e 100644 --- a/driver_with_mockserver_test.go +++ b/driver_with_mockserver_test.go @@ -4227,6 +4227,21 @@ func TestRunTransactionCommitError(t *testing.T) { } } +func TestRunTransactionPanics(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db, _, teardown := setupTestDBConnection(t) + defer teardown() + + err := RunTransaction(ctx, db, nil, func(ctx context.Context, tx *sql.Tx) error { + panic(nil) + }) + if err == nil { + t.Fatal("missing error from transaction runner") + } +} + func TestTransactionWithLevelDisableRetryAborts(t *testing.T) { t.Parallel()