Skip to content

Commit

Permalink
Merge pull request #16 from crystal-lang/feature/unprepared
Browse files Browse the repository at this point in the history
implement TextProtocol for unprepared statements without arguments
  • Loading branch information
bcardiff committed Dec 13, 2016
2 parents 400dc4a + 6191dbf commit 96dbaab
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 14 deletions.
35 changes: 22 additions & 13 deletions spec/driver_spec.cr
Expand Up @@ -4,11 +4,11 @@ def with_db(&block : DB::Database ->)
DB.open "mysql://root@localhost", &block
end

def with_test_db(&block : DB::Database ->)
def with_test_db(options = "", &block : DB::Database ->)
DB.open "mysql://root@localhost" do |db|
db.exec "DROP DATABASE IF EXISTS crystal_mysql_test"
db.exec "CREATE DATABASE crystal_mysql_test"
DB.open "mysql://root@localhost/crystal_mysql_test", &block
DB.open "mysql://root@localhost/crystal_mysql_test?#{options}", &block
db.exec "DROP DATABASE IF EXISTS crystal_mysql_test"
end
end
Expand Down Expand Up @@ -78,9 +78,10 @@ describe Driver do
end

# "SELECT 1" returns a Int64. So this test are not to be used as is on all DB::Any
{% for prepared_statements in [true, false] %}
{% for value in [1_i64, "hello", 1.5] %}
it "executes and select {{value.id}}" do
with_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.scalar("select #{sql({{value}})}").should eq({{value}})

db.query "select #{sql({{value}})}" do |rs|
Expand All @@ -89,6 +90,7 @@ describe Driver do
end
end
{% end %}
{% end %}

it "executes with bind nil" do
with_db do |db|
Expand All @@ -97,15 +99,17 @@ describe Driver do
end

{% for value in [54_i16, 1_i8, 5_i8, 1, 1_i64, "hello", 1.5, 1.5_f32] %}
{% for prepared_statements in [true, false] %}
it "executes and select nil as type of {{value.id}}" do
with_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.scalar("select null").should be_nil

db.query "select null" do |rs|
assert_single_read rs, typeof({{value}} || nil), nil
end
end
end
{% end %}

it "executes with bind {{value.id}}" do
with_db do |db|
Expand Down Expand Up @@ -134,8 +138,9 @@ describe Driver do
end
end

{% for prepared_statements in [true, false] %}
it "executes and selects blob" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table t1 (b1 BLOB)"
db.exec "insert into t1 (b1) values (X'415A617A')"
slice = db.scalar(%(select b1 from t1)).as(Bytes)
Expand All @@ -150,11 +155,12 @@ describe Driver do
{"type" => "LONGBLOB", "size" => 1000000},
].each do |row|
it "set/get " + row["type"].as(String) do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
ary = UInt8[0x41, 0x5A, 0x61, 0x7A] * row["size"].as(Int32)
slice = Bytes.new(ary.to_unsafe, ary.size)
db.exec "create table t1 (b1 " + row["type"].as(String) + ")"
db.exec "insert into t1 (b1) values (?)", slice
# TODO remove when unprepared statements support args
db.prepared.exec "insert into t1 (b1) values (?)", slice
slice = db.scalar(%(select b1 from t1)).as(Bytes)
slice.to_a.should eq(ary)
end
Expand All @@ -168,18 +174,19 @@ describe Driver do
{"type" => "LONGTEXT", "size" => 100000},
].each do |row|
it "set/get " + row["type"].as(String) do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
txt = "Ham Sandwich" * row["size"].as(Int32)
db.exec "create table tab1 (txt1 " + row["type"].as(String) + ")"
db.exec "insert into tab1 (txt1) values (?)", txt
# TODO remove when unprepared statements support args
db.prepared.exec "insert into tab1 (txt1) values (?)", txt
text = db.scalar(%(select txt1 from tab1))
text.should eq(txt)
end
end
end

it "gets column count" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table person (name varchar(25), age integer)"
db.query "select * from person" do |rs|
rs.column_count.should eq(2)
Expand All @@ -188,7 +195,7 @@ describe Driver do
end

it "gets column name" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table person (name varchar(25), age integer)"

db.query "select * from person" do |rs|
Expand All @@ -199,7 +206,7 @@ describe Driver do
end

it "gets last insert row id" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table person (id int not null primary key auto_increment, name varchar(25), age int)"
db.exec %(insert into person (name, age) values ("foo", 10))
res = db.exec %(insert into person (name, age) values ("foo", 10))
Expand All @@ -210,9 +217,10 @@ describe Driver do

