Skip to content

Commit

Permalink
add single row prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack Tang committed Apr 22, 2021
1 parent efe71e0 commit a3e314d
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
27 changes: 27 additions & 0 deletions example/load_saved_model.nim
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,35 @@ proc main() =

booster.loadModel("agaricus.txt.model")

# predict
let dtest = newXGDMatrix("agaricus.txt.test", silent=0)
echo booster.predict(dtest)

# predict single data
var row: array[126, float32]
row[1] = 1
row[9] = 1
row[19] = 1
row[21] = 1
row[24] = 1
row[34] = 1
row[36] = 1
row[39] = 1
row[42] = 1
row[53] = 1
row[56] = 1
row[65] = 1
row[69] = 1
row[77] = 1
row[86] = 1
row[88] = 1
row[92] = 1
row[95] = 1
row[102] = 1
row[106] = 1
row[117] = 1
row[122] = 1
echo booster.predict(@row)

when isMainModule:
main()
27 changes: 19 additions & 8 deletions src/xgboost.nim
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import strformat

export libxgboost

# or NaN.float32 ?
const DEFAULT_MISSING* = 0.0f32

type
XGError* = object of CatchableError
returnCode*: int
Expand Down Expand Up @@ -60,7 +63,7 @@ proc newXGDMatrix*(fname: string, silent: int = 1): XGDMatrix =
result.new(finalize)
check: XGDMatrixCreateFromFile(fname, silent.cint, result.self.addr)

proc newXGDMatrix*(data: seq[float32], nRow, nCol: int, missing: float32 = NaN.float32): XGDMatrix =
proc newXGDMatrix*(data: seq[float32], nRow, nCol: int, missing: float32 = DEFAULT_MISSING): XGDMatrix =
if data.len != nRow * nCol:
raise newException(XGError, fmt"invalid length of data data.len={data.len} nRow={nRow} nCol={nCol}")

Expand All @@ -77,18 +80,21 @@ proc newXGDMatrix*(data: seq[float32], nRow, nCol: int, missing: float32 = NaN.f
result.self.addr
)

proc newXGDMatrix*(data: seq[float32], nRow: int, missing: float32 = NaN.float32): XGDMatrix =
proc newXGDMatrix*(data: seq[float32], nRow: int, missing: float32 = DEFAULT_MISSING): XGDMatrix =
let nCol = data.len div nRow
if nCol * nRow != data.len:
raise newException(XGError, fmt"invalid length of data data.len={data.len} nRow={nRow}")
result = newXGDMatrix(data, nRow, nCol, missing)

proc newXGDMatrix*[N: static int](data: seq[array[N, float32]], missing: float32 = NaN.float32): XGDMatrix =
let nRow = N
let nCol = data.len div nRow
if nCol * nRow != data.len:
raise newException(XGError, fmt"invalid length of data data.len={data.len} nRow={nRow}")
result = newXGDMatrix(data, nRow, nCol, missing)
proc newXGDMatrix*[N: static int](data: seq[array[N, float32]], missing: float32 = DEFAULT_MISSING): XGDMatrix =
let nRow = data.len
let nCol = N
# todo: use iterator, not copy
var copy = newSeq[float32](nCol*nRow)
for i in 0 ..< nRow:
for j in 0 ..< nCol:
copy[i * nCol + j] = data[i][j]
result = newXGDMatrix(copy, nRow, nCol, missing)

proc nRow*(m: XGDMatrix): int =
var tmp: uint64
Expand Down Expand Up @@ -204,6 +210,11 @@ proc predict*(
for i in 0 ..< size:
result[i] = cast[ptr float32](cast[ByteAddress](outResultPtr) +% i.int * sizeof(float32))[]

proc predict*(b: XGBooster, v: seq[float32], missing: float32 = DEFAULT_MISSING): float32 =
let m = newXGDMatrix(v, 1, v.len, missing)
let res = b.predict(m)
result = res[0]

proc saveModel*(b: XGBooster, fname: string) =
check: XGBoosterSaveModel(b.self, fname)

Expand Down

0 comments on commit a3e314d

Please sign in to comment.