Skip to content

Commit

Permalink
unsqueezed first dim for datatypes when extension type is included
Browse files Browse the repository at this point in the history
  • Loading branch information
eltoto1219 committed Jul 14, 2020
1 parent 25e5617 commit 02a4b06
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 33 deletions.
56 changes: 27 additions & 29 deletions datasets/lxmert_pretraining_beta/to_arrow_data.py
Expand Up @@ -7,7 +7,6 @@
import time
import json
from tqdm import tqdm
from copy import deepcopy

import nlp
import nlp.features as features
Expand Down Expand Up @@ -241,46 +240,45 @@ def load_obj_tsv(fname, topk=300):
my_features = {
"image": features.Array2D(dtype="float32"),
"img_id": nlp.Value("string"),
# "boxes": nlp.features.MultiArray(shape=(36, 4), dtype="int32"),
# "img_h": nlp.Value("int32"),
# "img_w": nlp.Value("int32"),
# "labels": nlp.features.MultiArray(shape=(-1, -1), dtype="int32"),
# "labels_confidence": nlp.features.MultiArray(shape=(-1, -1), dtype="float32"),
# "num_boxes": nlp.Value("int32"),
# "attrs_id": nlp.features.Sequence(nlp.ClassLabel(num_classes=400)),
# "objs_id": nlp.features.Sequence(nlp.ClassLabel(num_classes=1600)),
# "attrs_confidence": nlp.features.Sequence(nlp.Value("float32")),
# "objs_confidence": nlp.features.Sequence(nlp.Value("float32")),
# "captions": nlp.features.Sequence(nlp.Value("string")),
# "questions": nlp.features.Sequence(nlp.Value("string")),
"boxes": nlp.features.Array2D(dtype="int32"),
"img_h": nlp.Value("int32"),
"img_w": nlp.Value("int32"),
"labels": nlp.features.Array2D(dtype="int32"),
"labels_confidence": nlp.features.Array2D(dtype="float32"),
"num_boxes": nlp.Value("int32"),
"attrs_id": nlp.features.Sequence(nlp.ClassLabel(num_classes=400)),
"objs_id": nlp.features.Sequence(nlp.ClassLabel(num_classes=1600)),
"attrs_confidence": nlp.features.Sequence(nlp.Value("float32")),
"objs_confidence": nlp.features.Sequence(nlp.Value("float32")),
"captions": nlp.features.Sequence(nlp.Value("string")),
"questions": nlp.features.Sequence(nlp.Value("string")),
}


ex = {
"image": deepcopy(new[0]["features"].astype("float32")),
# "img_id": deepcopy(str(new[0]["img_id"])),
"img_id": "12"
# "boxes": new[0]["boxes"],
# "img_h": new[0]["img_h"],
# "img_w": new[0]["img_w"],
# "labels": (new[0]["label"], (-1, -1), "int32"),
# "labels_confidence": (new[0]["label_conf"], (-1, -1), "int32"),
# "num_boxes": new[0]["num_boxes"],
# "attrs_id": new[0]["attrs_id"],
# "objs_id": new[0]["objects_id"],
# "attrs_confidence": new[0]["attrs_conf"],
# "objs_confidence": new[0]["objects_conf"],
# "captions": new[0]["sent"],
# "questions": new[0]["question"],
"image": new[0]["features"].astype("float32"),
"img_id": str(new[0]["img_id"]),
"boxes": new[0]["boxes"],
"img_h": new[0]["img_h"],
"img_w": new[0]["img_w"],
"labels": new[0]["label"],
"labels_confidence": new[0]["label_conf"],
"num_boxes": new[0]["num_boxes"],
"attrs_id": new[0]["attrs_id"],
"objs_id": new[0]["objects_id"],
"attrs_confidence": new[0]["attrs_conf"],
"objs_confidence": new[0]["objects_conf"],
"captions": new[0]["sent"],
"questions": new[0]["question"],
}


my_features = nlp.Features(my_features)
writer = ArrowWriter(data_type=my_features.type, path="/tmp/beta.arrow")
my_examples = [(0, ex), ]
print("HI", str(new[0]["img_id"]))
for key, record in my_examples:
example = my_features.encode_example(record)
writer.write(example)
num_examples, num_bytes = writer.finalize()
dataset = nlp.Dataset.from_file("/tmp/beta.arrow")
print(dataset)
7 changes: 3 additions & 4 deletions src/nlp/arrow_writer.py
Expand Up @@ -5,8 +5,6 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
Expand Down Expand Up @@ -122,8 +120,9 @@ def write_on_file(self):
if ext_cols and self.current_rows:
entries = []
for row in self.current_rows:
row_list = [row[col_name] for col_name in self._sorted_names]
print([(type(x), x) for x in row_list])
row_list = list(
map(lambda x: pa.array([row[x]], self._type[x].type)
if x not in ext_cols else row[x], self._sorted_names))
row = pa.RecordBatch.from_arrays(row_list, schema=self.schema)
row = pa.Table.from_batches([row])
entries.append(row)
Expand Down

0 comments on commit 02a4b06

Please sign in to comment.