In [186]:

import csv
from typing import List,Dict,Tuple
import copy
import json
from collections import OrderedDict
import enum
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Normalize, CenterCrop


In [15]:
class INaturalistSplit(enum.Enum):
  #"""The different split for the iNaturalist dataset."""
    USER_120K = enum.auto()
    GEO_100 = enum.auto()
    GEO_300 = enum.auto()
    GEO_1K = enum.auto()
    GEO_3K = enum.auto()
    GEO_10K = enum.auto()
    GEO_30K = enum.auto()

    def __repr__(self):
        return '<%s.%s>' % (self.__class__.__name__, self.name)

In [18]:
INaturalistSplit.USER_120K

<INaturalistSplit.USER_120K>

In [20]:
repr(INaturalistSplit)

"<enum 'INaturalistSplit'>"

In [21]:
b=3
b.name

AttributeError: 'int' object has no attribute 'name'

In [None]:
class ClientData(object, metaclass=abc.ABCMeta):
    """Object to hold a federated dataset.
    The federated dataset is represented as a list of client ids, and
    a function to look up the local dataset for each client id.
    Note: Cross-device federated learning does not use client IDs or perform any
    tracking of clients. However in simulation experiments using centralized test
    data the experimenter may select specific clients to be processed per round.
    The concept of a client ID is only available at the preprocessing stage when
    preparing input data for the simulation and is not part of the TensorFlow
    Federated core APIs.
    Each client's local dataset is represented as a `tf.data.Dataset`, but
    generally this class (and the corresponding datasets hosted by TFF) can
    easily be consumed by any Python-based ML framework as `numpy` arrays:
    ```python
    import tensorflow as tf
    import tensorflow_federated as tff
    import tensorflow_datasets as tfds
    for client_id in sampled_client_ids[:5]:
    client_local_dataset = tfds.as_numpy(
        emnist_train.create_tf_dataset_for_client(client_id))
    # client_local_dataset is an iterable of structures of numpy arrays
    for example in client_local_dataset:
      print(example)
    ```
    If desiring a manner for constructing ClientData objects for testing purposes,
    please see the `tff.simulation.datasets.TestClientData` class, as it provides
    an easy way to construct toy federated datasets.
    """

    @abc.abstractproperty
    def client_ids(self) -> List[str]:
        """A list of string identifiers for clients in this dataset."""
        pass

    @abc.abstractproperty
    def serializable_dataset_fn(self):
        """A callable accepting a client ID and returning a `tf.data.Dataset`.
        Note that this callable must be traceable by TF, as it will be used in the
        context of a `tf.function`.
        """
        pass

    def create_tf_dataset_for_client(self, client_id: str) -> tf.data.Dataset:
        """Creates a new `tf.data.Dataset` containing the client training examples.
        This function will create a dataset for a given client, given that
        `client_id` is contained in the `client_ids` property of the `ClientData`.
        Unlike `create_dataset`, this method need not be serializable.
        Args:
          client_id: The string client_id for the desired client.
        Returns:
          A `tf.data.Dataset` object.
        """
        if client_id not in self.client_ids:
            raise ValueError(
              'ID [{i}] is not a client in this ClientData. See '
              'property `client_ids` for the list of valid ids.'.format(
                  i=client_id))
        return self.serializable_dataset_fn(client_id)

    @property
    def dataset_computation(self):
        """A `tff.Computation` accepting a client ID, returning a dataset.
        Note: the `dataset_computation` property is intended as a TFF-specific
        performance optimization for distributed execution.
        """
        if (not hasattr(self, '_cached_dataset_computation')) or (
            self._cached_dataset_computation is None):

            @computations.tf_computation(tf.string)
            def dataset_computation(client_id):
                return self.serializable_dataset_fn(client_id)

            self._cached_dataset_computation = dataset_computation
        return self._cached_dataset_computation

    @abc.abstractproperty
    def element_type_structure(self):
    """The element type information of the client datasets.
    Returns:
      A nested structure of `tf.TensorSpec` objects defining the type of the
    elements returned by datasets in this `ClientData` object.
    """
        pass

    def datasets(
      self,
      limit_count: Optional[int] = None,
      seed: Optional[Union[int, Sequence[int]]] = None
    ) -> Iterable[tf.data.Dataset]:
    """Yields the `tf.data.Dataset` for each client in random order.
    This function is intended for use building a static array of client data
    to be provided to the top-level federated computation.
    Args:
      limit_count: Optional, a maximum number of datasets to return.
      seed: Optional, a seed to determine the order in which clients are
        processed in the joined dataset. The seed can be any nonnegative 32-bit
        integer, an array of such integers, or `None`.
    """
    check_numpy_random_seed(seed)
    # Create a copy to prevent the original list being reordered
    client_ids = self.client_ids.copy()
    np.random.RandomState(seed=seed).shuffle(client_ids)
    count = 0
    for client_id in client_ids:
        if limit_count is not None and count >= limit_count:
        return
        count += 1
        dataset = self.create_tf_dataset_for_client(client_id)
        py_typecheck.check_type(dataset, tf.data.Dataset)
        yield dataset

    def create_tf_dataset_from_all_clients(
      self,
      seed: Optional[Union[int, Sequence[int]]] = None) -> tf.data.Dataset:
    """Creates a new `tf.data.Dataset` containing _all_ client examples.
    This function is intended for use training centralized, non-distributed
    models (num_clients=1). This can be useful as a point of comparison
    against federated models.
    Currently, the implementation produces a dataset that contains
    all examples from a single client in order, and so generally additional
    shuffling should be performed.
    Args:
      seed: Optional, a seed to determine the order in which clients are
        processed in the joined dataset. The seed can be any nonnegative 32-bit
        integer, an array of such integers, or `None`.
    Returns:
      A `tf.data.Dataset` object.
    """
    check_numpy_random_seed(seed)
    client_ids = self.client_ids.copy()
    np.random.RandomState(seed=seed).shuffle(client_ids)
    nested_dataset = tf.data.Dataset.from_tensor_slices(client_ids)
    # We apply serializable_dataset_fn here to avoid loading all client datasets
    # in memory, which is slow. Note that tf.data.Dataset.map implicitly wraps
    # the input mapping in a tf.function.
    example_dataset = nested_dataset.flat_map(self.serializable_dataset_fn)
    return example_dataset

  def preprocess(
      self, preprocess_fn: Callable[[tf.data.Dataset],
                                    tf.data.Dataset]) -> 'ClientData':
    """Applies `preprocess_fn` to each client's data.
    Args:
      preprocess_fn: A callable accepting a `tf.data.Dataset` and returning a
        preprocessed `tf.data.Dataset`. This function must be traceable by TF.
    Returns:
      A `tff.simulation.datasets.ClientData`.
    Raises:
      IncompatiblePreprocessFnError: If `preprocess_fn` is a `tff.Computation`.
    """
    py_typecheck.check_callable(preprocess_fn)
    if isinstance(preprocess_fn, computation_base.Computation):
      raise IncompatiblePreprocessFnError()
    return PreprocessClientData(self, preprocess_fn)

  @classmethod
  def from_clients_and_tf_fn(
      cls,
      client_ids: Iterable[str],
      serializable_dataset_fn: Callable[[str], tf.data.Dataset],
  ) -> 'ClientData':
    """Constructs a `ClientData` based on the given function.
    Args:
      client_ids: A non-empty list of strings to use as input to
        `create_dataset_fn`.
      serializable_dataset_fn: A function that takes a client_id from the above
        list, and returns a `tf.data.Dataset`. This function must be
        serializable and usable within the context of a `tf.function` and
        `tff.Computation`.
    Returns:
      A `ClientData` object.
    """
    return ConcreteClientData(client_ids, serializable_dataset_fn)

  @classmethod
  def train_test_client_split(
      cls,
      client_data: 'ClientData',
      num_test_clients: int,
      seed: Optional[Union[int, Sequence[int]]] = None
  ) -> Tuple['ClientData', 'ClientData']:
    """Returns a pair of (train, test) `ClientData`.
    This method partitions the clients of `client_data` into two `ClientData`
    objects with disjoint sets of `ClientData.client_ids`. All clients in the
    test `ClientData` are guaranteed to have non-empty datasets, but the
    training `ClientData` may have clients with no data.
    Note: This method may be expensive, and so it may be useful to avoid calling
    multiple times and holding on to the results.
    Args:
      client_data: The base `ClientData` to split.
      num_test_clients: How many clients to hold out for testing. This can be at
        most len(client_data.client_ids) - 1, since we don't want to produce
        empty `ClientData`.
      seed: Optional seed to fix shuffling of clients before splitting. The seed
        can be any nonnegative 32-bit integer, an array of such integers, or
        `None`.
    Returns:
      A pair (train_client_data, test_client_data), where test_client_data
      has `num_test_clients` selected at random, subject to the constraint they
      each have at least 1 batch in their dataset.
    Raises:
      ValueError: If `num_test_clients` cannot be satistifed by `client_data`,
        or too many clients have empty datasets.
    """
    if num_test_clients <= 0:
      raise ValueError('Please specify num_test_clients > 0.')

    if len(client_data.client_ids) <= num_test_clients:
      raise ValueError('The client_data supplied has only {} clients, but '
                       '{} test clients were requested.'.format(
                           len(client_data.client_ids), num_test_clients))

    check_numpy_random_seed(seed)
    train_client_ids = list(client_data.client_ids)
    np.random.RandomState(seed).shuffle(train_client_ids)
    # These clients will be added back into the training set at the end.
    clients_with_insufficient_batches = []
    test_client_ids = []
    while len(test_client_ids) < num_test_clients:
      if not train_client_ids or (
          # Arbitrarily threshold where "many" (relative to num_test_clients)
          # clients have no data. Note: If needed, we could make this limit
          # configurable.
          len(clients_with_insufficient_batches) > 5 * num_test_clients + 10):

        raise ValueError('Encountered too many clients with no data.')

      client_id = train_client_ids.pop()
      dataset = client_data.create_tf_dataset_for_client(client_id)
      try:
        _ = next(dataset.__iter__())
      except StopIteration:
        logging.warning('Client %s had no data, skipping.', client_id)
        clients_with_insufficient_batches.append(client_id)
        continue

      test_client_ids.append(client_id)

    # Invariant for successful exit of the above loop:
    assert len(test_client_ids) == num_test_clients

    def from_ids(client_ids: Iterable[str]) -> 'ClientData':
      return cls.from_clients_and_tf_fn(client_ids,
                                        client_data.serializable_dataset_fn)

    return (from_ids(train_client_ids + clients_with_insufficient_batches),
            from_ids(test_client_ids))

