Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions Core/src/test/java/org/tribuo/test/MockMultiOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,11 +231,13 @@ public String toString() {
StringBuilder builder = new StringBuilder();

builder.append("(LabelSet={");
for (MockOutput l : labels) {
builder.append(l.toString());
builder.append(',');
if (labels.size() > 0) {
for (MockOutput l : labels) {
builder.append(l.toString());
builder.append(',');
}
builder.deleteCharAt(builder.length() - 1);
}
builder.deleteCharAt(builder.length()-1);
builder.append('}');
if (!Double.isNaN(score)) {
builder.append(",OverallScore=");
Expand Down
32 changes: 31 additions & 1 deletion Data/src/main/java/org/tribuo/data/csv/CSVLoader.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -68,6 +68,12 @@
* {@link org.tribuo.data.columnar.RowProcessor} to cope with your specific input format.
* <p>
* CSVLoader is thread safe and immutable.
* <p>
* Multi-output responses such as {@code MultiLabel} or {@code Regressor} can be processed in
* two different ways either as a single column of separated values, or multiple columns. If
* there is a single column the value is passed directly to the {@link OutputFactory}. If
* there are multiple response columns then the name of the column is concatenated with the
* value, then a list of the concatenated values is passed to the {@link OutputFactory}.
* @param <T> The type of the output generated.
*/
public class CSVLoader<T extends Output<T>> {
Expand Down Expand Up @@ -139,6 +145,10 @@ public MutableDataset<T> load(Path csvPath, String responseName, String[] header
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The path to load.
* @param responseNames The names of the response variables.
Expand All @@ -154,6 +164,10 @@ public MutableDataset<T> load(Path csvPath, Set<String> responseNames) throws IO
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The path to load.
* @param responseNames The names of the response variables.
Expand Down Expand Up @@ -220,6 +234,10 @@ public DataSource<T> loadDataSource(URL csvPath, String responseName, String[] h
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The csv to load from.
* @param responseNames The names of the response variables.
Expand All @@ -235,6 +253,10 @@ public DataSource<T> loadDataSource(Path csvPath, Set<String> responseNames) thr
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The csv to load from.
* @param responseNames The names of the response variables.
Expand All @@ -250,6 +272,10 @@ public DataSource<T> loadDataSource(URL csvPath, Set<String> responseNames) thro
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The csv to load from.
* @param responseNames The names of the response variables.
Expand All @@ -266,6 +292,10 @@ public DataSource<T> loadDataSource(Path csvPath, Set<String> responseNames, Str
* <p>
* The {@code responseNames} set is traversed in iteration order to emit outputs,
* and should be an ordered set to ensure reproducibility.
* <p>
* If there are multiple elements in {@code responseNames} then the responses are
* processed into the form 'column-name=column-value' before being passed to the
* {@link OutputFactory} for conversion into an {@link Output}.
*
* @param csvPath The csv to load from.
* @param responseNames The names of the response variables.
Expand Down
12 changes: 10 additions & 2 deletions Data/src/test/java/org/tribuo/data/csv/CSVLoaderTest.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -82,13 +82,21 @@ public void testLoadMultiOutput() throws IOException {
assertTrue(data.getExample(1).getOutput().contains("R1"));
assertTrue(data.getExample(1).getOutput().contains("R2"));


//
// Row #2: R1=False and R2=False.
// In this case, the labelSet is empty and the labelString is the empty string.
assertEquals(0, data.getExample(2).getOutput().getLabelSet().size());
assertEquals("", data.getExample(2).getOutput().getLabelString());
assertTrue(data.getExample(2).validateExample());

URL singlePath = CSVLoaderTest.class.getResource("/org/tribuo/data/csv/test-multioutput-singlecolumn.csv");
DataSource<MockMultiOutput> singleSource = loader.loadDataSource(singlePath, "Label");
MutableDataset<MockMultiOutput> singleData = new MutableDataset<>(singleSource);
assertEquals(6, singleData.size());

for (int i = 0; i < 6; i++) {
assertEquals(data.getExample(i).getOutput().getLabelString(), singleData.getExample(i).getOutput().getLabelString());
}
}

@Test
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
A,B,C,D,Label
1,2,3,4,"R1"
6,7,8,9,"R1,R2"
6,7,8,9,
2,5,3,4,"R1"
1,2,5,9,"R2"
0,2,5,9,"R2"
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,13 @@ public String toString() {
StringBuilder builder = new StringBuilder();

builder.append("(LabelSet={");
for (Label l : labels) {
builder.append(l.toString());
builder.append(',');
if (labels.size() > 0) {
for (Label l : labels) {
builder.append(l.toString());
builder.append(',');
}
builder.deleteCharAt(builder.length() - 1);
}
builder.deleteCharAt(builder.length()-1);
builder.append('}');
if (!Double.isNaN(score)) {
builder.append(",OverallScore=");
Expand Down