# Dataset Analysis and Pipeline Validation
## Image-Based Product Category Classification

This notebook covers the initial analysis of the raw dataset and validates the effectiveness of our automated data preparation pipeline.

The goals of this notebook are to:
- Inspect the raw dataset and its original category distribution.
- **Validate the new category structure** created by our NLP-based merging script.
- **Analyze the balance of the final, merged dataset**.
- Verify data quality, such as the validity of image URLs.

In [None]:
# Import tools needed for data analysis and visualization
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import requests
from concurrent.futures import ThreadPoolExecutor

plt.style.use("default")

## 1. Loading the Dataset

The dataset is provided as two CSV files, the first one contains product title, image URLs and category id.
The categories.csv, has the categories names.
Due to its size, the dataset is hosted externally and downloaded using a dedicated script.

In [None]:
products_df = pd.read_csv("../data/raw/products.csv")
categories_df = pd.read_csv("../data/raw/categories.csv")
products_df.head()

In [None]:
print(f"Number of samples: {len(products_df)}")
print(f"Number of columns: {len(products_df.columns)}")
products_df.columns

- The dataset contains product title, image URLs and their corresponding category id.
- The category labels are stored in a seperate csv file.
- The following analysis focuses on cleaning and refining these fields.

## 2. Raw Category Inspection

We first examine the distribution of all available categories to understand their frequency and suitability for image-based classification.

In [None]:
raw_category_counts = products_df["category_id"].value_counts()
print(categories_df[["id","category_name"]].head())

print("\nMin frequency of a category: " + str(raw_category_counts.min()))
print("Max frequency of a category: " + str(raw_category_counts.max()))
print("Mean frequency of a category: " + str(int(raw_category_counts.mean())))

## 3. Automated Category Merging

The original dataset contains 248 categories, many of which are either too specific, too broad, or semantically very similar. A manual approach to filtering or merging is not scalable or reproducible.

Also there are some categories that have a few products in them, and can unbalance the dataset.

To solve this, we implemented an automated pipeline that uses **NLP (Sentence-Transformers) and Clustering (K-Means)** to group the 248 original categories into a more manageable and balanced set of around 50 new categories. 

The following steps analyze the result of that pipeline.

First, we load the final processed dataset and the mapping file generated by our scripts.

In [None]:
# Load the results of the new data preparation pipeline
products_cleaned_df = pd.read_csv("../data/processed/products_cleaned.csv")
mapping_df = pd.read_csv("../data/processed/category_mapping.csv")

print(f"Loaded {len(products_cleaned_df)} cleaned products.")
print(f"Loaded mapping for {len(mapping_df)} original categories into {mapping_df['merged_category_id'].nunique()} new categories.")

For securing the dataset, categories with fewer than 100 samples were removed to ensure
a more balanced and learnable dataset.

After modifying, the final category set is obtained.
These categories are visually distinguishable and suitable for classification.


## 4. Dataset Balance Analysis

We analyze the number of samples per category to assess dataset balance.

In [None]:
# Getting ready for visualization
category_counts = products_cleaned_df["merged_category_id"].value_counts()
category_counts_df = (
    category_counts
    .reset_index()
)
category_counts_df.columns = ["merged_category_id", "count"]
category_counts_df = category_counts_df.merge(
    mapping_df,
    on="merged_category_id",
    how="left"
)

In [None]:
plt.figure(figsize=(15, 6))

plt.bar(
    category_counts_df["category_name"],
    category_counts_df["count"]
)

plt.xticks(rotation=90)
plt.xlabel("Category")
plt.ylabel("Number of Products")
plt.title("Category Distribution")

plt.tight_layout()
plt.show()

The graph shows some amount imbalance across categories.
This imbalance will be addressed in later phases using data augmentation and training strategies.

## 5. Image URL Quality Check

To assess data quality, we verify whether image URLs are reachable.
Due to dataset size, this check is performed on a random subset.


In [None]:
def is_url_valid(url):
    try:
        r = requests.get(
            url,
            timeout=5,
            stream=True,
            allow_redirects=True,
            headers={"User-Agent": "Mozilla/5.0"}
        )
        return r.status_code == 200
    except requests.RequestException:
        return False
    
