In [6]:
from balancers.configuration_reader import ConfigurationReader
from balancers.augmentation import AugmentationBalancer
from balancers.adasyn import ImageADASYN
from balancers.smote import ImageSMOTE
from balancers.autoencoder import AEBalancer
from balancers.dgan import DGANBalancer
from balancers.annotations import Annotations

path_to_input_imbalanced_images_folder="./maritime-flags-dataset/imbalanced_flags/images"
path_to_output_balanced_images_folder="./maritime-flags-dataset/balanced_flags/images/"
path_to_output_balanced_images_annotations_folder="./maritime-flags-dataset/balanced_flags/labels/"

# 1. Reading configuration.
configuration_reader = ConfigurationReader()
configuration_reader.read("./configuration.json")
configuration_reader.print()

# 2. Selecting mode for data balancing.
mode = "AE"
debug = True

# 3. Starting data balancing.
match mode:
    case "AUGMENTATION":
        print(f"======= AUGMENTATION =======\n")
        aug_balancer = AugmentationBalancer()
        aug_balancer.fit(
            path_to_input_image_folder=path_to_input_imbalanced_images_folder, 
            debug=debug
        )
        aug_balancer.balance(
            path_to_output_image_folder=path_to_output_balanced_images_folder, 
            debug=debug
        )
        print(f"============================\n")
    case "ADASYN":
        print(f"========== ADASYN ==========\n")
        adasyn_balancer = ImageADASYN()
        adasyn_balancer.fit(
            path_to_input_image_folder=path_to_input_imbalanced_images_folder,
            width_of_image=configuration_reader.width_of_image,
            height_of_image=configuration_reader.height_of_image
        )
        adasyn_balancer.balance(
            path_to_output_image_folder=path_to_output_balanced_images_folder,
            number_of_neighbors = configuration_reader.number_of_neighbors
        )
        print(f"===========================\n")
    case "SMOTE":
        print(f"========== SMOTE ==========\n")
        smote_balancer = ImageSMOTE()
        smote_balancer.fit(
            path_to_input_image_folder=path_to_input_imbalanced_images_folder, 
            width_of_image=configuration_reader.width_of_image, 
            height_of_image=configuration_reader.height_of_image
        )
        smote_balancer.balance(
            path_to_output_image_folder=path_to_output_balanced_images_folder
        )
        print(f"===========================\n")
    case "DGAN":
        print(f"========== DGAN ==========\n")
        dgan_balancer = DGANBalancer()
        dgan_balancer.fit(
            path_to_input_image_folder=path_to_input_imbalanced_images_folder,
            latent_dimension=configuration_reader.latent_dimension, 
            learning_rate=configuration_reader.learning_rate, 
            beta_01=configuration_reader.beta,
            batch_size=configuration_reader.batch_size, 
            number_of_epochs=configuration_reader.number_of_epochs, 
            delta=configuration_reader.delta
        )
        dgan_balancer.balance(
            path_to_output_image_folder=path_to_output_balanced_images_folder
        )
        print(f"===========================\n")     
    case "AE":
        print(f"======= Autoencoder =======\n")
        ae_balancer = AEBalancer()
        ae_balancer.fit(
            path_to_input_image_folder=path_to_input_imbalanced_images_folder, 
            batch_size=configuration_reader.batch_size, 
            number_of_epochs=configuration_reader.number_of_epochs, 
            delta=configuration_reader.delta
        )
        ae_balancer.balance(
            path_to_output_image_folder=path_to_output_balanced_images_folder,
            debug=debug
        )
        print(f"===========================\n")

# 4. Creating annotations for new/balanced images.     
annotator = Annotations()
annotator.annotate(
    path_to_input_images=path_to_output_balanced_images_folder,
    path_to_output_annotations=path_to_output_balanced_images_annotations_folder
)

Path to input imbalanced images folder: ./maritime-flags-dataset/imbalanced_flags/
Path to output balanced images folder: ./maritime-flags-dataset/balanced_flags/
Size of image: (128, 128, 3)
Batch size: 128
Number of epochs: 1
Latent dimension: 100
Learning rate: 0.002
Beta: 0.5

B: 200, Autoencoder(
  (encoder): Encoder(
    (conv2d_01): Conv2d(3, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (relu_01): ReLU()
    (conv2d_02): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (relu_02): ReLU()
    (conv2d_03): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (relu_03): ReLU()
    (flatten_04): Flatten(start_dim=1, end_dim=-1)
    (linear_04): Linear(in_features=65536, out_features=1024, bias=True)
    (relu_04): ReLU()
  )
  (decoder): Decoder(
    (linear_01): Linear(in_features=1024, out_features=65536, bias=True)
    (relu_01): ReLU()
    (unflatten_01): Unflatten(dim=1, unflattened_size=(256, 16, 16))
    (conv_transpose2d_

: 