Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql/pgwire: explicitly handle 0, 1 or n format codes #5783

Merged
merged 1 commit into from Apr 4, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
10 changes: 10 additions & 0 deletions acceptance/php_test.go
Expand Up @@ -38,4 +38,14 @@ $result = pg_query_params('SELECT 1, 2 > $1, $1', [%v])
or kill('Query failed: ' . pg_last_error());
$arr = pg_fetch_row($result);
($arr === ['1', 'f', '3']) or kill('Unexpected: ' . print_r($arr, true));

$dbh = new PDO('pgsql:','', null, array(PDO::ATTR_ERRMODE => PDO::ERRMODE_EXCEPTION));
$dbh->exec('CREATE database bank');
$dbh->exec('CREATE table bank.accounts (id INT PRIMARY KEY, balance INT)');
$dbh->exec('INSERT INTO bank.accounts (id, balance) VALUES (1, 1000), (2, 250)');
$dbh->beginTransaction();
$stmt = $dbh->prepare('UPDATE bank.accounts SET balance = balance + :deposit WHERE id=:account');
$stmt->execute(array('account' => 1, 'deposit' => 10));
$stmt->execute(array('account' => 2, 'deposit' => -10));
$dbh->commit();
`
2 changes: 1 addition & 1 deletion acceptance/util_test.go
Expand Up @@ -283,7 +283,7 @@ func testDockerSuccess(t *testing.T, name string, cmd []string) {
}

const (
postgresTestTag = "20160203-140220"
postgresTestTag = "20160401-140120"
)

func testDocker(t *testing.T, name string, cmd []string) error {
Expand Down
75 changes: 54 additions & 21 deletions sql/pgwire/v3.go
Expand Up @@ -491,34 +491,50 @@ func (c *v3Conn) handleBind(buf *readBuffer) error {
if !ok {
return c.sendInternalError(fmt.Sprintf("unknown prepared statement %q", statementName))
}

numParams := int16(len(stmt.inTypes))
paramFormatCodes := make([]formatCode, numParams)

// From the docs on number of parameter format codes to bind:
// This can be zero to indicate that there are no parameters or that the
// parameters all use the default format (text); or one, in which case the
// specified format code is applied to all parameters; or it can equal the
// actual number of parameters.
// http://www.postgresql.org/docs/current/static/protocol-message-formats.html
numParamFormatCodes, err := buf.getInt16()
if err != nil {
return err
}
numParams := len(stmt.inTypes)
if int(numParamFormatCodes) > numParams {
return c.sendInternalError(fmt.Sprintf("too many format codes specified: %d for %d paramaters", numParamFormatCodes, numParams))
}
paramFormatCodes := make([]formatCode, numParams)
for i := range paramFormatCodes[:numParamFormatCodes] {
switch numParamFormatCodes {
case 0:
case 1:
// `1` means read one code and apply it to every param.
c, err := buf.getInt16()
if err != nil {
return err
}
paramFormatCodes[i] = formatCode(c)
}
if numParamFormatCodes == 1 {
fmtCode := paramFormatCodes[0]

fmtCode := formatCode(c)
for i := range paramFormatCodes {
paramFormatCodes[i] = fmtCode
}
case numParams:
// Read one format code for each param and apply it to that param.
for i := range paramFormatCodes {
c, err := buf.getInt16()
if err != nil {
return err
}
paramFormatCodes[i] = formatCode(c)
}
default:
return c.sendInternalError(fmt.Sprintf("wrong number of format codes specified: %d for %d paramaters", numParamFormatCodes, numParams))
}

numValues, err := buf.getInt16()
if err != nil {
return err
}
if int(numValues) != numParams {
if numValues != numParams {
return c.sendInternalError(fmt.Sprintf("expected %d parameters, got %d", numParams, numValues))
}
params := make([]parser.Datum, numParams)
Expand All @@ -542,27 +558,44 @@ func (c *v3Conn) handleBind(buf *readBuffer) error {
params[i] = d
}

numColumns := int16(len(stmt.columns))
columnFormatCodes := make([]formatCode, numColumns)

// From the docs on number of result-column format codes to bind:
// This can be zero to indicate that there are no result columns or that
// the result columns should all use the default format (text); or one, in
// which case the specified format code is applied to all result columns
// (if any); or it can equal the actual number of result columns of the
// query.
// http://www.postgresql.org/docs/current/static/protocol-message-formats.html
numColumnFormatCodes, err := buf.getInt16()
if err != nil {
return err
}
numColumns := len(stmt.columns)
columnFormatCodes := make([]formatCode, numColumns)
for i := range columnFormatCodes[:numColumnFormatCodes] {
switch numColumnFormatCodes {
case 0:
case 1:
// Read one code and apply it to every column.
c, err := buf.getInt16()
if err != nil {
return err
}
columnFormatCodes[i] = formatCode(c)
}
if numColumnFormatCodes == 1 {
fmtCode := columnFormatCodes[0]

fmtCode := formatCode(c)
for i := range columnFormatCodes {
columnFormatCodes[i] = formatCode(fmtCode)
}
case numColumns:
// Read one format code for each column and apply it to that column.
for i := range columnFormatCodes {
c, err := buf.getInt16()
if err != nil {
return err
}
columnFormatCodes[i] = formatCode(c)
}
default:
return c.sendInternalError(fmt.Sprintf("expected 0, 1, or %d for number of format codes, got %d", numColumns, numColumnFormatCodes))
}

stmt.portalNames[portalName] = struct{}{}
c.preparedPortals[portalName] = preparedPortal{
stmt: stmt,
Expand Down