def check_urls_parallel(urls, max_workers=40):
    results = []
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        for res in tqdm(executor.map(is_url_valid, urls), total=len(urls)):
            results.append(res)
    return results

sample_df = products_cleaned_df.sample(200, random_state=42).copy()

sample_df["valid"] = check_urls_parallel(
    sample_df["imgUrl"].tolist(),
    max_workers=20
)

sample_df["valid"].value_counts()

The sample-based validation indicates that most image URLs are reachable.
Invalid URLs are removed automatically during the image download and training stage, where missing images will be skipped.

## 6. Example Images from Each Category

To better understand the visual characteristics of each category and verify that
categories are visually distinguishable, we display a small number of example
images from each category.

For each category, 2–3 sample images are randomly selected and visualized.

In [None]:
import os
import time
import hashlib
import requests
from PIL import Image
from io import BytesIO

# Base cache directory
IMAGE_CACHE_DIR = "../data/raw/images"
os.makedirs(IMAGE_CACHE_DIR, exist_ok=True)

def _url_to_filename(url):
    """
    Create a deterministic filename from a URL using hashing.
    """
    return hashlib.md5(url.encode("utf-8")).hexdigest() + ".jpg"


def load_image_from_url(
    url,
    timeout=5,
    retries=3,
    backoff=1.0,
    cache_dir=IMAGE_CACHE_DIR
):
    """
    Fetch an image from a URL with caching and retries.

    If the image was previously downloaded, it is loaded from disk
    instead of downloading again.

    Args:
        url (str): Image URL
        timeout (int): Request timeout in seconds
        retries (int): Number of retry attempts
        backoff (float): Seconds to wait between retries
        cache_dir (str): Directory for cached images

    Returns:
        PIL.Image or None
    """
    filename = _url_to_filename(url)
    filepath = os.path.join(cache_dir, filename)

    # 1 -> Load from cache if exists
    if os.path.exists(filepath):
        try:
            return Image.open(filepath).convert("RGB")
        except Exception:
            # Corrupted cache → remove and re-download
            os.remove(filepath)

    # 2 -> Download with retries
    for attempt in range(1, retries + 1):
        try:
            response = requests.get(url, timeout=timeout)
            response.raise_for_status()

            # Ensure it's an image
            if "image" not in response.headers.get("Content-Type", ""):
                raise ValueError("URL did not return an image")

            img = Image.open(BytesIO(response.content)).convert("RGB")

            # Save to cache
            img.save(filepath, format="JPEG", quality=90)

            return img

        except Exception:
            if attempt < retries:
                time.sleep(backoff)
            else:
                return None

In [None]:
def fetch_image_with_info(row):
    """
    Fetch image and return metadata (Category ID, Image Object, Product Title).
    """
    img = load_image_from_url(row["imgUrl"])
    # Return a tuple: (category_id, image_obj, product_title)
    return row["merged_category_id"], img, row["title"]

In [None]:
id_to_name = dict(zip(mapping_df["merged_category_id"], mapping_df["category_name"]))
# Gather Data
IMAGES_PER_CATEGORY = 2
rows = []
categories = products_cleaned_df["merged_category_id"].unique()

for category in categories:
    cat_df = products_cleaned_df[products_cleaned_df["merged_category_id"] == category]
    # Sample data
    sampled = cat_df.sample(min(IMAGES_PER_CATEGORY, len(cat_df)), random_state=42)
    rows.extend(sampled.to_dict("records"))

# Download in Parallel
images_with_meta = []
with ThreadPoolExecutor(max_workers=20) as executor:
    for result in tqdm(executor.map(fetch_image_with_info, rows), total=len(rows), desc="Downloading"):
        images_with_meta.append(result)

In [None]:
import textwrap

# Create Subplots
n_rows = len(categories)
n_cols = IMAGES_PER_CATEGORY

fig, axes = plt.subplots(
    n_rows, 
    n_cols, 
    figsize=(n_cols * 4, n_rows * 3) # increased width slightly for text
)