In [44]:
def load_data(
    image_dir: str = 'images',
    cache_dir: str = 'cache',
    split: INaturalistSplit = INaturalistSplit.USER_120K):
    #-> Tuple[ClientData, tf.data.Dataset]:#
    """Loads a federated version of the iNaturalist 2017 dataset.
    If the dataset is loaded for the first time, the images for the entire
    iNaturalist 2017 dataset will be downloaded from AWS Open Data Program.
    The dataset is created from the images stored inside the image_dir. Once the
    dataset is created, it will be cached inside the cache directory.
    The `tf.data.Datasets` returned by
    `tff.simulation.datasets.ClientData.create_tf_dataset_for_client` will yield
    `collections.OrderedDict` objects at each iteration, with the following keys
    and values:
    -   `'image/decoded'`: A `tf.Tensor` with `dtype=tf.uint8` that
        corresponds to the pixels of the images.
    -   `'class'`: A `tf.Tensor` with `dtype=tf.int64` and shape [1],
        corresponding to the class label.
    Seven splits of iNaturalist datasets are available. The details of each
    different dataset split can be found in https://arxiv.org/abs/2003.08082.
    For the USER_120K dataset, the images are split by the user id.
    The number of clients for USER_120K is 9275. The training set contains 120.300
    images of 1203 species, and test set contains 35641 images.
    For the GEO_* datasets, the images are splitted by the geo location.
    The number of clients for the GEO_* datasets:
    1. GEO_100: 3607.
    2. GEO_300: 1209.
    3. GEO_1K: 369.
    4: GEO_3K: 136.
    5. GEO_10K: 39.
    6. GEO_30K: 12.
    Args:
    image_dir: (Optional) The directory containing the images downloaded from
              https://github.com/visipedia/inat_comp/tree/master/2017
    cache_dir: (Optional) The directory to cache the created datasets.
    split: (Optional) The split of the dataset, default to be split by users.
    Returns:
    Tuple of (train, test) where the tuple elements are
    a `tff.simulation.datasets.ClientData` and a  `tf.data.Dataset`.
    """
    logging.basicConfig(filename='load_data.log', level=logging.INFO)
    logger = logging.getLogger(LOGGER)
    logger.info('Start to load data.')
    if not os.path.exists(cache_dir):
        logger.info('Creating cache directory.')
        os.mkdir(cache_dir)
    try:
        return _load_data_from_cache(cache_dir, split)
    except Exception:  # pylint: disable=broad-except:
        if not image_dir:
            raise ValueError('image_dir cannot be empty or none.')
        if not os.path.isdir(image_dir):
            logger.error('Image directory %s does not exist', image_dir)
            raise ValueError('%s does not exist or is not a directory' % image_dir)
    logger.info('Start to download the images for the training set.')
    tf.keras.utils.get_file(
        'train_val_images.tar.gz',
        origin=INAT_TRAIN_IMAGE_URL,
        file_hash=INAT_TRAIN_IMAGE_MD5_CHECKSUM,
        hash_algorithm='md5',
        extract=True,
        cache_dir=image_dir)
    logger.info('Finish to download the images for the training set.')
    logger.info('Start to download the images for the testing set.')
    tf.keras.utils.get_file(
        'test2017.tar.gz',
        origin=INAT_TEST_IMAGE_URL,
        file_hash=INAT_TEST_IMAGE_MD5_CHECKSUM,
        hash_algorithm='md5',
        extract=True,
        cache_dir=image_dir)
    logger.info('Finish to download the images for the testing set.')
    return _generate_data_from_image_dir(image_dir, cache_dir, split)

