Skip to content

Commit

Permalink
apacheGH-37175: [MATLAB] Support creating arrow.tabular.RecordBatch
Browse files Browse the repository at this point in the history
… instances from a list of `arrow.array.Array` values (apache#37176)

### Rationale for this change

Right now, the only way to construct an `arrow.tabular.RecordBatch` is from a MATLAB `table`:

```matlab
>> t = table([1; 2; 3], ["A"; "B"; "C"], VariableNames=["Numbers", "Letters"]);

t =

  3×2 table

    Numbers    Letters
    _______    _______

       1         "A"  
       2         "B"  
       3         "C"  

>> rb = arrow.recordbatch(t)

rb = 

Numbers:   [
    1,
    2,
    3
  ]
Letters:   [
    "A",
    "B",
    "C"
  ]
```

The interface should also support creating `arrow.tabular.RecordBatch` instances from lists of `arrow.array.Array` values.

### What changes are included in this PR?

Added a new static method to `arrow.tabular.RecordBatch` called `fromArrays`. This method accepts a comma-separated list of `arrow.array.Array` values which it uses to construct an  `arrow.tabular.RecordBatch`. It also accepts an optional name-value pair called `ColumnNames`, which can be used to specify the column names in the record batch. If this name-value pair is not supplied, the column names default to `"Column1"`, `"Column2"`, etc.

**Example Usage:**
```matlab
>> a1 = arrow.array([1, 2, 3]);
>> a2 = arrow.array(["A", "B", "C"]);

>> rb1 = arrow.tabular.RecordBatch.fromArrays(a1, a2)

rb1 = 

Column1:   [
    1,
    2,
    3
  ]
Column2:   [
    "A",
    "B",
    "C"
  ]

>> rb2 = arrow.tabular.RecordBatch.fromArrays(a1, a2, ColumnNames=["Numbers", "Letters"])

rb2 = 

Numbers:   [
    1,
    2,
    3
  ]
Letters:   [
    "A",
    "B",
    "C"
  ]
```

### Are these changes tested?

Yes.

1. Added new test class `arrow/test/tabular/tValidateArrayLengths.m`
2. Added new test class `arrow/test/tabular/tValidateColumnNames.m`
3. Added new test cases to `arrow/test/tabular/tRecordBatch.m`

### Are there any user-facing changes?

Yes, users can now create `arrow.tabular.RecordBatch` instances using the static method `arrow.tabular.RecordBatch.fromArrays`.

* Closes: apache#37175

Authored-by: Sarah Gilmore <sgilmore@mathworks.com>
Signed-off-by: Kevin Gurney <kgurney@mathworks.com>
  • Loading branch information
sgilmore10 authored and loicalleyne committed Nov 13, 2023
1 parent 2ffcfbe commit 333f100
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 21 deletions.
39 changes: 39 additions & 0 deletions matlab/src/matlab/+arrow/+tabular/+internal/validateArrayLengths.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
%VALIDATEARRAYLENGTHS Validates all arrays in the cell array arrowArrays
%have the same length.

% Licensed to the Apache Software Foundation (ASF) under one or more
% contributor license agreements. See the NOTICE file distributed with
% this work for additional information regarding copyright ownership.
% The ASF licenses this file to you under the Apache License, Version
% 2.0 (the "License"); you may not use this file except in compliance
% with the License. You may obtain a copy of the License at
%
% http://www.apache.org/licenses/LICENSE-2.0
%
% Unless required by applicable law or agreed to in writing, software
% distributed under the License is distributed on an "AS IS" BASIS,
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.

function validateArrayLengths(arrowArrays)

numArrays = numel(arrowArrays);

if numArrays == 0
return;
end

expectedLength = arrowArrays{1}.Length;

for ii = 2:numel(arrowArrays)
if arrowArrays{ii}.Length ~= expectedLength
errid = "arrow:tabular:UnequalArrayLengths";
msg = compose("Expected all arrays to have a length of %d," + ...
" but the array at position %d has a length of %d.", ...
expectedLength, ii, arrowArrays{ii}.Length);
error(errid, msg);
end
end
end

25 changes: 25 additions & 0 deletions matlab/src/matlab/+arrow/+tabular/+internal/validateColumnNames.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
%VAIDATECOLUMNNAMES Validates columnNames has the expected number of
%elements.

% Licensed to the Apache Software Foundation (ASF) under one or more
% contributor license agreements. See the NOTICE file distributed with
% this work for additional information regarding copyright ownership.
% The ASF licenses this file to you under the Apache License, Version
% 2.0 (the "License"); you may not use this file except in compliance
% with the License. You may obtain a copy of the License at
%
% http://www.apache.org/licenses/LICENSE-2.0
%
% Unless required by applicable law or agreed to in writing, software
% distributed under the License is distributed on an "AS IS" BASIS,
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.

function validateColumnNames(columnNames, numColumns)
if numel(columnNames) ~= numColumns
errid = "arrow:tabular:WrongNumberColumnNames";
msg = compose("Expected ColumnNames to have %d values.", numColumns);
error(errid, msg);
end
end
25 changes: 25 additions & 0 deletions matlab/src/matlab/+arrow/+tabular/RecordBatch.m
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,29 @@ function displayScalarObject(obj)
disp(obj.toString());
end
end

methods (Static, Access=public)
function recordBatch = fromArrays(arrowArrays, opts)
arguments(Repeating)
arrowArrays(1, 1) arrow.array.Array
end
arguments
opts.ColumnNames(1, :) string {mustBeNonmissing} = compose("Column%d", 1:numel(arrowArrays))
end

import arrow.tabular.internal.validateArrayLengths
import arrow.tabular.internal.validateColumnNames
import arrow.tabular.internal.getArrayProxyIDs

numColumns = numel(arrowArrays);
validateArrayLengths(arrowArrays);
validateColumnNames(opts.ColumnNames, numColumns);

arrayProxyIDs = getArrayProxyIDs(arrowArrays);
args = struct(ArrayProxyIDs=arrayProxyIDs, ColumnNames=opts.ColumnNames);
proxyName = "arrow.tabular.proxy.RecordBatch";
proxy = arrow.internal.proxy.create(proxyName, args);
recordBatch = arrow.tabular.RecordBatch(proxy);
end
end
end
129 changes: 108 additions & 21 deletions matlab/test/arrow/tabular/tRecordBatch.m
Original file line number Diff line number Diff line change
Expand Up @@ -27,28 +27,10 @@ function Basic(tc)

function SupportedTypes(tc)
% Create a table all supported MATLAB types.
TOriginal = table(int8 ([1, 2, 3]'), ...
int16 ([1, 2, 3]'), ...
int32 ([1, 2, 3]'), ...
int64 ([1, 2, 3]'), ...
uint8 ([1, 2, 3]'), ...
uint16 ([1, 2, 3]'), ...
uint32 ([1, 2, 3]'), ...
uint64 ([1, 2, 3]'), ...
logical([1, 0, 1]'), ...
single ([1, 2, 3]'), ...
double ([1, 2, 3]'), ...
string (["A", "B", "C"]'), ...
datetime(2023, 6, 28) + days(0:2)');
TOriginal = createTableWithAllSupportedTypes();
arrowRecordBatch = arrow.recordbatch(TOriginal);
TConverted = arrowRecordBatch.toMATLAB();
tc.verifyEqual(TOriginal, TConverted);
for ii = 1:arrowRecordBatch.NumColumns
column = arrowRecordBatch.column(ii);
tc.verifyEqual(column.toMATLAB(), TOriginal{:, ii});
traits = arrow.type.traits.traits(string(class(TOriginal{:, ii})));
tc.verifyInstanceOf(column, traits.ArrayClassName);
end
expectedColumnNames = compose("Var%d", 1:13);
tc.verifyRecordBatch(arrowRecordBatch, expectedColumnNames, TOriginal);
end

function ToMATLAB(tc)
Expand Down Expand Up @@ -135,5 +117,110 @@ function ErrorIfIndexIsNonPositive(tc)
fcn = @() arrowRecordBatch.column(-1);
tc.verifyError(fcn, "arrow:badsubscript:NonPositive");
end

function FromArraysColumnNamesNotProvided(tc)
% Verify arrow.tabular.RecordBatch.fromArrays creates the expected
% RecordBatch when given a comma-separated list of
% arrow.array.Array values.
import arrow.tabular.RecordBatch

TOriginal = createTableWithAllSupportedTypes();

arrowArrays = cell([1 width(TOriginal)]);
for ii = 1:width(TOriginal)
arrowArrays{ii} = arrow.array(TOriginal.(ii));
end

arrowRecordBatch = RecordBatch.fromArrays(arrowArrays{:});
expectedColumnNames = compose("Column%d", 1:13);
TOriginal.Properties.VariableNames = expectedColumnNames;
tc.verifyRecordBatch(arrowRecordBatch, expectedColumnNames, TOriginal);
end

function FromArraysWithColumnNamesProvided(tc)
% Verify arrow.tabular.RecordBatch.fromArrays creates the expected
% RecordBatch when given a comma-separated list of
% arrow.array.Array values and the ColumnNames nv-pair is provided.
import arrow.tabular.RecordBatch

TOriginal = createTableWithAllSupportedTypes();

arrowArrays = cell([1 width(TOriginal)]);
for ii = 1:width(TOriginal)
arrowArrays{ii} = arrow.array(TOriginal.(ii));
end

columnNames = string(char(65:77)')';
arrowRecordBatch = RecordBatch.fromArrays(arrowArrays{:}, ColumnNames=columnNames);
TOriginal.Properties.VariableNames = columnNames;
tc.verifyRecordBatch(arrowRecordBatch, columnNames, TOriginal);
end

function FromArraysUnequalArrayLengthsError(tc)
% Verify arrow.tabular.RecordBatch.fromArrays throws an error whose
% identifier is "arrow:tabular:UnequalArrayLengths" if the arrays
% provided don't all have the same length.
import arrow.tabular.RecordBatch

A1 = arrow.array([1, 2]);
A2 = arrow.array(["A", "B", "C"]);
fcn = @() RecordBatch.fromArrays(A1, A2);
tc.verifyError(fcn, "arrow:tabular:UnequalArrayLengths");
end

function FromArraysWrongNumberColumnNamesError(tc)
% Verify arrow.tabular.RecordBatch.fromArrays throws an error whose
% identifier is "arrow:tabular:WrongNumberColumnNames" if the
% ColumnNames provided doesn't have one element per array.
import arrow.tabular.RecordBatch

A1 = arrow.array([1, 2]);
A2 = arrow.array(["A", "B"]);
fcn = @() RecordBatch.fromArrays(A1, A2, columnNames=["A", "B", "C"]);
tc.verifyError(fcn, "arrow:tabular:WrongNumberColumnNames");
end

function FromArraysColumnNamesHasMissingString(tc)
% Verify arrow.tabular.RecordBatch.fromArrays throws an error whose
% identifier is "MATLAB:validators:mustBeNonmissing" if the
% ColumnNames provided has a missing string value.
import arrow.tabular.RecordBatch

A1 = arrow.array([1, 2]);
A2 = arrow.array(["A", "B"]);
fcn = @() RecordBatch.fromArrays(A1, A2, columnNames=["A", missing]);
tc.verifyError(fcn, "MATLAB:validators:mustBeNonmissing");
end
end

methods
function verifyRecordBatch(tc, recordBatch, expectedColumnNames, expectedTable)
tc.verifyEqual(recordBatch.NumColumns, int32(width(expectedTable)));
tc.verifyEqual(recordBatch.ColumnNames, expectedColumnNames);
convertedTable = recordBatch.table();
tc.verifyEqual(convertedTable, expectedTable);
for ii = 1:recordBatch.NumColumns
column = recordBatch.column(ii);
tc.verifyEqual(column.toMATLAB(), expectedTable{:, ii});
traits = arrow.type.traits.traits(string(class(expectedTable{:, ii})));
tc.verifyInstanceOf(column, traits.ArrayClassName);
end
end
end
end

function T = createTableWithAllSupportedTypes()
T = table(int8 ([1, 2, 3]'), ...
int16 ([1, 2, 3]'), ...
int32 ([1, 2, 3]'), ...
int64 ([1, 2, 3]'), ...
uint8 ([1, 2, 3]'), ...
uint16 ([1, 2, 3]'), ...
uint32 ([1, 2, 3]'), ...
uint64 ([1, 2, 3]'), ...
logical([1, 0, 1]'), ...
single ([1, 2, 3]'), ...
double ([1, 2, 3]'), ...
string (["A", "B", "C"]'), ...
datetime(2023, 6, 28) + days(0:2)');
end
69 changes: 69 additions & 0 deletions matlab/test/arrow/tabular/tValidateArrayLengths.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
%TVALIDATEARRAYLENGTHS Unit tests for
%arrow.tabular.internal.validateArrayLengths.

% Licensed to the Apache Software Foundation (ASF) under one or more
% contributor license agreements. See the NOTICE file distributed with
% this work for additional information regarding copyright ownership.
% The ASF licenses this file to you under the Apache License, Version
% 2.0 (the "License"); you may not use this file except in compliance
% with the License. You may obtain a copy of the License at
%
% http://www.apache.org/licenses/LICENSE-2.0
%
% Unless required by applicable law or agreed to in writing, software
% distributed under the License is distributed on an "AS IS" BASIS,
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.

classdef tValidateArrayLengths < matlab.unittest.TestCase

methods(Test)
function ArraysWithEqualLength(testCase)
% Verify validateArrayLengths() does not error if all the
% arrays have the same length.

import arrow.tabular.internal.validateArrayLengths

a = arrow.array(["A", "B", "C"]);
b = arrow.array([true, false, true]);
c = arrow.array([1, 2, 3]);

% cell array with one element
fcn = @() validateArrayLengths({a});
testCase.verifyWarningFree(fcn);

% cell array with two elements
fcn = @() validateArrayLengths({a, b});
testCase.verifyWarningFree(fcn);

% cell array with three elements
fcn = @() validateArrayLengths({a, b, c});
testCase.verifyWarningFree(fcn);
end

function ArraysWithUnequalLengths(testCase)
% Verify validateArrayLengths() throws an error whose
% identifier is "arrow:tabular:UnequalArrayLengths" if
% all the arrays do not have the same length.

import arrow.tabular.internal.validateArrayLengths

a = arrow.array(["A", "B", "C"]);
b = arrow.array([true, false, true, true]);
c = arrow.array([1, 2, 3]);

fcn = @() validateArrayLengths({a, b});
testCase.verifyError(fcn, "arrow:tabular:UnequalArrayLengths");

fcn = @() validateArrayLengths({b, a});
testCase.verifyError(fcn, "arrow:tabular:UnequalArrayLengths");

fcn = @() validateArrayLengths({b, a, c});
testCase.verifyError(fcn, "arrow:tabular:UnequalArrayLengths");

fcn = @() validateArrayLengths({a, c, b});
testCase.verifyError(fcn, "arrow:tabular:UnequalArrayLengths");
end
end
end
56 changes: 56 additions & 0 deletions matlab/test/arrow/tabular/tValidateColumnNames.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
%TVALIDATECOLUMNNAMES Unit tests for
% arrow.tabular.internal.validateColumnNames.

% Licensed to the Apache Software Foundation (ASF) under one or more
% contributor license agreements. See the NOTICE file distributed with
% this work for additional information regarding copyright ownership.
% The ASF licenses this file to you under the Apache License, Version
% 2.0 (the "License"); you may not use this file except in compliance
% with the License. You may obtain a copy of the License at
%
% http://www.apache.org/licenses/LICENSE-2.0
%
% Unless required by applicable law or agreed to in writing, software
% distributed under the License is distributed on an "AS IS" BASIS,
% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
% implied. See the License for the specific language governing
% permissions and limitations under the License.

classdef tValidateColumnNames < matlab.unittest.TestCase

methods(Test)
function ValidColumnNames(testCase)
% Verify validateColumnNames() does not error if the
% column names array has the expected number of elements.

import arrow.tabular.internal.validateColumnNames

columnNames = ["A", "B", "C"];
fcn = @() validateColumnNames(columnNames, 3);
testCase.verifyWarningFree(fcn);

columnNames = string.empty(1, 0);
fcn = @() validateColumnNames(columnNames, 0);
testCase.verifyWarningFree(fcn);
end

function WrongNumberColumnNames(testCase)
% Verify validateColumnNames() errors if the column names
% array provided does not have the correct number of elements.
% The error thrown should have the identifier
% "arrow:tabular:WrongNumberColumnNames";

import arrow.tabular.internal.validateColumnNames

columnNames = ["A", "B", "C"];
fcn = @() validateColumnNames(columnNames, 2);
testCase.verifyError(fcn, "arrow:tabular:WrongNumberColumnNames");

fcn = @() validateColumnNames(columnNames, 4);
testCase.verifyError(fcn, "arrow:tabular:WrongNumberColumnNames");

fcn = @() validateColumnNames(columnNames, 0);
testCase.verifyError(fcn, "arrow:tabular:WrongNumberColumnNames");
end
end
end

0 comments on commit 333f100

Please sign in to comment.