idx = 0
for row_idx, category_id in enumerate(categories):
    for col_idx in range(IMAGES_PER_CATEGORY):
        
        if n_rows == 1:
            ax = axes[col_idx]
        else:
            ax = axes[row_idx][col_idx]
            
        img_cat_id, img, title = images_with_meta[idx]

        if img is not None:
            ax.imshow(img)
        else:
            ax.text(0.5, 0.5, "Image\nNot Available", 
                    ha="center", va="center")
        
        ax.axis("off")

        # Add Product Title to Image
        wrapped_title = "\n".join(textwrap.wrap(title, width=30))[:100] # Limit chars
        ax.set_title(wrapped_title, fontsize=9)

        # Add Category Name to Row 
        if col_idx == 0:
            cat_name = id_to_name.get(category_id, f"Cat: {category_id}")
            
            wrapped_cat_name = "\n".join(textwrap.wrap(cat_name, width=15))
            
            ax.set_ylabel(
                wrapped_cat_name, 
                rotation=0, 
                labelpad=60,
                fontsize=11, 
                fontweight='bold',
                ha='right',
                va='center'
            )

        idx += 1

plt.tight_layout()
plt.show()

The example images demonstrate that the selected categories are visually distinct
and suitable for image-based classification. This qualitative inspection supports
the feasibility of the proposed learning task.

## 7. Image Dimension Distribution

Images collected from online marketplaces often vary significantly in resolution.
To analyze this variability and justify image resizing during preprocessing,
we examine the distribution of image widths and heights.

A random sample of 1000 images is used for this analysis.

In [None]:
# Sample images
SAMPLE_SIZE = 1000
sample_df = products_cleaned_df.sample(
    min(SAMPLE_SIZE, len(products_cleaned_df)),
    random_state=42
)

def get_image_size(url):
    img = load_image_from_url(url)
    if img is not None:
        return img.size  # (width, height)
    return None

sizes = []

with ThreadPoolExecutor(max_workers=20) as executor:
    for size in tqdm(
        executor.map(get_image_size, sample_df["imgUrl"]),
        total=len(sample_df),
        desc="Fetching image dimensions"
    ):
        if size is not None:
            sizes.append(size)

print(f"Collected dimensions for {len(sizes)} images")

In [None]:
widths, heights = zip(*sizes)

plt.figure(figsize=(6, 6))
plt.scatter(widths, heights, alpha=0.4, s=10)
plt.xlabel("Image Width (pixels)")
plt.ylabel("Image Height (pixels)")
plt.title("Image Width vs Height Distribution")
plt.grid(True, linestyle="--", alpha=0.3)
plt.tight_layout()
plt.show()

The scatter plot shows a wide variation in image dimensions across the dataset.
Both width and height span a large range of values, indicating that images are
not captured or stored at a consistent resolution.

This variability motivates the use of image resizing as a necessary
preprocessing step before model training to ensure consistent input dimensions.

### 8. Average Image Color per Category
To get a high-level visual sense of the image data, we can compute the average color for images within each category. This can reveal dominating color biases in the dataset (e.g., electronics being mostly black or silver) and helps justify the use of color-based data augmentation. We'll sample 20 images from each category to compute a representative average.

In [None]:
def get_average_color(url):
    """Fetches an image and computes its average RGB color."""
    try:
        img = load_image_from_url(url)
        if img:
            img.thumbnail((50, 50))
            avg_color = np.array(img).mean(axis=(0, 1))
            return avg_color
    except Exception:
        return None
    return None

In [None]:
SAMPLES_PER_CATEGORY = 20
category_avg_colors = []

unique_categories = products_cleaned_df["merged_category_id"].unique()

print(f"Sampling {SAMPLES_PER_CATEGORY} images for each of the {len(unique_categories)} categories...")

