# Data Exploration for image dataset

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

In [None]:
# Load image dataset
train_dir = 'data/train_data/'
test_dir = 'data/test_data/'
train_labels = pd.read_csv('data/train_labels.csv')

# Display sample images from each class
classes = train_labels['label'].unique()
for cls in classes:
    sample_image_path = train_dir + train_labels[train_labels['label'] == cls]['sample_index'].values[0]
    img = Image.open(sample_image_path)
    plt.imshow(img)
    plt.title(f'Class: {cls}')
    plt.axis('off')
    plt.show()

In [None]:
# Remove contaminated images from training data
import shutil
import os

# Parse the contaminated indices from the text file
contaminated_indices = []
with open('shrek_and_slimes.txt', 'r') as f:
    for line in f:
        line = line.strip()
        if line and line.isdigit():
            contaminated_indices.append(int(line))

print(f"Found {len(contaminated_indices)} contaminated samples to remove")

# Remove corresponding image and mask files
removed_count = 0
for idx in contaminated_indices:
    img_name = f'img_{idx:04d}.png'
    mask_name = f'mask_{idx:04d}.png'

    img_path = os.path.join(train_dir, img_name)
    mask_path = os.path.join(train_dir, mask_name)

    # Remove image if exists
    if os.path.exists(img_path):
        os.remove(img_path)
        removed_count += 1

    # Remove mask if exists
    if os.path.exists(mask_path):
        os.remove(mask_path)
        removed_count += 1

print(f"Removed {removed_count} files from {train_dir}")

# Update train_labels by removing contaminated indices
train_labels = train_labels[~train_labels['sample_index'].str.extract(r'(\d+)')[0].astype(int).isin(contaminated_indices)]
print(f"Training labels updated: {len(train_labels)} samples remaining")

In [None]:
# Class distribution
plt.figure(figsize=(10, 6))
sns.countplot(data=train_labels, x='label', order=train_labels['label'].value_counts().index)

In [None]:
# Show image pixel sizes throughout dataset
sizes = []
for idx in train_labels['sample_index']:
    img = Image.open(train_dir + idx)
    sizes.append(img.size)

sizes_df = pd.DataFrame(sizes, columns=['width', 'height'])
print(sizes_df.describe())

In [None]:
# Analyze image size distribution in train and test sets
import os

# Get all sizes from train set
train_sizes = []
for idx in train_labels['sample_index']:
    img = Image.open(train_dir + idx)
    train_sizes.append(img.size)

# Get all sizes from test set
test_files = os.listdir(test_dir)
test_sizes = []
for filename in test_files:
    img = Image.open(test_dir + filename)
    test_sizes.append(img.size)

# Create DataFrames
train_sizes_df = pd.DataFrame(train_sizes, columns=['width', 'height'])
test_sizes_df = pd.DataFrame(test_sizes, columns=['width', 'height'])

print("=== TRAIN SET IMAGE SIZES ===")
print(train_sizes_df.describe())
print(f"\nTotal train images: {len(train_sizes_df)}")
print(f"\nUnique sizes in train set: {train_sizes_df.value_counts()}")

print("\n=== TEST SET IMAGE SIZES ===")
print(test_sizes_df.describe())
print(f"\nTotal test images: {len(test_sizes_df)}")
print(f"\nUnique sizes in test set: {test_sizes_df.value_counts()}")

# Find the smallest dimensions
min_width_train = train_sizes_df['width'].min()
min_height_train = train_sizes_df['height'].min()
min_width_test = test_sizes_df['width'].min()
min_height_test = test_sizes_df['height'].min()

print(f"\n=== MINIMUM DIMENSIONS ===")
print(f"Train - Min width: {min_width_train}, Min height: {min_height_train}")
print(f"Test - Min width: {min_width_test}, Min height: {min_height_test}")
print(f"Overall minimum - Width: {min(min_width_train, min_width_test)}, Height: {min(min_height_train, min_height_test)}")

# Count how many images have non-standard sizes
most_common_size = tuple(train_sizes_df.mode().iloc[0])
print(f"\n=== SIZE VARIATION ANALYSIS ===")
print(f"Most common size in train: {most_common_size}")
non_standard_train = len(train_sizes_df[(train_sizes_df['width'] != most_common_size[0]) |
                                         (train_sizes_df['height'] != most_common_size[1])])
print(f"Images with non-standard size in train: {non_standard_train} ({non_standard_train/len(train_sizes_df)*100:.2f}%)")

if len(test_sizes_df) > 0:
    most_common_size_test = tuple(test_sizes_df.mode().iloc[0])
    print(f"Most common size in test: {most_common_size_test}")
    non_standard_test = len(test_sizes_df[(test_sizes_df['width'] != most_common_size_test[0]) |
                                          (test_sizes_df['height'] != most_common_size_test[1])])
    print(f"Images with non-standard size in test: {non_standard_test} ({non_standard_test/len(test_sizes_df)*100:.2f}%)")

In [None]:
# Analyze relationship between image dimensions and labels
train_sizes_with_labels = train_sizes_df.copy()
train_sizes_with_labels['label'] = train_labels['label'].values

# Calculate area
train_sizes_with_labels['area'] = train_sizes_with_labels['width'] * train_sizes_with_labels['height']

print("=== DIMENSIONS BY LABEL ===")
print(train_sizes_with_labels.groupby('label')[['width', 'height', 'area']].describe())

# Visualize width distribution by label
plt.figure(figsize=(14, 5))

plt.subplot(1, 3, 1)
sns.boxplot(data=train_sizes_with_labels, x='label', y='width')
plt.title('Width Distribution by Label')
plt.xticks(rotation=45)

plt.subplot(1, 3, 2)
sns.boxplot(data=train_sizes_with_labels, x='label', y='height')
plt.title('Height Distribution by Label')
plt.xticks(rotation=45)

plt.subplot(1, 3, 3)
sns.boxplot(data=train_sizes_with_labels, x='label', y='area')
plt.title('Area Distribution by Label')
plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

# Check if specific labels have different size patterns
print("\n=== SIZE STATISTICS BY LABEL ===")
for label in sorted(train_sizes_with_labels['label'].unique()):
    label_data = train_sizes_with_labels[train_sizes_with_labels['label'] == label]
    unique_sizes = label_data[['width', 'height']].value_counts()
    print(f"\nLabel {label}:")
    print(f"  Count: {len(label_data)}")
    print(f"  Width range: {label_data['width'].min()} - {label_data['width'].max()}")
    print(f"  Height range: {label_data['height'].min()} - {label_data['height'].max()}")
    print(f"  Unique sizes: {len(unique_sizes)}")
    if len(unique_sizes) <= 5:
        print(f"  All sizes:\n{unique_sizes}")