diff --git a/deeplearning2020/datasets.py b/deeplearning2020/datasets.py index 3acb0ec..a3f7958 100644 --- a/deeplearning2020/datasets.py +++ b/deeplearning2020/datasets.py @@ -62,7 +62,7 @@ def __init__(self, dataset: str, file_path: str = None) -> None: print(f"Loaded {self.image_count} images") self.raw_class_names = [ - item for item in tf.io.gfile.listdir(data_dir) if item != "LICENSE.txt" + item.strip("/") for item in tf.io.gfile.listdir(data_dir) if item != "LICENSE.txt" ] self.raw_class_names.sort()