{% for value in [54_i16, 1_i8, 5_i8, 1, 1_i64, "hello", 1.5, 1.5_f32] %}
it "insert/get value {{value.id}} from table" do
with_test_db do |db|
with_test_db "prepared_statements=#{{{prepared_statements}}}" do |db|
db.exec "create table table1 (col1 #{mysql_type_for({{value}})})"
db.exec %(insert into table1 (col1) values (#{sql({{value}})}))

db.scalar("select col1 from table1").should eq({{value}})
end
end
Expand All @@ -226,6 +234,7 @@ describe Driver do
end
end
{% end %}
{% end %}

# zero dates http://dev.mysql.com/doc/refman/5.7/en/datetime.html - work on some mysql not others,
# NO_ZERO_IN_DATE enabled as part of strict mode in MySQL 5.7.8. - http://dev.mysql.com/doc/refman/5.7/en/sql-mode.html#sql-mode-changes
Expand Down
6 changes: 5 additions & 1 deletion src/mysql/connection.cr
Expand Up @@ -137,7 +137,11 @@ class MySql::Connection < DB::Connection
end
end

def build_statement(query)
def build_prepared_statement(query)
MySql::Statement.new(self, query)
end

def build_unprepared_statement(query)
MySql::UnpreparedStatement.new(self, query)
end
end
84 changes: 84 additions & 0 deletions src/mysql/text_result_set.cr
@@ -0,0 +1,84 @@
# Implementation of ProtocolText::Resultset.
# Used for unprepared statements.
class MySql::TextResultSet < DB::ResultSet
getter columns

@conn : MySql::Connection
@row_packet : MySql::ReadPacket?
@header : UInt8

def initialize(statement, column_count)
super(statement)
@conn = statement.connection.as(MySql::Connection)

columns = @columns = [] of ColumnSpec
@conn.read_column_definitions(columns, column_count)

@column_index = 0 # next column index to return

@header = 0u8
@eof_reached = false
end

def do_close
super

while move_next
end

if row_packet = @row_packet
row_packet.discard
end
end

def move_next : Bool
return false if @eof_reached

# skip previous row_packet
if row_packet = @row_packet
row_packet.discard
end

@row_packet = row_packet = @conn.build_read_packet

@header = row_packet.read_byte!
if @header == 0xfe # EOF
@eof_reached = true
return false
end

@column_index = 0
# TODO remove row_packet.read(@null_bitmap_slice)
return true
end

def column_count : Int32
@columns.size
end

def column_name(index : Int32) : String
@columns[index].name
end

def read
row_packet = @row_packet.not_nil!

is_nil = @header == 0xfb
col = @column_index
@column_index += 1
if is_nil
nil
else
length = row_packet.read_lenenc_int(@header)
val = row_packet.read_string(length)
val = @columns[col].column_type.parse(val)

# http://dev.mysql.com/doc/internals/en/character-set.html
if val.is_a?(Slice(UInt8)) && @columns[col].character_set != 63
::String.new(val)
else
val
end
end
end
end
30 changes: 30 additions & 0 deletions src/mysql/types.cr
Expand Up @@ -76,6 +76,12 @@ abstract struct MySql::Type
raise "not supported read"
end

# Parse from str a value in TextProtocol format of the type
# specified by self.
def self.parse(str : ::String)
raise "not supported"
end

macro decl_type(name, value, db_any_type = nil)
struct {{name}} < Type
@@hex_value = {{value}}
Expand All @@ -92,6 +98,10 @@ abstract struct MySql::Type
def self.read(packet)
packet.read_bytes {{db_any_type}}, IO::ByteFormat::LittleEndian
end

def self.parse(str : ::String)
{{db_any_type}}.new(str)
end
{% end %}

{{yield}}
Expand All @@ -110,6 +120,10 @@ abstract struct MySql::Type
def self.read(packet)
nil
end

def self.parse(str : ::String)
nil
end
end
decl_type Timestamp, 0x07u8
decl_type LongLong, 0x08u8, ::Int64
Expand All @@ -135,6 +149,10 @@ abstract struct MySql::Type
ms = packet.read_int.to_i32 / 1000 # returns microseconds, time only supports milliseconds
return ::Time.new(year, month, day, hour, minute, second, ms)
end

def self.parse(str : ::String)
raise "TextProtocol::Time not implemented"
end
end
decl_type Year, 0x0du8
decl_type VarChar, 0x0fu8
Expand All @@ -157,6 +175,10 @@ abstract struct MySql::Type
def self.read(packet)
packet.read_blob
end

def self.parse(str : ::String)
str.to_slice
end
end
decl_type VarString, 0xfdu8, ::String do
def self.write(packet, v : ::String)
Expand All @@ -166,6 +188,10 @@ abstract struct MySql::Type
def self.read(packet)
packet.read_lenenc_string
end

def self.parse(str : ::String)
str
end
end
decl_type String, 0xfeu8, ::String do
def self.write(packet, v : ::String)
Expand All @@ -175,6 +201,10 @@ abstract struct MySql::Type
def self.read(packet)
packet.read_lenenc_string
end

def self.parse(str : ::String)
str
end
end
decl_type Geometry, 0xffu8
end
41 changes: 41 additions & 0 deletions src/mysql/unprepared_statement.cr
@@ -0,0 +1,41 @@
class MySql::UnpreparedStatement < DB::Statement
def initialize(connection, @sql : String)
super(connection)
end

protected def conn
@connection.as(Connection)
end

protected def perform_query(args : Enumerable) : DB::ResultSet
perform_exec_or_query(args).as(DB::ResultSet)
end

protected def perform_exec(args : Enumerable) : DB::ExecResult
perform_exec_or_query(args).as(DB::ExecResult)
end

private def perform_exec_or_query(args : Enumerable)
raise "exec/query with args is not supported" if args.size > 0

conn = self.conn
conn.write_packet do |packet|
packet.write_byte 0x03u8
packet << @sql
# TODO to support args an interpolation needs to be done
end

conn.read_packet do |packet|
case header = packet.read_byte.not_nil!
when 255 # err packet
conn.handle_err_packet(packet)
when 0 # ok packet
affected_rows = packet.read_lenenc_int
last_insert_id = packet.read_lenenc_int
DB::ExecResult.new affected_rows, last_insert_id
else
MySql::TextResultSet.new(self, packet.read_lenenc_int(header))
end
end
end
end

0 comments on commit 96dbaab

Please sign in to comment.