In [32]:
!pip install split-folders imutils

Collecting imutils
  Downloading imutils-0.5.4.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25ldone
[?25hBuilding wheels for collected packages: imutils
  Building wheel for imutils (setup.py) ... [?25ldone
[?25h  Created wheel for imutils: filename=imutils-0.5.4-py3-none-any.whl size=25834 sha256=ea006b4ddf52962a70da3afe636e9559f76e17788b2218334b4a1c0647a88cd9
  Stored in directory: /root/.cache/pip/wheels/85/cf/3a/e265e975a1e7c7e54eb3692d6aa4e2e7d6a3945d29da46f2d7
Successfully built imutils
Installing collected packages: imutils
Successfully installed imutils-0.5.4


In [38]:
# imports
import os
import splitfolders
import tensorflow as tf
from typing import Tuple
from string import digits
from imutils import paths
from random import choices
from tensorflow.data import AUTOTUNE
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import GRU, Dense, Bidirectional, InputLayer, Reshape

In [107]:
class Dataset:
    def __init__(self, path: str, output: str | None = None, ratio: Tuple[float, float, float] = (.70, .15, .15), split: bool = True, image_size: Tuple[int, int] = (224,224), batch_size: int = 8):
        self.path = path
        self.ratio = ratio
        
        self.output_path = output if output else f"./data-{''.join(choices(digits, k=5))}"
        
        self.train = os.path.sep.join([self.output_path, 'train'])
        self.val = os.path.sep.join([self.output_path, 'val'])
        self.test = os.path.sep.join([self.output_path, 'test'])
        
        self.image_size = image_size
        self.batch_size = batch_size
        
        if split:
            self.__split_folder()
        
    def __split_folder(self):
        splitfolders.ratio(self.path, output=self.output_path, seed=42, ratio=self.ratio, group_prefix=None)
    
    def generate(self) -> Tuple:
        label_classes = [name for name in os.listdir(self.path) if os.path.isdir(os.path.join(self.path, name))]
        
        def load_images(image_path):
            image = tf.io.read_file(image_path)
            image = tf.image.decode_png(image, channels=3)
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
            image = tf.image.resize(image, self.image_size)

            label = tf.strings.split(image_path, os.path.sep)[-2]
            label = tf.cast(tf.equal(label_classes, label), tf.int32)
    
            return (image, label)
        
        train_paths = list(paths.list_images(self.train))
        train_ds = tf.data.Dataset.from_tensor_slices(train_paths)
        train_ds = (train_ds
            .shuffle(len(train_paths))
            .map(load_images, num_parallel_calls=AUTOTUNE)
            .cache()
            .batch(self.batch_size)
            .prefetch(AUTOTUNE)
        )
        
        val_paths = list(paths.list_images(self.val))
        val_ds = tf.data.Dataset.from_tensor_slices(val_paths)
        val_ds = (val_ds
            .map(load_images, num_parallel_calls=AUTOTUNE)
            .cache()
            .batch(self.batch_size)
            .prefetch(AUTOTUNE)
        )
        
        test_paths = list(paths.list_images(self.test))
        test_ds = tf.data.Dataset.from_tensor_slices(test_paths)
        test_ds = (test_ds
            .map(load_images, num_parallel_calls=AUTOTUNE)
            .cache()
            .batch(self.batch_size)
            .prefetch(AUTOTUNE)
        )
        
        return train_ds, val_ds, test_ds

In [108]:
train, val, test = Dataset(path='/kaggle/input/musicgenreprediction/Data/images_augmented/stft').generate()

In [110]:
for images, labels in train.take(2):
  img, label = images[0], labels[0]
  display([img, label])

[<tf.Tensor: shape=(224, 224, 3), dtype=float32, numpy=
 array([[[0.21568629, 0.35686275, 0.5529412 ],
         [0.2509804 , 0.27450982, 0.53333336],
         [0.25154063, 0.264846  , 0.52605045],
         ...,
         [0.26666668, 0.00392157, 0.32941177],
         [0.26876736, 0.01347782, 0.338968  ],
         [0.2784314 , 0.05744048, 0.3829307 ]],
 
        [[0.21568629, 0.35686275, 0.5529412 ],
         [0.2509804 , 0.27450982, 0.53333336],
         [0.25154063, 0.264846  , 0.52605045],
         ...,
         [0.26666668, 0.00392157, 0.32941177],
         [0.26876736, 0.014425  , 0.33991522],
         [0.2784314 , 0.0627451 , 0.38823533]],
 
        [[0.21568629, 0.35443804, 0.5505165 ],
         [0.25340512, 0.26481095, 0.53090864],
         [0.25387874, 0.2554935 , 0.52371234],
         ...,
         [0.26666668, 0.00392157, 0.32941177],
         [0.26876736, 0.01652569, 0.34131566],
         [0.2784314 , 0.07450981, 0.39607847]],
 
        ...,
 
        [[0.16078432, 0.48235297

[<tf.Tensor: shape=(224, 224, 3), dtype=float32, numpy=
 array([[[0.14394258, 0.51857495, 0.5555935 ],
         [0.13863796, 0.5357581 , 0.5555935 ],
         [0.1385473 , 0.53593534, 0.5555482 ],
         ...,
         [0.26064426, 0.7357406 , 0.4460347 ],
         [0.2511521 , 0.72399926, 0.45351058],
         [0.20748426, 0.6699843 , 0.4879027 ]],
 
        [[0.14856443, 0.50702035, 0.556749  ],
         [0.2247199 , 0.33578435, 0.5491334 ],
         [0.22104001, 0.34505615, 0.54899335],
         ...,
         [0.15237221, 0.6778624 , 0.5062238 ],
         [0.15105303, 0.67583203, 0.5071069 ],
         [0.14498425, 0.6664916 , 0.5111695 ]],
 
        [[0.15779062, 0.4844188 , 0.5568628 ],
         [0.16078432, 0.4784314 , 0.5568628 ],
         [0.1608506 , 0.47831917, 0.5568628 ],
         ...,
         [0.19263828, 0.6986695 , 0.4822917 ],
         [0.18369006, 0.6864342 , 0.48976445],
         [0.14252451, 0.63014704, 0.5241422 ]],
 
        ...,
 
        [[0.12156864, 0.59607846