Skip to content

Commit

Permalink
Added xcat (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ouwen committed Jul 7, 2020
1 parent 4b573a8 commit af8d806
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 0 deletions.
1 change: 1 addition & 0 deletions tensorflow_datasets/image/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,4 @@
from tensorflow_datasets.image_classification.uc_merced import UcMerced
from tensorflow_datasets.image_classification.vgg_face2 import VggFace2
from tensorflow_datasets.image_classification.visual_domain_decathlon import VisualDomainDecathlon
from tensorflow_datasets.image.xcat import Xcat # TODO(xcat) Sort alphabetically
85 changes: 85 additions & 0 deletions tensorflow_datasets/image/xcat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""xcat dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow_datasets.public_api as tfds
import tensorflow as tf
import csv
import os

# TODO(xcat): BibTeX citation
_CITATION = """
"""

# TODO(xcat):
_DESCRIPTION = """
"""


class Xcat(tfds.core.GeneratorBasedBuilder):
"""TODO(xcat): Short description of my dataset."""

VERSION = tfds.core.Version('0.1.0')
MANUAL_DOWNLOAD_INSTRUCTIONS = "Internal Gradient Dataset"
def _info(self):
return tfds.core.DatasetInfo(
builder=self,
description=_DESCRIPTION,
features=tfds.features.FeaturesDict({
'xray': tfds.features.Image(shape=(2048,2048,1), dtype=tf.uint16, encoding_format='png'),
'bse': tfds.features.Image(shape=(2048,2048,1), dtype=tf.uint16, encoding_format='png')
}),
supervised_keys=('xray', 'bse'),
homepage='https://dataset-homepage/',
citation=_CITATION
)

def _split_generators(self, dl_manager):
"""Returns SplitGenerators."""
return [
tfds.core.SplitGenerator(
name=tfds.Split.TRAIN,
gen_kwargs={
'manual_dir': dl_manager.manual_dir,
'file': 'train_body_c7.csv',
'data_dir': 'projections_body_c7'
}
),
tfds.core.SplitGenerator(
name=tfds.Split.TEST,
gen_kwargs={
'manual_dir': dl_manager.manual_dir,
'file': 'test_body_c7.csv',
'data_dir': 'projections_body_c7'
}
)
]

def _generate_examples(self, manual_dir=None, file=None, data_dir=None):
with tf.io.gfile.GFile(os.path.join(manual_dir, file)) as f:
reader = csv.DictReader(f)
for row in reader:
bse_file = row['filename']
xray_file = row['filename'].split('.')
xray_file[2] = 'bone'
xray_file = ".".join(xray_file)

key = row['filename'].split('.')[0]

bse_data = tf.io.read_file(os.path.join(manual_dir, data_dir, bse_file))
xray_data = tf.io.read_file(os.path.join(manual_dir, data_dir, xray_file))
bse = tf.reshape((tf.io.decode_raw(bse_data, tf.float32)), (2048,2048,1))
xray = tf.reshape((tf.io.decode_raw(xray_data, tf.float32)), (2048,2048,1))

bse = (bse - tf.reduce_min(bse))/(tf.reduce_max(bse)-tf.reduce_min(bse))
xray = (xray - tf.reduce_min(xray))/(tf.reduce_max(xray)-tf.reduce_min(xray))

bse = tf.cast(tf.round(bse*65534), tf.uint16)
xray = tf.cast(tf.round(xray*65534), tf.uint16)

yield key, {
'xray': xray.numpy(),
'bse': bse.numpy()
}
28 changes: 28 additions & 0 deletions tensorflow_datasets/image/xcat_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""xcat dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow_datasets.public_api as tfds
from tensorflow_datasets.image import xcat


class XcatTest(tfds.testing.DatasetBuilderTestCase):
# TODO(xcat):
DATASET_CLASS = xcat.Xcat
SPLITS = {
"train": 3, # Number of fake train example
"test": 1, # Number of fake test example
}

# If you are calling `download/download_and_extract` with a dict, like:
# dl_manager.download({'some_key': 'http://a.org/out.txt', ...})
# then the tests needs to provide the fake output paths relative to the
# fake data directory
# DL_EXTRACT_RESULT = {'some_key': 'output_file1.txt', ...}


if __name__ == "__main__":
tfds.testing.test_main()

Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
TODO(xcat): Add fake data in this directory
3 changes: 3 additions & 0 deletions tensorflow_datasets/url_checksums/xcat.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# TODO(xcat): If your dataset downloads files, then the checksums will be
# automatically added here when running the download_and_prepare script
# with --register_checksums.

0 comments on commit af8d806

Please sign in to comment.