Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix support for Postgres array of custom types #1483

Merged
merged 1 commit into from
Dec 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion sqlx-core/src/postgres/connection/describe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::sync::Arc;
/// Describes the type of the `pg_type.typtype` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html>
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum TypType {
Base,
Composite,
Expand Down Expand Up @@ -45,6 +46,7 @@ impl TryFrom<u8> for TypType {
/// Describes the type of the `pg_type.typcategory` column
///
/// See <https://www.postgresql.org/docs/13/catalog-pg-type.html#CATALOG-TYPCATEGORY-TABLE>
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum TypCategory {
Array,
Boolean,
Expand Down Expand Up @@ -198,7 +200,9 @@ impl PgConnection {

(Ok(TypType::Base), Ok(TypCategory::Array)) => {
Ok(PgTypeInfo(PgType::Custom(Arc::new(PgCustomType {
kind: PgTypeKind::Array(self.fetch_type_by_oid(element).await?),
kind: PgTypeKind::Array(
self.maybe_fetch_type_info_by_oid(element, true).await?,
),
name: name.into(),
oid,
}))))
Expand Down
120 changes: 120 additions & 0 deletions sqlx-core/src/postgres/type_info.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![allow(dead_code)]

use std::borrow::Cow;
use std::fmt::{self, Display, Formatter};
use std::ops::Deref;
use std::sync::Arc;
Expand Down Expand Up @@ -750,6 +751,125 @@ impl PgType {
}
}
}

/// If `self` is an array type, return the type info for its element.
///
/// This method should only be called on resolved types: calling it on
/// a type that is merely declared (DeclareWithOid/Name) is a bug.
pub(crate) fn try_array_element(&self) -> Option<Cow<'_, PgTypeInfo>> {
// We explicitly match on all the `None` cases to ensure an exhaustive match.
match self {
PgType::Bool => None,
PgType::BoolArray => Some(Cow::Owned(PgTypeInfo(PgType::Bool))),
PgType::Bytea => None,
PgType::ByteaArray => Some(Cow::Owned(PgTypeInfo(PgType::Bytea))),
PgType::Char => None,
PgType::CharArray => Some(Cow::Owned(PgTypeInfo(PgType::Char))),
PgType::Name => None,
PgType::NameArray => Some(Cow::Owned(PgTypeInfo(PgType::Name))),
PgType::Int8 => None,
PgType::Int8Array => Some(Cow::Owned(PgTypeInfo(PgType::Int8))),
PgType::Int2 => None,
PgType::Int2Array => Some(Cow::Owned(PgTypeInfo(PgType::Int2))),
PgType::Int4 => None,
PgType::Int4Array => Some(Cow::Owned(PgTypeInfo(PgType::Int4))),
PgType::Text => None,
PgType::TextArray => Some(Cow::Owned(PgTypeInfo(PgType::Text))),
PgType::Oid => None,
PgType::OidArray => Some(Cow::Owned(PgTypeInfo(PgType::Oid))),
PgType::Json => None,
PgType::JsonArray => Some(Cow::Owned(PgTypeInfo(PgType::Json))),
PgType::Point => None,
PgType::PointArray => Some(Cow::Owned(PgTypeInfo(PgType::Point))),
PgType::Lseg => None,
PgType::LsegArray => Some(Cow::Owned(PgTypeInfo(PgType::Lseg))),
PgType::Path => None,
PgType::PathArray => Some(Cow::Owned(PgTypeInfo(PgType::Path))),
PgType::Box => None,
PgType::BoxArray => Some(Cow::Owned(PgTypeInfo(PgType::Box))),
PgType::Polygon => None,
PgType::PolygonArray => Some(Cow::Owned(PgTypeInfo(PgType::Polygon))),
PgType::Line => None,
PgType::LineArray => Some(Cow::Owned(PgTypeInfo(PgType::Line))),
PgType::Cidr => None,
PgType::CidrArray => Some(Cow::Owned(PgTypeInfo(PgType::Cidr))),
PgType::Float4 => None,
PgType::Float4Array => Some(Cow::Owned(PgTypeInfo(PgType::Float4))),
PgType::Float8 => None,
PgType::Float8Array => Some(Cow::Owned(PgTypeInfo(PgType::Float8))),
PgType::Circle => None,
PgType::CircleArray => Some(Cow::Owned(PgTypeInfo(PgType::Circle))),
PgType::Macaddr8 => None,
PgType::Macaddr8Array => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr8))),
PgType::Money => None,
PgType::MoneyArray => Some(Cow::Owned(PgTypeInfo(PgType::Money))),
PgType::Macaddr => None,
PgType::MacaddrArray => Some(Cow::Owned(PgTypeInfo(PgType::Macaddr))),
PgType::Inet => None,
PgType::InetArray => Some(Cow::Owned(PgTypeInfo(PgType::Inet))),
PgType::Bpchar => None,
PgType::BpcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Bpchar))),
PgType::Varchar => None,
PgType::VarcharArray => Some(Cow::Owned(PgTypeInfo(PgType::Varchar))),
PgType::Date => None,
PgType::DateArray => Some(Cow::Owned(PgTypeInfo(PgType::Date))),
PgType::Time => None,
PgType::TimeArray => Some(Cow::Owned(PgTypeInfo(PgType::Time))),
PgType::Timestamp => None,
PgType::TimestampArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamp))),
PgType::Timestamptz => None,
PgType::TimestamptzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timestamptz))),
PgType::Interval => None,
PgType::IntervalArray => Some(Cow::Owned(PgTypeInfo(PgType::Interval))),
PgType::Timetz => None,
PgType::TimetzArray => Some(Cow::Owned(PgTypeInfo(PgType::Timetz))),
PgType::Bit => None,
PgType::BitArray => Some(Cow::Owned(PgTypeInfo(PgType::Bit))),
PgType::Varbit => None,
PgType::VarbitArray => Some(Cow::Owned(PgTypeInfo(PgType::Varbit))),
PgType::Numeric => None,
PgType::NumericArray => Some(Cow::Owned(PgTypeInfo(PgType::Numeric))),
PgType::Record => None,
PgType::RecordArray => Some(Cow::Owned(PgTypeInfo(PgType::Record))),
PgType::Uuid => None,
PgType::UuidArray => Some(Cow::Owned(PgTypeInfo(PgType::Uuid))),
PgType::Jsonb => None,
PgType::JsonbArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonb))),
PgType::Int4Range => None,
PgType::Int4RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int4Range))),
PgType::NumRange => None,
PgType::NumRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::NumRange))),
PgType::TsRange => None,
PgType::TsRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TsRange))),
PgType::TstzRange => None,
PgType::TstzRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::TstzRange))),
PgType::DateRange => None,
PgType::DateRangeArray => Some(Cow::Owned(PgTypeInfo(PgType::DateRange))),
PgType::Int8Range => None,
PgType::Int8RangeArray => Some(Cow::Owned(PgTypeInfo(PgType::Int8Range))),
PgType::Jsonpath => None,
PgType::JsonpathArray => Some(Cow::Owned(PgTypeInfo(PgType::Jsonpath))),
// There is no `UnknownArray`
PgType::Unknown => None,
// There is no `VoidArray`
PgType::Void => None,
PgType::Custom(ty) => match &ty.kind {
PgTypeKind::Simple => None,
PgTypeKind::Pseudo => None,
PgTypeKind::Domain(_) => None,
PgTypeKind::Composite(_) => None,
PgTypeKind::Array(ref elem_type_info) => Some(Cow::Borrowed(elem_type_info)),
PgTypeKind::Enum(_) => None,
PgTypeKind::Range(_) => None,
},
PgType::DeclareWithOid(oid) => {
unreachable!("(bug) use of unresolved type declaration [oid={}]", oid);
}
PgType::DeclareWithName(name) => {
unreachable!("(bug) use of unresolved type declaration [name={}]", name);
}
}
}
}

