Skip to content

Commit 3405ee8

Browse files
authoredNov 1, 2023
[python-package] Allow to pass Arrow table as training data (#6034)
1 parent fcf76bc commit 3405ee8

File tree

13 files changed

+1017
-18
lines changed

13 files changed

+1017
-18
lines changed
 

‎.ci/test-python-oldest.sh

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
#
88
echo "installing lightgbm's dependencies"
99
pip install \
10+
'cffi==1.15.1' \
1011
'dataclasses' \
11-
'numpy==1.12.0' \
12+
'numpy==1.16.6' \
1213
'pandas==0.24.0' \
14+
'pyarrow==6.0.1' \
1315
'scikit-learn==0.18.2' \
1416
'scipy==0.19.0' \
1517
|| exit -1

‎.ci/test.sh

+4
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,13 @@ fi
130130
# including python=version[build=*cpython] to ensure that conda doesn't fall back to pypy
131131
mamba create -q -y -n $CONDA_ENV \
132132
${CONSTRAINED_DEPENDENCIES} \
133+
cffi \
133134
cloudpickle \
134135
joblib \
135136
matplotlib \
136137
numpy \
137138
psutil \
139+
pyarrow \
138140
pytest \
139141
${CONDA_PYTHON_REQUIREMENT} \
140142
python-graphviz \
@@ -315,11 +317,13 @@ matplotlib.use\(\"Agg\"\)\
315317

316318
# importing the library should succeed even if all optional dependencies are not present
317319
conda uninstall --force --yes \
320+
cffi \
318321
dask \
319322
distributed \
320323
joblib \
321324
matplotlib \
322325
psutil \
326+
pyarrow \
323327
python-graphviz \
324328
scikit-learn || exit -1
325329
python -c "import lightgbm" || exit -1

‎.ci/test_windows.ps1

+2
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,14 @@ conda install brotlipy
5252

5353
conda update -q -y conda
5454
conda create -q -y -n $env:CONDA_ENV `
55+
cffi `
5556
cloudpickle `
5657
joblib `
5758
matplotlib `
5859
numpy `
5960
pandas `
6061
psutil `
62+
pyarrow `
6163
pytest `
6264
"python=$env:PYTHON_VERSION[build=*cpython]" `
6365
python-graphviz `

‎include/LightGBM/arrow.h

+256
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
/*!
2+
* Copyright (c) 2023 Microsoft Corporation. All rights reserved.
3+
* Licensed under the MIT License. See LICENSE file in the project root for license information.
4+
*
5+
* Author: Oliver Borchert
6+
*/
7+
8+
#ifndef LIGHTGBM_ARROW_H_
9+
#define LIGHTGBM_ARROW_H_
10+
11+
#include <algorithm>
12+
#include <cstdint>
13+
#include <functional>
14+
#include <iterator>
15+
#include <limits>
16+
#include <memory>
17+
#include <utility>
18+
#include <vector>
19+
20+
/* -------------------------------------- C DATA INTERFACE ------------------------------------- */
21+
// The C data interface is taken from
22+
// https://arrow.apache.org/docs/format/CDataInterface.html#structure-definitions
23+
// and is available under Apache License 2.0 (https://www.apache.org/licenses/LICENSE-2.0).
24+
25+
#ifdef __cplusplus
26+
extern "C" {
27+
#endif
28+
29+
#define ARROW_FLAG_DICTIONARY_ORDERED 1
30+
#define ARROW_FLAG_NULLABLE 2
31+
#define ARROW_FLAG_MAP_KEYS_SORTED 4
32+
33+
struct ArrowSchema {
34+
// Array type description
35+
const char* format;
36+
const char* name;
37+
const char* metadata;
38+
int64_t flags;
39+
int64_t n_children;
40+
struct ArrowSchema** children;
41+
struct ArrowSchema* dictionary;
42+
43+
// Release callback
44+
void (*release)(struct ArrowSchema*);
45+
// Opaque producer-specific data
46+
void* private_data;
47+
};
48+
49+
struct ArrowArray {
50+
// Array data description
51+
int64_t length;
52+
int64_t null_count;
53+
int64_t offset;
54+
int64_t n_buffers;
55+
int64_t n_children;
56+
const void** buffers;
57+
struct ArrowArray** children;
58+
struct ArrowArray* dictionary;
59+
60+
// Release callback
61+
void (*release)(struct ArrowArray*);
62+
// Opaque producer-specific data
63+
void* private_data;
64+
};
65+
66+
#ifdef __cplusplus
67+
}
68+
#endif
69+
70+
/* --------------------------------------------------------------------------------------------- */
71+
/* CHUNKED ARRAY */
72+
/* --------------------------------------------------------------------------------------------- */
73+
74+
namespace LightGBM {
75+
76+
/**
77+
* @brief Arrow array-like container for a list of Arrow arrays.
78+
*/
79+
class ArrowChunkedArray {
80+
/* List of length `n` for `n` chunks containing the individual Arrow arrays. */
81+
std::vector<const ArrowArray*> chunks_;
82+
/* Schema for all chunks. */
83+
const ArrowSchema* schema_;
84+
/* List of length `n + 1` for `n` chunks containing the offsets for each chunk. */
85+
std::vector<int64_t> chunk_offsets_;
86+
87+
inline void construct_chunk_offsets() {
88+
chunk_offsets_.reserve(chunks_.size() + 1);
89+
chunk_offsets_.emplace_back(0);
90+
for (size_t k = 0; k < chunks_.size(); ++k) {
91+
chunk_offsets_.emplace_back(chunks_[k]->length + chunk_offsets_.back());
92+
}
93+
}
94+
95+
public:
96+
/**
97+
* @brief Construct a new Arrow Chunked Array object.
98+
*
99+
* @param chunks A list with the chunks.
100+
* @param schema The schema for all chunks.
101+
*/
102+
inline ArrowChunkedArray(std::vector<const ArrowArray*> chunks, const ArrowSchema* schema) {
103+
chunks_ = chunks;
104+
schema_ = schema;
105+
construct_chunk_offsets();
106+
}
107+
108+
/**
109+
* @brief Construct a new Arrow Chunked Array object.
110+
*
111+
* @param n_chunks The number of chunks.
112+
* @param chunks A C-style array containing the chunks.
113+
* @param schema The schema for all chunks.
114+
*/
115+
inline ArrowChunkedArray(int64_t n_chunks,
116+
const struct ArrowArray* chunks,
117+
const struct ArrowSchema* schema) {
118+
chunks_.reserve(n_chunks);
119+
for (auto k = 0; k < n_chunks; ++k) {
120+
chunks_.push_back(&chunks[k]);
121+
}
122+
schema_ = schema;
123+
construct_chunk_offsets();
124+
}
125+
126+
/**
127+
* @brief Get the length of the chunked array.
128+
* This method returns the cumulative length of all chunks.
129+
* Complexity: O(1)
130+
*
131+
* @return int64_t The number of elements in the chunked array.
132+
*/
133+
inline int64_t get_length() const { return chunk_offsets_.back(); }
134+
135+
/* ----------------------------------------- ITERATOR ---------------------------------------- */
136+
template <typename T>
137+
class Iterator {
138+
using getter_fn = std::function<T(const ArrowArray*, int64_t)>;
139+
140+
/* Reference to the chunked array that this iterator iterates over. */
141+
const ArrowChunkedArray& array_;
142+
/* Function to fetch the value at a certain index from a single chunk. */
143+
getter_fn get_;
144+
/* The chunk the iterator currently points to. */
145+
int64_t ptr_chunk_;
146+
/* The index inside the current chunk that the iterator points to. */
147+
int64_t ptr_offset_;
148+
149+
public:
150+
using iterator_category = std::random_access_iterator_tag;
151+
using difference_type = int64_t;
152+
using value_type = T;
153+
using pointer = value_type*;
154+
using reference = value_type&;
155+
156+
/**
157+
* @brief Construct a new Iterator object.
158+
*
159+
* @param array Reference to the chunked array to iterator over.
160+
* @param get Function to fetch the value at a certain index from a single chunk.
161+
* @param ptr_chunk The index of the chunk to whose first index the iterator points to.
162+
*/
163+
Iterator(const ArrowChunkedArray& array, getter_fn get, int64_t ptr_chunk);
164+
165+
T operator*() const;
166+
template <typename I>
167+
T operator[](I idx) const;
168+
169+
Iterator<T>& operator++();
170+
Iterator<T>& operator--();
171+
Iterator<T>& operator+=(int64_t c);
172+
173+
template <typename V>
174+
friend bool operator==(const Iterator<V>& a, const Iterator<V>& b);
175+
template <typename V>
176+
friend bool operator!=(const Iterator<V>& a, const Iterator<V>& b);
177+
template <typename V>
178+
friend int64_t operator-(const Iterator<V>& a, const Iterator<V>& b);
179+
};
180+
181+
/**
182+
* @brief Obtain an iterator to the beginning of the chunked array.
183+
*
184+
* @tparam T The value type of the iterator. May be any primitive type.
185+
* @return Iterator<T> The iterator.
186+
*/
187+
template <typename T>
188+
inline Iterator<T> begin() const;
189+
190+
/**
191+
* @brief Obtain an iterator to the beginning of the chunked array.
192+
*
193+
* @tparam T The value type of the iterator. May be any primitive type.
194+
* @return Iterator<T> The iterator.
195+
*/
196+
template <typename T>
197+
inline Iterator<T> end() const;
198+
199+
template <typename V>
200+
friend int64_t operator-(const Iterator<V>& a, const Iterator<V>& b);
201+
};
202+
203+
/**
204+
* @brief Arrow container for a list of chunked arrays.
205+
*/
206+
class ArrowTable {
207+
std::vector<ArrowChunkedArray> columns_;
208+
209+
public:
210+
/**
211+
* @brief Construct a new Arrow Table object.
212+
*
213+
* @param n_chunks The number of chunks.
214+
* @param chunks A C-style array containing the chunks.
215+
* @param schema The schema for all chunks.
216+
*/
217+
inline ArrowTable(int64_t n_chunks, const ArrowArray* chunks, const ArrowSchema* schema) {
218+
columns_.reserve(schema->n_children);
219+
for (int64_t j = 0; j < schema->n_children; ++j) {
220+
std::vector<const ArrowArray*> children_chunks;
221+
children_chunks.reserve(n_chunks);
222+
for (int64_t k = 0; k < n_chunks; ++k) {
223+
children_chunks.push_back(chunks[k].children[j]);
224+
}
225+
columns_.emplace_back(children_chunks, schema->children[j]);
226+
}
227+
}
228+
229+
/**
230+
* @brief Get the number of rows in the table.
231+
*
232+
* @return int64_t The number of rows.
233+
*/
234+
inline int64_t get_num_rows() const { return columns_.front().get_length(); }
235+
236+
/**
237+
* @brief Get the number of columns of this table.
238+
*
239+
* @return int64_t The column count.
240+
*/
241+
inline int64_t get_num_columns() const { return columns_.size(); }
242+
243+
/**
244+
* @brief Get the column at a particular index.
245+
*
246+
* @param idx The index of the column, must me in the range `[0, num_columns)`.
247+
* @return const ArrowChunkedArray& The chunked array for the child at the provided index.
248+
*/
249+
inline const ArrowChunkedArray& get_column(size_t idx) const { return this->columns_[idx]; }
250+
};
251+
252+
} // namespace LightGBM
253+
254+
#include "arrow.tpp"
255+
256+
#endif /* LIGHTGBM_ARROW_H_ */

0 commit comments

Comments
 (0)
Please sign in to comment.