Skip to content

Commit

Permalink
Postgres OID resolution query does not take into account current `sea…
Browse files Browse the repository at this point in the history
…rch_path` (#2133)

* Fix oid resolution query

* Address review comments
  • Loading branch information
95ulisse committed Feb 15, 2023
1 parent 7970dad commit 3e611e1
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 11 deletions.
18 changes: 7 additions & 11 deletions sqlx-postgres/src/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -344,17 +344,13 @@ WHERE rngtypid = $1
}

// language=SQL
let (oid,): (Oid,) = query_as(
"
SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1
",
)
.bind(name)
.fetch_optional(&mut *self)
.await?
.ok_or_else(|| Error::TypeNotFound {
type_name: String::from(name),
})?;
let (oid,): (Oid,) = query_as("SELECT $1::regtype::oid")
.bind(name)
.fetch_optional(&mut *self)
.await?
.ok_or_else(|| Error::TypeNotFound {
type_name: String::from(name),
})?;

self.cache_type_oid.insert(name.to_string().into(), oid);
Ok(oid)
Expand Down
2 changes: 2 additions & 0 deletions sqlx-postgres/src/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ impl Connection for PgConnection {

fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
Box::pin(async move {
self.cache_type_oid.clear();

let mut cleared = 0_usize;

self.wait_until_ready().await?;
Expand Down
63 changes: 63 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1388,6 +1388,69 @@ VALUES
Ok(())
}

#[sqlx_macros::test]
async fn custom_type_resolution_respects_search_path() -> anyhow::Result<()> {
let mut conn = new::<Postgres>().await?;

conn.execute(
r#"
DROP TYPE IF EXISTS some_enum_type;
DROP SCHEMA IF EXISTS another CASCADE;
CREATE SCHEMA another;
CREATE TYPE some_enum_type AS ENUM ('a', 'b', 'c');
CREATE TYPE another.some_enum_type AS ENUM ('d', 'e', 'f');
"#,
)
.await?;

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct SomeEnumType(String);

impl sqlx::Type<Postgres> for SomeEnumType {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("some_enum_type")
}

fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool {
*ty == Self::type_info()
}
}

impl<'r> sqlx::Decode<'r, Postgres> for SomeEnumType {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
Ok(Self(<String as sqlx::Decode<Postgres>>::decode(value)?))
}
}

impl<'q> sqlx::Encode<'q, Postgres> for SomeEnumType {
fn encode_by_ref(
&self,
buf: &mut sqlx::postgres::PgArgumentBuffer,
) -> sqlx::encode::IsNull {
<String as sqlx::Encode<Postgres>>::encode_by_ref(&self.0, buf)
}
}

let mut conn = new::<Postgres>().await?;

sqlx::query("set search_path = 'another'")
.execute(&mut conn)
.await?;

let result = sqlx::query("SELECT 1 WHERE $1::some_enum_type = 'd'::some_enum_type;")
.bind(SomeEnumType("d".into()))
.fetch_all(&mut conn)
.await;

let result = result.unwrap();
assert_eq!(result.len(), 1);

Ok(())
}

#[sqlx_macros::test]
async fn test_pg_server_num() -> anyhow::Result<()> {
let conn = new::<Postgres>().await?;
Expand Down

0 comments on commit 3e611e1

Please sign in to comment.