# WaferScan AI: Data Exploration & Processing

This notebook performs the initial data pipeline:
1.  **Loading & Cleaning**: Parsing raw WM-811K data.
2.  **Imbalance Analysis**: Quantifying class distribution.
3.  **Visualization**: Inspecting sample defects.
4.  **Stratified Split**: Creating reproducible Train/Val/Test sets.

In [None]:
import os
import sys
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter
from tqdm import tqdm
from google.colab import drive

# Mount Drive
drive.mount('/content/drive')

# Define Paths
PROJECT_ROOT = "/content/drive/MyDrive/wafer-hackathon"
RAW_DATA_PATH = f"{PROJECT_ROOT}/data/raw/LSWMD.pkl"
PROCESSED_PATH = f"{PROJECT_ROOT}/data/processed"
DOCS_PATH = f"{PROJECT_ROOT}/docs"

# Ensure directories exist
os.makedirs(PROCESSED_PATH, exist_ok=True)
os.makedirs(DOCS_PATH, exist_ok=True)

print("Environment Setup Complete.")

## Data Processing Pipeline (Raw -> Processed)

Helper function to load and clean the legacy pickle file.

In [None]:
# Legacy pandas compatibility
if "pandas.indexes" not in sys.modules:
    sys.modules["pandas.indexes"] = pd.core.indexes

def load_wafer_dataset(pkl_path):
    """
    Load and clean WM-811K dataset from pickle.
    Extracted from actual_01_data_exploration.ipynb.
    """
    if not os.path.exists(pkl_path):
        raise FileNotFoundError(f"File not found: {pkl_path}")
    
    print(f"Loading raw data from {pkl_path}...")
    try:
        with open(pkl_path, "rb") as f:
            data = pickle.load(f, encoding="latin1")
    except Exception as e:
        raise ValueError(f"Pickle load failed: {e}")

    # Class mapping
    class_map = {
        "none": 0, "Center": 1, "Donut": 2, "Edge-Loc": 3, 
        "Edge-Ring": 4, "Loc": 5, "Near-full": 6, "Random": 7, "Scratch": 8
    }

    images = []
    labels = []
    metadata_rows = []
    skipped = 0

    iterator = data.itertuples() if isinstance(data, pd.DataFrame) else data
    total = len(data)

    for record in tqdm(iterator, total=total, desc="Processing"):
        try:
            # Extract fields robustly
            if isinstance(record, tuple) and hasattr(record, "waferMap"):
                wafer_map = record.waferMap
                failure_type = record.failureType
                lot_name = getattr(record, "lotName", "")
                wafer_index = getattr(record, "waferIndex", 0)
            elif isinstance(record, dict):
                wafer_map = record.get("waferMap")
                failure_type = record.get("failureType")
                lot_name = record.get("lotName", "")
                wafer_index = record.get("waferIndex", 0)
            else:
                wafer_map = getattr(record, "waferMap", None)
                failure_type = getattr(record, "failureType", None)
                lot_name = getattr(record, "lotName", "")
                wafer_index = getattr(record, "waferIndex", 0)

            # Validate
            if wafer_map is None or failure_type is None:
                skipped += 1
                continue
            
            wafer_map = np.array(wafer_map)
            if wafer_map.ndim != 2:
                skipped += 1
                continue

            # Normalize label
            if isinstance(failure_type, (np.ndarray, list)):
                if len(failure_type) == 0:
                    f_label = "none"
                else:
                    item = failure_type[0] if isinstance(failure_type, list) else failure_type.flat[0]
                    f_label = str(item[0]) if isinstance(item, list) and len(item) > 0 else str(item)
            elif isinstance(failure_type, str):
                f_label = failure_type
            else:
                skipped += 1
                continue
            
            f_label = f_label.strip()
            if f_label not in class_map:
                skipped += 1
                continue

            label_idx = class_map[f_label]
            images.append(wafer_map)
            labels.append(label_idx)
            metadata_rows.append({
                "lotName": lot_name,
                "waferIndex": wafer_index,
                "failureType": f_label,
                "mapped_label": label_idx
            })

        except Exception:
            skipped += 1
            continue

    print(f"Loaded {len(images)} valid samples. Skipped: {skipped}")
    return images, labels, pd.DataFrame(metadata_rows)

In [None]:
# Execute Pipeline if processed file missing or needs update
# (Always runs to ensure latest logic)
images, labels, metadata = load_wafer_dataset(RAW_DATA_PATH)

# Save Intermediate
save_path = f"{PROCESSED_PATH}/processed_dataset.pkl"
with open(save_path, 'wb') as f:
    pickle.dump({
        'images': images,
        'labels': labels,
        'metadata': metadata
    }, f)

print(f"Intermediate data saved to {save_path}")

### SECTION 1: Load Data
- Load from processed_dataset.pkl (saved in previous step)
- Verify data loaded correctly

In [None]:
load_path = f"{PROCESSED_PATH}/processed_dataset.pkl"
print(f"Loading from {load_path}...")

with open(load_path, 'rb') as f:
    data = pickle.load(f)

