# Prepare Corel Dataset for Stable Diffusion

This notebook prepares the Corel dataset for training Stable Diffusion models with LoRA.

**Tasks:**
1. Load and explore the Corel dataset
2. Generate metadata (captions.json) for each class
3. Organize data into appropriate folders
4. Prepare everything for training and generation notebooks

## Instructions for Google Colab

1. Upload your Corel dataset to Google Drive or Colab
2. Mount Google Drive if needed (see next cell)
3. Adjust paths in the Configuration section if necessary
4. Run all cells in order


In [None]:
# Install required dependencies
%pip install -q pillow matplotlib numpy
print("✓ Dependencies installed")


## 2. Mount Google Drive (Optional)

If your dataset is in Google Drive, uncomment and run this cell:


In [None]:
# Uncomment the following lines if you need to mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')
# print("Google Drive mounted")


In [None]:
## 1. Install Dependencies

Install required packages for dataset preparation:

## 2. Imports and Configuration


In [None]:
import os
import json
import glob
import re
from pathlib import Path
from collections import defaultdict
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import shutil
import torch

# Verify CUDA availability (for reference, this notebook doesn't use GPU)
if torch.cuda.is_available():
    print(f"✓ CUDA available: {torch.cuda.get_device_name(0)}")
    print("  (Note: This notebook processes files, GPU not required)")
else:
    print("ℹ Running on CPU (this notebook doesn't require GPU)")

# ===== CONFIGURATION =====

# Base directory (use '/content' for Colab, '.' for local)
BASE_DIR = Path('/content/drive/MyDrive/nsl_25.2/t2')  # Change to '.' if running locally

# Corel dataset directory
# If using Google Drive: '/content/drive/MyDrive/path/to/data/corel'
# If uploaded to Colab: '/content/data/corel'
COREL_DATA_DIR = BASE_DIR / 'data' / 'corel'

# Output directories
TRAINING_DATA_DIR = BASE_DIR / 'training_data'
COREL_TRAINING_DIR = TRAINING_DATA_DIR / 'corel'
FIGS_DIR = BASE_DIR / 'figs'
OUTPUT_DIR = BASE_DIR / 'outputs'

# ===== END CONFIGURATION =====

# Create directories if they don't exist
TRAINING_DATA_DIR.mkdir(parents=True, exist_ok=True)
COREL_TRAINING_DIR.mkdir(parents=True, exist_ok=True)
FIGS_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print("="*60)
print("DIRECTORIES CREATED/VERIFIED")
print("="*60)
print(f"Corel dataset: {COREL_DATA_DIR}")
print(f"Training data: {COREL_TRAINING_DIR}")
print(f"Figures: {FIGS_DIR}")
print(f"Outputs: {OUTPUT_DIR}")
print("="*60)


## 3. Explore Corel Dataset


In [None]:
# Find all PNG images
image_files = sorted(glob.glob(str(COREL_DATA_DIR / '*.png')))
print(f"Total images found: {len(image_files)}")

if len(image_files) == 0:
    print(f"ERROR: No PNG files found in {COREL_DATA_DIR}")
    print("Please check the path configuration in the previous cell.")
else:
    # Analyze class structure
    class_distribution = defaultdict(list)
    pattern = re.compile(r'^(\d+)_(\d+)\.png$')

    for img_path in image_files:
        filename = os.path.basename(img_path)
        match = pattern.match(filename)
        if match:
            class_num = int(match.group(1))
            example_num = match.group(2)
            class_distribution[class_num].append(filename)

    # Show class distribution
    print(f"\nClasses found: {len(class_distribution)}")
    print("\nDistribution by class:")
    for class_num in sorted(class_distribution.keys()):
        count = len(class_distribution[class_num])
        print(f"  Class {class_num:04d}: {count} images")


## 4. Create classes.txt File


In [None]:
# Check if classes.txt exists, otherwise create a generic one
classes_file = COREL_DATA_DIR / 'classes.txt'

if classes_file.exists():
    print(f"Found classes.txt at {classes_file}")
    with open(classes_file, 'r') as f:
        print("\nContent:")
        print(f.read())
else:
    print(f"WARNING: classes.txt not found. Creating a generic one...")

    # Create generic names based on class numbers
    class_names = {}
    for class_num in sorted(class_distribution.keys()):
        # Generic names - ADJUST according to your actual dataset
        generic_names = [
            "royalguard", "beach", "mountain", "flower",
            "building", "animal", "vehicle", "person"
        ]
        idx = (class_num - 1) % len(generic_names)
        class_names[class_num] = generic_names[idx]

    # Save classes.txt
    with open(classes_file, 'w') as f:
        for class_num in sorted(class_names.keys()):
            f.write(f"{class_num} {class_names[class_num]}\n")

    print(f"Created classes.txt at {classes_file}")


## 5. Generate captions.json for Each Class


In [None]:
def read_classes(classes_file):
    """Read classes.txt file and return a dictionary class -> name"""
    class_mapping = {}

    with open(classes_file, 'r') as f:
        for line in f:
            line = line.strip()
            if not line:
                continue

            parts = line.split(maxsplit=1)
            if len(parts) >= 2:
                class_num = int(parts[0])
                class_name = parts[1].strip()
                class_mapping[class_num] = class_name

    return class_mapping

