Skip to content

Commit

Permalink
Merge pull request #381 from MichaelS11/OCIStmtPrepare2
Browse files Browse the repository at this point in the history
Switched to OCIStmtPrepare2 & Improve context done
  • Loading branch information
mattn committed Feb 5, 2020
2 parents aaa5bac + e24785c commit 8b842f2
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 40 deletions.
54 changes: 29 additions & 25 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
// Ping database connection
func (conn *OCI8Conn) Ping(ctx context.Context) error {
done := make(chan struct{})
go conn.ociBreak(ctx, done)
go conn.ociBreakDone(ctx, done)
result := C.OCIPing(conn.svc, conn.errHandle, C.OCI_DEFAULT)
close(done)
if result == C.OCI_SUCCESS || result == C.OCI_SUCCESS_WITH_INFO {
Expand Down Expand Up @@ -96,24 +96,23 @@ func (conn *OCI8Conn) PrepareContext(ctx context.Context, query string) (driver.
defer C.free(unsafe.Pointer(queryP))

// statement handle
stmt, _, err := conn.ociHandleAlloc(C.OCI_HTYPE_STMT, 0)
if err != nil {
return nil, fmt.Errorf("allocate statement handle error: %v", err)
}

if rv := C.OCIStmtPrepare(
(*C.OCIStmt)(*stmt),
conn.errHandle,
queryP,
C.ub4(len(query)),
C.ub4(C.OCI_NTV_SYNTAX),
C.ub4(C.OCI_DEFAULT),
var stmtTemp *C.OCIStmt
stmt := &stmtTemp
if rv := C.OCIStmtPrepare2(
conn.svc, // service context handle
stmt, // pointer to the statement handle returned
conn.errHandle, // error handle
queryP, // statement text
C.ub4(len(query)), // statement text length
nil, // key to be used for searching the statement in the statement cache
C.ub4(0), // length of the key
C.ub4(C.OCI_NTV_SYNTAX), // syntax - OCI_NTV_SYNTAX: syntax depends upon the version of the server
C.ub4(C.OCI_DEFAULT), // mode
); rv != C.OCI_SUCCESS {
C.OCIHandleFree(*stmt, C.OCI_HTYPE_STMT)
return nil, conn.getError(rv)
}

return &OCI8Stmt{conn: conn, stmt: (*C.OCIStmt)(*stmt)}, nil
return &OCI8Stmt{conn: conn, stmt: *stmt}, nil
}

// Begin starts a transaction
Expand Down Expand Up @@ -567,23 +566,28 @@ func appendSmallInt(slice []byte, num int) []byte {
return append(slice, byte('0'+num/10), byte('0'+(num%10)))
}

// ociBreak calls OCIBreak if ctx.Done is finished before done chan is closed
func (conn *OCI8Conn) ociBreak(ctx context.Context, done chan struct{}) {
// ociBreakDone calls OCIBreak if ctx.Done is finished before done chan is closed
func (conn *OCI8Conn) ociBreakDone(ctx context.Context, done chan struct{}) {
select {
case <-done:
case <-ctx.Done():
// select again to avoid race condition if both are done
select {
case <-done:
default:
result := C.OCIBreak(
unsafe.Pointer(conn.svc), // The service context handle or the server context handle.
conn.errHandle, // An error handle
)
err := conn.getError(result)
if err != nil {
conn.logger.Print("OCIBreak error: ", err)
}
conn.ociBreak()
}
}
}

// ociBreak calls OCIBreak
func (conn *OCI8Conn) ociBreak() {
result := C.OCIBreak(
unsafe.Pointer(conn.svc), // service or server context handle
conn.errHandle, // error handle
)
err := conn.getError(result)
if err != nil {
conn.logger.Print("OCIBreak error: ", err)
}
}
3 changes: 2 additions & 1 deletion globals.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import "C"
// noPkgConfig is a Go tag for disabling using pkg-config and using environmental settings like CGO_CFLAGS and CGO_LDFLAGS instead

import (
"context"
"database/sql"
"errors"
"io/ioutil"
Expand Down Expand Up @@ -82,7 +83,6 @@ type (
conn *OCI8Conn
stmt *C.OCIStmt
closed bool
pbind []oci8Bind // bind params
}

// OCI8Result is Oracle result
Expand Down Expand Up @@ -120,6 +120,7 @@ type (
defines []oci8Define
e bool
closed bool
ctx context.Context
done chan struct{}
}
)
Expand Down
5 changes: 4 additions & 1 deletion rows.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ func (rows *OCI8Rows) Next(dest []driver.Value) error {
return nil
}

if rows.ctx.Err() != nil {
return rows.ctx.Err()
}

result := C.OCIStmtFetch2(
rows.stmt.stmt,
rows.stmt.conn.errHandle,
Expand Down Expand Up @@ -198,7 +202,6 @@ func (rows *OCI8Rows) Next(dest []driver.Value) error {
return fmt.Errorf("Unhandled column type: %d", rows.defines[i].dataType)

}

}

return nil
Expand Down
39 changes: 26 additions & 13 deletions statement.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ func (stmt *OCI8Stmt) Close() error {
}
stmt.closed = true

C.OCIHandleFree(unsafe.Pointer(stmt.stmt), C.OCI_HTYPE_STMT)

result := C.OCIStmtRelease(
stmt.stmt, // statement handle
stmt.conn.errHandle, // error handle
nil, // key to be associated with the statement in the cache
C.ub4(0), // length of the key
C.ub4(C.OCI_DEFAULT), // mode
)
stmt.stmt = nil
stmt.pbind = nil

return nil
return stmt.conn.getError(result)
}

// NumInput returns the number of input
Expand Down Expand Up @@ -66,11 +70,9 @@ func (stmt *OCI8Stmt) bindValues(ctx context.Context, values []driver.Value, nam
}

for i := 0; i < count; i++ {
select {
case <-ctx.Done():
if ctx.Err() != nil {
freeBinds(binds)
return nil, ctx.Err()
default:
}

var valueInterface interface{}
Expand Down Expand Up @@ -410,8 +412,12 @@ func (stmt *OCI8Stmt) query(ctx context.Context, binds []oci8Bind) (driver.Rows,
mode = mode | C.OCI_COMMIT_ON_SUCCESS
}

if ctx.Err() != nil {
return nil, ctx.Err()
}

done := make(chan struct{})
go stmt.conn.ociBreak(ctx, done)
go stmt.conn.ociBreakDone(ctx, done)
err = stmt.ociStmtExecute(iter, mode)
close(done)
if err != nil {
Expand All @@ -428,11 +434,9 @@ func (stmt *OCI8Stmt) query(ctx context.Context, binds []oci8Bind) (driver.Rows,
defines := make([]oci8Define, paramCount)

for i := 0; i < paramCount; i++ {
select {
case <-ctx.Done():
if ctx.Err() != nil {
freeDefines(defines)
return nil, ctx.Err()
default:
}

var param *C.OCIParam
Expand Down Expand Up @@ -618,13 +622,18 @@ func (stmt *OCI8Stmt) query(ctx context.Context, binds []oci8Bind) (driver.Rows,
}
}

if ctx.Err() != nil {
return nil, ctx.Err()
}

rows := &OCI8Rows{
stmt: stmt,
defines: defines,
ctx: ctx,
done: make(chan struct{}),
}

go stmt.conn.ociBreak(ctx, rows.done)
go stmt.conn.ociBreakDone(ctx, rows.done)

return rows, nil
}
Expand Down Expand Up @@ -692,8 +701,12 @@ func (stmt *OCI8Stmt) exec(ctx context.Context, binds []oci8Bind) (driver.Result
mode = mode | C.OCI_COMMIT_ON_SUCCESS
}

if ctx.Err() != nil {
return nil, ctx.Err()
}

done := make(chan struct{})
go stmt.conn.ociBreak(ctx, done)
go stmt.conn.ociBreakDone(ctx, done)
err := stmt.ociStmtExecute(1, mode)
close(done)
if err != nil && err != ErrOCISuccessWithInfo {
Expand Down

0 comments on commit 8b842f2

Please sign in to comment.