# WaferScan AI: Data Processing & Exploration

This notebook handles the initial data pipeline for the WM-811K dataset. It performs:
1.  **Loading & Cleaning**: Parsing the raw pickle file and removing corrupted entries.
2.  **Label Mapping**: Converting string defect types to integer class IDs (0-8).
3.  **Visualization**: Inspecting class balance and sample defects.
4.  **Stratified Split**: Creating reproducible Train/Val/Test sets.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## 1. Environment Setup

In [None]:
PROJECT_ROOT = "/content/drive/MyDrive/wafer-hackathon"

!mkdir -p $PROJECT_ROOT/data/raw
!mkdir -p $PROJECT_ROOT/data/processed
!mkdir -p $PROJECT_ROOT/models/checkpoints
!mkdir -p $PROJECT_ROOT/models/final
!mkdir -p $PROJECT_ROOT/models/onnx
!mkdir -p $PROJECT_ROOT/models/metrics
!mkdir -p $PROJECT_ROOT/docs

print("Project structure verified.")

In [None]:
# Install dependencies
!pip install numpy==1.26.4 timm==0.9.12 netcal==1.3.5 wandb==0.16.1 pandas==2.0.3

## 2. Data Loader Implementation

Handles legacy pandas compatibility and robustly parses the nested pickle structure.

In [None]:
import pickle
import os
import sys
from collections import Counter
from typing import List, Tuple, Dict, Any

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

# Legacy pandas pickle support
if "pandas.indexes" not in sys.modules:
    sys.modules["pandas.indexes"] = pd.core.indexes


def load_wafer_dataset(
    pkl_path: str,
) -> Tuple[List[np.ndarray], List[int], pd.DataFrame]:
    """
    Load and validate WM-811K dataset from pickle file.
    """
    if not os.path.exists(pkl_path):
        raise FileNotFoundError(f"Dataset file not found: {pkl_path}")

    print(f"Loading dataset from {pkl_path}...")
    try:
        with open(pkl_path, "rb") as f:
            data = pickle.load(f, encoding="latin1")
    except Exception as e:
        raise ValueError(f"Failed to load pickle: {e}")

    # Define class mapping
    class_map = {
        "none": 0,
        "Center": 1,
        "Donut": 2,
        "Edge-Loc": 3,
        "Edge-Ring": 4,
        "Loc": 5,
        "Near-full": 6,
        "Random": 7,
        "Scratch": 8,
    }

    images: List[np.ndarray] = []
    labels: List[int] = []
    metadata_rows: List[Dict[str, Any]] = []

    skipped_count = 0
    iterator = data
    total = len(data)

    if isinstance(data, pd.DataFrame):
        iterator = data.itertuples()

    for record in tqdm(iterator, total=total, desc="Processing wafers"):
        try:
            if isinstance(record, tuple) and hasattr(record, "waferMap"):
                wafer_map = record.waferMap
                failure_type = record.failureType
                lot_name = getattr(record, "lotName", "")
                wafer_index = getattr(record, "waferIndex", 0)
            elif isinstance(record, dict):
                wafer_map = record.get("waferMap")
                failure_type = record.get("failureType")
                lot_name = record.get("lotName", "")
                wafer_index = record.get("waferIndex", 0)
            else:
                wafer_map = getattr(record, "waferMap", None)
                failure_type = getattr(record, "failureType", None)
                lot_name = getattr(record, "lotName", "")
                wafer_index = getattr(record, "waferIndex", 0)

            if wafer_map is None or failure_type is None:
                skipped_count += 1
                continue

            wafer_map = np.array(wafer_map)
            if wafer_map.ndim != 2:
                skipped_count += 1
                continue

            if isinstance(failure_type, np.ndarray):
                if failure_type.size == 0:
                    f_label = "none"
                else:
                    item = failure_type.flat[0]
                    f_label = str(item)
            elif isinstance(failure_type, list):
                if len(failure_type) == 0:
                    f_label = "none"
                else:
                    item = failure_type[0]
                    if isinstance(item, list) and len(item) > 0:
                        f_label = str(item[0])
                    else:
                        f_label = str(item)
            elif isinstance(failure_type, str):
                f_label = failure_type
            else:
                skipped_count += 1
                continue

            f_label = f_label.strip()
            if f_label not in class_map:
                skipped_count += 1
                continue

            label_idx = class_map[f_label]

            images.append(wafer_map)
            labels.append(label_idx)
            metadata_rows.append(
                {
                    "lotName": lot_name,
                    "waferIndex": wafer_index,
                    "failureType": f_label,
                    "mapped_label": label_idx,
                }
            )
        except Exception:
            skipped_count += 1
            continue

    metadata_df = pd.DataFrame(metadata_rows)
    valid_count = len(images)

    print("-" * 40)
    print(f"Data Loading Complete: {valid_count} valid samples")
    print(f"Skipped samples: {skipped_count}")
    print("-" * 40)

    if valid_count > 0:
        heights = [img.shape[0] for img in images]
        widths = [img.shape[1] for img in images]
        print(f"Image Heights: Min={min(heights)}, Max={max(heights)}")
        print(f"Image Widths:  Min={min(widths)}, Max={max(widths)}")

    return images, labels, metadata_df

