In [1]:
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

tqdm.pandas()

# Import Dataset

In [13]:
# Create df and labels
folder_path = "./data/full-dataset/train/clean"
file_list = os.listdir(folder_path)
df = pd.DataFrame(file_list, columns=["file_name"])
df["original_file_path"] = df["file_name"].apply(lambda x: os.path.join(folder_path, x))
df["label"] = df["file_name"].str.split("-0").str[0]

# Get subset of dataframe
df = df[0:100]

# Preprocessing / Segmentation (Training)

In [14]:
def replace_black_with_surrounding_color_optimized(img):
    result = img.copy()

    height, width, _ = img.shape

    # Create a mask where black pixels are True (shape: (height, width))
    black_mask = np.all(img == [0, 0, 0], axis=-1)

    # Initialize an array to accumulate the surrounding pixel values
    accumulator = np.zeros_like(img, dtype=np.float32)

    # Initialize a count of surrounding non-black pixels for each black pixel
    surrounding_count = np.zeros((height, width), dtype=np.float32)

    # Loop through the 8 surrounding pixels (dy, dx in {-1, 0, 1}, skipping (0, 0))
    for dy in [-1, 0, 1]:
        for dx in [-1, 0, 1]:
            if dy == 0 and dx == 0:
                continue  # Skip the black pixel itself

            # Roll the image to get the surrounding pixels
            rolled_img = np.roll(np.roll(img, shift=dy, axis=0), shift=dx, axis=1)

            # Create a mask of the surrounding non-black pixels
            non_black_mask = np.all(rolled_img != [0, 0, 0], axis=-1)

            # Only accumulate colors from non-black pixels
            accumulator += rolled_img * non_black_mask[..., np.newaxis]

            # Update the surrounding count for non-black pixels
            surrounding_count += non_black_mask

    # Avoid division by zero: replace zeros in surrounding_count with 1 to prevent errors
    surrounding_count[surrounding_count == 0] = 1

    # For black pixels, replace them with the average color of surrounding pixels
    result[black_mask] = (accumulator[black_mask] / surrounding_count[black_mask, np.newaxis]).astype(np.uint8)

    return result

def process_image(image_path, output_folder):
    # Load the image using OpenCV
    img = cv2.imread(image_path)

    if img is not None:
        # Apply the optimized function to clean the image
        cleaned_image = replace_black_with_surrounding_color_optimized(img)

        if not os.path.exists(output_folder):
            os.makedirs(output_folder)

        # Construct the new file path (same filename, different folder)
        output_path = os.path.join(output_folder, os.path.basename(image_path))

        # Save the cleaned image to the new folder
        cv2.imwrite(output_path, cleaned_image)

        return output_path
    else:
        print(f"Error: Could not load image at {image_path}")
        return None

def process_images(image_paths, output_folder):
    results = []
    for path in tqdm(image_paths):
        results.append(process_image(path, output_folder))
    return results

# Example usage:
output_folder = "cleaned_images"
df['processed_images'] = process_images(df['original_file_path'].tolist(), output_folder)

100%|██████████| 100/100 [00:01<00:00, 74.88it/s]


In [4]:
# Function to process images and extract ROIs
output_char_folder = "extracted_chars"
if not os.path.exists(output_char_folder):
    os.makedirs(output_char_folder)
contour_too_small_threshold = 8

def extract_rois(image_path):
    # Load the processed image
    img = cv2.imread(image_path)

    if img is not None:
        # Convert to grayscale and threshold
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        _, thresh = cv2.threshold(gray, 254, 255, cv2.THRESH_BINARY_INV)

        # Find contours
        contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        img_with_contours = img.copy()
        cv2.drawContours(img_with_contours, contours, -1, (0, 255, 75), 2)

        # Sort contours from left to right
        sorted_contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[0])

        rois = []  # List to store ROIs for display later
        for contour in sorted_contours:
            x, y, w, h = cv2.boundingRect(contour)
            if w <= contour_too_small_threshold and h <= contour_too_small_threshold:
                continue

            # Draw bounding box on the image (optional)
            cv2.rectangle(img_with_contours, (x, y), (x + w, y + h), (36, 255, 12), 2)

            # Extract ROI and store it TESTINGGGGGG
            ROI = img[y:y+h, x:x+w]

            # Padding and resizing
            if w > h:
                # Add padding to the top and bottom
                pad_top = (w - h) // 2
                pad_bottom = w - h - pad_top
                ROI_padded = cv2.copyMakeBorder(ROI, pad_top, pad_bottom, 0, 0, cv2.BORDER_CONSTANT, value=(255, 255, 255))
            else:
                # Add padding to the left and right
                pad_left = (h - w) // 2
                pad_right = h - w - pad_left
                ROI_padded = cv2.copyMakeBorder(ROI, 0, 0, pad_left, pad_right, cv2.BORDER_CONSTANT, value=(255, 255, 255))

            resized_ROI = cv2.resize(ROI_padded, (32, 32))
            rois.append(resized_ROI)

        return img_with_contours, rois  # Return both the image with contours and ROIs
    else:
        print(f"Error: Could not load image at {image_path}")
        return None, []

