Skip to content

Commit

Permalink
feat(mssql): add support for string types including a framework for h…
Browse files Browse the repository at this point in the history
…andling non-UTF8 data
  • Loading branch information
mehcode committed Jun 7, 2020
1 parent 2a272bd commit 28636c1
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 10 deletions.
10 changes: 10 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion sqlx-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ default = [ "runtime-async-std" ]
postgres = [ "md-5", "sha2", "base64", "sha-1", "rand", "hmac", "futures-channel/sink", "futures-util/sink" ]
mysql = [ "sha-1", "sha2", "generic-array", "num-bigint", "base64", "digest", "rand" ]
sqlite = [ "libsqlite3-sys" ]
mssql = [ "uuid" ]
mssql = [ "uuid", "encoding_rs" ]

# types
all-types = [ "chrono", "time", "bigdecimal", "ipnetwork", "json", "uuid" ]
Expand Down Expand Up @@ -48,6 +48,7 @@ crossbeam-queue = "0.2.1"
crossbeam-channel = "0.4.2"
crossbeam-utils = { version = "0.7.2", default-features = false }
digest = { version = "0.8.1", default-features = false, optional = true, features = [ "std" ] }
encoding_rs = { version = "0.8.23", optional = true }
either = "1.5.3"
futures-channel = { version = "0.3.4", default-features = false, features = [ "alloc", "std" ] }
futures-core = { version = "0.3.4", default-features = false }
Expand Down
63 changes: 55 additions & 8 deletions sqlx-core/src/mssql/protocol/type_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ use std::borrow::Cow;

use bitflags::bitflags;
use bytes::{Buf, Bytes};
use encoding_rs::Encoding;

use crate::encode::Encode;
use crate::error::Error;
use crate::mssql::MsSql;
use url::quirks::set_search;

bitflags! {
pub(crate) struct CollationFlags: u8 {
Expand Down Expand Up @@ -106,6 +106,31 @@ impl TypeInfo {
}
}

pub(crate) fn encoding(&self) -> Result<&'static Encoding, Error> {
match self.ty {
DataType::NChar | DataType::NVarChar => Ok(encoding_rs::UTF_16LE),

DataType::VarChar | DataType::Char | DataType::BigChar | DataType::BigVarChar => {
// unwrap: impossible to unwrap here, collation will be set
Ok(match self.collation.unwrap().locale {
// This is the Western encoding for Windows. It is an extension of ISO-8859-1,
// which is known as Latin 1.
0x0409 => encoding_rs::WINDOWS_1252,

locale => {
return Err(err_protocol!("unsupported locale 0x{:?}", locale));
}
})
}

_ => {
// default to UTF-8 for anything
// else coming in here
Ok(encoding_rs::UTF_8)
}
}
}

// reads a TYPE_INFO from the buffer
pub(crate) fn get(buf: &mut Bytes) -> Result<Self, Error> {
let ty = DataType::get(buf)?;
Expand Down Expand Up @@ -445,13 +470,35 @@ impl TypeInfo {
_ => unreachable!("invalid size {} for float"),
}),

DataType::NVarChar => {
s.push_str("nvarchar(");
let _ = itoa::fmt(&mut *s, self.size / 2);
s.push_str(")");
DataType::VarChar
| DataType::NVarChar
| DataType::BigVarChar
| DataType::Char
| DataType::BigChar
| DataType::NChar => {
// name
s.push_str(match self.ty {
DataType::VarChar => "varchar",
DataType::NVarChar => "nvarchar",
DataType::BigVarChar => "bigvarchar",
DataType::Char => "char",
DataType::BigChar => "bigchar",
DataType::NChar => "nchar",

_ => unreachable!(),
});

// size
if self.size < 8000 && self.size > 0 {
s.push_str("(");
let _ = itoa::fmt(&mut *s, self.size);
s.push_str(")");
} else {
s.push_str("(max)");
}
}

_ => unimplemented!("unsupported data type {:?}", self.ty),
_ => unimplemented!("fmt: unsupported data type {:?}", self.ty),
}
}
}
Expand Down Expand Up @@ -511,8 +558,8 @@ impl DataType {

impl Collation {
pub(crate) fn get(buf: &mut Bytes) -> Collation {
let locale_sort_version = buf.get_u32();
let locale = locale_sort_version & 0xF_FFFF;
let locale_sort_version = buf.get_u32_le();
let locale = locale_sort_version & 0xfffff;
let flags = CollationFlags::from_bits_truncate(((locale_sort_version >> 20) & 0xFF) as u8);
let version = (locale_sort_version >> 28) as u8;
let sort = buf.get_u8();
Expand Down
43 changes: 42 additions & 1 deletion sqlx-core/src/mssql/types/str.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use byteorder::{ByteOrder, LittleEndian};
use bytes::Buf;

use crate::database::{Database, HasArguments, HasValueRef};
use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
use crate::error::BoxDynError;
use crate::mssql::io::MsSqlBufMutExt;
use crate::mssql::protocol::type_info::{DataType, TypeInfo};
use crate::mssql::protocol::type_info::{Collation, DataType, TypeInfo};
use crate::mssql::{MsSql, MsSqlTypeInfo, MsSqlValueRef};
use crate::types::Type;

Expand All @@ -15,6 +16,12 @@ impl Type<MsSql> for str {
}
}

impl Type<MsSql> for String {
fn type_info() -> MsSqlTypeInfo {
<str as Type<MsSql>>::type_info()
}
}

impl Encode<'_, MsSql> for &'_ str {
fn produces(&self) -> MsSqlTypeInfo {
MsSqlTypeInfo(TypeInfo::new(DataType::NVarChar, (self.len() * 2) as u32))
Expand All @@ -26,3 +33,37 @@ impl Encode<'_, MsSql> for &'_ str {
IsNull::No
}
}

impl Encode<'_, MsSql> for String {
fn produces(&self) -> MsSqlTypeInfo {
<&str as Encode<MsSql>>::produces(&self.as_str())
}

fn encode_by_ref(&self, buf: &mut Vec<u8>) -> IsNull {
<&str as Encode<MsSql>>::encode_by_ref(&self.as_str(), buf)
}
}

impl Decode<'_, MsSql> for String {
fn accepts(ty: &MsSqlTypeInfo) -> bool {
matches!(
ty.0.ty,
DataType::NVarChar
| DataType::NChar
| DataType::BigVarChar
| DataType::VarChar
| DataType::BigChar
| DataType::Char
)
}

fn decode(value: MsSqlValueRef<'_>) -> Result<Self, BoxDynError> {
Ok(value
.type_info
.0
.encoding()?
.decode_without_bom_handling(value.as_bytes()?)
.0
.into_owned())
}
}

0 comments on commit 28636c1

Please sign in to comment.