# Greek Motif Dataset Exploration

This notebook explores the processed Greek traditional motif dataset:
- Dataset statistics and distribution
- Visual samples by region
- Geometric feature analysis
- Embedding visualization (if available)
- Quality checks

In [None]:
import sys
from pathlib import Path

# Add project root to path
project_root = Path().resolve().parent
sys.path.insert(0, str(project_root))

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import torch
from torchvision import transforms

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

print("✓ Imports successful")

## 1. Load Dataset Metadata

In [None]:
# Load metadata
metadata_path = project_root / "data" / "processed" / "metadata.csv"

if metadata_path.exists():
    df = pd.read_csv(metadata_path)
    print(f"✓ Loaded {len(df)} samples")
    print(f"\nColumns: {list(df.columns)}")
    df.head()
else:
    print(f"❌ Metadata not found at {metadata_path}")
    print("Please run Phase 1 preprocessing first.")

## 2. Dataset Statistics

In [None]:
# Basic statistics
print("Dataset Overview:")
print("=" * 60)
print(f"Total images: {len(df)}")
print(f"Number of regions: {df['region'].nunique()}")
print(f"Regions: {sorted(df['region'].unique())}")
print("\nSample distribution by region:")
print(df['region'].value_counts())

In [None]:
# Visualize region distribution
fig, ax = plt.subplots(1, 1, figsize=(12, 6))

region_counts = df['region'].value_counts()
region_counts.plot(kind='bar', ax=ax, color='steelblue', edgecolor='black')

ax.set_title('Sample Distribution by Region', fontsize=16, fontweight='bold')
ax.set_xlabel('Region', fontsize=12)
ax.set_ylabel('Number of Samples', fontsize=12)
ax.grid(axis='y', alpha=0.3)

# Add count labels on top of bars
for i, v in enumerate(region_counts):
    ax.text(i, v + 2, str(v), ha='center', va='bottom', fontweight='bold')

plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

## 3. Geometric Feature Analysis

In [None]:
# Check if geometric features exist
geometric_cols = ['vertical_symmetry', 'horizontal_symmetry', 'edge_density']
has_geometric = all(col in df.columns for col in geometric_cols)

if has_geometric:
    print("Geometric Feature Statistics:")
    print("=" * 60)
    print(df[geometric_cols].describe())
else:
    print("❌ Geometric features not found in metadata")

In [None]:
if has_geometric:
    # Plot geometric feature distributions
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # Vertical symmetry
    axes[0].hist(df['vertical_symmetry'], bins=30, color='skyblue', edgecolor='black')
    axes[0].set_title('Vertical Symmetry Distribution')
    axes[0].set_xlabel('Symmetry Score')
    axes[0].set_ylabel('Frequency')
    axes[0].grid(axis='y', alpha=0.3)

    # Horizontal symmetry
    axes[1].hist(df['horizontal_symmetry'], bins=30, color='lightcoral', edgecolor='black')
    axes[1].set_title('Horizontal Symmetry Distribution')
    axes[1].set_xlabel('Symmetry Score')
    axes[1].set_ylabel('Frequency')
    axes[1].grid(axis='y', alpha=0.3)

    # Edge density
    axes[2].hist(df['edge_density'], bins=30, color='lightgreen', edgecolor='black')
    axes[2].set_title('Edge Density Distribution')
    axes[2].set_xlabel('Edge Density')
    axes[2].set_ylabel('Frequency')
    axes[2].grid(axis='y', alpha=0.3)

    plt.tight_layout()
    plt.show()

## 4. Visual Samples by Region

