Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

first commit, core features ready.

  • Loading branch information...
commit 803b00dcab2350e30222b47619365d57b2460df6 0 parents
@nightsailer authored
3  .gitignore
@@ -0,0 +1,3 @@
+.DS_Store
+log/**/*
+tmp/**/*
12 lib/resty/mongo.lua
@@ -0,0 +1,12 @@
+module('resty.mongo',package.seeall)
+
+require("resty.mongo.support")
+
+local Connection = require('resty.mongo.connection')
+
+local function new_connection( ... ) return Connection.new(...) end
+
+return {
+ new_connection = new_connection,
+}
+
215 lib/resty/mongo/bson.lua
@@ -0,0 +1,215 @@
+-- bson encoder/decoder
+
+require("resty.mongo.support")
+module(...,package.seeall)
+local assert , error = assert , error
+local pairs = pairs
+local getmetatable = getmetatable
+local type = type
+local tonumber , tostring = tonumber , tostring
+local t_insert = table.insert
+local t_concat = table.concat
+local strformat = string.format
+local strmatch = string.match
+local util = require('resty.mongo.util')
+local le_uint_to_num = util.le_uint_to_num
+local le_int_to_num = util.le_int_to_num
+local num_to_le_uint = util.num_to_le_uint
+local from_double = util.from_double
+local to_double = util.to_double
+local read_terminated_string = util.read_terminated_string
+local new_str_buffer = util.new_str_buffer
+
+local oid = require('resty.mongo.object_id')
+local new_object_id = oid.new
+local object_id_mt = oid.metatable
+
+-- read document from a string buffer
+local function read_document( strbuf , numerical )
+ local bytes = le_uint_to_num( strbuf(4) )
+
+ local ho , hk , hv = false , false , false
+ local t = { }
+ while true do
+ local op = strbuf(1)
+ if op == "\0" then break end
+
+ local e_name = read_terminated_string(strbuf)
+ local v
+ if op == "\1" then -- Double
+ v = from_double(strbuf( 8 ))
+ elseif op == "\2" then -- String
+ local len = le_uint_to_num( strbuf(4) )
+ v = strbuf( len - 1 )
+ assert ( strbuf( 1 ) == "\0" )
+ elseif op == "\3" then -- Embedded document
+ v = read_document(strbuf, false)
+ elseif op == "\4" then -- Array
+ v = read_document(strbuf, true )
+ elseif op == "\5" then -- Binary
+ local len = le_uint_to_num(strbuf(4))
+ local subtype = strbuf( 1 )
+ v = strbuf( len )
+ elseif op == "\7" then -- ObjectId
+ v = new_object_id(strbuf( 12 ) )
+ elseif op == "\8" then -- Boolean
+ local f = strbuf( 1 )
+ if f == "\0" then
+ v = false
+ elseif f == "\1" then
+ v = true
+ else
+ error( f:byte() )
+ end
+ elseif op == "\9" then -- unix time
+ v = le_uint_to_num( strbuf( 8 ) , 1 , 8 )
+ elseif op == "\10" then -- Null
+ v = nil
+ elseif op == "\11" then -- Regullar expression
+ error( "BSON type:'Regullar expression' not support yet")
+ elseif op == "\12" then -- DBPointer — Deprecated
+ error( "BSON type:'DBPointer' not support yet")
+ elseif op == "\13" then -- JavaScript code
+ error( "BSON type:'JavaScript code' not support yet")
+ elseif op == "\14" then -- Symbol
+ error( "BSON type:'Symbol' not support yet")
+ elseif op == "\15" then -- JavaScript code w/ scope
+ error( "BSON type:'JavaScript code w/ scope' not support yet")
+ elseif op == "\16" then -- int32
+ v = le_int_to_num( strbuf(4), 1,8)
+ elseif op == "\17" then --timestamp
+ error( "BSON type:'Timestamp' not support yet")
+ elseif op == "\18" then --int64
+ error( "BSON type:'Int64' not support yet")
+ else
+ error ( "Unknown BSON type" .. strbyte ( op ) )
+ end
+
+ if numerical then
+ t [ tonumber ( e_name ) ] = v
+ else
+ t [ e_name ] = v
+ end
+
+ -- Check for special universal map
+ if e_name == "_keys" then
+ hk = v
+ elseif e_name == "_vals" then
+ hv = v
+ else
+ ho = true
+ end
+ end
+
+ if not ho and hk and hv then
+ t = { }
+ for i=1,#hk do
+ t [ hk [ i ] ] = hv [ i ]
+ end
+ end
+
+ return t
+end
+
+local function from_bson_buf( strbuf )
+ local t = read_document(strbuf , false)
+ return t
+end
+
+
+local function from_bson(str)
+ local t = read_document(new_str_buffer(str), false)
+ return t
+end
+
+local to_bson
+
+local function pack( k , v )
+ local ot = type ( v )
+ local mt = getmetatable ( v )
+
+ if ot == "number" then
+ return "\1" .. k .. "\0" .. to_double ( v )
+ elseif ot == "nil" then
+ return "\10" .. k .. "\0"
+ elseif ot == "string" then
+ return "\2" .. k .. "\0" .. num_to_le_uint ( #v + 1 ) .. v .. "\0"
+ elseif ot == "boolean" then
+ if v == false then
+ return "\8" .. k .. "\0\0"
+ else
+ return "\8" .. k .. "\0\1"
+ end
+ elseif mt == object_id_mt then
+ return "\7" .. k .. "\0" .. v.id
+ elseif ot == "table" then
+ local doc , array = to_bson ( v )
+ if array then
+ return "\4" .. k .. "\0" .. doc
+ else
+ return "\3" .. k .. "\0" .. doc
+ end
+ else
+ error ( "Failure converting " .. ot ..": " .. tostring ( v ) )
+ end
+end
+
+function to_bson( ob )
+ -- Find out if ob if an array; string->value map; or general table
+ local onlyarray = true
+ local seen_n , high_n = { } , 0
+ local onlystring = true
+ for k , v in pairs ( ob ) do
+ local t_k = type ( k )
+ onlystring = onlystring and ( t_k == "string" )
+ if onlyarray then
+ if t_k == "number" and k >= 0 then
+ if k > high_n then
+ high_n = k
+ seen_n [ k ] = v
+ end
+ else
+ onlyarray = false
+ end
+ end
+ if not onlyarray and not onlystring then break end
+ end
+
+ local retarray , m = false
+ if onlystring then -- Do string first so the case of an empty table is done properly
+ local r = { }
+ for k , v in pairs ( ob ) do
+ t_insert ( r , pack ( k , v ) )
+ end
+ m = t_concat ( r )
+ elseif onlyarray then
+ local r = { }
+
+ local low = 1
+ if seen_n [ 0 ] then low = 0 end
+
+ for i=1 , high_n do
+ r [ i ] = pack ( i , seen_n [ i ] )
+ end
+
+ m = t_concat ( r , "" , low , high_n )
+ retarray = true
+ else
+ local ni = 1
+ local keys , vals = { } , { }
+ for k , v in pairs ( ob ) do
+ keys [ ni ] = k
+ vals [ ni ] = v
+ ni = ni + 1
+ end
+ return to_bson ( { _keys = keys , _vals = vals } )
+ end
+
+ return num_to_le_uint ( #m + 4 + 1 ) .. m .. "\0" , retarray
+end
+
+return {
+ from_bson = from_bson,
+ from_bson_buf = from_bson_buf,
+ to_bson = to_bson,
+}
181 lib/resty/mongo/collection.lua
@@ -0,0 +1,181 @@
+-- collection class
+
+module(... , package.seeall)
+
+local t_ordered = require("resty.mongo.orderedtable")
+local Cursor = require("resty.mongo.cursor")
+local protocol = require("resty.mongo.protocol")
+local ZERO32 = protocol.ZERO32
+local ZEROID = protocol.ZEROID
+local NS = protocol.NS
+local mongo_insert_message = protocol.insert_message
+local mongo_delete_message = protocol.delete_message
+local mongo_update_message = protocol.update_message
+local mongo_query_message = protocol.query_message
+local mongo_send_message = protocol.send_message
+local mongo_send_message_with_safe = protocol.send_message_with_safe
+local mongo_recv_message = protocol.recv_message
+local db_ns = protocol.db_ns
+local collection = {}
+local collection_mt = { __index = collection }
+
+----------------------------
+-- attributes
+----------------------------
+collection.name = nil
+collection.ns = nil
+
+-- -------------------------
+-- instance methods
+---------------------------
+
+function collection:find(query,opts)
+ return Cursor.new(self.db,self.name,query,opts)
+end
+
+function collection:find_one(query,opts)
+ opts = opts or {}
+ opts.limit = -1
+ local i,doc = self:find(query,opts):next_doc()
+ return doc
+end
+
+function collection:insert(docs,options)
+ local t = {}
+ local single = #docs < 1
+ if single then
+ docs = { docs }
+ end
+ options = options or {}
+ local continue_err = options.continue_on_err
+ local no_ids = options.no_ids
+ local _,m = mongo_insert_message(self.ns, docs, continue_err,no_ids)
+ -- todo, check message bson size
+ if options.safe then
+ return self:with_safe(m,options)
+ else
+ mongo_send_message(self.conn.sock,m)
+ return true
+ end
+end
+
+function collection:with_safe(m,options)
+ local ok,error = mongo_send_message_with_safe(self.conn.sock,m,self.db.name, {
+ w = options.w or self.conn.w,
+ wtimeout = options.wtimeout or self.conn.wtimeout,
+ j = options.j,
+ fsync = options.fsync,
+ })
+ if not ok then
+ return nil,error
+ end
+ return ok
+end
+
+function collection:update(selector,obj,options)
+ selector = selector or {}
+ options = options or {}
+ local multiple = options.multiple or false
+ local upsert = options.upsert or false
+ local _,m = mongo_update_message(self.ns,selector or {}, obj, upsert, multiple)
+ if options.safe then
+ return self:with_safe(m,options)
+ else
+ mongo_send_message(self.conn.sock,m)
+ return true
+ end
+end
+
+-- function collection:find_and_modify(options) end
+
+function collection:remove(selector,options)
+ selector = selector or {}
+ options = options or {}
+ local _,m = mongo_delete_message(self.ns,selector,options.single_remove)
+ if options.safe then
+ return self:with_safe(m,options)
+ else
+ mongo_send_message(self.conn.sock,m)
+ return true
+ end
+end
+
+function collection:ensure_index(keys,options)
+ assert(keys,"ensure_index:keys is nil")
+ assert(type(keys) == "table","ensure_index:keys must be table")
+ local doc = t_ordered({"ns",self.ns})
+ local _keys = t_ordered():merge(keys)
+ doc.key = _keys
+ if options.name then
+ doc.name = options.name
+ else
+ doc.name = t_concat(_keys,'_')
+ end
+ local _v = {}
+ for i,v in ipairs({"unique", "drop_dups", "background", "sparse"}) do
+ if options[v] ~= nil then
+ doc[v] = options[v] and true or false
+ options[v] = nil
+ end
+ end
+ options.name = nil
+ options.no_ids = true
+ return self.db:get_collection(NS.SYSTEM_INDEX_COLLECTION):insert(doc,options)
+end
+
+function collection:save(doc,options)
+ if doc._id ~= nil then
+ if options == nil or type(options) ~= "table" then
+ options = { upsert = true }
+ else
+ options.upsert = true
+ end
+ return self:update({ _id = doc._id },doc,options)
+ else
+ return self:insert(doc,options)
+ end
+end
+
+function collection:count(query)
+ return Cursor.new(self.db,self.name,query):count()
+ -- local r = self.db:run_command(t_ordered({ "count", self.name, "query", query }))
+ -- if r.ok == 1 then return r.n end
+ -- if r.missing or r.errmsg == "ns missing" then return 0 end
+ -- error("count failed:" .. r.errmsg)
+end
+
+function collection:validate()
+ return self.db:run_command({ validate = self.name })
+end
+
+function collection:drop_indexes() return self:drop_index("*") end
+
+function collection:drop_index(index_name)
+ assert(index_name,"drop_index:index_name is nil")
+ return self.db:run_command(t_ordered("deleteIndexes",self.name, "index",index_name))
+end
+
+function collection:get_indexes()
+ return self.db:get_collection(NS.SYSTEM_INDEX_COLLECTION):find({ns = self.ns}):all()
+end
+
+function collection:drop()
+ return self.db:drop_collection(self.name)
+end
+
+-- -- -------------------------
+-- consturctor
+-- -- -------------------------
+
+local function new(name,db)
+ return setmetatable({
+ name = name,
+ db = db,
+ conn = db.conn,
+ ns = db_ns(db.name,name),
+ },collection_mt)
+end
+
+return {
+ new = new,
+}
113 lib/resty/mongo/connection.lua
@@ -0,0 +1,113 @@
+-- connection class
+
+module(...,package.seeall)
+
+-- import
+
+local util = require("resty.mongo.util")
+local tcp = util.socket.tcp
+local split = util.split
+local bson = require("resty.mongo.bson")
+local substr = string.sub
+local Database = require("resty.mongo.database")
+
+
+-- class body
+
+local connection = {}
+local connection_mt = { __index = connection }
+
+-- -------------------------
+-- attributes
+-- -------------------------
+connection.host = "127.0.0.1"
+connection.port = 27017
+connection.w = 1
+connection.wtimeout = 1000
+connection.auto_connect = true
+connection.user_name = nil
+connection.password = nil
+connection.db_name = 'admin'
+connection.query_timeout = 1000
+connection.max_bson_size = 4*1024*1024
+connection.find_master = false;
+connection.sock = nil
+connection.connected = false
+
+-- dynmatic find master
+connection.hosts = {}
+connection.arbiters = {}
+connection.passives = {}
+
+-- -------------------------
+-- instance methods
+-- -------------------------
+
+function connection:connect(...)
+ local host,port = ...
+ host = host or self.host or "127.0.0.1"
+ port = port or self.port or 27017
+ local sock = self.sock
+ assert(sock:connect(host,port),"connect failed")
+ self.connected = true
+end
+
+function connection:database_names()
+ local r = self:get_database("admin"):run_command({ listDatabases = true })
+ if r.ok == 1 then
+ return r.databases
+ end
+ error("failed to get database names:"..r.errmsg)
+end
+
+--[[ todo
+
+function connection:get_master()
+end
+--]]
+
+
+function connection:get_database(name)
+ return Database.new(name,self)
+end
+
+--[[ todo
+function connection:auth(dbname,user,password,is_digest)
+
+end
+--]]
+
+function connection:get_max_bson_size()
+ local buildinfo = self:get_database("admin"):run_command({buildinfo = true})
+ if buildinfo then
+ return buildinfo.maxBsonObjectSize or 4194304
+ end
+ return 4194304
+end
+
+function connection:init()
+ local h = substr(self.host,10)
+ local _t = split(h,":")
+ local host,port = _t[1],_t[2]
+ self.host = host
+ if port then
+ self.port = port
+ end
+ self.sock = tcp()
+ if self.auto_connect then
+ self:connect(host,port)
+ end
+end
+-----------------------------
+-- consturctor
+-----------------------------
+local function new(option)
+ option = option or {}
+ local obj = setmetatable(option, connection_mt)
+ obj:init()
+ return obj;
+end
+
+return {
+ new = new,
+}
287 lib/resty/mongo/cursor.lua
@@ -0,0 +1,287 @@
+-- cursor class
+
+module(...,package.seeall)
+
+local t_ordered = require("resty.mongo.orderedtable")
+local t_insert = table.insert
+local t_remove = table.remove
+local t_concat = table.concat
+
+local strbyte = string.byte
+local strformat = string.format
+
+local protocol = require("resty.mongo.protocol")
+local ZERO32 = protocol.ZERO32
+local ZEROID = protocol.ZEROID
+local mongo_get_more_message = protocol.get_more_message
+local mongo_query_message = protocol.query_message
+local mongo_send_message = protocol.send_message
+local mongo_recv_message = protocol.recv_message
+local mongo_killcusors_message = protocol.kill_cursors_message
+local db_ns = protocol.db_ns
+
+local cursor = { }
+local cursor_mt = { __index = cursor }
+
+cursor_mt.__gc = function(self)
+ self:kill_cursor()
+end
+
+cursor_mt.__tostring = function(ob)
+ local t = { }
+ for i = 1 , 8 do
+ t_insert( t , strformat ( "%02x" , strbyte ( ob.id , i , i ) ) )
+ end
+ return "CursorId(" .. t_concat ( t ) .. ")"
+end
+
+-- -------------------------
+-- attributes
+-- -------------------------
+cursor.slave_ok = false
+cursor.tailable = false
+cursor.immortal = false
+cursor.batch_size = 10
+
+cursor._limit = 0
+cursor._skip = 0
+cursor._fields = false
+
+cursor.ns = nil
+cursor.query_run = false
+cursor.closed = false
+cursor.id = false
+cursor.at = 0
+cursor.req_id = 0
+cursor.number_received = 0
+
+-- private
+
+cursor._result_cache = {}
+cursor._special = nil
+
+local function assert_state(self)
+ assert(not self.query_run and not self.closed, "Cannot modify the query once it has been run or closed.")
+end
+
+local function close_cursor_if_query_complete(self)
+ local limit = self._limit
+ if limit > 0 and self.number_received >= limit then
+ self:close()
+ end
+end
+
+local function send_init_query(self)
+ -- print(">> send_init_query")
+ if self.query_run then
+ -- print("<< send_init_query run again")
+ return false
+ end
+ local opts = {
+ tailable = self.taiable,
+ slave_ok = self.slave_ok,
+ immortal = self.immortal,
+ partial = self.partial,
+ }
+ local sock = self.sock
+ local query = self._special or self.query or {}
+ -- special query process
+ if self._special then
+ query.query = self.query
+ end
+ local req_id,message = mongo_query_message(self.ns,query,self._fields,self._limit,self._skip,opts)
+ mongo_send_message(sock,message)
+ local number_received = 0
+ self.id,number_received,err,self._result_cache = mongo_recv_message(sock,req_id)
+ self.query_run = true
+ if self.id == ZEROID then
+ self.id = false
+ end
+ self.number_received = self.number_received + number_received
+ close_cursor_if_query_complete(self)
+ -- print("<< send_init_query")
+ return true
+end
+
+local function refill_via_getmore(self)
+ -- print(">> refill_via_getmore")
+ if not self.id then
+ return
+ end
+ -- print("id:",self.id,"batch_size:",self.batch_size)
+ local req_id,message = mongo_get_more_message(self.ns,self.id,self.batch_size)
+ local sock = self.sock
+ mongo_send_message(sock,message)
+ local id, number_received, err,docs = mongo_recv_message(sock,req_id)
+ -- empty result
+ if id == ZEROID then
+ self.id = false
+ end
+ if err.CURSOR_NOT_FOUND then
+ -- print("<< refill_via_getmore ERR")
+ self.id = false
+ self.last_error_msg = "cursor not found"
+ self:close()
+ end
+ if err.QUERY_FAILURE then
+ self.id = false
+ self.last_error_msg = "query failue"
+ self:close()
+ end
+ self.last_error = err
+ if number_received == 0 then
+ -- print("<< refill_via_getmore number_received == 0")
+ self:close()
+ return
+ end
+ local result = self._result_cache
+ for i,v in ipairs(docs) do
+ t_insert(result,v)
+ end
+ self.number_received = self.number_received + number_received
+ close_cursor_if_query_complete(self)
+ -- print("<< refill_via_getmore ***")
+end
+
+local function _add_special_query(self,k,v)
+ local special = self._special or {}
+ special[k] = v
+ self._special = special
+end
+-- -------------------------
+-- instance methods
+-- -------------------------
+
+function cursor:fields(fields)
+ assert_state(self)
+ self._fields = fields
+ return self
+end
+
+function cursor:sort(order)
+ assert_state(self)
+ _add_special_query(self,'orderby',order)
+ return self
+end
+
+function cursor:limit(size)
+ assert_state(self)
+ self._limit = size
+ return self
+end
+
+function cursor:skip(size)
+ assert_state(self)
+ self._skip = size >= 0 and size or 0
+ return self
+end
+
+function cursor:snapshot()
+ assert_state(self)
+ _add_special_query(self,"$snapshot",true)
+ return self
+end
+
+function cursor:count(include_all)
+ local cmd = t_ordered({ "count", self.name, "query", self.query })
+ if include_all then
+ if self._limit ~= 0 then
+ cmd.limit = self._limit
+ end
+ if self._skip ~= 0 then
+ cmd.skip = self._skip
+ end
+ end
+ local r = self.db:run_command(cmd)
+ if r.ok == 1 then return r.n end
+ if r.errmsg == "ns missing" then return 0 end
+ error("count failed:" .. r.errmsg)
+end
+
+function cursor:hint(index)
+ assert_state(self)
+ _add_special_query(self,"$hint",index)
+ return self
+end
+
+function cursor:reset()
+ self.id = false
+ self.closed = false
+ self.at = 0
+ self.query_run = false
+ self._result_cache = {}
+end
+
+function cursor:all()
+ local r = {}
+ for i,v in self:next() do
+ r[i] = v
+ end
+ return r
+end
+
+function cursor:next_doc()
+ -- first
+ if not self.query_run then
+ send_init_query(self)
+ end
+ if self.id and #self._result_cache == 0 and (self._limit <= 0 or self.at < self._limit ) then
+ refill_via_getmore(self)
+ end
+ local v = t_remove(self._result_cache,1)
+ if v ~= nil then
+ self.at = self.at+1
+ return self.at,v
+ end
+ return nil
+end
+
+-- interator
+
+function cursor:next()
+ return self.next_doc,self
+end
+
+function cursor:close( )
+ self:kill_cursor()
+ self.closed = true
+end
+
+function cursor:kill_cursor()
+ local id = self.id
+ if id then
+ local m = mongo_killcusors_message({ id })
+ mongo_send_message(self.conn.sock, m )
+ self.id = false
+ end
+end
+-----------------------------
+-- consturctor
+-----------------------------
+
+local function new(db, name, query , opts)
+ local c = {}
+ c.db = db
+ c.sock = db.conn.sock
+ c.name = name
+ c.ns = db_ns(db.name,name)
+ c.query = query or {}
+ local _limit, _skip, _snapshot, _fields,_sort_by = c.limit,c.skip,c.snapshot,c.fields,c.sort_by
+ c._limit = _limit or 0
+ c._skip = _skip or 0
+ setmetatable(c,cursor_mt)
+ if _snapshot then
+ c:snapshot()
+ end
+ if _fields then
+ c:fields(_fields)
+ end
+ if _sort_by then
+ c:sort(_sort_by)
+ end
+ return c
+end
+
+return {
+ new = new,
+}
85 lib/resty/mongo/database.lua
@@ -0,0 +1,85 @@
+-- database class
+
+module(..., package.seeall)
+
+local Collection = require("resty.mongo.collection")
+local Cursor = require("resty.mongo.cursor")
+local protocol = require("resty.mongo.protocol")
+local NS = protocol.NS
+local t_ordered = require("resty.mongo.orderedtable")
+
+local database = {}
+local database_mt = { __index = database }
+
+-- -------------------------
+-- attributes
+-- -------------------------
+database.name = nil
+database.conn = nil
+
+-- -------------------------
+-- instance methods
+-- -------------------------
+
+function database:collection_names() end
+
+function database:get_collection(name)
+ return Collection.new(name,self)
+end
+
+--[[ todo
+
+function database:get_gridfs(prefix)
+end
+--]]
+
+function database:drop()
+ return self:run_command({ dropDatabase = true })
+end
+
+function database:drop_collection( name )
+ local ok = self:run_command({ drop = name })
+ return ok.ok == 1 or ok.ok == true
+end
+
+function database:get_last_error(options)
+ options = options or {}
+ local w = options.w or self.conn.w
+ local wtimeout = options.wtimeout or self.conn.wtimeout
+ local cmd = t_ordered({"getlasterror",true, "w",w,"wtimeout",wtimeout})
+ if options.fsync then cmd.fsync = true end
+ if options.j then cmd.j = true end
+ return self:run_command(cmd)
+end
+
+function database:run_command(cmd)
+ local cursor = Cursor.new(self, NS.SYSTEM_COMMAND_COLLECTION,cmd)
+ local result = cursor:limit(-1):all()
+ if not result[1] then
+ -- raise error?
+ -- return nil,cursor.last_error_msg
+ return { ok = 0, errmsg = cursor.last_error_msg }
+ end
+ return result[1]
+end
+
+--[[ todo
+
+function database:eval(code,args)
+end
+--]]
+
+-----------------------------
+-- consturctor
+-----------------------------
+
+local function new(name,conn)
+ assert(name,"Database name not provide")
+ assert(conn,"Connection is nil")
+ local obj = { name = name, conn = conn }
+ return setmetatable(obj, database_mt)
+end
+
+return {
+ new = new,
+}
54 lib/resty/mongo/object_id.lua
@@ -0,0 +1,54 @@
+-- object id class
+
+module(...,package.seeall)
+
+local setmetatable = setmetatable
+local strbyte = string.byte
+local strformat = string.format
+local t_insert = table.insert
+local t_concat = table.concat
+
+local md5 = require "md5"
+local util = require "resty.mongo.util"
+local num_to_le_uint = util.num_to_le_uint
+local num_to_be_uint = util.num_to_be_uint
+
+local oid_mt = {
+ __tostring = function( ob )
+ local t = { }
+ for i = 1 , 12 do
+ t_insert( t , strformat ( "%02x" , strbyte( ob.id , i , i ) ) )
+ end
+ return "ObjectId(" .. t_concat ( t ) .. ")"
+ end,
+ __eq = function( a , b ) return a.id == b.id end,
+}
+
+local machineid = md5.sum( util.machineid() ):sub(1,3)
+
+local pid = util.getpid() % 0xffff
+pid = num_to_le_uint(pid,2)
+
+local inc = 0
+
+local function generate_id()
+ inc = inc + 1
+ -- "A BSON ObjectID is a 12-byte value consisting of a 4-byte timestamp (seconds since epoch),
+ -- a 3-byte machine id, a 2-byte process id, and a 3-byte counter.
+ -- Note that the timestamp and counter fields must be stored big endian unlike the rest of BSON"
+ return num_to_be_uint( util.time(), 4 ) .. machineid .. pid .. num_to_be_uint( inc , 3 )
+end
+
+local function new_object_id( str )
+ if str then
+ assert( #str == 12 )
+ else
+ str = generate_id()
+ end
+ return setmetatable( { id = str } , oid_mt )
+end
+
+return {
+ new = new_object_id,
+ metatable = oid_mt,
+}
59 lib/resty/mongo/orderedtable.lua
@@ -0,0 +1,59 @@
+require("resty.mongo.support")
+
+local setmetatable = setmetatable
+local t_insert = table.insert
+local rawget,rawset = rawget,rawset
+
+local ordered_mt = {}
+ordered_mt.__newindex = function(t,key,value)
+ local _keys = t._keys
+ if rawget(t,key) == nil then
+ t_insert(_keys,key)
+ rawset(t,key,value)
+ else
+ rawset(t,key,value)
+ end
+end
+
+function ordered_mt.__pairs(t)
+ local _i = 1
+ local _next = function(t,k)
+ local _k = rawget(t._keys,_i)
+ if _k == nil then
+ _i = 1
+ return nil
+ end
+ local _v = rawget(t,_k)
+ if _v == nil then
+ _i = 1
+ return nil
+ end
+ _i = _i + 1
+ return _k,_v
+ end
+ return _next,t
+end
+
+local function merge(self,t)
+ if type(t) ~= "table" then
+ return
+ end
+ for k,v in pairs(t) do
+ self[k] = v
+ end
+ return self
+end
+
+local function ordered_table(a)
+ local t = { _keys = {} }
+ t.merge = merge
+ setmetatable(t,ordered_mt)
+ if a then
+ for i=1,#a,2 do
+ t[a[i]] = a[i+1]
+ end
+ end
+ return t, ordered_mt
+end
+
+return ordered_table
414 lib/resty/mongo/protocol.lua
@@ -0,0 +1,414 @@
+--------------------------------------------------------
+-- MongoDB Lua driver for OpenResty
+--------------------------------------------------------
+-- License: MIT
+-- Copyright(c) 2012 Pan Fan( Night Sailer)
+--------------------------------------------------------
+-- Mongo Wire Protocol
+--
+-- There are two types of messages, client requests and database responses, each having a slightly different structure.
+--
+-- Client Request Messages
+--
+-- Standard Message Header
+--
+-- In general, each message consists of a standard message header followed by request-specific data.
+-- The standard message header is structured as follows :
+--
+-- struct MsgHeader {
+-- int32 messageLength; // total message size, including this
+-- int32 requestID; // identifier for this message
+-- int32 responseTo; // requestID from the original request
+-- // (used in reponses from db)
+-- int32 opCode; // request type - see table below
+-- }
+--
+-- OP_UPDATE
+--
+-- struct OP_UPDATE {
+-- MsgHeader header; // standard message header
+-- int32 ZERO; // 0 - reserved for future use
+-- cstring fullCollectionName; // "dbname.collectionname"
+-- int32 flags; // bit vector. see below
+-- document selector; // the query to select the document
+-- document update; // specification of the update to perform
+-- }
+--
+-- OP_INSERT
+--
+-- The OP_INSERT message is used to insert one or more documents into a collection.
+-- The format of the OP_INSERT message is
+--
+-- struct {
+-- MsgHeader header; // standard message header
+-- int32 ZERO; // 0 - reserved for future use
+-- cstring fullCollectionName; // "dbname.collectionname"
+-- document* documents; // one or more documents to insert into the collection
+-- }
+--
+--
+-- OP_QUERY
+--
+-- The OP_QUERY message is used to query the database for documents in a collection.
+-- The format of the OP_QUERY message is :
+--
+-- struct OP_QUERY {
+-- MsgHeader header; // standard message header
+-- int32 flags; // bit vector of query options. See below for details.
+-- cstring fullCollectionName; // "dbname.collectionname"
+-- int32 numberToSkip; // number of documents to skip
+-- int32 numberToReturn; // number of documents to return
+-- // in the first OP_REPLY batch
+-- document query; // query object. See below for details.
+-- [ document returnFieldSelector; ] // Optional. Selector indicating the fields
+-- // to return. See below for details.
+-- }
+--
+--
+-- OP_GETMORE
+--
+-- The OP_GETMORE message is used to query the database for documents in a collection.
+-- The format of the OP_GETMORE message is :
+--
+-- struct {
+-- MsgHeader header; // standard message header
+-- int32 ZERO; // 0 - reserved for future use
+-- cstring fullCollectionName; // "dbname.collectionname"
+-- int32 numberToReturn; // number of documents to return
+-- int64 cursorID; // cursorID from the OP_REPLY
+-- }
+--
+-- OP_DELETE
+--
+-- The OP_DELETE message is used to remove one or more messages from a collection.
+-- The format of the OP_DELETE message is :
+--
+-- struct {
+-- MsgHeader header; // standard message header
+-- int32 ZERO; // 0 - reserved for future use
+-- cstring fullCollectionName; // "dbname.collectionname"
+-- int32 flags; // bit vector - see below for details.
+-- document selector; // query object. See below for details.
+-- }
+--
+-- OP_KILL_CURSORS
+--
+-- The OP_KILL_CURSORS message is used to close an active cursor in the database. This is necessary to ensure
+-- that database resources are reclaimed at the end of the query. The format of the OP_KILL_CURSORS message is :
+--
+-- struct {
+-- MsgHeader header; // standard message header
+-- int32 ZERO; // 0 - reserved for future use
+-- int32 numberOfCursorIDs; // number of cursorIDs in message
+-- int64* cursorIDs; // sequence of cursorIDs to close
+-- }
+--
+--
+-- Database Response Messages
+--
+-- OP_REPLY
+--
+-- The OP_REPLY message is sent by the database in response to an
+-- OP_QUERY or OP_GET_MORE
+-- message. The format of an OP_REPLY message is:
+--
+-- struct {
+-- MsgHeader header; // standard message header
+-- int32 responseFlags; // bit vector - see details below
+-- int64 cursorID; // cursor id if client needs to do get more's
+-- int32 startingFrom; // where in the cursor this reply is starting
+-- int32 numberReturned; // number of documents in the reply
+-- document* documents; // documents
+-- }
+---------------------------------------------------------------
+-- More detail about Mongo Wire Protocol, please visit:
+-- http://www.mongodb.org/display/DOCS/Mongo+Wire+Protocol
+---------------------------------------------------------------
+
+-- Mongo Wire Protocol support functions
+
+module(...,package.seeall)
+
+local bson = require('resty.mongo.bson')
+local to_bson,from_bson,from_bson_buf = bson.to_bson,bson.from_bson,bson.from_bson_buf
+local t_concat,t_insert = table.concat, table.insert
+
+local util = require('resty.mongo.util')
+local num_to_le_uint,num_to_le_int = util.num_to_le_uint,util.num_to_le_int
+local new_str_buffer = util.new_str_buffer
+local le_bpeek = util.le_bpeek
+local slice_le_uint, extract_flag_bits = util.slice_le_uint,util.extract_flag_bits
+
+local oid = require("resty.mongo.object_id")
+local t_ordered = require("resty.mongo.orderedtable")
+
+-- reserved collection namespace
+
+local ns = {
+ SYSTEM_NAMESPACE_COLLECTION = "system.namespaces",
+ SYSTEM_INDEX_COLLECTION = "system.indexes",
+ SYSTEM_PROFILE_COLLECTION = "system.profile",
+ SYSTEM_USER_COLLECTION = "system.users",
+ SYSTEM_JS_COLLECTION = "system.js",
+ SYSTEM_COMMAND_COLLECTION = "$cmd",
+}
+
+
+-- opcodes
+
+local op_codes = {
+ OP_REPLY = 1,
+ OP_MSG = 1000,
+ OP_UPDATE = 2001,
+ OP_INSERT = 2002,
+ RESERVED = 2003,
+ OP_QUERY = 2004,
+ OP_GETMORE = 2005,
+ OP_DELETE = 2006,
+ OP_KILL_CURSORS = 2007,
+}
+
+-- message header size
+
+local STANDARD_HEADER_SIZE = 16
+local RESPONSE_HEADER_SIZE = 20
+
+-- place holder
+
+local ZERO32 = "\0\0\0\0"
+local ZEROID = "\0\0\0\0\0\0\0\0"
+
+-- flag bit constant
+
+local flags = {
+ -- used in update message
+ update = {
+ -- If set, the database will insert the supplied object into the collection if no matching document is found.
+ Upsert = 1,
+ -- If set, the database will update all matching objects in the collection.
+ -- Otherwise only updates first matching doc.
+ MultiUpdate = 2,
+ -- 2-31 reserved
+ },
+ -- used in insert message
+ insert = {
+ -- If set, the database will not stop processing a bulk insert if one fails (eg due to duplicate IDs).
+ -- This makes bulk insert behave similarly to a series of single inserts, except lastError will be set if any insert fails, not just the last one.
+ -- If multiple errors occur, only the most recent will be reported by getLastError. (new in 1.9.1)
+ ContinueOnError = 1,
+ },
+ -- used in query message
+ query = {
+ -- Tailable means cursor is not closed when the last data is retrieved.
+ -- Rather, the cursor marks the final object's position.
+ -- You can resume using the cursor later, from where it was located, if more data were received.
+ -- Like any "latent cursor", the cursor may become invalid at some point (CursorNotFound)
+ -- – for example if the final object it references were deleted.
+ TailableCursor = 2,
+ -- Allow query of replica slave. Normally these return an error except for namespace "local".
+ SlaveOk = 4,
+ -- Internal replication use only - driver should not set
+ OplogReplay = 8,
+ -- The server normally times out idle cursors after an inactivity period (10 minutes) to prevent excess memory use.
+ -- Set this option to prevent that.
+ NoCursorTimeout = 16,
+ -- Use with TailableCursor. If we are at the end of the data, block for a while rather than returning no data.
+ -- After a timeout period, we do return as normal.
+ AwaitData = 32,
+ -- Stream the data down full blast in multiple "more" packages, on the assumption that the client will fully read all data queried.
+ -- Faster when you are pulling a lot of data and know you want to pull it all down.
+ -- Note: the client is not allowed to not read all the data unless it closes the connection.
+ Exhaust = 64,
+ -- Get partial results from a mongos if some shards are down (instead of throwing an error)
+ Partial = 128,
+ },
+ -- used in delete message
+ delete = {
+ SingleRemove = 1,
+ },
+ -- used in reponse message
+ reply = {
+ -- CursorNotFound: Set when getMore is called but the cursor id is not valid at the server.
+ -- Returned with zero results.
+ REPLY_CURSOR_NOT_FOUND = 1,
+ -- QueryFailure: Set when query failed. Results consist of one document containing an "$err" field describing the failure.
+ REPLY_QUERY_FAILURE = 2,
+ -- ShardConfigStale: Drivers should ignore this. Only mongos will ever see this set, in which case,
+ -- it needs to update config from the server.
+ REPLY_SHARD_CONFIG_STALE = 4,
+ -- AwaitCapable: Set when the server supports the AwaitData Query option.
+ -- If it doesn't, a client should sleep a little between getMore's of a Tailable cursor.
+ -- Mongod version 1.6 supports AwaitData and thus always sets AwaitCapable.
+ REPLY_AWAIT_CAPABLE = 8,
+ -- Reserved 4-31
+ },
+}
+
+local ERR = {
+ CURSOR_NOT_FOUND = 1,
+ QUERY_FAILURE = 2,
+}
+
+local current_request_id = 0;
+
+local function with_header(opcode,message,response_to)
+ current_request_id = current_request_id+1
+ local request_id = num_to_le_uint(current_request_id)
+ response_to = response_to or ZERO32
+ opcode = num_to_le_uint(assert(op_codes[opcode]))
+ -- header(length,request_id,response_to,opcode) + message
+ -- print("message size",#message+STANDARD_HEADER_SIZE)
+ return current_request_id, num_to_le_uint (#message + STANDARD_HEADER_SIZE)
+ .. request_id .. response_to .. opcode .. message
+end
+
+local function query_message(full_collection_name,query,fields,limit,skip,options)
+ skip = skip or 0
+ local flag = 0
+ if options then
+ flag = (options.tailable and flags.query.TailableCursor or 0)
+ + (options.slave_ok and flags.query.SlaveOk or 0 )
+ + (options.oplog_replay and flags.query.OplogReplay or 0)
+ + (options.immortal and flags.query.NoCursorTimeout or 0)
+ + (options.await_data and flags.query.AwaitData or 0)
+ + (options.exhaust and flags.query.Exhaust or 0)
+ + (options.partial and flags.query.Partial or 0)
+ end
+ query = to_bson(query)
+ if fields then
+ fields = to_bson(fields)
+ else
+ fields = ""
+ end
+ return with_header("OP_QUERY",
+ num_to_le_uint(flag) .. full_collection_name .. num_to_le_uint(skip) .. num_to_le_int(limit)
+ .. query .. fields
+ )
+end
+
+local function get_more_message(full_collection_name, cursor_id, limit)
+ return with_header("OP_GETMORE", ZERO32 .. full_collection_name .. num_to_le_int(limit or 0) .. cursor_id )
+end
+
+local function delete_message(full_collection_name,selector,singleremove)
+ local flags = (singleremove and flags.delete.SingleRemove or 0)
+ selector = to_bson(selector)
+ return with_header('OP_DELETE', ZERO32 .. full_collection_name .. num_to_le_uint(flags) .. selector)
+end
+
+local function update_message(full_collection_name,selector,update,upsert,multiupdate)
+ local flags = (upsert and flags.update.Upsert or 0) + ( multiupdate and flags.update.MultiUpdate or 0)
+ selector = to_bson(selector)
+ update = to_bson(update)
+ return with_header('OP_UPDATE',ZERO32 .. full_collection_name .. num_to_le_uint(flags) .. selector .. update)
+end
+
+local function insert_message(full_collection_name,docs,continue_on_error,no_ids)
+ local flags = ( continue_on_error and flags.insert.ContinueOnError or 0 )
+ local r = {}
+ -- local oids = {}
+ for i,v in ipairs(docs) do
+ local _id = v._id
+ if not _id and not no_ids then
+ _id = oid.new()
+ v._id = _id
+ end
+ r[i] = to_bson(v)
+ end
+ return with_header("OP_INSERT", num_to_le_uint(flags) .. full_collection_name .. t_concat(r))
+end
+
+local function kill_cursors_message(cursor_id)
+ local n = #cursor_id
+ cursor_id = t_concat(cursor_id)
+ return with_header('OP_KILL_CURSORS',ZERO32 .. num_to_le_uint(n) .. cursor_id )
+end
+
+local function recv_message(sock, request_id)
+ -- print("recv_message,reqid",request_id)
+ -- msg header
+ local header = assert(sock:receive(STANDARD_HEADER_SIZE))
+ local msg_length,req_id,response_to,opcode = slice_le_uint(header,4)
+ -- print("msg_length:",msg_length,"req_id",req_id,"response_to",response_to,"opcode",opcode)
+ assert(request_id == response_to, "response_to:".. response_to .. " should:" .. request_id)
+ assert(opcode == op_codes.OP_REPLY,"invalid response opcode")
+ -- read message data
+ local msg_data = assert(sock:receive(msg_length-STANDARD_HEADER_SIZE))
+ local msg_buf = new_str_buffer(msg_data)
+ -- response header,20 bytes
+ local response_flags,cursor_id = msg_buf(4), msg_buf(8)
+ local starting_from,number_returned = slice_le_uint(msg_buf(8),2)
+ local err = {}
+ -- parse reponse flags
+ local cursor_not_found,query_failure,shard_config_stale,await_capable = extract_flag_bits(response_flags,4)
+
+ -- print('cursor_id:',cursor_id,"starting_from:",starting_from,"number_returned:",number_returned,"cursor_not_found:",
+ -- cursor_not_found,"query_failure:",query_failure)
+
+ -- todo: validate flags?
+ -- assert(not cursor_not_found,'cursor not found')
+ if cursor_not_found then
+ -- print("ERR:cursor_not_found")
+ err.CURSOR_NOT_FOUND = true
+ end
+ if query_failure then
+ -- print("ERR:query_failure")
+ err.QUERY_FAILURE = true
+ end
+ -- print("number_returned:"..number_returned)
+ -- client should ignore this flag
+ -- assert(not shard_config_stale,'shard confi is stale')
+ local docs = {}
+ -- documents
+ if not cursor_not_found then
+ for i=1,number_returned do
+ docs[i] = from_bson_buf(msg_buf)
+ end
+ end
+ return cursor_id,number_returned,err,docs
+end
+
+local function db_ns(db,name )
+ return db .. "." .. name .."\0"
+end
+
+local function send_message( sock, message ) return sock:send(message) end
+
+local function send_message_with_safe(sock,message,dbname,opts)
+ local cmd = t_ordered({"getlasterror",true, "w",opts.w,"wtimeout",opts.wtimeout})
+ if opts.fsync then cmd.fsync = true end
+ if opts.j then cmd.j = true end
+ local req_id,last_error_msg = query_message(db_ns(dbname,ns.SYSTEM_COMMAND_COLLECTION),cmd,nil,-1,0)
+ sock:send(message .. last_error_msg)
+ local _, number,err, docs = recv_message(sock,req_id)
+ if number == 1 and ( docs[1]['err'] or docs[1]['errmsg'] ) then
+ return false, docs[1]
+ end
+ return docs[1]
+end
+
+
+return {
+
+-- exported constants
+
+ OPCODES = op_codes,
+ NS = ns,
+ FLAGS = flags,
+ ZERO32 = ZERO32,
+ ZEROID = ZEROID,
+ ERR = ERR,
+
+-- exported functions
+
+ db_ns = db_ns,
+ update_message = update_message,
+ get_more_message = get_more_message,
+ delete_message = delete_message,
+ query_message = query_message,
+ insert_message = insert_message,
+ kill_cursors_message = kill_cursors_message,
+ recv_message = recv_message,
+ send_message = send_message,
+ send_message_with_safe = send_message_with_safe,
+}
38 lib/resty/mongo/support.lua
@@ -0,0 +1,38 @@
+-- some magick support
+
+if _G._resty_mongo then return end
+
+local getmetatable , setmetatable = getmetatable , setmetatable
+local pairs = pairs
+local next = next
+do
+ -- check support __pairs natively if lua 5.2
+ local run = false
+ local _t = setmetatable({} , { __pairs = function() run = true end })
+ pairs(_t)
+ if not supported then
+ _G.pairs = function( t )
+ local mt = getmetatable(t)
+ if mt then
+ local f = mt.__pairs
+ if f then
+ return f(t)
+ end
+ end
+ return pairs(t)
+ end
+ _G.pairs(_t)
+ assert(run)
+ end
+ _G.dump = function(v)
+ if type(v) == 'table' then
+ print(v)
+ for _k,_v in _G.pairs(v) do
+ print("",_k,_v)
+ end
+ return
+ end
+ print("type:",type(v), "value:",v)
+ end
+end
+_G._resty_mongo = true
308 lib/resty/mongo/util.lua
@@ -0,0 +1,308 @@
+-- lowlevel support utlities
+-- most functions implementation are stolen from https://github.com/mongol
+-- Thanks!
+
+module(..., package.seeall)
+
+local assert = assert
+local unpack = unpack
+local floor = math.floor
+local strbyte , strchar = string.byte , string.char
+local strsub = string.sub
+local t_insert = table.insert
+local t_concat = table.concat
+local tonumber = tonumber
+-- check nginx lua env
+local ngx = ngx
+local hasffi , ffi = pcall ( require , "ffi" )
+
+local util = { }
+
+-- little-endian functions
+
+local le_uint_to_num = function( s , i , j )
+ i , j = i or 1 , j or #s
+ local b = { strbyte( s , i , j ) }
+ local n = 0
+ for i=#b , 1 , -1 do
+ n = n*2^8 + b[ i ]
+ end
+ return n
+end
+local le_int_to_num = function( s , i , j )
+ i , j = i or 1 , j or #s
+ local n = le_uint_to_num( s , i , j )
+ local overflow = 2^(8*(j-i) + 7)
+ if n > 2^overflow then
+ n = - ( n % 2^overflow )
+ end
+ return n
+end
+
+local num_to_le_uint = function( n , bytes )
+ bytes = bytes or 4
+ local b = { }
+ for i=1 , bytes do
+ b[ i ] , n = n % 2^8 , floor(n / 2^8)
+ end
+ assert( n == 0 )
+ return strchar( unpack(b) )
+end
+
+local num_to_le_int = function( n , bytes )
+ bytes = bytes or 4
+ if n < 0 then -- Converted to unsigned.
+ n = 2^(8*bytes) + n
+ end
+ return num_to_le_uint(n , bytes)
+end
+
+-- Look at ith bit in given string (indexed from 0)
+-- Returns boolean
+local le_bpeek = function( s , bitnum )
+ local byte = floor( bitnum / 8 ) + 1
+ local bit = bitnum % 8
+ local char = strbyte( s , byte )
+ return floor( ( char % 2^(bit+1) ) / 2^bit ) == 1
+end
+
+-- big-edian unpack function
+
+local be_uint_to_num = function( s , i , j )
+ i , j = i or 1 , j or #s
+ local b = { strbyte ( s , i , j ) }
+ local n = 0
+ for i=1 , #b do
+ n = n*2^8 + b [ i ]
+ end
+ return n
+end
+local num_to_be_uint = function( n , bytes )
+ bytes = bytes or 4
+ local b = { }
+ for i=bytes , 1 , -1 do
+ b [ i ] , n = n % 2^8 , floor( n / 2^8 )
+ end
+ assert ( n == 0 )
+ return strchar( unpack( b ) )
+end
+
+-- Returns (as a number); bits i to j (indexed from 0)
+local extract_bits = function( s , i , j )
+ j = j or i
+ local i_byte = floor( i / 8 ) + 1
+ local j_byte = floor( j / 8 ) + 1
+
+ local n = be_uint_to_num( s , i_byte , j_byte )
+ n = n % 2^( j_byte*8 - i )
+ n = floor( n / 2^( (-(j+1) ) % 8 ) )
+ return n
+end
+
+-- Test with:
+-- local sum = 0
+-- for i=0,31 do
+-- v = le_bpeek( num_to_le_uint(N , 4 ) , i)
+-- sum=sum + ( v and 1 or 0 )*2^i
+-- end
+-- assert( sum == N )
+local be_bpeek = function( s , bitnum )
+ local byte = floor( bitnum / 8 ) + 1
+ local bit = 7-bitnum % 8
+ local char = strbyte ( s , byte )
+ return floor( ( char % 2^(bit+1) ) / 2^bit ) == 1
+end
+
+-- Test with:
+-- local sum = 0
+-- for i=0,31 do
+ -- v = be_bpeek( num_to_be_uint(N , 4), i)
+-- sum=sum + ( v and 1 or 0 )*2^(31-i)
+-- end
+-- assert ( sum == N )
+local to_double , from_double
+do
+ local s , e , d
+ if hasffi then
+ d = ffi.new ( "double[1]" )
+ else
+ -- Can't use with to_double as we can't strip debug info :(
+ d = string.dump ( loadstring ( [[return 523123.123145345]] ) )
+ s , e = d:find ( "\3\54\208\25\126\204\237\31\65" )
+ s = d:sub ( 1 , s )
+ e = d:sub ( e+1 , -1 )
+ end
+ function to_double( n )
+ if hasffi then
+ d [ 0 ] = n
+ return ffi.string ( d , 8 )
+ else
+ -- Should be the 8 bytes following the second \3 (LUA_TSTRING == '\3')
+ local str = string.dump ( loadstring ( [[return ]] .. n ) )
+ local loc , en , mat = str:find ( "\3(........)" , str:find ( "\3" ) + 1 )
+ return mat
+ end
+ end
+ function from_double( str )
+ assert ( #str == 8 )
+ if hasffi then
+ ffi.copy ( d , str , 8 )
+ return d [ 0 ]
+ else
+ str = s .. str .. e
+ return loadstring ( str ) ( )
+ end
+ end
+end
+
+-- a simple string buffer
+-- todo: maybe rewrite with ffi
+local function new_str_buffer(s,i)
+ i = i or 1
+ return function( n )
+ if not n then -- Rest of string
+ n = #s - i + 1
+ end
+ i = i + n
+ assert ( i-1 <= #s , "Unable to read enough characters" )
+ return strsub( s , i-n , i-1 )
+ end , function ( new_i )
+ if new_i then i = new_i end
+ return i
+ end
+end
+
+local function string_to_array_of_chars(s)
+ local t = { }
+ for i = 1 , #s do
+ t[ i ] = strsub(s , i , i)
+ end
+ return t
+end
+
+-- read from string buffer until got the terminators
+local function read_terminated_string(strbuf , terminators)
+ local terminators = string_to_array_of_chars( terminators or "\0" )
+ local str = { }
+ local found = 0
+ while found < #terminators do
+ local c = strbuf(1)
+ if c == terminators[ found + 1 ] then
+ found = found + 1
+ else
+ found = 0
+ end
+ t_insert( str , c )
+ end
+ return t_concat( str , "" , 1 , #str - #terminators )
+end
+
+local function slice_le_uint(buf,num)
+ local t = {}
+ -- local i = 0
+ -- while num > 0
+ -- do
+ -- t_insert(t, le_uint_to_num(buf,i,i+3))
+ -- i = i + 4
+ -- num = num -1
+ -- end
+ local i = 1
+ for j=1,num do
+ t[j] = le_uint_to_num(buf,i,i+3)
+ i = i+4
+ end
+ return unpack(t)
+end
+
+
+local function extract_flag_bits(flag, bits)
+ local t = {}
+ for i=1, bits do
+ t[i] = le_bpeek(flag,i-1)
+ end
+ return unpack(t)
+end
+
+local function machineid()
+ if hasposix then
+ return posix.uname ( "%n" )
+ else
+ return assert ( io.popen ( "uname -n" ) ):read ( "*l" )
+ end
+end
+
+local function getpid()
+ if ngx then
+ return ngx.var.pid
+ end
+ if hasposix then
+ return posix.getpid().pid
+ else
+ return assert( tonumber( assert( io.popen ( "ps -o ppid= -p $$") ):read ( "*a" ) ) )
+ end
+end
+
+-- nginx lua agent
+local md5,time,socket;
+if not ngx then
+ local md5sum = require("md5").sum
+ socket = require('socket')
+ function md5(...) return md5sum(...) end
+ function time() return os.time() end
+else
+ md5 = ngx.md5
+ time = function()
+ return ngx.time()
+ end
+ socket = ngx.socket
+end
+
+function split(s,sep)
+ local sep, fields = sep or ":", {}
+ local pattern = string.format("([^%s]+)", sep)
+ s:gsub(pattern, function(c) fields[#fields+1] = c end)
+ return fields
+end
+
+-- dump table
+local function table_print(t)
+ for k, v in pairs(t) do
+ if type(v) == [[table]] then
+ table_print(v)
+ else
+ if k then print(k,":") end
+ print(v)
+ end
+ end
+end
+
+return {
+ le_uint_to_num = le_uint_to_num,
+ le_int_to_num = le_int_to_num,
+ num_to_le_uint = num_to_le_uint,
+ num_to_le_int = num_to_le_int,
+ slice_le_uint = slice_le_uint,
+
+ be_uint_to_num = be_uint_to_num,
+ num_to_be_uint = num_to_be_uint,
+
+ extract_bits = extract_bits,
+ extract_flag_bits = extract_flag_bits,
+
+ le_bpeek = le_bpeek,
+ be_bpeek = be_bpeek,
+
+ to_double = to_double,
+ from_double = from_double,
+ new_str_buffer = new_str_buffer,
+ read_terminated_string = read_terminated_string,
+
+ split = split,
+ table_print = table_print,
+
+ machineid = machineid,
+ getpid = getpid,
+ time = time,
+ md5 = md5,
+ socket = socket,
+}
39 t/bson.t
@@ -0,0 +1,39 @@
+#!/usr/bin/env lua
+
+require "Test.More"
+
+plan 'no_plan'
+
+if not require_ok("resty.mongo.bson") then
+ BAIL_OUT "no lib"
+end
+
+local util = require "resty.mongo.util"
+local bson = require "resty.mongo.bson"
+local to_bson = bson.to_bson
+local from_bson = bson.from_bson
+local from_bson_buf = bson.from_bson_buf
+
+local oid = require "resty.mongo.object_id"
+local new_str_buffer = util.new_str_buffer
+
+local o = {
+ a = "lol" ,
+ b = "foo" ,
+ c = 42,
+ d = { 5 , 4 , 3 , 2 , 1 },
+ e = { { { { } } } },
+ f = { [true] = {baz = "mars"} },
+ g = oid.new("abcdefghijkl" ),
+ --z = { [{}] = {} } ; -- Can't test as tables are unique
+}
+
+local b = to_bson( o )
+
+local f = from_bson(b)
+local d = from_bson_buf(new_str_buffer(b))
+
+is_deeply(f,o,"from/to bson")
+is_deeply(f,d,"from bson buffer")
+
+done_testing()
61 t/collection.t
@@ -0,0 +1,61 @@
+#!/usr/bin/env lua
+require 'Test.More'
+plan 'no_plan'
+
+local test = require("t.testutil")
+local db = test.test_db()
+
+local col = db:get_collection('foo')
+
+local c = 0
+
+do
+ is(col:count(), c ,"count/empty collection")
+end
+
+do
+ local d = { t = 301 }
+ ok(col:insert(d), "insert")
+ ok(d._id, "single doc /oid generated")
+ c = c + 1
+
+ d = { {t=302 },{ t = 303 }}
+ ok(col:insert(d),"batch insert")
+ c = c + #d
+ is(col:count(),c,"check insert count")
+end
+
+do
+ col:insert({ _id =2, t = 304 })
+ c = c + 1
+ local _ok,err = col:insert({ _id = 2 }, { safe = true })
+ ok(_ok == nil, "insert/safe")
+end
+
+do
+ local query = { t = {} }
+ query.t["$gt"] = 100
+ n = col:count(query)
+ is(n,c,"count/query")
+end
+
+do
+ col:update({_id = 2 }, { t = 305 })
+ local _ok,err = col:update({_id = 2 }, { t = 305 },{ safe = true })
+ is(_ok.n,1,"update/safe")
+end
+do
+ local d = col:find_one({_id = 2})
+ is(d.t,305,"find_one")
+end
+do
+ col:remove({t = 305 })
+ c = c-1
+ is(col:count(),c, 'remove(selector)')
+ col:remove()
+ is(col:count(),0,"remove all")
+end
+
+col:drop()
+
+done_testing()
10 t/connection.t
@@ -0,0 +1,10 @@
+#!/usr/bin/env lua
+require 'Test.More'
+plan 'no_plan'
+
+local mongo_connection = require('resty.mongo.connection')
+local con = mongo_connection.new()
+
+diag(con.host)
+ok(con)
+done_testing()
41 t/cursor.t
@@ -0,0 +1,41 @@
+#!/usr/bin/env lua
+
+require 'Test.More'
+plan 'no_plan'
+
+local util = require('resty.mongo.util')
+local protocol = require("resty.mongo.protocol")
+local t_print = util.table_print
+local Connection = require("resty.mongo.connection")
+local Collection = require("resty.mongo.collection")
+local Cursor = require('resty.mongo.cursor')
+
+local host,port = "127.0.0.1",27017
+local conn = Connection.new({host = host, port = port })
+local db = conn:get_database("test")
+local col = db:get_collection("foo")
+for i=1,100 do
+ col:insert({t = i})
+end
+
+local query = { t = {} }
+query.t["$gte"] = 20
+local cursor = Cursor.new(db,"foo",query)
+cursor:sort({t = -1})
+local _t = false
+for i,item in cursor:next() do
+ if item.t ~= (100-i+1) then
+ _t = false
+ else
+ _t = true
+ end
+end
+ok(_t, "cursor:next()")
+
+cursor:reset()
+
+local i, item = cursor:next_doc()
+ok(i == 1 and item.t == 100, "reset")
+
+col:drop()
+done_testing()
19 t/db.t
@@ -0,0 +1,19 @@
+#!/usr/bin/env lua
+require 'Test.More'
+plan 'no_plan'
+
+local test = require("t.testutil")
+local t_ordered = require("resty.mongo.orderedtable")
+local db = test.test_db()
+local list = db:run_command({ buildInfo = true })
+
+for k,v in pairs(list) do
+ print("k",k)
+end
+ok(list.version, "run_command/buildInfo")
+
+local q = { t = {}}
+q.t["$gte"] = 100
+local r = db:run_command(t_ordered({"count","foo","query", q }))
+dump(r)
+done_testing()
21 t/orderedtable.t
@@ -0,0 +1,21 @@
+#!/usr/bin/env lua
+require 'Test.More'
+plan 'no_plan'
+local newtable = require("resty.mongo.orderedtable")
+local keys = { 'a','b','c' }
+local t = newtable()
+t.a = 1
+t.b = 2
+t.c = 3
+for k,v in pairs(t) do
+ is(keys[v],k,"key:"..k .. " index:"..v)
+end
+local b = newtable({"a",1,"b",2,"c",3})
+for k,v in pairs(b) do
+ is(keys[v],k,"init talbe, key:" .. k .. " index:"..v)
+end
+local c = newtable()
+c:merge({t = 1})
+is(c.t,1,"merge")
+
+done_testing()
21 t/testutil.lua
@@ -0,0 +1,21 @@
+module(...,package.seeall)
+
+require("resty.mongo.support")
+local util = require("resty.mongo.util")
+local Connection = require("resty.mongo.connection")
+local Cursor = require('resty.mongo.cursor')
+local host,port = "127.0.0.1",27017
+local conn = Connection.new({host = host, port = port })
+
+local function new_conn()
+ return conn
+end
+
+local function test_db(db)
+ db = db or "test"
+ return conn:get_database(db)
+end
+return {
+ new_conn = new_conn,
+ test_db = test_db,
+}
30 t/util.t
@@ -0,0 +1,30 @@
+#!/usr/bin/env lua
+
+require 'Test.More'
+
+plan 'no_plan'
+
+local posix = require 'posix'
+
+local util = require 'resty.mongo.util'
+
+local md5 = require('md5')
+
+local util_md5 = util.md5;
+
+is(util.time(),os.time(),'time')
+is(util.getpid(),posix.getpid().pid,'getpid')
+
+is(util_md5('123abc'), md5.sum('123abc'),'md5')
+
+local t = util.split("127.0.0.1:22",":")
+
+is(#t,2,"split")
+
+is(t[1],"127.0.0.1","split[1]")
+is(t[2],"22","split[2]")
+
+local long = util.num_to_le_uint(66001 % 0xffff,2)
+ok(util.le_uint_to_num(long) > 0 , 'Int64')
+
+done_testing()
Please sign in to comment.
Something went wrong with that request. Please try again.