Skip to content

Commit

Permalink
bugfixes in dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Dicarlo-Cox committed Apr 20, 2012
1 parent 9b7bb8a commit 1ac1e0f
Showing 1 changed file with 24 additions and 37 deletions.
61 changes: 24 additions & 37 deletions genthor/datasets.py
Expand Up @@ -80,21 +80,21 @@ def _get_meta(self):
print('Generating meta for %s' % model)
for _ind in range(n_ex_per_model):
l = stochastic.sample(template, rng)
l['modelname'] = model
l['obj'] = model
l['category'] = model_categories[model][0]
l['id'] = get_image_id(l)
rec = (l['bgname'],
float(l['bgphi']),
float(l['bgpsi']),
float(l['bgscale']),
l['category'],
l['modelname'],
l['obj'],
float(l['ryz']),
float(l['rxz']),
float(l['rxy']),
float(l['ty']),
float(l['tz']),
float(l['scale']),
float(l['s']),
tname,
l['id'])
latents.append(rec)
Expand All @@ -103,18 +103,18 @@ def _get_meta(self):
'bgpsi',
'bgscale',
'category',
'modelname',
'obj',
'ryz',
'rxz',
'rxy',
'ty',
'tz',
'scale',
's',
'tname',
'id'])
return meta

def get_images(self, dtype, preproc):
def get_images(self, preproc):
name = self.specific_name
basedir = self.home()
cache_file = os.path.join(basedir, name)
Expand Down Expand Up @@ -252,7 +252,7 @@ class GenerativeDataset1(GenerativeDatasetBase):
'bgscale': 1.,
'bgpsi': 0,
'bgphi': uniform(-180.0, 180.),
'scale': 1,
's': 1,
'ty': 0,
'tz': 0,
'ryz': 0,
Expand All @@ -261,40 +261,26 @@ class GenerativeDataset1(GenerativeDatasetBase):
}
},
{'n_ex_per_model': 50,
'name': 'translation',
'name': 'translation_scale',
'template': {'bgname': choice(good_backgrounds),
'bgscale': 1.,
'bgpsi': 0,
'bgphi': uniform(-180.0, 180.),
'scale': 1,
's': loguniform(np.log(2./3), np.log(2.)),
'ty': uniform(-1.0, 1.0),
'tz': uniform(-1.0, 1.0),
'ryz': 0,
'rxy': 0,
'rxz': 0,
}
},
{'n_ex_per_model': 50,
'name': 'scale',
'template': {'bgname': choice(good_backgrounds),
'bgscale': 1.,
'bgpsi': 0,
'bgphi': uniform(-180.0, 180.),
'scale': loguniform(np.log(2./3), np.log(2.)),
'ty': 0,
'tz': 0,
'ryz': 0,
'rxy': 0,
'rxz': 0,
}
},
{'n_ex_per_model': 50,
{'n_ex_per_model': 30,
'name': 'rotation',
'template': {'bgname': choice(good_backgrounds),
'bgscale': 1.,
'bgpsi': 0,
'bgphi': uniform(-180.0, 180.),
'scale': 1,
's': 1,
'ty': 0,
'tz': 0,
'ryz': uniform(-180., 180.),
Expand All @@ -303,19 +289,19 @@ class GenerativeDataset1(GenerativeDatasetBase):
}
},
{'n_ex_per_model': 100,
'name': 'var_all',
'name': 'var1',
'template': {'bgname': choice(good_backgrounds),
'bgscale': 1.,
'bgpsi': 0,
'bgphi': uniform(-180.0, 180.),
'scale': loguniform(np.log(2./3), np.log(2.)),
's': loguniform(np.log(2./3), np.log(2.)),
'ty': uniform(-1.0, 1.0),
'tz': uniform(-1.0, 1.0),
'ryz': uniform(-180., 180.),
'rxy': uniform(-180., 180.),
'rxz': uniform(-180., 180.),
}
}]
}]
specific_name = 'GenerativeDataset1'


Expand All @@ -334,7 +320,7 @@ class GenerativeDatasetTest(GenerativeDataset1):
'bgscale': 1.,
'bgpsi': 0,
'bgphi': uniform(-180.0, 180.),
'scale': loguniform(np.log(2./3), np.log(2.)),
's': loguniform(np.log(2./3), np.log(2.)),
'ty': uniform(-1.0, 1.0),
'tz': uniform(-1.0, 1.0),
'ryz': uniform(-180., 180.),
Expand All @@ -347,7 +333,7 @@ class GenerativeDatasetTest(GenerativeDataset1):

class ImgRendererResizer(object):
def __init__(self, model_root, bg_root, preproc, lbase, output):
self._shape = preproc['size']
self._shape = tuple(preproc['size'])
self._ndim = len(self._shape)
self._dtype = preproc['dtype']
self.mode = preproc['mode']
Expand All @@ -368,9 +354,9 @@ def rval_getattr(self, attr, objs):

def __call__(self, m):
modelpath = os.path.join(self.model_root,
m['modelname'], m['modelname'] + '.bam')
m['obj'], m['obj'] + '.bam')
bgpath = os.path.join(self.bg_root, m['bgname'])
scale = [m['scale']]
scale = [m['s']]
pos = [m['ty'], m['tz']]
hpr = [m['ryz'], m['rxz'], m['rxy']]
bgscale = [m['bgscale']]
Expand Down Expand Up @@ -434,7 +420,7 @@ def _get_meta(self):
'bgp',
'bgscale',
'category',
'model_id',
'obj',
'ryz',
'rxz',
'rxy',
Expand All @@ -448,11 +434,12 @@ def _get_meta(self):
def filenames(self):
return self.meta['filename']

def get_images(self, dtype, preproc):
def get_images(self, preproc):
self.fetch()
size = tuple(preproc['size'])
normalize = preproc['global_normalize']
mode = preproc['mode']
dtype = preproc['dtype']
return larray.lmap(ImgLoaderResizer(inshape=(256, 256),
shape=size,
dtype=dtype,
Expand All @@ -473,7 +460,7 @@ def __init__(self,
mode='RGB',
crop=None,
mask=None):
self.inshape = inshape
self.inshape = tuple(inshape)
assert len(shape) == 2
shape = tuple(shape)
if crop is None:
Expand Down Expand Up @@ -543,7 +530,7 @@ def test_training_dataset():
dataset = TrainingDataset()
meta = dataset.meta
assert len(meta) == 11000
agg = meta[['model_id', 'category']].aggregate(['category'],
agg = meta[['obj', 'category']].aggregate(['category'],
AggFunc=lambda x: len(x))
assert agg.tolist() == [('boats', 1000),
('buildings', 1000),
Expand All @@ -557,7 +544,7 @@ def test_training_dataset():
('reptiles', 1000),
('table', 1000)]

agg2 = meta[['model_id', 'category']].aggregate(['category'],
agg2 = meta[['obj', 'category']].aggregate(['category'],
AggFunc=lambda x : len(np.unique(x)))
assert agg2.tolist() == [('boats', 10),
('buildings', 10),
Expand Down

0 comments on commit 1ac1e0f

Please sign in to comment.