Skip to content

Commit

Permalink
Add input_access utils
Browse files Browse the repository at this point in the history
  • Loading branch information
satyaog committed Dec 11, 2019
1 parent f9e940a commit de30874
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 32 deletions.
41 changes: 41 additions & 0 deletions src/benzina/utils/input_access.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os

import torch


class SubsetSequentialSampler(torch.utils.data.Sampler):
"""
Samples elements sequentially from a given list of indices, without replacement.
Args:
indices (sequence): a sequence of indices
"""
def __init__(self, indices):
self.indices = indices

def __iter__(self):
return (i for i in self.indices)

def __len__(self):
return len(self.indices)


def get_indices_by_names(dataset, filenames):
"""
Retreive the indices of inputs by file names
Args:
dataset (benzina.torch.dataset.Dataset): dataset from which to fetch the indices.
filenames (sequence): a sequence of file names
"""
filenames_indices = {}
filenames_lookup = set(filenames)

with open(os.path.join(dataset.root, "data.filenames"), 'r') as names:
for i, name in enumerate(names):
# skip the end line following the filename
name = name.rstrip()
if name in filenames_lookup:
filenames_indices[name] = i

return [filenames_indices.get(filename, None) for filename in filenames]
4 changes: 0 additions & 4 deletions tests/benzina/torch/meson.build
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
#
# A few tests using this executable
#

pytest = find_program('pytest')

test_devsubset = files('test_devsubset.py')
Expand Down
39 changes: 11 additions & 28 deletions tests/benzina/torch/test_devsubset.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,33 @@
import os

import torch

import benzina.torch as bz


class SubsetSequentialSampler(torch.utils.data.Sampler):
def __init__(self, indices):
self.indices = indices

def __iter__(self):
return (i for i in self.indices)

def __len__(self):
return len(self.indices)
from benzina.utils.input_access import SubsetSequentialSampler, get_indices_by_names


def test_devsubset_pytorch_loading():
dataset_path = os.environ["DATASET_PATH"]

with open("tests/data/devsubset_files", 'r') as devsubset_list:
subset_filenames = []
subset_targets = []
# skip the end line following the filename
subset_targets_map = {line.rstrip()[5:]: int(line[:4]) for line in devsubset_list if line}
for line in devsubset_list:
if not line:
continue
subset_filenames.append(line.rstrip()[5:])
subset_targets.append(int(line[:4]))

subset_filenames = set(subset_targets_map.keys())

with open(os.path.join(dataset_path, "data.filenames"), 'r') as filenames:
subset_indices = []
subset_targets = []
for i, filename in enumerate(filenames):
# skip the end line following the filename
filename = filename.rstrip()
if filename in subset_filenames:
subset_indices.append(i)
subset_targets.append(subset_targets_map[filename])

dataset = bz.ImageNet(dataset_path)
subset_indices = get_indices_by_names(dataset, subset_filenames)

subset_sampler = SubsetSequentialSampler(subset_indices)

subset_loader = bz.DataLoader(dataset,
batch_size=100,
sampler=subset_sampler,
seed=0,
shape=(256,256))
shape=(256, 256))

for start_i, (images, targets) in zip(range(0, len(subset_sampler), 100), subset_loader):
for i, (image, target) in enumerate(zip(images, targets)):
assert image.size() == (3, 256, 256)
Expand Down

0 comments on commit de30874

Please sign in to comment.