Skip to content

Commit

Permalink
add setInfo getInfo
Browse files Browse the repository at this point in the history
  • Loading branch information
Jack Tang committed Apr 22, 2021
1 parent d02fb33 commit 65ef442
Showing 1 changed file with 56 additions and 0 deletions.
56 changes: 56 additions & 0 deletions src/xgboost.nim
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,62 @@ proc slice*(handle: XGDMatrix, idx: seq[int]): XGDMatrix =
result.self.addr
)

proc setInfo*(handle: XGDMatrix, field: string, arr: seq[float32]) =
## set float vector to a content in info
var dummy: float32
check: XGDMatrixSetFloatInfo(
handle.self, field,
if arr.len == 0: dummy.addr else: arr[0].unsafeAddr,
arr.len.uint64
)

proc setInfo*(handle: XGDMatrix, field: string, arr: seq[uint32]) =
## set uint32 vector to a content in info
var dummy: uint32
check: XGDMatrixSetUIntInfo(
handle.self, field,
if arr.len == 0: dummy.addr else: arr[0].unsafeAddr,
arr.len.uint64
)

proc setInfo*(handle: XGDMatrix, field: string, arr: seq[string]) =
## Set string encoded information of all features.
var data = allocCStringArray(arr)
check: XGDMatrixSetStrFeatureInfo(
handle.self, field,
data,
arr.len.uint64
)

proc saveBinary*(handle: XGDMatrix, fname: string, silent: int = 1) =
## load a data matrix into binary file
check: XGDMatrixSaveBinary(handle.self, fname, silent.int32)

proc getFloatInfo*(handle: XGDMatrix, field: string): seq[float32] =
## get float info vector from matrix.
var outLen: uint64
var outResult: ptr float32
check: XGDMatrixGetFloatInfo(handle.self, field, outLen.addr, outResult.addr)
result = newSeq[float32](outLen.int)
for i in 0 ..< outLen.int:
result[i] = cast[float32](cast[ByteAddress](outResult) +% i * sizeof(float32))

proc getUIntInfo*(handle: XGDMatrix, field: string): seq[uint32] =
## get uint32 info vector from matrix
var outLen: uint64
var outResult: ptr uint32
check: XGDMatrixGetUIntInfo(handle.self, field, outLen.addr, outResult.addr)
result = newSeq[uint32](outLen.int)
for i in 0 ..< outLen.int:
result[i] = cast[uint32](cast[ByteAddress](outResult) +% i * sizeof(uint32))

proc getStrFeatureInfo*(handle: XGDMatrix, field: string): seq[string] =
## Get string encoded information of all features.
var outLen: uint64
var outResult: cstringArray
check: XGDMatrixGetStrFeatureInfo(handle.self, field, outLen.addr, outResult.addr)
result = cstringArrayToSeq(outResult)

# -------------------------------------------------------------
# XGBooster

Expand Down

0 comments on commit 65ef442

Please sign in to comment.