def generate_captions_for_class(class_num, class_name, image_files, output_dir):
    """Generate captions.json for a specific class"""
    class_dir = output_dir / f"class_{class_num:04d}"
    class_dir.mkdir(parents=True, exist_ok=True)

    captions = {}
    pattern = re.compile(r'^(\d+)_(\d+)\.png$')

    for img_path in image_files:
        filename = os.path.basename(img_path)
        match = pattern.match(filename)

        if match and int(match.group(1)) == class_num:
            # Copy image to class folder
            dest_path = class_dir / filename
            if not dest_path.exists():
                shutil.copy2(img_path, dest_path)

            # Create descriptive caption
            caption = f"a photo of a {class_name}"
            captions[filename] = caption

    # Save captions.json
    captions_file = class_dir / 'captions.json'
    with open(captions_file, 'w') as f:
        json.dump(captions, f, indent=2)

    return len(captions), class_dir

# Read classes
class_mapping = read_classes(classes_file)
print(f"Loaded {len(class_mapping)} classes from classes.txt")

# Generate captions for each class
print("\nGenerating captions.json for each class...")
class_dirs = {}

for class_num in sorted(class_distribution.keys()):
    if class_num in class_mapping:
        class_name = class_mapping[class_num]
        count, class_dir = generate_captions_for_class(
            class_num, class_name, image_files, COREL_TRAINING_DIR
        )
        class_dirs[class_num] = class_dir
        print(f"  Class {class_num:04d} ({class_name}): {count} images -> {class_dir.name}")
    else:
        print(f"  WARNING: Class {class_num:04d} not found in classes.txt")


## 6. Generate Unified Dataset (All Classes Together)


In [None]:
# Option: Create a unified dataset with all classes
corel_all_dir = COREL_TRAINING_DIR / 'corel_all'
corel_all_dir.mkdir(parents=True, exist_ok=True)

all_captions = {}
pattern = re.compile(r'^(\d+)_(\d+)\.png$')

print("\nCreating unified dataset (all classes)...")
for img_path in image_files:
    filename = os.path.basename(img_path)
    match = pattern.match(filename)

    if match:
        class_num = int(match.group(1))

        if class_num in class_mapping:
            class_name = class_mapping[class_num]

            # Copy image
            dest_path = corel_all_dir / filename
            if not dest_path.exists():
                shutil.copy2(img_path, dest_path)

            # Create caption
            caption = f"a photo of a {class_name}"
            all_captions[filename] = caption

# Save unified captions.json
captions_file = corel_all_dir / 'captions.json'
with open(captions_file, 'w') as f:
    json.dump(all_captions, f, indent=2)

print(f"Unified dataset created: {corel_all_dir}")
print(f"  Total images: {len(all_captions)}")
print(f"  Captions saved to: {captions_file}")


## 7. Visualize Dataset Samples


In [None]:
# Visualize a sample from each class
n_samples_per_class = 3
n_classes = len(class_distribution)

fig, axes = plt.subplots(n_classes, n_samples_per_class, figsize=(15, 5*n_classes))
if n_classes == 1:
    axes = axes.reshape(1, -1)

for class_idx, class_num in enumerate(sorted(class_distribution.keys())[:n_classes]):
    class_images = class_distribution[class_num][:n_samples_per_class]
    class_name = class_mapping.get(class_num, f"class_{class_num}")

    for img_idx, img_filename in enumerate(class_images):
        img_path = COREL_DATA_DIR / img_filename

        if img_path.exists():
            img = Image.open(img_path).convert('RGB')
            axes[class_idx, img_idx].imshow(img)
            axes[class_idx, img_idx].set_title(f"{class_name}\n{img_filename}", fontsize=10)
            axes[class_idx, img_idx].axis('off')

plt.suptitle('Corel Dataset Samples by Class', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig(FIGS_DIR / 'corel_dataset_samples.png', dpi=150, bbox_inches='tight')
print(f"Visualization saved to: {FIGS_DIR / 'corel_dataset_samples.png'}")
plt.show()


## 8. Summary and Next Steps


In [None]:
print("="*60)
print("PREPARATION SUMMARY")
print("="*60)
print(f"\nCorel dataset processed:")
print(f"  - Total images: {len(image_files)}")
print(f"  - Classes found: {len(class_distribution)}")
print(f"\nStructure created:")
print(f"  - Unified dataset: {corel_all_dir}")
print(f"    -> {len(all_captions)} images with captions.json")
print(f"\n  - Per-class datasets:")
for class_num, class_dir in sorted(class_dirs.items()):
    class_name = class_mapping.get(class_num, 'unknown')
    img_count = len(list(class_dir.glob('*.png')))
    print(f"    -> class_{class_num:04d} ({class_name}): {img_count} images")

print(f"\n" + "="*60)
print("NEXT STEPS")
print("="*60)
print("\n1. To train LoRA with ALL classes:")
print(f"   Use notebook: 6-train-lora-corel.ipynb")
print(f"   Set train_data_dir = '{corel_all_dir}'")
print("\n2. To train LoRA per class (recommended if few samples per class):")
for class_num in sorted(class_dirs.keys())[:3]:  # Show only first 3
    class_dir = class_dirs[class_num]
    print(f"   Use notebook: 6-train-lora-corel.ipynb")
    print(f"   Set train_data_dir = '{class_dir}'")
if len(class_dirs) > 3:
    print(f"   ... and {len(class_dirs)-3} more classes")

print("\n3. To generate images with trained LoRA:")
print(f"   Use notebook: 7-generate-lora-corel.ipynb")
print("="*60)