In [None]:
def show_samples_by_region(region_name, n_samples=6):
    """
    Display sample images from a specific region.
    """
    region_df = df[df['region'] == region_name]

    if len(region_df) == 0:
        print(f"No samples found for region: {region_name}")
        return

    n_samples = min(n_samples, len(region_df))
    samples = region_df.sample(n=n_samples, random_state=42)

    fig, axes = plt.subplots(2, 3, figsize=(12, 8))
    axes = axes.flatten()

    for idx, (_, row) in enumerate(samples.iterrows()):
        if idx >= len(axes):
            break

        img_path = project_root / row['image_path']
        if img_path.exists():
            img = Image.open(img_path)
            axes[idx].imshow(img)
            axes[idx].axis('off')

            # Add filename as title
            axes[idx].set_title(row['filename'], fontsize=8)
        else:
            axes[idx].text(0.5, 0.5, 'Image not found', ha='center', va='center')
            axes[idx].axis('off')

    # Hide unused subplots
    for idx in range(n_samples, len(axes)):
        axes[idx].axis('off')

    fig.suptitle(f'Sample Motifs from {region_name} (n={len(region_df)} total)',
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Show samples from each region
regions = sorted(df['region'].unique())
print(f"Showing samples from {len(regions)} regions...\n")

In [None]:
# Display samples for each region (choose regions to explore)
for region in regions[:3]:  # Show first 3 regions
    show_samples_by_region(region)

## 5. Color Analysis

In [None]:
def analyze_color_distribution(n_samples=100):
    """
    Analyze color distribution across the dataset.
    """
    samples = df.sample(n=min(n_samples, len(df)), random_state=42)

    r_values, g_values, b_values = [], [], []

    for _, row in samples.iterrows():
        img_path = project_root / row['image_path']
        if img_path.exists():
            img = Image.open(img_path).convert('RGB')
            img_array = np.array(img)

            # Average color per channel
            r_values.append(img_array[:, :, 0].mean())
            g_values.append(img_array[:, :, 1].mean())
            b_values.append(img_array[:, :, 2].mean())

    # Plot color distributions
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))

    ax.hist(r_values, bins=30, alpha=0.5, label='Red', color='red', edgecolor='black')
    ax.hist(g_values, bins=30, alpha=0.5, label='Green', color='green', edgecolor='black')
    ax.hist(b_values, bins=30, alpha=0.5, label='Blue', color='blue', edgecolor='black')

    ax.set_title(f'RGB Channel Distribution (n={len(samples)} samples)', fontsize=14, fontweight='bold')
    ax.set_xlabel('Average Pixel Value', fontsize=12)
    ax.set_ylabel('Frequency', fontsize=12)
    ax.legend()
    ax.grid(axis='y', alpha=0.3)

    plt.tight_layout()
    plt.show()

    print(f"Color Statistics (n={len(samples)} samples):")
    print("=" * 60)
    print(f"Red   - Mean: {np.mean(r_values):.2f}, Std: {np.std(r_values):.2f}")
    print(f"Green - Mean: {np.mean(g_values):.2f}, Std: {np.std(g_values):.2f}")
    print(f"Blue  - Mean: {np.mean(b_values):.2f}, Std: {np.std(b_values):.2f}")

analyze_color_distribution()

## 6. Embedding Exploration (if available)

In [None]:
# Check for embeddings
embeddings_path = project_root / "data" / "embeddings" / "embeddings.npz"

if embeddings_path.exists():
    print("✓ Loading embeddings...")
    embeddings = np.load(embeddings_path)
    print(f"\nAvailable embeddings: {list(embeddings.keys())}")

    for key in embeddings.keys():
        print(f"  {key}: shape = {embeddings[key].shape}")

    HAS_EMBEDDINGS = True
else:
    print("❌ Embeddings not found. Run Phase 2 to create embeddings.")
    HAS_EMBEDDINGS = False