# Apply the ROI extraction to processed images in the DataFrame
df['rois'] = df['processed_images'].progress_apply(lambda path: extract_rois(path))

# Display results
# for index, (contours_image, rois) in enumerate(zip(df['rois'].apply(lambda x: x[0]), df['rois'].apply(lambda x: x[1]))):
#     plt.figure(figsize=(10, 5))

#     # Display image with bounding boxes
#     plt.subplot(1, 2, 1)
#     plt.imshow(cv2.cvtColor(contours_image, cv2.COLOR_BGR2RGB))
#     plt.title("Image with Contours")
#     plt.axis('off')

#     # Display ROIs
#     plt.subplot(1, 2, 2)
#     for i, roi in enumerate(rois):
#         plt.subplot(1, len(rois), i + 1)
#         plt.imshow(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB))
#         plt.title(f"CHAR {i + 1}")
#         plt.axis('off')

#     plt.show()

100%|██████████| 100/100 [00:00<00:00, 340.01it/s]


In [5]:
# df # Debug purposes

Unnamed: 0,file_name,original_file_path,label,processed_images,rois
0,0024miih-0.png,./data/full-dataset/train/clean\0024miih-0.png,0024miih,cleaned_images\0024miih-0.png,"([[[255 255 255], [255 255 255], [255 255 255]..."
1,002k-0.png,./data/full-dataset/train/clean\002k-0.png,002k,cleaned_images\002k-0.png,"([[[255 255 255], [255 255 255], [255 255 255]..."
2,006aguv-0.png,./data/full-dataset/train/clean\006aguv-0.png,006aguv,cleaned_images\006aguv-0.png,"([[[255 255 255], [255 255 255], [255 255 255]..."
3,00hai-0.png,./data/full-dataset/train/clean\00hai-0.png,00hai,cleaned_images\00hai-0.png,"([[[255 255 255], [255 255 255], [255 255 255]..."
4,00hgi3n7-0.png,./data/full-dataset/train/clean\00hgi3n7-0.png,00hgi3n7,cleaned_images\00hgi3n7-0.png,"([[[255 255 255], [255 255 255], [255 255 255]..."
...,...,...,...,...,...
95,0gs296j-0.png,./data/full-dataset/train/clean\0gs296j-0.png,0gs296j,cleaned_images\0gs296j-0.png,"([[[255 255 255], [255 255 255], [255 255 255]..."
96,0gwk-0.png,./data/full-dataset/train/clean\0gwk-0.png,0gwk,cleaned_images\0gwk-0.png,"([[[255 255 255], [255 255 255], [255 255 255]..."
97,0h4g-0.png,./data/full-dataset/train/clean\0h4g-0.png,0h4g,cleaned_images\0h4g-0.png,"([[[255 255 255], [255 255 255], [255 255 255]..."
98,0h7htrjq-0.png,./data/full-dataset/train/clean\0h7htrjq-0.png,0h7htrjq,cleaned_images\0h7htrjq-0.png,"([[[255 255 255], [255 255 255], [255 255 255]..."


In [15]:
# Create a directory for saving character images if it doesn't exist
output_char_folder = "extracted_chars"
if not os.path.exists(output_char_folder):
    os.makedirs(output_char_folder)

mismatch_count = 0
mismatched_images = []  # To store mismatched images for display


# Function to save ROIs if the count matches the label length
def save_rois_if_match(image_path, label, rois):
    global mismatch_count  # Declare mismatch_count as global
    # Check if number of ROIs matches the length of the label
    if len(rois) == len(label):
        saved_paths = []
        for i, roi in enumerate(rois):
            # Construct the file name and save the ROI
            roi_file_name = f"{os.path.basename(image_path).split('.')[0]}_char_{i + 1}.png"
            roi_path = os.path.join(output_char_folder, roi_file_name)
            cv2.imwrite(roi_path, roi)
            saved_paths.append(roi_path)
        return saved_paths
    else:
        print(f"Mismatch for {image_path}: {len(rois)} ROIs but label length is {len(label)}.")
        mismatch_count += 1
        mismatched_images.append((image_path, rois))
        return []

