Skip to content

Commit

Permalink
Merge pull request #2044 from deepchem/duplicate_balancing
Browse files Browse the repository at this point in the history
Duplicate Balancing Transformer
  • Loading branch information
Bharath Ramsundar committed Jul 25, 2020
2 parents f34c9ce + 1c9474d commit 59ddc9e
Show file tree
Hide file tree
Showing 20 changed files with 2,180 additions and 1,460 deletions.
16 changes: 9 additions & 7 deletions deepchem/data/datasets.py
Expand Up @@ -831,8 +831,9 @@ def transform(self, transformer: "dc.trans.Transformer",
-------
a newly constructed Dataset object
"""
newx, newy, neww = transformer.transform_array(self._X, self._y, self._w)
return NumpyDataset(newx, newy, neww, self._ids[:])
newx, newy, neww, newids = transformer.transform_array(
self._X, self._y, self._w, self._ids)
return NumpyDataset(newx, newy, neww, newids)

def select(self, indices: Sequence[int],
select_dir: str = None) -> "NumpyDataset":
Expand Down Expand Up @@ -1402,8 +1403,8 @@ def generator():
for shard_num, row in self.metadata_df.iterrows():
logger.info("Transforming shard %d/%d" % (shard_num, n_shards))
X, y, w, ids = self.get_shard(shard_num)
newx, newy, neww = transformer.transform_array(X, y, w)
yield (newx, newy, neww, ids)
newx, newy, neww, newids = transformer.transform_array(X, y, w, ids)
yield (newx, newy, neww, newids)

dataset = DiskDataset.create_dataset(
generator(), data_dir=out_dir, tasks=tasks)
Expand All @@ -1420,7 +1421,7 @@ def _transform_shard(transformer: "dc.trans.Transformer", shard_num: int,
y = None if y_file is None else np.array(load_from_disk(y_file))
w = None if w_file is None else np.array(load_from_disk(w_file))
ids = np.array(load_from_disk(ids_file))
X, y, w = transformer.transform_array(X, y, w)
X, y, w, ids = transformer.transform_array(X, y, w, ids)
basename = "shard-%d" % shard_num
return DiskDataset.write_data_to_disk(out_dir, basename, tasks, X, y, w,
ids)
Expand Down Expand Up @@ -2150,8 +2151,9 @@ def transform(self, transformer: "dc.trans.Transformer",
-------
a newly constructed Dataset object
"""
newx, newy, neww = transformer.transform_array(self.X, self.y, self.w)
return NumpyDataset(newx, newy, neww, self.ids[:])
newx, newy, neww, newids = transformer.transform_array(
self.X, self.y, self.w, self.ids)
return NumpyDataset(newx, newy, neww, newids)

def select(self, indices: Sequence[int],
select_dir: str = None) -> "ImageDataset":
Expand Down

0 comments on commit 59ddc9e

Please sign in to comment.