impl TypeInfo for PgTypeInfo {
Expand Down
7 changes: 4 additions & 3 deletions sqlx-core/src/postgres/types/array.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use bytes::Buf;
use std::borrow::Cow;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand Down Expand Up @@ -77,7 +78,6 @@ where
T: for<'a> Decode<'a, Postgres> + Type<Postgres>,
{
fn decode(value: PgValueRef<'r>) -> Result<Self, BoxDynError> {
let element_type_info;
let format = value.format();

match format {
Expand Down Expand Up @@ -105,7 +105,8 @@ where

// the OID of the element
let element_type_oid = buf.get_u32();
element_type_info = PgTypeInfo::try_from_oid(element_type_oid)
let element_type_info: PgTypeInfo = PgTypeInfo::try_from_oid(element_type_oid)
.or_else(|| value.type_info.try_array_element().map(Cow::into_owned))
.unwrap_or_else(|| PgTypeInfo(PgType::DeclareWithOid(element_type_oid)));

// length of the array axis
Expand Down Expand Up @@ -133,7 +134,7 @@ where

PgValueFormat::Text => {
// no type is provided from the database for the element
element_type_info = T::type_info();
let element_type_info = T::type_info();

let s = value.as_str()?;

Expand Down
87 changes: 87 additions & 0 deletions tests/postgres/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1094,6 +1094,93 @@ CREATE TABLE heating_bills (
Ok(())
}

#[sqlx_macros::test]
async fn it_resolves_custom_type_in_array() -> anyhow::Result<()> {
// Only supported in Postgres 11+
let mut conn = new::<Postgres>().await?;
if matches!(conn.server_version_num(), Some(version) if version < 110000) {
return Ok(());
}

// language=PostgreSQL
conn.execute(
r#"
DROP TABLE IF EXISTS pets;
DROP TYPE IF EXISTS pet_name_and_race;

CREATE TYPE pet_name_and_race AS (
name TEXT,
race TEXT
);
CREATE TABLE pets (
owner TEXT NOT NULL,
name TEXT NOT NULL,
race TEXT NOT NULL,
PRIMARY KEY (owner, name)
);
INSERT INTO pets(owner, name, race)
VALUES
('Alice', 'Foo', 'cat');
INSERT INTO pets(owner, name, race)
VALUES
('Alice', 'Bar', 'dog');
"#,
)
.await?;

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct PetNameAndRace {
name: String,
race: String,
}

impl sqlx::Type<Postgres> for PetNameAndRace {
fn type_info() -> sqlx::postgres::PgTypeInfo {
sqlx::postgres::PgTypeInfo::with_name("pet_name_and_race")
}
}

impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRace {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
let mut decoder = sqlx::postgres::types::PgRecordDecoder::new(value)?;
let name = decoder.try_decode::<String>()?;
let race = decoder.try_decode::<String>()?;
Ok(Self { name, race })
}
}

#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
struct PetNameAndRaceArray(Vec<PetNameAndRace>);

impl sqlx::Type<Postgres> for PetNameAndRaceArray {
fn type_info() -> sqlx::postgres::PgTypeInfo {
// Array type name is the name of the element type prefixed with `_`
sqlx::postgres::PgTypeInfo::with_name("_pet_name_and_race")
}
}

impl<'r> sqlx::Decode<'r, Postgres> for PetNameAndRaceArray {
fn decode(
value: sqlx::postgres::PgValueRef<'r>,
) -> Result<Self, Box<dyn std::error::Error + 'static + Send + Sync>> {
Ok(Self(Vec::<PetNameAndRace>::decode(value)?))
}
}

let mut conn = new::<Postgres>().await?;

let row = sqlx::query("select owner, array_agg(row(name, race)::pet_name_and_race) as pets from pets group by owner")
.fetch_one(&mut conn)
.await?;

let pets: PetNameAndRaceArray = row.get("pets");

assert_eq!(pets.0.len(), 2);
Ok(())
}

#[sqlx_macros::test]
async fn test_pg_server_num() -> anyhow::Result<()> {
use sqlx::postgres::PgConnectionInfo;
Expand Down