In [None]:
import os
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms
from sklearn.model_selection import train_test_split


# Define the classes
classes = ['cherry', 'strawberry', 'tomato']
data_dir = './train_data'

# Dictionary to store the loaded images
data = {}

# List of images to exclude
excluded_images = {
    'cherry_0055.jpg',
    'cherry_0105.jpg',
    'cherry_0147.jpg',
    'strawberry_0931.jpg',
    'tomato_0087.jpg'
}

for class_name in classes:
    class_dir = os.path.join(data_dir, class_name)
    images = []
    
    # Loop through all files in the class directory
    for file_name in os.listdir(class_dir):
        if file_name.endswith('.jpg'):  # Check for image files
            # Check if the file should be excluded
            if file_name in excluded_images:
                continue  # Skip this file
            file_path = os.path.join(class_dir, file_name)
            
            # Open the image and append it to the list
            img = Image.open(file_path)
            images.append(img)
    
    # Store images for this class
    data[class_name] = images


# Example: Accessing images from the 'cherry' class
print(f'Loaded {len(data["cherry"])} images from cherry class.')
print(f'Loaded {len(data["strawberry"])} images from strawberry class.')
print(f'Loaded {len(data["tomato"])} images from tomato class.')

In [None]:
import numpy as np
from collections import defaultdict

def detect_and_filter_rgb_outliers(image_data, thresholds):
    filtered_data = defaultdict(list)
    outliers = []
    grayscale_count = 0
    total_input_images = sum(len(images) for images in image_data.values())
    
    for class_name, images in image_data.items():
        for img in images:
            img_np = np.array(img)  # Convert image to NumPy array
            
            if len(img_np.shape) == 2:  # Grayscale image (only height and width)
                grayscale_count += 1
                continue
            
            # Calculate the mean pixel intensity for each RGB channel
            mean_channels = np.mean(img_np, axis=(0, 1))
            
            # Detect if any of the channels are outside their specific thresholds
            condition = (mean_channels < [t[0] for t in thresholds]) | (mean_channels > [t[1] for t in thresholds])
            if np.any(condition):
                outliers.append(img)
            else:
                filtered_data[class_name].append(img)
    
    total_processed_images = sum(len(images) for images in filtered_data.values()) + len(outliers)
    
    print(f"Input images: {total_input_images}")
    print(f"Processed images: {total_processed_images}")
    print(f"Removed Grayscale images: {grayscale_count}")
    print(f"RGB images: {total_processed_images - grayscale_count}")
    print(f"Outliers: {len(outliers)}")
    print(f"Images in filtered_data: {sum(len(images) for images in filtered_data.values())}")
    
    return dict(filtered_data), outliers

# Define channel-specific thresholds based on the distributions
thresholds = [
    (27, 238),  # Red channel (low, high)
    (14, 220),  # Green channel (low, high)
    (8, 218)    # Blue channel (low, high)
]

# Use the optimized function with new thresholds
filtered_data, rgb_outliers = detect_and_filter_rgb_outliers(filtered_data, thresholds)
print(f'\nFound {len(rgb_outliers)} potential RGB channel-based outliers out of {sum(len(images) for images in filtered_data.values()) + len(rgb_outliers)} total images.')
print(f'Filtered data contains {sum(len(images) for images in filtered_data.values())} images after RGB channel-based filtering.')