In [None]:
import os
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import cv2

class AnimalDataset(Dataset):
    def __init__(self, root, train=True):
        self.classes = os.listdir(root)

        self.classname_to_index = {}
        for index, class_name in enumerate(self.classes):
            self.classname_to_index[class_name] = index

        self.image_paths = []
        self.labels = []
        for class_name in self.classes:
            class_path = os.path.join(root, class_name)
            for image_name in os.listdir(class_path):
                image_path = os.path.join(class_path, image_name)
                self.image_paths.append(image_path)
                self.labels.append(self.classname_to_index[class_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]        
        image = cv2.imread(image_path)
        label = self.labels[index]

        return image, label

train_path = "animals/train"
test_path = "animals/test"

train_data = AnimalDataset(root=train_path, train=True)
test_data = AnimalDataset(root=test_path, train=False)
print(train_data.__len__())
print(test_data.__len__())

image, label = train_data[101]
print(image.shape)
print(image)
print(label)
cv2.imshow("test", image)
cv2.waitKey(0)

23699
2596
(640, 489, 3)
[[[ 45 101  90]
  [ 41  97  86]
  [ 40  96  85]
  ...
  [ 41 118 104]
  [ 37 114 100]
  [ 40 115 101]]

 [[ 47 103  92]
  [ 43  99  88]
  [ 42  98  87]
  ...
  [ 39 118 104]
  [ 37 114 100]
  [ 43 118 104]]

 [[ 49 105  94]
  [ 45 101  90]
  [ 44 100  89]
  ...
  [ 38 117 103]
  [ 37 114 100]
  [ 41 116 102]]

 ...

 [[ 28  84  65]
  [ 26  85  65]
  [ 27  86  66]
  ...
  [228 254 248]
  [125 159 149]
  [ 58  98  87]]

 [[ 30  89  69]
  [ 29  88  68]
  [ 31  90  70]
  ...
  [231 252 249]
  [187 215 209]
  [ 83 117 110]]

 [[ 32  91  71]
  [ 32  91  71]
  [ 31  92  72]
  ...
  [230 249 246]
  [215 240 236]
  [129 160 153]]]
0
(316, 640, 3)
[[[ 75  94 129]
  [ 75  94 129]
  [ 75  94 129]
  ...
  [122  71  15]
  [121  70  14]
  [120  69  13]]

 [[ 75  94 129]
  [ 75  94 129]
  [ 75  94 129]
  ...
  [121  70  14]
  [121  70  14]
  [120  69  13]]

 [[ 77  93 129]
  [ 77  93 129]
  [ 77  93 129]
  ...
  [121  71  13]
  [121  70  14]
  [120  69  13]]

 ...

 [[ 48  78 

-1