Skip to content

Commit

Permalink
add nullability info to Describe
Browse files Browse the repository at this point in the history
implement nullability check for Postgres as a query on pg_attribute
  • Loading branch information
abonander committed Jan 25, 2020
1 parent f0c88da commit 71d2876
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 16 deletions.
9 changes: 9 additions & 0 deletions sqlx-core/src/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ where
pub name: Option<Box<str>>,
pub table_id: Option<DB::TableId>,
pub type_info: DB::TypeInfo,
/// Whether or not the column may be `NULL` (or if that is even known).
pub nullability: Nullability,
}

impl<DB> Debug for Column<DB>
Expand All @@ -55,3 +57,10 @@ where
.finish()
}
}

#[derive(Debug, PartialEq, Eq)]
pub enum Nullability {
NonNull,
Nullable,
Unknown
}
92 changes: 76 additions & 16 deletions sqlx-core/src/postgres/executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ use std::sync::Arc;
use futures_core::future::BoxFuture;
use futures_core::stream::BoxStream;

use crate::describe::{Column, Describe};
use crate::postgres::protocol::{self, Encode, Message, StatementId, TypeFormat};
use crate::describe::{Column, Describe, Nullability};
use crate::postgres::protocol::{self, Encode, Message, StatementId, TypeFormat, Field};
use crate::postgres::{PgArguments, PgRow, PgTypeInfo, Postgres};
use crate::row::Row;
use crate::arguments::Arguments;
use futures_util::TryStreamExt;
use crate::encode::IsNull::No;

#[derive(Debug)]
enum Step {
Expand All @@ -31,7 +35,7 @@ impl super::PgConnection {
query,
param_types: &*args.types,
}
.encode(self.stream.buffer_mut());
.encode(self.stream.buffer_mut());

self.statement_cache.put(query.to_owned(), id);

Expand All @@ -53,7 +57,7 @@ impl super::PgConnection {
values: &*args.values,
result_formats: &[TypeFormat::Binary],
}
.encode(self.stream.buffer_mut());
.encode(self.stream.buffer_mut());
}

fn write_execute(&mut self, portal: &str, limit: i32) {
Expand Down Expand Up @@ -280,28 +284,84 @@ impl super::PgConnection {
}
};

while let Some(_) = self.step().await? {}

Ok(Describe {
param_types: params
.ids
.iter()
.map(|id| PgTypeInfo::new(*id))
.collect::<Vec<_>>()
.into_boxed_slice(),
result_columns: result
.map(|r| r.fields)
.unwrap_or_default()
.into_vec()
.into_iter()
// TODO: Should [Column] just wrap [protocol::Field] ?
.map(|field| Column {
name: field.name,
table_id: field.table_id,
type_info: PgTypeInfo::new(field.type_id),
})
.collect::<Vec<_>>()
result_columns: self.map_result_columns(
result
.map(|r| r.fields)
.unwrap_or_default()
).await?
.into_boxed_slice(),
})
}

async fn map_result_columns(&mut self, fields: Box<[Field]>) -> crate::Result<Vec<Column<Postgres>>> {
use crate::describe::Nullability::*;
use std::fmt::Write;

if fields.is_empty() { return Ok(vec![]); }

let mut query = "select col.idx, pg_attribute.attnotnull from (VALUES ".to_string();

let mut pushed = false;

let mut iter = fields
.iter()
.enumerate()
.flat_map(|(i, field)| {
field.table_id.and_then(|table_id| {
// column_id = 0 means not a real column, < 0 means system column
// TODO: how do we want to handle system columns, always non-null?
if field.column_id > 0 { Some((i, table_id, field.column_id)) } else { None }
})
});

let mut args = PgArguments::default();

for ((i, table_id, column_id), bind) in iter.zip((1 ..).step_by(3)) {
if pushed {
query += ", ";
}

pushed = true;
let _ = write!(query, "(${}, ${}, ${})", bind, bind + 1, bind + 2);

args.add(i as i32);
args.add(table_id as i32);
args.add(column_id);
}

let mut columns: Vec<_> = fields.into_vec().into_iter().map(|field| Column {
name: field.name,
table_id: field.table_id,
type_info: PgTypeInfo::new(field.type_id),
nullability: Unknown
}).collect();

query += ") as col(idx, table_id, col_idx) \
inner join pg_catalog.pg_attribute on attrelid = table_id and attnum = col_idx \
order by idx;";

log::trace!("describe pg_attribute query: {:#?}", query);

let mut result = self.fetch(&query, args);

while let Some(row) = result.try_next().await? {
let i = row.get::<i32, _>(0);
let nonnull = row.get::<bool, _>(1);

columns[i as usize].nullability = if nonnull { NonNull } else { Nullable };
}

Ok(columns)
}
}

impl crate::Executor for super::PgConnection {
Expand Down
26 changes: 26 additions & 0 deletions tests/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,32 @@ async fn it_remains_stable_issue_30() -> anyhow::Result<()> {
Ok(())
}

#[cfg_attr(feature = "runtime-async-std", async_std::test)]
#[cfg_attr(feature = "runtime-tokio", tokio::test)]
async fn test_describe_nullability() -> anyhow::Result<()> {
use sqlx::describe::Nullability::*;

let mut conn = connect().await?;

let _ = conn.send(r#"
CREATE TEMP TABLE nullability_test (
id SERIAL primary key,
name text not null,
address text
)
"#).await?;

let describe = conn.describe("select nt.*, ''::text from nullability_test nt")
.await?;

assert_eq!(describe.result_columns[0].nullability, NonNull);
assert_eq!(describe.result_columns[1].nullability, NonNull);
assert_eq!(describe.result_columns[2].nullability, Nullable);
assert_eq!(describe.result_columns[3].nullability, Unknown);

Ok(())
}

async fn connect() -> anyhow::Result<PgConnection> {
let _ = dotenv::dotenv();
let _ = env_logger::try_init();
Expand Down

0 comments on commit 71d2876

Please sign in to comment.