loaded_images = data['images']
loaded_labels = data['labels']
loaded_meta = data['metadata']

print("Data loaded successfully.")
print(f"Total Samples: {len(loaded_images)}")
print(f"Labels Count: {len(loaded_labels)}")
print(f"Metadata Shape: {loaded_meta.shape}")

### SECTION 2: Class Distribution
- Create bar chart showing samples per class
- Calculate and display imbalance ratio

In [None]:
label_counts = Counter(loaded_labels)
sorted_counts = dict(sorted(label_counts.items()))

class_names = {
    0: 'none', 1: 'Center', 2: 'Donut', 3: 'Edge-Loc', 
    4: 'Edge-Ring', 5: 'Loc', 6: 'Near-full', 7: 'Random', 8: 'Scratch'
}

x = [class_names[k] for k in sorted_counts.keys()]
y = list(sorted_counts.values())

plt.figure(figsize=(10, 5))
bars = plt.bar(x, y, color='skyblue', edgecolor='black')
plt.title('Class Distribution (WM-811K)')
plt.xlabel('Failure Type')
plt.ylabel('Count')
plt.xticks(rotation=45)
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add labels
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, height, f'{height:,}', 
             ha='center', va='bottom', fontsize=8)

plt.tight_layout()
save_fig_path = f"{DOCS_PATH}/class_distribution.png"
plt.savefig(save_fig_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"Figure saved: {save_fig_path}")

### SECTION 3: Sample Visualization (3Ã—3 grid)
- Show one example from each of the 9 defect classes

In [None]:
# Find one index per class
sample_indices = {}
for idx, label in enumerate(loaded_labels):
    if label not in sample_indices:
        sample_indices[label] = idx
    if len(sample_indices) == 9:
        break

fig, axes = plt.subplots(3, 3, figsize=(10, 10))
axes = axes.flatten()

for class_id in range(9):
    ax = axes[class_id]
    idx = sample_indices.get(class_id)
    if idx is not None:
        img = loaded_images[idx]
        ax.imshow(img, cmap='gray')
        ax.set_title(class_names[class_id])
    ax.axis('off')

plt.tight_layout()
save_sample_path = f"{DOCS_PATH}/sample_defects.png"
plt.savefig(save_sample_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"Figure saved: {save_sample_path}")

### SECTION 4: Class Imbalance Analysis
- Print imbalance ratio
- Note on class weighting

In [None]:
defect_counts = [v for k, v in sorted_counts.items() if k != 0]
if defect_counts:
    max_count = max(defect_counts)
    min_count = min(defect_counts)
    ratio = max_count / min_count
    print(f"Max Defect Count: {max_count}")
    print(f"Min Defect Count: {min_count}")
    print(f"Imbalance Ratio: {ratio:.2f}")
    print("NOTE: This requires class weighting during training to handle the imbalance.")
else:
    print("No defects found in loaded subset.")

### SECTION 5: Image Statistics
- Sample 1000 random images
- Count dimension occurrences

In [None]:
np.random.seed(42)
num_samples = min(1000, len(loaded_images))
indices = np.random.choice(len(loaded_images), num_samples, replace=False)

sizes = [(loaded_images[i].shape[0], loaded_images[i].shape[1]) for i in indices]
size_counts = Counter(sizes)

print(f"Analyzed {num_samples} random variants.")
print(f"Total unique sizes found: {len(size_counts)}")
print("Top 5 most common sizes (H, W):")
for size, count in size_counts.most_common(5):
    print(f"  {size}: {count}")

### SECTION 6: Stratified Split
- Create Train (70%) / Val (15%) / Test (15%) splits
- Save indices to processed_dataset.pkl

In [None]:
from sklearn.model_selection import StratifiedShuffleSplit

all_labels = np.array(loaded_labels)
all_indices = np.arange(len(all_labels))

# 1. Split Train (70%) vs Temp (30%)
# Using actual_01_data_exploration logic
sss1 = StratifiedShuffleSplit(n_splits=1, test_size=0.30, random_state=42)
train_idx, temp_idx = next(sss1.split(all_indices, all_labels))

# 2. Split Temp into Val (50% of 30% = 15%) and Test (50% of 30% = 15%)
sss2 = StratifiedShuffleSplit(n_splits=1, test_size=0.50, random_state=42)
val_sub_idx, test_sub_idx = next(sss2.split(temp_idx, all_labels[temp_idx]))

val_idx = temp_idx[val_sub_idx]
test_idx = temp_idx[test_sub_idx]

print(f"Train size: {len(train_idx)} ({len(train_idx)/len(all_labels):.1%})")
print(f"Val size:   {len(val_idx)} ({len(val_idx)/len(all_labels):.1%})")
print(f"Test size:  {len(test_idx)} ({len(test_idx)/len(all_labels):.1%})")

# Update Pickle
with open(load_path, 'rb') as f:
    data = pickle.load(f)

data['train_indices'] = train_idx
data['val_indices'] = val_idx
data['test_indices'] = test_idx

with open(load_path, 'wb') as f:
    pickle.dump(data, f)

print(f"Processed dataset updated with stratified splits at: {load_path}")