def process_category_color(category_id):
    """
    Fetches samples for a specific category and calculates the overall mean color.
    """
    cat_df = products_cleaned_df[products_cleaned_df["merged_category_id"] == category_id]
    
    sample_size = min(SAMPLES_PER_CATEGORY, len(cat_df))
    sampled_urls = cat_df.sample(n=sample_size, random_state=42)["imgUrl"].tolist()
    
    image_colors = []
    with ThreadPoolExecutor(max_workers=10) as executor:
        results = executor.map(get_average_color, sampled_urls)
        
        for color in results:
            if color is not None:
                image_colors.append(color)
    
    if image_colors:
        final_color = np.mean(np.array(image_colors), axis=0)
        
        cat_name_row = mapping_df[mapping_df["merged_category_id"] == category_id]
        cat_name = cat_name_row["category_name"].iloc[0] if not cat_name_row.empty else f"Cat {category_id}"
        
        return {
            "category_id": category_id,
            "category_name": cat_name,
            "rgb_color": final_color,
            "normalized_color": final_color / 255.0  # Normalized
        }
    return None

for cat_id in tqdm(unique_categories, desc="Processing Categories"):
    result = process_category_color(cat_id)
    if result:
        category_avg_colors.append(result)

colors_df = pd.DataFrame(category_avg_colors)

In [None]:
# Visualization
if not colors_df.empty:
    plt.figure(figsize=(15, 6))
    
    bars = plt.bar(
        colors_df["category_name"], 
        [1] * len(colors_df), 
        color=colors_df["normalized_color"],
        edgecolor="gray"
    )
    
    plt.xticks(rotation=90)
    plt.yticks([])
    plt.xlabel("Category")
    plt.title(f"Average Color Profile per Category (n={SAMPLES_PER_CATEGORY})")
    plt.tight_layout()
    plt.show()
else:
    print("No color data extracted.")

While the images in our dataset consistently feature **pure white backgrounds**, we have elected to retain the images as-is rather than performing/training complex background segmentation or mask generation. Our reasoning is two-fold:

1. **High Signal Clarity:** The white background provides a high contrast ratio, allowing the model to easily identify product contours and textures without the initial interference of environmental "noise."
2. **Robustness via Augmentation:** To prevent the model from overfitting to this "studio setting" (and failing on real-world photos), by applying random brightness shifts, rotations, and color jitters, we simulate real-world variability. This approach is computationally more efficient than segmentation while still forcing the model to learn invariant product features rather than just background-to-object contrast.

### Analysis of Average Image Color

The visualization below displays the **mean RGB value** for a sample of images within each category. This analysis serves as a validation tool for our dataset:

* **Consistency Check:** A very light average color across all categories confirms the dominance of the white background, verifying that our sampling is consistent.
* **Inter-Category Variance:** Subtle differences in the average color (e.g., a darker average for "Books" vs. a lighter average for "Electronics") indicate the typical "bulk" or "density" of the products within that category.
* **Model Bias Awareness:** Since the average color is heavily skewed toward white (RGB ~255, 255, 255), this output reminds us that our normalization layer must be carefully tuned to ensure that the actual product features (the "minority" of pixels) are not washed out during training.

## 8. Training Pipeline
To have a high level view and a brief explanation, during training phase, we do the Augmentation process:
- **Resize:** To a uniform `224x224` pixels.
- **RandomHorizontalFlip:** p=0.5.
- **RandomRotation:** 15 degrees.
- **ColorJitter:** brightness=0.2, contrast=0.2, saturation=0.2.
- **ToTensor:** Convert PIL Image to PyTorch Tensor.
- **Normalize:** Use standard ImageNet mean `[0.485, 0.456, 0.406]` and std `[0.229, 0.224, 0.225]`.

For validation and testing we do only:
- **Resize:** `224x224` pixels.
- **ToTensor:** Convert to Tensor.
- **Normalize:** Use the same mean and std as the training pipeline.

## Summary

In this notebook, we performed a full cycle of data analysis and pipeline validation:
- We began by inspecting the **raw dataset** and its original, imbalanced category structure.
- We then loaded the results of our **automated NLP-based category merging pipeline**.
- We **visualized the final dataset's balance**, confirming that the new 50 categories are much more evenly distributed.
- We reviewed the mapping to understand how original categories were grouped.
- We verified the quality of image URLs in the dataset.

The resulting dataset, `products_cleaned.csv`, is now validated and ready to serve as the foundation for model training.