## 3. Execution: Load Raw Data

In [None]:
images, labels, metadata = load_wafer_dataset(
    f'{PROJECT_ROOT}/data/raw/LSWMD.pkl'
)

print("Loaded:", len(images))
print("Metadata shape:", metadata.shape)
print("Unique labels:", sorted(set(labels)))
print("Label distribution:", Counter(labels))

## 4. Save Intermediate Processed Data

Saves cleaned data before splitting.

In [None]:
processed_path = f'{PROJECT_ROOT}/data/processed'
os.makedirs(processed_path, exist_ok=True)

with open(f'{processed_path}/processed_dataset.pkl', 'wb') as f:
    pickle.dump({
        'images': images,
        'labels': labels,
        'metadata': metadata
    }, f)

print(f"processed_dataset.pkl saved inside {processed_path}/")

## 5. Exploratory Visualization

**Class Distribution**: Shows severe class imbalance (96% 'none').
**Sample Defects**: Visual confirmation of label mapping.

In [None]:
docs_path = f'{PROJECT_ROOT}/docs'
os.makedirs(docs_path, exist_ok=True)

label_counts = Counter(labels)
sorted_counts = dict(sorted(label_counts.items()))

class_names = {
    0: 'none',
    1: 'Center',
    2: 'Donut',
    3: 'Edge-Loc',
    4: 'Edge-Ring',
    5: 'Loc',
    6: 'Near-full',
    7: 'Random',
    8: 'Scratch'
}

x = [class_names[k] for k in sorted_counts.keys()]
y = list(sorted_counts.values())

plt.figure(figsize=(10, 5))
bars = plt.bar(x, y)
plt.title('Class Distribution (WM-811K)')
plt.xlabel('Failure Type')
plt.ylabel('Count')
plt.xticks(rotation=45)

for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, height, f'{height:,}', ha='center', va='bottom', fontsize=8)

plt.tight_layout()
plt.savefig(f'{docs_path}/class_distribution.png', dpi=300, bbox_inches='tight')
plt.show()

defect_counts = [v for k, v in sorted_counts.items() if k != 0]
if defect_counts:
    imbalance_ratio = max(defect_counts) / min(defect_counts)
    print(f"Defect Imbalance Ratio (max/min): {imbalance_ratio:.2f}")

In [None]:
sample_indices = {}
for idx, label in enumerate(labels):
    if label not in sample_indices:
        sample_indices[label] = idx
    if len(sample_indices) == 9:
        break

fig, axes = plt.subplots(3, 3, figsize=(10, 10))
axes = axes.flatten()

for class_id in range(9):
    ax = axes[class_id]
    if class_id in sample_indices:
        img = images[sample_indices[class_id]]
        ax.imshow(img, cmap='gray')
        ax.set_title(class_names[class_id])
    ax.axis('off')

plt.tight_layout()
plt.savefig(f'{docs_path}/sample_defects.png', dpi=300, bbox_inches='tight')
plt.show()

## 6. Stratified Train/Val/Test Split

Splits the data while maintaining class proportions.
- **Train**: 70%
- **Validation**: 15%
- **Test**: 15%

In [None]:
import pickle
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit

# Load existing processed dataset
processed_path = '/content/drive/MyDrive/wafer-hackathon/data/processed/processed_dataset.pkl'

with open(processed_path, 'rb') as f:
    data = pickle.load(f)

all_labels = np.array(data['labels'])
all_indices = np.arange(len(all_labels))

# First split: 70% train, 30% temp
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=0.30, random_state=42)
train_idx, temp_idx = next(sss1.split(all_indices, all_labels))

# Second split: 50% val, 50% test from temp (15% each of total)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.50, random_state=42)
val_idx, test_idx = next(sss2.split(temp_idx, all_labels[temp_idx]))
val_idx = temp_idx[val_idx]
test_idx = temp_idx[test_idx]

print(f'Train size: {len(train_idx)}')
print(f'Validation size: {len(val_idx)}')
print(f'Test size: {len(test_idx)}')

# Add indices to existing dictionary
data['train_indices'] = train_idx
data['val_indices'] = val_idx
data['test_indices'] = test_idx

# Re-save
with open(processed_path, 'wb') as f:
    pickle.dump(data, f)

print('processed_dataset.pkl updated with split indices.')