Skip to content

Commit

Permalink
Add support for prepared statements
Browse files Browse the repository at this point in the history
  • Loading branch information
lxn committed Aug 7, 2010
1 parent 746833f commit 6e4a863
Show file tree
Hide file tree
Showing 3 changed files with 412 additions and 0 deletions.
83 changes: 83 additions & 0 deletions parameter.go
@@ -0,0 +1,83 @@
// Copyright 2010 Alexander Neumann. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package pgsql

import (
"os"
)

// Parameter is used to set the value of a parameter in a Statement.
type Parameter struct {
name string
stmt *Statement
typ Type
value interface{}
}

// NewParameter returns a new Parameter with the specified properties.
func NewParameter(name string, typ Type) *Parameter {
return &Parameter{name: name, typ: typ}
}

// Name returns the name of the Parameter.
func (p *Parameter) Name() string {
return p.name
}

// Type returns the PostgreSQL data type of the Parameter.
func (p *Parameter) Type() Type {
return p.typ
}

// Value returns the current value of the Parameter.
func (p *Parameter) Value() interface{} {
return p.value
}

// SetValue sets the current value of the Parameter.
func (p *Parameter) SetValue(v interface{}) (err os.Error) {
defer func() {
if x := recover(); x != nil {
if p.stmt == nil {
err = x.(os.Error)
} else {
err = p.stmt.conn.logAndConvertPanic(x)
}
}
}()

if p.stmt != nil && p.stmt.conn.LogLevel >= LogVerbose {
defer p.stmt.conn.logExit(p.stmt.conn.logEnter("*Parameter.SetValue"))
}

var ok bool
switch p.typ {
case Boolean:
p.value = v.(bool)

case Char, Text, Varchar:
p.value = v.(string)

case Real:
p.value = v.(float32)

case Double:
p.value = v.(float64)

case Smallint:
p.value = v.(int16)

case Integer:
p.value, ok = v.(int32)
if !ok {
p.value = v.(int)
}

case Bigint:
p.value = v.(int64)
}

return
}
232 changes: 232 additions & 0 deletions statement.go
@@ -0,0 +1,232 @@
// Copyright 2010 Alexander Neumann. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.

package pgsql

import (
"bytes"
"fmt"
"os"
"regexp"
"strings"
)

var nextStatementId, nextPortalId uint64
var quoteRegExp = regexp.MustCompile("['][^']*[']")

// Statement is a means to efficiently execute a parameterized SQL command multiple times.
// Call *Conn.Prepare to create a new prepared Statement.
type Statement struct {
conn *Conn
name, portalName, command, actualCommand string
isClosed bool
params []*Parameter
name2param map[string]*Parameter
}

func replaceParameterName(command, old, new string) string {
buf := bytes.NewBuffer(nil)

quoteIndices := quoteRegExp.ExecuteString(command)
prevQuoteEnd := 0
for i := 0; i < len(quoteIndices); i += 2 {
quoteStart := quoteIndices[i]
quoteEnd := quoteIndices[i+1]

buf.WriteString(strings.Replace(command[prevQuoteEnd:quoteStart], old, new, -1))
buf.WriteString(command[quoteStart:quoteEnd])

prevQuoteEnd = quoteEnd
}

if buf.Len() > 0 {
buf.WriteString(strings.Replace(command[prevQuoteEnd:], old, new, -1))

return buf.String()
}

return strings.Replace(command, old, new, -1)
}

func adjustCommand(command string, params []*Parameter) string {
for i, p := range params {
command = replaceParameterName(command, p.name, fmt.Sprintf("$%d", i+1))
}

return command
}

func newStatement(conn *Conn, command string, params []*Parameter) *Statement {
if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("newStatement"))
}

stmt := new(Statement)

stmt.name2param = make(map[string]*Parameter)

for _, param := range params {
if param.stmt != nil {
panic(fmt.Sprintf("parameter '%s' already used in another statement", param.name))
}
param.stmt = stmt

stmt.name2param[param.name] = param
}

stmt.conn = conn

stmt.name = fmt.Sprint("stmt", nextStatementId)
nextStatementId++

stmt.portalName = fmt.Sprint("prtl", nextPortalId)
nextPortalId++

stmt.command = command
stmt.actualCommand = adjustCommand(command, params)

stmt.params = make([]*Parameter, len(params))
copy(stmt.params, params)

return stmt
}

// Parameter returns the Parameter with the specified name or nil, if the Statement has no Parameter with that name.
func (stmt *Statement) Parameter(name string) *Parameter {
conn := stmt.conn

if conn.LogLevel >= LogVerbose {
defer conn.logExit(conn.logEnter("*Statement.Parameter"))
}

param, ok := stmt.name2param[name]
if !ok {
return nil
}

return param
}

// Parameters returns a slice containing the parameters of the Statement.
func (stmt *Statement) Parameters() []*Parameter {
conn := stmt.conn

if conn.LogLevel >= LogVerbose {
defer conn.logExit(conn.logEnter("*Statement.Parameters"))
}

params := make([]*Parameter, len(stmt.params))
copy(params, stmt.params)
return params
}

// IsClosed returns if the Statement has been closed.
func (stmt *Statement) IsClosed() bool {
conn := stmt.conn

if conn.LogLevel >= LogVerbose {
defer conn.logExit(conn.logEnter("*Statement.IsClosed"))
}

return stmt.isClosed
}

// Close closes the Statement, releasing resources on the server.
func (stmt *Statement) Close() (err os.Error) {
conn := stmt.conn

defer func() {
if x := recover(); x != nil {
err = conn.logAndConvertPanic(x)
}
}()

if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Statement.Close"))
}

stmt.conn.state.closeStatement(stmt)

stmt.isClosed = true
return
}

// ActualCommand returns the actual command text that is sent to the server.
// The original command is automatically adjusted if it contains parameters so
// it complies with what PostgreSQL expects. Refer to the return value of this
// method to make sense of the position information contained in many error
// messages.
func (stmt *Statement) ActualCommand() string {
conn := stmt.conn

if conn.LogLevel >= LogVerbose {
defer conn.logExit(conn.logEnter("*Statement.ActualCommand"))
}

return stmt.actualCommand
}

// Command is the original command text as given to *Conn.Prepare.
func (stmt *Statement) Command() string {
conn := stmt.conn

if conn.LogLevel >= LogVerbose {
defer conn.logExit(conn.logEnter("*Statement.Command"))
}

return stmt.command
}

// Query executes the Statement and returns a
// Reader for row-by-row retrieval of the results.
// The returned Reader must be closed before sending another
// query or command to the server over the same connection.
func (stmt *Statement) Query() (reader *Reader, err os.Error) {
conn := stmt.conn

defer func() {
if x := recover(); x != nil {
err = conn.logAndConvertPanic(x)
}
}()

if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Statement.Query"))
}

r := newReader(conn)

conn.state.execute(stmt, r)

reader = r

return
}

// Execute executes the Statement and returns the number
// of rows affected. If the results of a query are needed, use the
// Query method instead.
func (stmt *Statement) Execute() (rowsAffected int64, err os.Error) {
conn := stmt.conn

defer func() {
if x := recover(); x != nil {
err = conn.logAndConvertPanic(x)
}
}()

if conn.LogLevel >= LogDebug {
defer conn.logExit(conn.logEnter("*Statement.Execute"))
}

reader, err := stmt.Query()
if err != nil {
return
}

err = reader.Close()

rowsAffected = reader.rowsAffected
return
}

0 comments on commit 6e4a863

Please sign in to comment.