# Creating a new DataFrame to store saved image paths
saved_images_data = []

# Iterate through the DataFrame to process each image
print('-----POTENTIAL MISMATCH-----')
for idx, row in df.iterrows():
    image_path = row['processed_images']
    label = row['label']  # Assuming you have a 'label' column in the DataFrame
    contours_image, rois = extract_rois(image_path)

    # Save ROIs if they match the label length
    saved_paths = save_rois_if_match(image_path, label, rois)

    # Append the results to the new DataFrame
    if saved_paths:
        saved_images_data.append({
            'image_path': image_path,
            'label': label,
            'saved_char_paths': saved_paths
        })

print(f'TOTAL MISMATCH COUNT: {mismatch_count}')

# Uncomment to view missegmented images
# if mismatched_images:
#     print("Displaying mismatched images:")
#     for image_path, rois in mismatched_images:
#         img = cv2.imread(image_path)
#         if img is not None:
#             if rois:
#                 plt.figure(figsize=(12, 6))
#                 for i, roi in enumerate(rois):
#                     plt.subplot(1, len(rois), i + 1)
#                     plt.imshow(cv2.cvtColor(roi, cv2.COLOR_BGR2RGB))
#                     plt.title(f"ROI {i + 1}\n{roi.size}")
#                     plt.axis('off')
#                 plt.show()
#         else:
#             print(f"Could not load image: {image_path}")

-----POTENTIAL MISMATCH-----
Mismatch for cleaned_images\010ud-0.png: 3 ROIs but label length is 5.
Mismatch for cleaned_images\02cc-0.png: 2 ROIs but label length is 4.
Mismatch for cleaned_images\0axhfa-0.png: 4 ROIs but label length is 6.
Mismatch for cleaned_images\0duo1r-0.png: 3 ROIs but label length is 6.
Mismatch for cleaned_images\0fn5eioh-0.png: 9 ROIs but label length is 8.
Mismatch for cleaned_images\0h7htrjq-0.png: 7 ROIs but label length is 8.
TOTAL MISMATCH COUNT: 6


In [16]:
saved_images_df = pd.DataFrame(saved_images_data)
saved_images_df

Unnamed: 0,image_path,label,saved_char_paths
0,cleaned_images\0024miih-0.png,0024miih,"[extracted_chars\0024miih-0_char_1.png, extrac..."
1,cleaned_images\002k-0.png,002k,"[extracted_chars\002k-0_char_1.png, extracted_..."
2,cleaned_images\006aguv-0.png,006aguv,"[extracted_chars\006aguv-0_char_1.png, extract..."
3,cleaned_images\00hai-0.png,00hai,"[extracted_chars\00hai-0_char_1.png, extracted..."
4,cleaned_images\00hgi3n7-0.png,00hgi3n7,"[extracted_chars\00hgi3n7-0_char_1.png, extrac..."
...,...,...,...
89,cleaned_images\0grcaqy-0.png,0grcaqy,"[extracted_chars\0grcaqy-0_char_1.png, extract..."
90,cleaned_images\0gs296j-0.png,0gs296j,"[extracted_chars\0gs296j-0_char_1.png, extract..."
91,cleaned_images\0gwk-0.png,0gwk,"[extracted_chars\0gwk-0_char_1.png, extracted_..."
92,cleaned_images\0h4g-0.png,0h4g,"[extracted_chars\0h4g-0_char_1.png, extracted_..."


In [None]:
from PIL import Image

# Sanity check (see if cropped) --> dont need to run everytime
def print_image_size(image_path):
    # Open the image using PIL
    with Image.open(image_path) as img:
        # Get image dimensions
        width, height = img.size
        return f"{width}x{height}"

for _, row in saved_images_df.iterrows():
    plt.figure(figsize=(10, 5))
    for i, char_path in enumerate(row['saved_char_paths']):
        plt.subplot(1, len(row['saved_char_paths']), i + 1)
        plt.imshow(cv2.cvtColor(cv2.imread(char_path), cv2.COLOR_BGR2RGB))
        plt.title(f"{row['label'][i]}\n{print_image_size(row['saved_char_paths'][i])}")  # Use the actual character from the label
        plt.axis('off')
    plt.show()

In [None]:
# Segmented images are in saved_images_df