Skip to content

Commit

Permalink
mysql(chrono): support decoding ZERO_DATES
Browse files Browse the repository at this point in the history
mysql: refactor to a generic MysqlZeroDate enum

mysql: test MysqlZeroDate
  • Loading branch information
blackwolf12333 committed Jun 28, 2020
1 parent e4005bb commit 4ca50b9
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 6 deletions.
106 changes: 100 additions & 6 deletions sqlx-core/src/mysql/types/chrono.rs
Expand Up @@ -10,6 +10,19 @@ use crate::mysql::protocol::text::ColumnType;
use crate::mysql::type_info::MySqlTypeInfo;
use crate::mysql::{MySql, MySqlValueFormat, MySqlValueRef};
use crate::types::Type;
use crate::mysql::types::MysqlZeroDate;


#[derive(Debug, Clone)]
struct ZeroDateError;

impl std::fmt::Display for ZeroDateError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Unexpected ZERO_DATE encountered!")
}
}

impl std::error::Error for ZeroDateError {}

impl Type<MySql> for DateTime<Utc> {
fn type_info() -> MySqlTypeInfo {
Expand All @@ -29,7 +42,10 @@ impl Encode<'_, MySql> for DateTime<Utc> {

impl<'r> Decode<'r, MySql> for DateTime<Utc> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
let naive: NaiveDateTime = Decode::<MySql>::decode(value)?;
let naive: NaiveDateTime = match Decode::<MySql>::decode(value)? {
MysqlZeroDate::Zero => return Err(Box::new(ZeroDateError)),
MysqlZeroDate::NotZero(date) => date,
};

Ok(DateTime::from_utc(naive, Utc))
}
Expand Down Expand Up @@ -97,12 +113,44 @@ impl<'r> Decode<'r, MySql> for NaiveTime {
}
}

impl Type<MySql> for MysqlZeroDate<NaiveDate> {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Date)
}
}

impl Type<MySql> for NaiveDate {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Date)
}
}

impl Encode<'_, MySql> for MysqlZeroDate<NaiveDateTime> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
match *self {
MysqlZeroDate::Zero => {
buf.push(0);

IsNull::No
},
MysqlZeroDate::NotZero(date) => Encode::<MySql>::encode(date, buf)
}
}
}

impl Encode<'_, MySql> for MysqlZeroDate<NaiveDate> {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
match *self {
MysqlZeroDate::Zero => {
buf.push(0);

IsNull::No
},
MysqlZeroDate::NotZero(date) => Encode::<MySql>::encode(date, buf)
}
}
}

impl Encode<'_, MySql> for NaiveDate {
fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
buf.push(4);
Expand All @@ -118,18 +166,47 @@ impl Encode<'_, MySql> for NaiveDate {
}

impl<'r> Decode<'r, MySql> for NaiveDate {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
let naive: NaiveDate = match Decode::<MySql>::decode(value)? {
MysqlZeroDate::Zero => return Err(Box::new(ZeroDateError)),
MysqlZeroDate::NotZero(date) => date,
};

Ok(naive)
}
}

impl<'r> Decode<'r, MySql> for MysqlZeroDate<NaiveDate> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => Ok(decode_date(&value.as_bytes()?[1..])),
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;
let len = buf[0];
if len == 0 {
Ok(MysqlZeroDate::Zero)
} else {
Ok(MysqlZeroDate::NotZero(decode_date(
&value.as_bytes()?[1..],
)))
}
}

MySqlValueFormat::Text => {
let s = value.as_str()?;
NaiveDate::parse_from_str(s, "%Y-%m-%d").map_err(Into::into)
NaiveDate::parse_from_str(s, "%Y-%m-%d")
.map(MysqlZeroDate::NotZero)
.map_err(Into::into)
}
}
}
}

impl Type<MySql> for MysqlZeroDate<NaiveDateTime> {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Datetime)
}
}

impl Type<MySql> for NaiveDateTime {
fn type_info() -> MySqlTypeInfo {
MySqlTypeInfo::binary(ColumnType::Datetime)
Expand Down Expand Up @@ -171,14 +248,29 @@ impl Encode<'_, MySql> for NaiveDateTime {
}
}
}

impl<'r> Decode<'r, MySql> for NaiveDateTime {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
let naive: NaiveDateTime = match Decode::<MySql>::decode(value)? {
MysqlZeroDate::Zero => return Err(Box::new(ZeroDateError)),
MysqlZeroDate::NotZero(date) => date,
};

Ok(naive)
}
}