In [49]:
def _create_train_data_files(image_path_map: Dict[str, str], cache_dir: str,
                             split: INaturalistSplit, train_path: str):
    
    '''
    Create the train data and persist it into a separate file per user.
    Args:
    image_path_map: The dictionary containing the image id to image path
      mapping.
    cache_dir: The directory containing the created datasets.
    split: The split of the federated iNaturalist 2017 dataset.
    train_path: The path to the mapping file for training data.
    '''
    logger = logging.getLogger(LOGGER)

    mapping_table = utils.read_csv(train_path)
    user_id_col = split.name.lower()
    expected_cols = [user_id_col, 'image_id', 'class']
    if not all(col in mapping_table[0].keys() for col in expected_cols):
        logger.error('%s has wrong format.', train_path)
        raise ValueError(
            'The mapping file must contain the user_id for the chosen split, image_id and class columns. '
            'The existing columns are %s' % ','.join(mapping_table[0].keys()))
    cache_dir = os.path.join(cache_dir, split.name.lower(), TRAIN_SUB_DIR)
    if not os.path.exists(cache_dir):
        logger.info('Creating cache directory for training data.')
        os.makedirs(cache_dir)
    mapping_per_user = collections.defaultdict(list)
    for row in mapping_table:
        user_id = row[user_id_col]
        if user_id != 'NA':
            mapping_per_user[user_id].append(row)
    for user_id, data in mapping_per_user.items():
        examples = _create_dataset_with_mapping(image_path_map, data)
    with tf.io.TFRecordWriter(os.path.join(cache_dir, str(user_id))) as writer:
        for example in examples:
            writer.write(example.SerializeToString())
        logger.info('Created tfrecord file for user %s with %d examples, at %s',
                  user_id, len(examples), cache_dir)

In [50]:
def _create_train_data_files(image_path_map = "/MD1400/jinkyu/train_val_images", cache_dir: str,
                             split, train_path: str)

SyntaxError: invalid syntax (2114374364.py, line 2)

In [27]:
import tensorflow as tff

ModuleNotFoundError: No module named 'tensorflow'

In [106]:
import os
import os.path
from typing import Any, Callable, Dict, List, Optional, Union, Tuple

from PIL import Image

from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg
from torchvision.datasets.vision import VisionDataset

CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]

DATASET_URLS = {
    "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
    "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
    "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
    "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
    "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
    "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
}

DATASET_MD5 = {
    "2017": "7c784ea5e424efaec655bd392f87301f",
    "2018": "b1c6952ce38f31868cc50ea72d066cc3",
    "2019": "c60a6e2962c9b8ccbd458d12c8582644",
    "2021_train": "38a7bb733f7a09214d44293460ec0021",
    "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
    "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
}


