From 02a4b069aae74a852b634b714f9e50fd802d1a30 Mon Sep 17 00:00:00 2001 From: antonio Date: Tue, 14 Jul 2020 02:29:26 -0400 Subject: [PATCH] unsqueezed first dim for datatypes when extension type is included --- .../lxmert_pretraining_beta/to_arrow_data.py | 56 +++++++++---------- src/nlp/arrow_writer.py | 7 +-- 2 files changed, 30 insertions(+), 33 deletions(-) diff --git a/datasets/lxmert_pretraining_beta/to_arrow_data.py b/datasets/lxmert_pretraining_beta/to_arrow_data.py index 57e6e6f8c5df..5faccb00045d 100644 --- a/datasets/lxmert_pretraining_beta/to_arrow_data.py +++ b/datasets/lxmert_pretraining_beta/to_arrow_data.py @@ -7,7 +7,6 @@ import time import json from tqdm import tqdm -from copy import deepcopy import nlp import nlp.features as features @@ -241,46 +240,45 @@ def load_obj_tsv(fname, topk=300): my_features = { "image": features.Array2D(dtype="float32"), "img_id": nlp.Value("string"), - # "boxes": nlp.features.MultiArray(shape=(36, 4), dtype="int32"), - # "img_h": nlp.Value("int32"), - # "img_w": nlp.Value("int32"), - # "labels": nlp.features.MultiArray(shape=(-1, -1), dtype="int32"), - # "labels_confidence": nlp.features.MultiArray(shape=(-1, -1), dtype="float32"), - # "num_boxes": nlp.Value("int32"), - # "attrs_id": nlp.features.Sequence(nlp.ClassLabel(num_classes=400)), - # "objs_id": nlp.features.Sequence(nlp.ClassLabel(num_classes=1600)), - # "attrs_confidence": nlp.features.Sequence(nlp.Value("float32")), - # "objs_confidence": nlp.features.Sequence(nlp.Value("float32")), - # "captions": nlp.features.Sequence(nlp.Value("string")), - # "questions": nlp.features.Sequence(nlp.Value("string")), + "boxes": nlp.features.Array2D(dtype="int32"), + "img_h": nlp.Value("int32"), + "img_w": nlp.Value("int32"), + "labels": nlp.features.Array2D(dtype="int32"), + "labels_confidence": nlp.features.Array2D(dtype="float32"), + "num_boxes": nlp.Value("int32"), + "attrs_id": nlp.features.Sequence(nlp.ClassLabel(num_classes=400)), + "objs_id": nlp.features.Sequence(nlp.ClassLabel(num_classes=1600)), + "attrs_confidence": nlp.features.Sequence(nlp.Value("float32")), + "objs_confidence": nlp.features.Sequence(nlp.Value("float32")), + "captions": nlp.features.Sequence(nlp.Value("string")), + "questions": nlp.features.Sequence(nlp.Value("string")), } ex = { - "image": deepcopy(new[0]["features"].astype("float32")), - # "img_id": deepcopy(str(new[0]["img_id"])), - "img_id": "12" - # "boxes": new[0]["boxes"], - # "img_h": new[0]["img_h"], - # "img_w": new[0]["img_w"], - # "labels": (new[0]["label"], (-1, -1), "int32"), - # "labels_confidence": (new[0]["label_conf"], (-1, -1), "int32"), - # "num_boxes": new[0]["num_boxes"], - # "attrs_id": new[0]["attrs_id"], - # "objs_id": new[0]["objects_id"], - # "attrs_confidence": new[0]["attrs_conf"], - # "objs_confidence": new[0]["objects_conf"], - # "captions": new[0]["sent"], - # "questions": new[0]["question"], + "image": new[0]["features"].astype("float32"), + "img_id": str(new[0]["img_id"]), + "boxes": new[0]["boxes"], + "img_h": new[0]["img_h"], + "img_w": new[0]["img_w"], + "labels": new[0]["label"], + "labels_confidence": new[0]["label_conf"], + "num_boxes": new[0]["num_boxes"], + "attrs_id": new[0]["attrs_id"], + "objs_id": new[0]["objects_id"], + "attrs_confidence": new[0]["attrs_conf"], + "objs_confidence": new[0]["objects_conf"], + "captions": new[0]["sent"], + "questions": new[0]["question"], } my_features = nlp.Features(my_features) writer = ArrowWriter(data_type=my_features.type, path="/tmp/beta.arrow") my_examples = [(0, ex), ] -print("HI", str(new[0]["img_id"])) for key, record in my_examples: example = my_features.encode_example(record) writer.write(example) num_examples, num_bytes = writer.finalize() dataset = nlp.Dataset.from_file("/tmp/beta.arrow") +print(dataset) diff --git a/src/nlp/arrow_writer.py b/src/nlp/arrow_writer.py index f51b01ba78cd..a35f6f055978 100644 --- a/src/nlp/arrow_writer.py +++ b/src/nlp/arrow_writer.py @@ -5,8 +5,6 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 -# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -122,8 +120,9 @@ def write_on_file(self): if ext_cols and self.current_rows: entries = [] for row in self.current_rows: - row_list = [row[col_name] for col_name in self._sorted_names] - print([(type(x), x) for x in row_list]) + row_list = list( + map(lambda x: pa.array([row[x]], self._type[x].type) + if x not in ext_cols else row[x], self._sorted_names)) row = pa.RecordBatch.from_arrays(row_list, schema=self.schema) row = pa.Table.from_batches([row]) entries.append(row)