In [None]:
if HAS_EMBEDDINGS:
    from sklearn.decomposition import PCA
    from sklearn.manifold import TSNE

    # Use combined embeddings for visualization
    combined_emb = embeddings['combined']
    print(f"Combined embedding shape: {combined_emb.shape}")

    # Load embedding metadata to get region labels
    emb_metadata_path = project_root / "data" / "embeddings" / "embeddings_metadata.csv"
    if emb_metadata_path.exists():
        emb_df = pd.read_csv(emb_metadata_path)

        # PCA visualization
        print("\nRunning PCA...")
        pca = PCA(n_components=2)
        pca_result = pca.fit_transform(combined_emb)

        fig, ax = plt.subplots(1, 1, figsize=(12, 8))

        # Plot by region
        for region in emb_df['region'].unique():
            mask = emb_df['region'] == region
            ax.scatter(pca_result[mask, 0], pca_result[mask, 1],
                      label=region, alpha=0.6, s=50)

        ax.set_title('PCA Projection of Combined Embeddings', fontsize=14, fontweight='bold')
        ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.2%} variance)', fontsize=12)
        ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.2%} variance)', fontsize=12)
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        ax.grid(alpha=0.3)

        plt.tight_layout()
        plt.show()

        print(f"\nPCA Explained Variance:")
        print(f"  PC1: {pca.explained_variance_ratio_[0]:.2%}")
        print(f"  PC2: {pca.explained_variance_ratio_[1]:.2%}")
        print(f"  Total: {pca.explained_variance_ratio_[:2].sum():.2%}")

## 7. Data Quality Checks

In [None]:
print("Data Quality Checks:")
print("=" * 60)

# Check for missing values
print("\n1. Missing Values:")
missing = df.isnull().sum()
if missing.sum() == 0:
    print("  ✓ No missing values found")
else:
    print("  ❌ Missing values detected:")
    print(missing[missing > 0])

# Check for duplicate filenames
print("\n2. Duplicate Filenames:")
duplicates = df['filename'].duplicated().sum()
if duplicates == 0:
    print("  ✓ No duplicate filenames")
else:
    print(f"  ❌ {duplicates} duplicate filenames found")

# Check image file existence
print("\n3. Image File Existence:")
missing_files = 0
for _, row in df.iterrows():
    img_path = project_root / row['image_path']
    if not img_path.exists():
        missing_files += 1

if missing_files == 0:
    print(f"  ✓ All {len(df)} image files found")
else:
    print(f"  ❌ {missing_files} image files not found")

# Check image dimensions
print("\n4. Image Dimensions (sampling 20 images):")
sample_dims = []
for _, row in df.sample(n=min(20, len(df)), random_state=42).iterrows():
    img_path = project_root / row['image_path']
    if img_path.exists():
        img = Image.open(img_path)
        sample_dims.append(img.size)

if len(sample_dims) > 0:
    unique_dims = set(sample_dims)
    print(f"  Found {len(unique_dims)} unique dimensions:")
    for dim in unique_dims:
        count = sample_dims.count(dim)
        print(f"    {dim[0]}x{dim[1]}: {count} images")

print("\n" + "=" * 60)
print("✓ Data quality check complete!")

## 8. Summary & Recommendations

In [None]:
print("Dataset Summary:")
print("=" * 60)
print(f"Total samples: {len(df)}")
print(f"Number of regions: {df['region'].nunique()}")
print(f"\nRegion with most samples: {df['region'].value_counts().index[0]} ({df['region'].value_counts().iloc[0]} samples)")
print(f"Region with least samples: {df['region'].value_counts().index[-1]} ({df['region'].value_counts().iloc[-1]} samples)")

# Check if dataset is balanced
max_count = df['region'].value_counts().iloc[0]
min_count = df['region'].value_counts().iloc[-1]
imbalance_ratio = max_count / min_count

print(f"\nClass imbalance ratio: {imbalance_ratio:.2f}x")
if imbalance_ratio > 3:
    print("  ⚠️  High class imbalance detected. Consider:")
    print("     - Using weighted sampling during training")
    print("     - Data augmentation for under-represented regions")
else:
    print("  ✓ Dataset is reasonably balanced")

print("\n" + "=" * 60)
print("Ready to proceed with training!")
print("Next steps:")
print("  1. Run Phase 2 (symbolic analysis) if not done yet")
print("  2. Check training configuration in configs/stylegan3_greek_simple.yaml")
print("  3. Start training with: python scripts/train_gan.py")