class INaturalist(VisionDataset):
    """`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.

    Args:
        root (string): Root directory of dataset where the image files are stored.
            This class does not require/use annotation files.
        version (string, optional): Which version of the dataset to download/use. One of
            '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
            Default: `2021_train`.
        target_type (string or list, optional): Type of target to use, for 2021 versions, one of:

            - ``full``: the full category (species)
            - ``kingdom``: e.g. "Animalia"
            - ``phylum``: e.g. "Arthropoda"
            - ``class``: e.g. "Insecta"
            - ``order``: e.g. "Coleoptera"
            - ``family``: e.g. "Cleridae"
            - ``genus``: e.g. "Trichodes"

            for 2017-2019 versions, one of:

            - ``full``: the full (numeric) category
            - ``super``: the super category, e.g. "Amphibians"

            Can also be a list to output a tuple with all specified target types.
            Defaults to ``full``.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    def __init__(
        self,
        root: str,
        version: str = "2021_train",
        target_type: Union[List[str], str] = "full",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
    ) -> None:
        self.version = verify_str_arg(version, "version", DATASET_URLS.keys())

        super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)

        os.makedirs(root, exist_ok=True)
        if download:
            self.download()
        print("self.root",self.root)
        print(self._check_integrity())
        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")

        self.all_categories: List[str] = []

        # map: category type -> name of category -> index
        self.categories_index: Dict[str, Dict[str, int]] = {}

        # list indexed by category id, containing mapping from category type -> index
        self.categories_map: List[Dict[str, int]] = []

        if not isinstance(target_type, list):
            target_type = [target_type]
        if self.version[:4] == "2021":
            self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
            self._init_2021()
        else:
            self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
            self._init_pre2021()

        # index of all files: (full category id, filename)
        self.index: List[Tuple[int, str]] = []

        for dir_index, dir_name in enumerate(self.all_categories):
            files = os.listdir(os.path.join(self.root, dir_name))
            for fname in files:
                self.index.append((dir_index, fname))

    def _init_2021(self) -> None:
        """Initialize based on 2021 layout"""

        self.all_categories = sorted(os.listdir(self.root))

        # map: category type -> name of category -> index
        self.categories_index = {k: {} for k in CATEGORIES_2021}

        for dir_index, dir_name in enumerate(self.all_categories):
            pieces = dir_name.split("_")
            if len(pieces) != 8:
                raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
            if pieces[0] != f"{dir_index:05d}":
                raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
            cat_map = {}
            for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
                if name in self.categories_index[cat]:
                    cat_id = self.categories_index[cat][name]
                else:
                    cat_id = len(self.categories_index[cat])
                    self.categories_index[cat][name] = cat_id
                cat_map[cat] = cat_id
            self.categories_map.append(cat_map)

    def _init_pre2021(self) -> None:
        """Initialize based on 2017-2019 layout"""

        # map: category type -> name of category -> index
        self.categories_index = {"super": {}}

        cat_index = 0
        super_categories = sorted(os.listdir(self.root))
        for sindex, scat in enumerate(super_categories):
            self.categories_index["super"][scat] = sindex
            subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
            for subcat in subcategories:
                if self.version == "2017":
                    # this version does not use ids as directory names
                    subcat_i = cat_index
                    cat_index += 1
                else:
                    try:
                        subcat_i = int(subcat)
                    except ValueError:
                        raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
                if subcat_i >= len(self.categories_map):
                    old_len = len(self.categories_map)
                    self.categories_map.extend([{}] * (subcat_i - old_len + 1))
                    self.all_categories.extend([""] * (subcat_i - old_len + 1))
                if self.categories_map[subcat_i]:
                    raise RuntimeError(f"Duplicate category {subcat}")
                self.categories_map[subcat_i] = {"super": sindex}
                self.all_categories[subcat_i] = os.path.join(scat, subcat)

        # validate the dictionary
        for cindex, c in enumerate(self.categories_map):
            if not c:
                raise RuntimeError(f"Missing category {cindex}")

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """

        cat_id, fname = self.index[index]
        img = Image.open(os.path.join(self.root, self.all_categories[cat_id], fname))
        print("fname",fname)
        target: Any = []
        for t in self.target_type:
            if t == "full":
                target.append(cat_id)
            else:
                target.append(self.categories_map[cat_id][t])
        target = tuple(target) if len(target) > 1 else target[0]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def __len__(self) -> int:
        return len(self.index)

    def category_name(self, category_type: str, category_id: int) -> str:
        """
        Args:
            category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
            category_id(int): an index (class id) from this category

        Returns:
            the name of the category
        """
        if category_type == "full":
            return self.all_categories[category_id]
        else:
            if category_type not in self.categories_index:
                raise ValueError(f"Invalid category type '{category_type}'")
            else:
                for name, id in self.categories_index[category_type].items():
                    if id == category_id:
                        return name
                raise ValueError(f"Invalid category id {category_id} for {category_type}")


    def _check_integrity(self) -> bool:
        return os.path.exists(self.root) and len(os.listdir(self.root)) > 0

    def download(self) -> None:
        if self._check_integrity():
            raise RuntimeError(
                f"The directory {self.root} already exists. "
                f"If you want to re-download or re-extract the images, delete the directory."
            )

        base_root = os.path.dirname(self.root)

        download_and_extract_archive(
            DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
        )

        orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
        if not os.path.exists(orig_dir_name):
            raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
        os.rename(orig_dir_name, self.root)
        print(f"Dataset version '{self.version}' has been downloaded and prepared for use")


In [74]:
import torchvision

In [107]:
trainset = INaturalist(root = '/MD1400/jinkyu/train_val_images',version = '2017', download = False)

self.root /MD1400/jinkyu/train_val_images/2017
True


In [108]:
os.path.exists('/MD1400/jinkyu/train_val_images') and len(os.listdir('/MD1400/jinkyu/train_val_images'))>0

True

In [109]:
os.path.exists('/MD1400/jinkyu/train_val_images')

True

In [80]:
print(trainset.index)

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



In [81]:
len(trainset.index)

675170

In [84]:
print(trainset.index[0])

(0, '38a37064e7ba2cba7e24c7abfcee9702.jpg')


In [88]:
trainset.all_categories[0]

'Actinopterygii/Abudefduf saxatilis'

In [83]:
print(trainset[0])

(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=800x518 at 0x7F073C425C10>, 0)


In [92]:
print(len(trainset.all_categories))

5089


In [94]:
print(trainset.all_categories[1054])

Aves/Rhipidura fuliginosa


In [101]:
print(trainset[60132])

fname aeed4898f1d933ec73f36558c3bf49c0.jpg
(<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=800x528 at 0x7F0745744550>, 393)


