# Starter Code for Image Processing

In [15]:
import os
import numpy as np
import pandas as pd
from collections import Counter
import keras
from keras import Sequential
from keras import layers
from keras.layers import Conv2D,MaxPool2D,Dense,Flatten,BatchNormalization,Dropout, Input
from keras.optimizers import Adam
from PIL import Image
from tensorflow import data as tf_data
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.preprocessing import image

from sklearn.model_selection import train_test_split

from tqdm import tqdm

In [3]:
img_directory = "../bttai-ajl-2025/test/test"

In [None]:
print("Testing Images: ", len(os.listdir(img_directory)))

In [None]:
num_skipped = 0
# Initialize a counter to keep track of the number of corrupted images that are skipped.

for fname in os.listdir(img_directory):
    # Loop through each file name in the current folder.

    fpath = os.path.join(img_directory, fname)
    # Construct the full file path by joining the folder path with the file name.

    try:
        fobj = open(fpath, "rb")
        # Open the file in binary read mode, which decodes the image into bytes (0,1).

        is_jfif = b"JFIF" in fobj.peek(10)
        # Check if the first 10 bytes of the file contain the "JFIF" marker, indicating a valid JPEG file.

    finally:
        fobj.close()
        # Ensure the file is closed after checking, whether or not an exception occurs.

    if not is_jfif:
        # If the file is not a valid JPEG (does not contain the JFIF marker):

        num_skipped += 1
        # Increment the counter for skipped (corrupted) images.

        os.remove(fpath)
        # Delete the corrupted image file from the directory.

print(f"Deleted {num_skipped} images.")
# Print the total number of corrupted images that were deleted.


In [6]:
def get_image_data(directory):
    """
    Extracts image dimensions and pixel channel values from images in the given directory.

    Args:
        directory (str): Path to the directory containing image files.

    Returns:
        dimensions (list): A list of tuples containing image dimensions and channel count (width, height, channels).
        channel_values (numpy.ndarray): A flattened array of RGB pixel values from all images.
    """
    dimensions = []  # List to store image dimensions
    channel_values = []  # List to store pixel values for RGB images

    # Walk through the directory and its subdirectories
    for root, _, files in os.walk(directory):
        for file in files:  # Iterate over each file
            # Check if the file is a JPEG
            if file.lower().endswith('.jpg') or file.lower().endswith('.jpeg'):
                file_path = os.path.join(root, file)  # Get the full file path
                try:
                    with Image.open(file_path) as img:  # Open the image file
                        width, height = img.size  # Get the width and height of the image
                        channels = len(img.getbands())  # Get the number of color channels
                        dimensions.append((width, height, channels))  # Append dimensions to the list

                        # Collect pixel values if the image has RGB channels
                        if channels == 3:
                            pixels = np.array(img)  # Convert image to a NumPy array
                            channel_values.append(pixels.reshape(-1, 3))  # Flatten and append pixel values
                except Exception as e:
                    # Print an error message if the file cannot be processed
                    print(f"Could not process file {file_path}: {e}")

    # Combine all channel values into a single array if any RGB images exist
    if channel_values:
        channel_values = np.concatenate(channel_values, axis=0)

    return dimensions, channel_values  # Return the dimensions and pixel values
    # Remember that dimentions and channel values are both lists; they need to be accessed accordingly



In [7]:
# Function to visualize the distribution of dimensions and channel values
def plot_distributions(dimensions, channel_values):
    widths = [dim[0] for dim in dimensions]  # Extract image widths
    heights = [dim[1] for dim in dimensions]  # Extract image heights
    channels = [dim[2] for dim in dimensions]  # Extract channel counts

    # Create a figure for plotting
    plt.figure(figsize=(14, 8))

    # Plot the distribution of widths and heights
    plt.subplot(2, 2, 1)
    plt.hist(widths, bins=30, alpha=0.7, label='Widths')  # Histogram for widths
    plt.hist(heights, bins=30, alpha=0.7, label='Heights')  # Histogram for heights
    plt.title('Width and Height Distribution')  # Title for the plot
    plt.xlabel('Pixels')  # X-axis label
    plt.ylabel('Frequency')  # Y-axis label
    plt.legend()  # Add a legend

    # Plot the distribution of channel counts
    plt.subplot(2, 2, 2)
    channel_counts = Counter(channels)  # Count occurrences of each channel count
    plt.bar(channel_counts.keys(), channel_counts.values(), color='orange', alpha=0.7)  # Bar chart for channel counts
    plt.title('Channel Count Distribution')  # Title for the plot
    plt.xlabel('Number of Channels')  # X-axis label
    plt.ylabel('Frequency')  # Y-axis label

    # Plot the distribution of RGB channel values if available
    if channel_values.size > 0:
        plt.subplot(2, 1, 2)
        colors = ['Red', 'Green', 'Blue']  # Define color labels
        for i, color in enumerate(colors):  # Iterate over each channel
            plt.hist(channel_values[:, i], bins=50, alpha=0.7, label=f'{color} Channel', color=color.lower())  # Histogram for each channel
        plt.title('RGB Channel Value Distribution')  # Title for the plot
        plt.xlabel('Pixel Value')  # X-axis label
        plt.ylabel('Frequency')  # Y-axis label
        plt.legend()  # Add a legend

    # Adjust layout for better readability
    plt.tight_layout()
    plt.show()  # Display the plots

In [None]:
dims, channel_val = get_image_data(img_directory)
plot_distributions(dims, channel_val)

In [16]:
image_files = [os.path.join(img_directory, fname) for fname in os.listdir(img_directory) if fname.lower().endswith(('.jpg', '.jpeg'))]
dataset = tf.data.Dataset.from_tensor_slices(image_files)

# Function to load and preprocess images
def load_and_preprocess_image(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, [256, 256])  # Resize to the desired size
    return img

# Map the load function to the dataset
dataset = dataset.map(load_and_preprocess_image)

# Batch the dataset
dataset = dataset.batch(32)

In [None]:
def display_batch(dataset, num_images=25):
    plt.figure(figsize=(10, 10))
    
    # Iterate through the dataset
    for images in dataset.take(1):  # Take one batch
        for i in range(num_images):
            ax = plt.subplot(5, 5, i + 1)  # Create a 3x3 grid
            plt.imshow(images[i].numpy().astype("uint8"))  # Convert tensor to numpy array
            plt.axis("off")  # Turn off axis
    plt.show()

# Call the function to display images
display_batch(dataset)