impl<'r> Decode<'r, MySql> for MysqlZeroDate<NaiveDateTime> {
fn decode(value: MySqlValueRef<'r>) -> Result<Self, BoxDynError> {
match value.format() {
MySqlValueFormat::Binary => {
let buf = value.as_bytes()?;

let len = buf[0];

if len == 0 {
return Ok(MysqlZeroDate::Zero);
}

let date = decode_date(&buf[1..]);

let dt = if len > 4 {
Expand All @@ -187,12 +279,14 @@ impl<'r> Decode<'r, MySql> for NaiveDateTime {
date.and_hms(0, 0, 0)
};

Ok(dt)
Ok(MysqlZeroDate::NotZero(dt))
}

MySqlValueFormat::Text => {
let s = value.as_str()?;
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f").map_err(Into::into)
NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f")
.map(MysqlZeroDate::NotZero)
.map_err(Into::into)
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions sqlx-core/src/mysql/types/mod.rs
Expand Up @@ -68,6 +68,9 @@ mod float;
mod int;
mod str;
mod uint;
mod mysql_zero_date;

pub use mysql_zero_date::MysqlZeroDate;

#[cfg(feature = "bigdecimal")]
mod bigdecimal;
Expand Down
18 changes: 18 additions & 0 deletions sqlx-core/src/mysql/types/mysql_zero_date.rs
@@ -0,0 +1,18 @@
#[derive(Debug, PartialEq)]
pub enum MysqlZeroDate<T> {
Zero,
NotZero(T)
}

#[derive(Debug, Clone)]
struct ZeroDateError;

impl std::fmt::Display for ZeroDateError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Unexpected ZERO_DATE encountered!")
}
}

impl std::error::Error for ZeroDateError {}


102 changes: 102 additions & 0 deletions tests/mysql/types.rs
Expand Up @@ -42,6 +42,7 @@ test_type!(bytes<Vec<u8>>(MySql,
mod chrono {
use super::*;
use sqlx::types::chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc};
use sqlx::mysql::types::MysqlZeroDate;

test_type!(chrono_date<NaiveDate>(
MySql,
Expand All @@ -67,6 +68,107 @@ mod chrono {
Utc,
)
));

test_type!(chrono_maybe_zero_datetime<MysqlZeroDate<NaiveDateTime>>(
MySql,
"DATE '2019-01-02 05:10:20'" == MysqlZeroDate::NotZero(NaiveDate::from_ymd(2019, 1, 2).and_hms(5, 10, 20))
));

#[sqlx_macros::test]
async fn test_prepared_type_chrono_zero_datetime () -> anyhow::Result<()> {
use sqlx::Row;
use sqlx::Executor;

let mut conn = sqlx_test::new::<MySql>().await?;

// In order for this test to work we need to allow zero dates on the connection
// which is disabled by default
conn.execute("SET @@sql_mode := REPLACE(@@sql_mode, 'NO_ZERO_IN_DATE', '');").await?;
conn.execute("SET @@sql_mode := REPLACE(@@sql_mode, 'NO_ZERO_DATE', '');").await?;

let query = format!("SELECT {0} <=> ?, {0} as _2, ? as _3;", "TIMESTAMP '0000-00-00 00:00:00'");

let row = sqlx::query(&query)
.bind(MysqlZeroDate::<NaiveDateTime>::Zero)
.bind(MysqlZeroDate::<NaiveDateTime>::Zero)
.fetch_one(&mut conn)
.await?;

let matches: i32 = row.try_get(0)?;
let returned: MysqlZeroDate<NaiveDateTime> = row.try_get::<MysqlZeroDate<NaiveDateTime>, _>(1)?;
let round_trip: MysqlZeroDate<NaiveDateTime> = row.try_get::<MysqlZeroDate<NaiveDateTime>, _>(2)?;

assert!(matches != 0,
"[1] DB value mismatch; given value: {:?}\n\
as returned: {:?}\n\
round-trip: {:?}",
MysqlZeroDate::<NaiveDateTime>::Zero, returned, round_trip);

assert_eq!(MysqlZeroDate::<NaiveDateTime>::Zero, returned,
"[2] DB value mismatch; given value: {:?}\n\
as returned: {:?}\n\
round-trip: {:?}",
MysqlZeroDate::<NaiveDateTime>::Zero, returned, round_trip);

assert_eq!(MysqlZeroDate::<NaiveDateTime>::Zero, round_trip,
"[3] DB value mismatch; given value: {:?}\n\
as returned: {:?}\n\
round-trip: {:?}",
MysqlZeroDate::<NaiveDateTime>::Zero, returned, round_trip);

Ok(())
}


test_type!(chrono_maybe_zero_date<MysqlZeroDate<NaiveDate>>(
MySql,
"DATE '2019-01-02'" == MysqlZeroDate::NotZero(NaiveDate::from_ymd(2019, 1, 2))
));

#[sqlx_macros::test]
async fn test_prepared_type_chrono_zero_date () -> anyhow::Result<()> {
use sqlx::Row;
use sqlx::Executor;

let mut conn = sqlx_test::new::<MySql>().await?;

// In order for this test to work we need to allow zero dates on the connection
// which is disabled by default
conn.execute("SET @@sql_mode := REPLACE(@@sql_mode, 'NO_ZERO_IN_DATE', '');").await?;
conn.execute("SET @@sql_mode := REPLACE(@@sql_mode, 'NO_ZERO_DATE', '');").await?;

let query = format!("SELECT {0} <=> ?, {0} as _2, ? as _3;", "DATE '0000-00-00'");

let row = sqlx::query(&query)
.bind(MysqlZeroDate::<NaiveDate>::Zero)
.bind(MysqlZeroDate::<NaiveDate>::Zero)
.fetch_one(&mut conn)
.await?;

let matches: i32 = row.try_get(0)?;
let returned: MysqlZeroDate<NaiveDate> = row.try_get::<MysqlZeroDate<NaiveDate>, _>(1)?;
let round_trip: MysqlZeroDate<NaiveDate> = row.try_get::<MysqlZeroDate<NaiveDate>, _>(2)?;

assert!(matches != 0,
"[1] DB value mismatch; given value: {:?}\n\
as returned: {:?}\n\
round-trip: {:?}",
MysqlZeroDate::<NaiveDate>::Zero, returned, round_trip);

assert_eq!(MysqlZeroDate::<NaiveDate>::Zero, returned,
"[2] DB value mismatch; given value: {:?}\n\
as returned: {:?}\n\
round-trip: {:?}",
MysqlZeroDate::<NaiveDate>::Zero, returned, round_trip);

assert_eq!(MysqlZeroDate::<NaiveDate>::Zero, round_trip,
"[3] DB value mismatch; given value: {:?}\n\
as returned: {:?}\n\
round-trip: {:?}",
MysqlZeroDate::<NaiveDate>::Zero, returned, round_trip);

Ok(())
}
}

#[cfg(feature = "time")]
Expand Down

0 comments on commit 4ca50b9

Please sign in to comment.