In [102]:
def _generate_image_map(image_dir: str) -> Dict[str, str]:
    """Create an dictionary with key as image id, value as path to the image file.
    Args:
    image_dir: The directory containing all the images.
    Returns:
    The dictionary containing the image id to image file path mapping.
    """
    image_map = {}
    for root, _, files in os.walk(image_dir):
        for f in files:
            if f.endswith('.jpg'):
                image_id = f.rstrip('.jpg')
                image_path = os.path.join(root, f)
                image_map[image_id] = image_path
    return image_map

In [103]:
image_map = _generate_image_map('/MD1400/jinkyu/train_val_images')

In [104]:
image_map['b6f3b1c50816b4eab7e4dfb693a669d2']

'/MD1400/jinkyu/train_val_images/2017/Plantae/Nymphaea odorata/b6f3b1c50816b4eab7e4dfb693a669d2.jpg'

In [112]:
file_path = "./iNaturalist_client.json"
with open(file_path,'r') as f:
    client_data = json.load(f)

In [115]:
list(client_data.items())[1]

('Adam_Heathcote',
 {'client_idx': 1,
  'client_data': [{'image_id': 'c2cfd0484188da465b09e0ff962de74d',
    'class': '1001',
    'label': '9346'},
   {'image_id': '7db1ea38b4d361fcaf37a477acc80cfb',
    'class': '1110',
    'label': '57458'},
   {'image_id': '1adc627e90a6ed9349074c0bc528d6b7',
    'class': '1041',
    'label': '6433'},
   {'image_id': 'fdef4f285a898ccc73618d6c76f872a5',
    'class': '982',
    'label': '118078'},
   {'image_id': 'da4771d82e809c84280ac77e68d8c824',
    'class': '1110',
    'label': '57458'},
   {'image_id': '2c984f32cfd3572539e013a01482be43',
    'class': '1087',
    'label': '84549'},
   {'image_id': '05ee245c5c20fb44ebca20552d1579db',
    'class': '832',
    'label': '12727'},
   {'image_id': '5e9fd154e020e2cc3178961d753a2ebe',
    'class': '1024',
    'label': '52821'},
   {'image_id': '2cbef67f8963c2871fae18c411230752',
    'class': '207',
    'label': '9424'},
   {'image_id': 'de01b7fd20dfe9ae5e4e9129aca96867',
    'class': '957',
    'label': '47

In [116]:
filepath = "./iNaturalist_client_idx.txt"
dataset = {}
with open(filepath) as f:
    for idx, line in enumerate(f):
        dataset = eval(line)

In [117]:
dataset[1]

[4707,
 15026,
 16102,
 20355,
 26947,
 31606,
 31645,
 31719,
 35236,
 50474,
 52116,
 53646,
 53788,
 57782,
 58900,
 61756,
 73053,
 75472,
 81102,
 82197,
 82974,
 85952,
 87984,
 97589,
 98691,
 99735,
 101126,
 103636,
 110651,
 116355]

In [203]:
import os
import os.path
from typing import Any, Callable, Dict, List, Optional, Union, Tuple

from PIL import Image

from torchvision.datasets.utils import download_and_extract_archive, verify_str_arg
from torchvision.datasets.vision import VisionDataset

CATEGORIES_2021 = ["kingdom", "phylum", "class", "order", "family", "genus"]

DATASET_URLS = {
    "2017": "https://ml-inat-competition-datasets.s3.amazonaws.com/2017/train_val_images.tar.gz",
    "2018": "https://ml-inat-competition-datasets.s3.amazonaws.com/2018/train_val2018.tar.gz",
    "2019": "https://ml-inat-competition-datasets.s3.amazonaws.com/2019/train_val2019.tar.gz",
    "2021_train": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train.tar.gz",
    "2021_train_mini": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/train_mini.tar.gz",
    "2021_valid": "https://ml-inat-competition-datasets.s3.amazonaws.com/2021/val.tar.gz",
}

DATASET_MD5 = {
    "2017": "7c784ea5e424efaec655bd392f87301f",
    "2018": "b1c6952ce38f31868cc50ea72d066cc3",
    "2019": "c60a6e2962c9b8ccbd458d12c8582644",
    "2021_train": "38a7bb733f7a09214d44293460ec0021",
    "2021_train_mini": "db6ed8330e634445efc8fec83ae81442",
    "2021_valid": "f6f6e0e242e3d4c9569ba56400938afc",
}


class INaturalist(VisionDataset):
    """`iNaturalist <https://github.com/visipedia/inat_comp>`_ Dataset.

    Args:
        root (string): Root directory of dataset where the image files are stored.
            This class does not require/use annotation files.
        version (string, optional): Which version of the dataset to download/use. One of
            '2017', '2018', '2019', '2021_train', '2021_train_mini', '2021_valid'.
            Default: `2021_train`.
        target_type (string or list, optional): Type of target to use, for 2021 versions, one of:

            - ``full``: the full category (species)
            - ``kingdom``: e.g. "Animalia"
            - ``phylum``: e.g. "Arthropoda"
            - ``class``: e.g. "Insecta"
            - ``order``: e.g. "Coleoptera"
            - ``family``: e.g. "Cleridae"
            - ``genus``: e.g. "Trichodes"

            for 2017-2019 versions, one of:

            - ``full``: the full (numeric) category
            - ``super``: the super category, e.g. "Amphibians"

            Can also be a list to output a tuple with all specified target types.
            Defaults to ``full``.
        transform (callable, optional): A function/transform that takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If dataset is already downloaded, it is not
            downloaded again.
    """

    def __init__(
        self,
        root: str,
        csv_path: str,
        version: str = "2021_train",
        target_type: Union[List[str], str] = "full",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False,
        
    ) -> None:
        self.version = verify_str_arg(version, "version", DATASET_URLS.keys())

        super().__init__(os.path.join(root, version), transform=transform, target_transform=target_transform)

        os.makedirs(root, exist_ok=True)
        if download:
            self.download()
        #print(self.root)
        #print(self._check_integrity())
        if not self._check_integrity():
            raise RuntimeError("Dataset not found or corrupted. You can use download=True to download it")
            
          

        '''
        self.all_categories: List[str] = []

        # map: category type -> name of category -> index
        self.categories_index: Dict[str, Dict[str, int]] = {}

        # list indexed by category id, containing mapping from category type -> index
        self.categories_map: List[Dict[str, int]] = []

        if not isinstance(target_type, list):
            target_type = [target_type]
        if self.version[:4] == "2021":
            self.target_type = [verify_str_arg(t, "target_type", ("full", *CATEGORIES_2021)) for t in target_type]
            self._init_2021()
        else:
            self.target_type = [verify_str_arg(t, "target_type", ("full", "super")) for t in target_type]
            self._init_pre2021()

        # index of all files: (full category id, filename)
        self.index: List[Tuple[int, str]] = []

        for dir_index, dir_name in enumerate(self.all_categories):
            files = os.listdir(os.path.join(self.root, dir_name))
            for fname in files:
                self.index.append((dir_index, fname))
        '''
        self.image_map = self._generate_image_map(self.root)
        self.split_csv = self.read_csv(csv_path)
        #self.client_data = _read_client_data(client_data_path)
        
    def _generate_image_map(self,image_dir: str) -> Dict[str, str]:
        """Create an dictionary with key as image id, value as path to the image file.
        Args:
        image_dir: The directory containing all the images.
        Returns:
        The dictionary containing the image id to image file path mapping.
        """
        image_map = {}
        for root, _, files in os.walk(image_dir):
            for f in files:
                if f.endswith('.jpg'):
                    image_id = f.rstrip('.jpg')
                    image_path = os.path.join(root, f)
                    image_map[image_id] = image_path
        return image_map        
    '''  
    def _read_client_data(file_path: str) -> json:
        with open(file_path,'r') as f:
            client_data = json.load(f)    
        return client_data
    '''
    def read_csv(self,path: str) -> List[Dict[str, str]]:
        """Reads a csv file, and returns the content inside a list of dictionaries.
        Args:
        path: The path to the csv file.
        Returns:
        A list of dictionaries. Each row in the csv file will be a list entry. The
        dictionary is keyed by the column names.
        """
        with open(path, 'r') as f:
            return list(csv.DictReader(f))
    def _init_2021(self) -> None:
        """Initialize based on 2021 layout"""

        self.all_categories = sorted(os.listdir(self.root))

        # map: category type -> name of category -> index
        self.categories_index = {k: {} for k in CATEGORIES_2021}

        for dir_index, dir_name in enumerate(self.all_categories):
            pieces = dir_name.split("_")
            if len(pieces) != 8:
                raise RuntimeError(f"Unexpected category name {dir_name}, wrong number of pieces")
            if pieces[0] != f"{dir_index:05d}":
                raise RuntimeError(f"Unexpected category id {pieces[0]}, expecting {dir_index:05d}")
            cat_map = {}
            for cat, name in zip(CATEGORIES_2021, pieces[1:7]):
                if name in self.categories_index[cat]:
                    cat_id = self.categories_index[cat][name]
                else:
                    cat_id = len(self.categories_index[cat])
                    self.categories_index[cat][name] = cat_id
                cat_map[cat] = cat_id
            self.categories_map.append(cat_map)

          
            
    def _init_pre2021(self) -> None:
        """Initialize based on 2017-2019 layout"""

        # map: category type -> name of category -> index
        self.categories_index = {"super": {}}

        cat_index = 0
        super_categories = sorted(os.listdir(self.root))
        for sindex, scat in enumerate(super_categories):
            self.categories_index["super"][scat] = sindex
            subcategories = sorted(os.listdir(os.path.join(self.root, scat)))
            for subcat in subcategories:
                if self.version == "2017":
                    # this version does not use ids as directory names
                    subcat_i = cat_index
                    cat_index += 1
                else:
                    try:
                        subcat_i = int(subcat)
                    except ValueError:
                        raise RuntimeError(f"Unexpected non-numeric dir name: {subcat}")
                if subcat_i >= len(self.categories_map):
                    old_len = len(self.categories_map)
                    self.categories_map.extend([{}] * (subcat_i - old_len + 1))
                    self.all_categories.extend([""] * (subcat_i - old_len + 1))
                if self.categories_map[subcat_i]:
                    raise RuntimeError(f"Duplicate category {subcat}")
                self.categories_map[subcat_i] = {"super": sindex}
                self.all_categories[subcat_i] = os.path.join(scat, subcat)

        # validate the dictionary
        for cindex, c in enumerate(self.categories_map):
            if not c:
                raise RuntimeError(f"Missing category {cindex}")

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where the type of target specified by target_type.
        """

        this_data = self.split_csv[index]
        img = Image.open(self.image_map[this_data['image_id']])
        img = img.convert('RGB')
        #img = pil_loader(self.image_map[this_data['image_id']])
        target = int(this_data['class'])
        #print("fname",fname)
        '''
        target: Any = []

        
        for t in self.target_type:
            if t == "full":
                target.append(cat_id)
            else:
                target.append(self.categories_map[cat_id][t])
        
        target = tuple(target) if len(target) > 1 else target[0]
        '''
        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target


    def __len__(self) -> int:
        return len(self.split_csv)

    def category_name(self, category_type: str, category_id: int) -> str:
        """
        Args:
            category_type(str): one of "full", "kingdom", "phylum", "class", "order", "family", "genus" or "super"
            category_id(int): an index (class id) from this category

        Returns:
            the name of the category
        """
        if category_type == "full":
            return self.all_categories[category_id]
        else:
            if category_type not in self.categories_index:
                raise ValueError(f"Invalid category type '{category_type}'")
            else:
                for name, id in self.categories_index[category_type].items():
                    if id == category_id:
                        return name
                raise ValueError(f"Invalid category id {category_id} for {category_type}")


    def _check_integrity(self) -> bool:
        return os.path.exists(self.root) and len(os.listdir(self.root)) > 0

    def download(self) -> None:
        if self._check_integrity():
            raise RuntimeError(
                f"The directory {self.root} already exists. "
                f"If you want to re-download or re-extract the images, delete the directory."
            )

        base_root = os.path.dirname(self.root)

        download_and_extract_archive(
            DATASET_URLS[self.version], base_root, filename=f"{self.version}.tgz", md5=DATASET_MD5[self.version]
        )

        orig_dir_name = os.path.join(base_root, os.path.basename(DATASET_URLS[self.version]).rstrip(".tar.gz"))
        if not os.path.exists(orig_dir_name):
            raise RuntimeError(f"Unable to find downloaded files at {orig_dir_name}")
        os.rename(orig_dir_name, self.root)
        print(f"Dataset version '{self.version}' has been downloaded and prepared for use")



In [204]:
transforms = Compose([CenterCrop((224, 224)),
                      ToTensor(),
                      Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                      ])

In [205]:
trainset = INaturalist(root = '/MD1400/jinkyu/train_val_images',csv_path = './inaturalist-user-120k/federated_train_user_120k.csv',version = '2017', download = False, transform = transforms)

In [206]:
trainset[0]

NameError: name 'pil_loader' is not defined

In [None]:
print(trainset.split_csv[0])

In [None]:
print(trainset.image_map['e619826038358b817d926d869a0bbcf1'])

In [None]:
dl = DataLoader(trainset, batch_size=5, shuffle=False)

In [None]:
for i,x in enumerate(dl):
    if i<3:
        print("This is batch ",i)
        print(x)
        print(len(x[1]))
    else:
        break

In [171]:
len(trainset.split_csv)

120300

In [172]:
from torch.utils.data import  Dataset
import torch
class DatasetSplit(Dataset):
    """An abstract Dataset class wrapped around Pytorch Dataset class.
    """

    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = [int(i) for i in idxs]

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return torch.tensor(image), torch.tensor(label)

In [173]:
dl = DataLoader(DatasetSplit(trainset, dataset[1]), batch_size=5, shuffle=False)

In [177]:
z=DatasetSplit(trainset, dataset[1])

In [178]:
print(z.idxs)

[4707, 15026, 16102, 20355, 26947, 31606, 31645, 31719, 35236, 50474, 52116, 53646, 53788, 57782, 58900, 61756, 73053, 75472, 81102, 82197, 82974, 85952, 87984, 97589, 98691, 99735, 101126, 103636, 110651, 116355]


In [175]:
print(dataset[1])

[4707, 15026, 16102, 20355, 26947, 31606, 31645, 31719, 35236, 50474, 52116, 53646, 53788, 57782, 58900, 61756, 73053, 75472, 81102, 82197, 82974, 85952, 87984, 97589, 98691, 99735, 101126, 103636, 110651, 116355]


In [174]:
for i,x in enumerate(dl):
    if i<3:
        print("This is batch ",i)
        print(x)
        print(len(x[1]))
    else:
        break

  app.launch_new_instance()


This is batch  0
[tensor([[[[-0.2684, -0.1314,  0.2624,  ..., -0.6452, -0.7137, -0.5596],
          [-0.3027, -0.0629,  0.0912,  ..., -0.5253, -0.5938, -0.4911],
          [-0.2684, -0.1828, -0.1314,  ..., -0.3541, -0.4739, -0.4054],
          ...,
          [-0.2856, -0.2856, -0.2856,  ..., -0.2513, -0.2513, -0.0972],
          [-0.2856, -0.2856, -0.2856,  ..., -0.2513, -0.1828, -0.0458],
          [-0.2856, -0.2684, -0.2684,  ..., -0.2171, -0.1999, -0.1143]],

         [[ 0.7129,  0.7829,  0.9930,  ...,  0.0126, -0.1099, -0.0049],
          [ 0.6254,  0.8354,  0.8179,  ...,  0.1527,  0.0826,  0.2052],
          [ 0.6254,  0.6604,  0.5903,  ...,  0.3452,  0.3277,  0.4503],
          ...,
          [ 0.6604,  0.6604,  0.6604,  ...,  0.6254,  0.6429,  0.6429],
          [ 0.6604,  0.6604,  0.6604,  ...,  0.6254,  0.6604,  0.6954],
          [ 0.6604,  0.6779,  0.6779,  ...,  0.6254,  0.6429,  0.6254]],

         [[ 1.9603,  1.7163,  1.1934,  ...,  0.5485, -0.0964, -0.1487],
          [ 

This is batch  2
[tensor([[[[ 0.7077,  0.7248,  0.6906,  ..., -0.0629, -0.1657, -0.2513],
          [ 0.6221,  0.4851,  0.4337,  ...,  0.0569, -0.1143, -0.2342],
          [ 0.4337,  0.3309,  0.0912,  ...,  0.3481,  0.1768, -0.2171],
          ...,
          [-0.2342, -0.2513, -0.2684,  ...,  0.0227,  0.1083,  0.0912],
          [-0.3369, -0.4054, -0.4568,  ..., -0.1486, -0.1657, -0.1143],
          [-0.4739, -0.5938, -0.6794,  ..., -0.1657, -0.2171, -0.0972]],

         [[ 0.7654,  0.7479,  0.6779,  ...,  0.7479,  0.6254,  0.5028],
          [ 0.5553,  0.4503,  0.3978,  ...,  0.8704,  0.6954,  0.5553],
          [ 0.2752,  0.2227,  0.0651,  ...,  1.2031,  0.9930,  0.5728],
          ...,
          [-0.3200, -0.3025, -0.2500,  ...,  0.8179,  0.8704,  0.9230],
          [-0.4426, -0.4426, -0.4951,  ...,  0.7654,  0.7129,  0.7829],
          [-0.5651, -0.5826, -0.6176,  ...,  0.7479,  0.6954,  0.8179]],

         [[ 0.6531,  0.6879,  0.6356,  ...,  0.1825,  0.0256, -0.0790],
          [ 

In [179]:
print(trainset[4707])

(tensor([[[-0.2684, -0.1314,  0.2624,  ..., -0.6452, -0.7137, -0.5596],
         [-0.3027, -0.0629,  0.0912,  ..., -0.5253, -0.5938, -0.4911],
         [-0.2684, -0.1828, -0.1314,  ..., -0.3541, -0.4739, -0.4054],
         ...,
         [-0.2856, -0.2856, -0.2856,  ..., -0.2513, -0.2513, -0.0972],
         [-0.2856, -0.2856, -0.2856,  ..., -0.2513, -0.1828, -0.0458],
         [-0.2856, -0.2684, -0.2684,  ..., -0.2171, -0.1999, -0.1143]],

        [[ 0.7129,  0.7829,  0.9930,  ...,  0.0126, -0.1099, -0.0049],
         [ 0.6254,  0.8354,  0.8179,  ...,  0.1527,  0.0826,  0.2052],
         [ 0.6254,  0.6604,  0.5903,  ...,  0.3452,  0.3277,  0.4503],
         ...,
         [ 0.6604,  0.6604,  0.6604,  ...,  0.6254,  0.6429,  0.6429],
         [ 0.6604,  0.6604,  0.6604,  ...,  0.6254,  0.6604,  0.6954],
         [ 0.6604,  0.6779,  0.6779,  ...,  0.6254,  0.6429,  0.6254]],

        [[ 1.9603,  1.7163,  1.1934,  ...,  0.5485, -0.0964, -0.1487],
         [ 1.8905,  1.8383,  1.1934,  ...,  

In [180]:
1+3

4

In [181]:
print(type(dataset))

<class 'dict'>


In [182]:
print(len(dataset.keys()))

9275


In [185]:
print(trainset[0][0].shape)

torch.Size([3, 224, 224])


In [188]:
try:
    read_csv('dd')
except:
    print('error')

error


In [189]:
testset = INaturalist(root = '/MD1400/jinkyu/test2017',csv_path = '/MD1400/jinkyu/inaturalist-user-120k/test.csv',version = '2017', download = False, transform = transforms)

In [194]:
testset[1]

KeyError: 'e266241cfaf647e67b02af066df3e13e'

In [192]:
print(len((testset.image_map)))

182707


In [193]:
print(len((trainset.image_map)))

675170


In [196]:
print(testset.split_csv[:10])

[OrderedDict([('image_id', 'c328b4e68bab48659f5627e9e1aa430d'), ('class', '388'), ('label', '11901')]), OrderedDict([('image_id', 'e266241cfaf647e67b02af066df3e13e'), ('class', '209'), ('label', '47188')]), OrderedDict([('image_id', 'e0adcc71c0542bc912c2d380e8dd759f'), ('class', '703'), ('label', '49651')]), OrderedDict([('image_id', 'e49ae0faae685701d9509c6d7d3dd14a'), ('class', '149'), ('label', '55851')]), OrderedDict([('image_id', '25b6401f8db3ae8696da29bc45dc3529'), ('class', '922'), ('label', '11867')]), OrderedDict([('image_id', '7ab01c6ebd426727bd68baf3097a84c9'), ('class', '896'), ('label', '49882')]), OrderedDict([('image_id', '6989f90b44760ece405b6bca69f779b3'), ('class', '1187'), ('label', '5112')]), OrderedDict([('image_id', '1317028067092050235e53e8cc053192'), ('class', '825'), ('label', '55990')]), OrderedDict([('image_id', '90a774bc5229cbbfdef974722ea32d8a'), ('class', '974'), ('label', '52136')]), OrderedDict([('image_id', '0f3a106eb2565836a5e9cd9ffa367767'), ('class',

In [197]:
#### ???? root를 test가 아니라 train_val_images에서 찾아야 하네 testset도 ㅋㅋㅋ

In [198]:
testset = INaturalist(root = '/MD1400/jinkyu/train_val_images',csv_path = '/MD1400/jinkyu/inaturalist-user-120k/test.csv',version = '2017', download = False, transform = transforms)

In [199]:
testset[1]

(tensor([[[-1.7583, -1.7240, -1.6898,  ...,  1.9749,  1.6153,  1.2899],
          [-1.7240, -1.7240, -1.7069,  ...,  1.4440,  0.9817,  1.2043],
          [-1.7583, -1.7754, -1.7754,  ...,  1.1700,  1.0502,  1.5468],
          ...,
          [ 1.4440,  1.6324,  1.8893,  ...,  1.5125,  2.1633,  1.5468],
          [ 1.3242,  1.4612,  1.7523,  ...,  2.2147,  1.7865,  1.6667],
          [ 1.3070,  1.3755,  1.6324,  ...,  2.2147,  2.0092,  1.9407]],
 
         [[-1.6856, -1.6506, -1.6155,  ...,  1.8158,  1.2556,  0.8004],
          [-1.6506, -1.6506, -1.6331,  ...,  1.3256,  0.5903,  0.6254],
          [-1.6331, -1.6506, -1.6506,  ...,  0.9405,  0.5203,  0.8529],
          ...,
          [ 0.5378,  0.7829,  1.1155,  ...,  1.3256,  1.8508,  0.9930],
          [ 0.5028,  0.6779,  1.0105,  ...,  2.2360,  1.7458,  1.5007],
          [ 0.5553,  0.6254,  0.8704,  ...,  2.2535,  2.1660,  2.0784]],
 
         [[-1.5256, -1.4907, -1.4559,  ...,  1.0191,  0.6008,  0.2348],
          [-1.4907, -1.4907,

In [202]:
for i, data in enumerate(dl, 0): 
    print(i)
    print(data[0].shape)

  app.launch_new_instance()


0
torch.Size([5, 3, 224, 224])
1
torch.Size([5, 3, 224, 224])
2
torch.Size([5, 3, 224, 224])
3
torch.Size([5, 3, 224, 224])
4
torch.Size([5, 3, 224, 224])
5
torch.Size([5, 3, 224, 224])
