Skip to content

Commit

Permalink
Merge pull request #3 from david-klindt/patch-2
Browse files Browse the repository at this point in the history
Update dataset.py
  • Loading branch information
ysharma1126 committed Mar 11, 2021
2 parents 3fe5e43 + 2130b24 commit ca9ba64
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions scripts/dataset.py
Expand Up @@ -376,12 +376,20 @@ def __init__(self, path='./data/smallNORB/', download=True,

self.infos = infos[sorted_inds]
self.data = data[sorted_inds].numpy() # is uint8
def sample(self, num, random_state):

def sample_factors(self, num, random_state):
# override super to ignore instance (see https://github.com/google-research/disentanglement_lib/blob/86a644d4ed35c771560dc3360756363d35477357/disentanglement_lib/data/ground_truth/norb.py#L52)
factors, observations = super().sample(num, random_state)
factors = super().sample_factors(self, num, random_state)
if self.evaluate:
factors = np.concatenate([factors[:, :1], factors[:, 2:]], 1)
return factors

def sample_observations_from_factors(self, factors, random_state):
# override super to ignore instance (see https://github.com/google-research/disentanglement_lib/blob/86a644d4ed35c771560dc3360756363d35477357/disentanglement_lib/data/ground_truth/norb.py#L52)
if self.evaluate:
instances = random_state.randint(0, self.factor_sizes[1], factors[:, :1].shape)
factors = np.concatenate([factors[:, :1], instances, factors[:, 2:]], 1)
observations = super().sample_observations_from_factors(self, factors, random_state)
return factors, observations

def __len__(self):
Expand Down

0 comments on commit ca9ba64

Please sign in to comment.