-
Notifications
You must be signed in to change notification settings - Fork 8
/
ExampleWriteSupport.java
57 lines (48 loc) · 1.56 KB
/
ExampleWriteSupport.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
package me.lyh.parquet.tensorflow;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.hadoop.api.WriteSupport;
import org.apache.parquet.io.api.RecordConsumer;
import org.apache.parquet.schema.MessageType;
import org.apache.parquet.schema.MessageTypeParser;
import org.tensorflow.proto.example.Example;
import org.tensorflow.proto.example.Features;
import java.util.Collections;
public class ExampleWriteSupport extends WriteSupport<Example> {
private Schema schema;
private RecordConsumer recordConsumer;
public ExampleWriteSupport() {}
public ExampleWriteSupport(Schema schema) {
this.schema = schema;
}
@Override
public WriteContext init(Configuration configuration) {
MessageType messageType;
if (schema == null) {
String schemaString = configuration.get(ExampleParquetOutputFormat.SCHEMA_KEY);
messageType = MessageTypeParser.parseMessageType(schemaString);
schema = Schema.fromParquet(messageType);
} else {
messageType = schema.toParquet();
}
return new WriteContext(messageType, Collections.emptyMap());
}
@Override
public String getName() {
return "example";
}
@Override
public void prepareForWrite(RecordConsumer recordConsumer) {
this.recordConsumer = recordConsumer;
}
@Override
public void write(Example record) {
recordConsumer.startMessage();
int i = 0;
Features features = record.getFeatures();
for (Schema.Field field : schema.getFields()) {
field.write(i, recordConsumer, features);
i++;
}
recordConsumer.endMessage();
}
}