diff --git a/sqlx-core/src/postgres/connection/describe.rs b/sqlx-core/src/postgres/connection/describe.rs index dda7ada4a8..813bcdec2e 100644 --- a/sqlx-core/src/postgres/connection/describe.rs +++ b/sqlx-core/src/postgres/connection/describe.rs @@ -346,7 +346,7 @@ WHERE rngtypid = $1 // language=SQL let (oid,): (Oid,) = query_as( " -SELECT oid FROM pg_catalog.pg_type WHERE typname ILIKE $1 +SELECT $1::regtype::oid ", ) .bind(name) diff --git a/tests/postgres/postgres.rs b/tests/postgres/postgres.rs index d6f9cbac37..73b0ff4a1b 100644 --- a/tests/postgres/postgres.rs +++ b/tests/postgres/postgres.rs @@ -1388,6 +1388,69 @@ VALUES Ok(()) } +#[sqlx_macros::test] +async fn custom_type_resolution_respects_search_path() -> anyhow::Result<()> { + let mut conn = new::().await?; + + conn.execute( + r#" +DROP TYPE IF EXISTS some_enum_type; +DROP SCHEMA IF EXISTS another CASCADE; + +CREATE SCHEMA another; +CREATE TYPE some_enum_type AS ENUM ('a', 'b', 'c'); +CREATE TYPE another.some_enum_type AS ENUM ('d', 'e', 'f'); + "#, + ) + .await?; + + #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] + struct SomeEnumType(String); + + impl sqlx::Type for SomeEnumType { + fn type_info() -> sqlx::postgres::PgTypeInfo { + sqlx::postgres::PgTypeInfo::with_name("some_enum_type") + } + + fn compatible(ty: &sqlx::postgres::PgTypeInfo) -> bool { + *ty == Self::type_info() + } + } + + impl<'r> sqlx::Decode<'r, Postgres> for SomeEnumType { + fn decode( + value: sqlx::postgres::PgValueRef<'r>, + ) -> Result> { + Ok(Self(>::decode(value)?)) + } + } + + impl<'q> sqlx::Encode<'q, Postgres> for SomeEnumType { + fn encode_by_ref( + &self, + buf: &mut sqlx::postgres::PgArgumentBuffer, + ) -> sqlx::encode::IsNull { + >::encode_by_ref(&self.0, buf) + } + } + + let mut conn = new::().await?; + + sqlx::query("set search_path = 'another'") + .execute(&mut conn) + .await?; + + let result = sqlx::query("SELECT 1 WHERE $1::some_enum_type = 'd'::some_enum_type;") + .bind(SomeEnumType("d".into())) + .fetch_all(&mut conn) + .await; + + let result = result.unwrap(); + assert_eq!(result.len(), 1); + + Ok(()) +} + #[sqlx_macros::test] async fn test_pg_server_num() -> anyhow::Result<()> { let conn = new::().await?;