Here we preprocess the images, i.e normalization and also we augment the data (migh use later on);<br>
They are saved in NormalizedData and AugmentedData directories respectfully.

In [1]:
import os
from PIL import Image, ImageDraw
import numpy as np
from torchvision import transforms
import mlflow
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import glob
import torch.optim as optim
import torch.onnx

In [None]:
# Paths to the dataset
base_path = "Data"
vegetable_images_path = os.path.join(base_path, "VegetableImages")
holed_images_path = os.path.join(base_path, "HoledImages")
output_normalized_path = "NormalizedData"
output_augmented_path = "AugmentedData"

# Transformations
normalize_transform = transforms.Compose([
    transforms.ToTensor(),  # Convert image to tensor with values in [0, 1]
    # transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize to [-1, 1]
])

augmentation_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor()
])

# Function to process images
def preprocess_images(input_path, output_path, transform, process_type="normalization"):
    for root, _, files in os.walk(input_path):
        for file in files:
            if file.endswith(('.png', '.jpg', '.jpeg')):  # Add supported image formats
                img_path = os.path.join(root, file)
                img = Image.open(img_path).convert("RGB")
                processed_img = transform(img)
                
                # Convert tensor to numpy array for saving
                np_img = processed_img.numpy().transpose(1, 2, 0)
                
                # Save image in the corresponding output folder
                save_path = os.path.join(output_path, os.path.relpath(img_path, input_path))
                os.makedirs(os.path.dirname(save_path), exist_ok=True)
                Image.fromarray((np_img * 255).astype('uint8')).save(save_path)
                print(f"{process_type.capitalize()} processed: {save_path}")

# MLflow Integration
mlflow.set_experiment("Image Inpainting Preprocessing")

with mlflow.start_run(run_name="Preprocessing with better normalization") as run:
    start_time = time.time()
    
    # Log parameters for normalization and augmentation
    mlflow.log_param("Normalization", "[-1, 1]")
    mlflow.log_param("Augmentation", "RandomHorizontalFlip, RandomRotation, ColorJitter")
    
    # Normalize VegetableImages
    print("Normalizing VegetableImages...")
    preprocess_images(vegetable_images_path, os.path.join(output_normalized_path, "VegetableImages"), normalize_transform, process_type="normalization")
    mlflow.log_artifacts(os.path.join(output_normalized_path, "VegetableImages"), artifact_path="NormalizedVegetableImages")

    # Normalize HoledImages
    print("Normalizing HoledImages...")
    preprocess_images(holed_images_path, os.path.join(output_normalized_path, "HoledImages"), normalize_transform, process_type="normalization")
    mlflow.log_artifacts(os.path.join(output_normalized_path, "HoledImages"), artifact_path="NormalizedHoledImages")
    
    # Augment VegetableImages
    print("Augmenting VegetableImages...")
    preprocess_images(vegetable_images_path, os.path.join(output_augmented_path, "VegetableImages"), augmentation_transforms, process_type="augmentation")
    mlflow.log_artifacts(os.path.join(output_augmented_path, "VegetableImages"), artifact_path="AugmentedVegetableImages")

    # Augment HoledImages
    print("Augmenting HoledImages...")
    preprocess_images(holed_images_path, os.path.join(output_augmented_path, "HoledImages"), augmentation_transforms, process_type="augmentation")
    mlflow.log_artifacts(os.path.join(output_augmented_path, "HoledImages"), artifact_path="AugmentedHoledImages")
    
    # Log time taken for preprocessing
    total_time = time.time() - start_time
    mlflow.log_metric("Preprocessing_Time_(seconds)", total_time)
    print(f"Preprocessing completed in {total_time:.2f} seconds.")