Skip to content

Commit

Permalink
Add pytorch test
Browse files Browse the repository at this point in the history
  • Loading branch information
satyaog committed Nov 26, 2019
1 parent df1e69b commit 92fe2b6
Show file tree
Hide file tree
Showing 4 changed files with 3,078 additions and 0 deletions.
9 changes: 9 additions & 0 deletions tests/benzina/torch/meson.build
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#
# A few tests using this executable
#

pytest = find_program('pytest')

test_devsubset = files('test_devsubset.py')

test('Parse devsubset', pytest, suite: 'torch', args: [test_devsubset], timeout: 120, workdir: meson.source_root())
68 changes: 68 additions & 0 deletions tests/benzina/torch/test_devsubset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import os

import torch

import benzina.torch as bz


class Subset(torch.utils.data.Dataset):
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = indices

def __getitem__(self, idx):
return self.dataset[self.indices[idx]]

def __len__(self):
return len(self.indices)

@property
def shape(self):
return self.dataset.shape

@property
def root(self):
return self._root

@property
def _core(self):
return self.dataset._core

@property
def _root(self):
return self.dataset._root


def test_devsubset_pytorch_loading():
dataset_path = os.environ["DATASET_PATH"]

with open("tests/data/devsubset_files", 'r') as devsubset_list:
# skip the end line following the filename
subset_targets_map = {line.rstrip()[5:]: int(line[:4]) for line in devsubset_list if line}

subset_filenames = set(subset_targets_map.keys())

with open(os.path.join(dataset_path, "data.filenames"), 'r') as filenames:
subset_indices = []
subset_targets = []
for i, filename in enumerate(filenames):
# skip the end line following the filename
filename = filename.rstrip()
if filename in subset_filenames:
subset_indices.append(i)
subset_targets.append(subset_targets_map[filename])

dataset = bz.ImageNet(dataset_path)

subset = Subset(dataset, subset_indices)

subset_loader = bz.DataLoader(subset,
batch_size=100,
seed=0,
shape=(256,256))

for start_i, (images, targets) in zip(range(0, len(subset), 100), subset_loader):
for i, (image, target) in enumerate(zip(images, targets)):
assert image.size() == (3, 256, 256)
assert image.sum().item() > 0
assert target.item() == subset_targets[start_i + i]

0 comments on commit 92fe2b6

Please sign in to comment.