<a href="https://colab.research.google.com/github/dzthai/CS5787_SwiftEdit2_Fashion/blob/main/Milestone1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
# Development outline

# 1. Get SwiftEdit2 working (4.1)
# 2. Prepare / store data (David)
    # need to get (4.4) (image, text edit, output image) probably from deepFashion dataset
# 3. Train model (Sona)
# 4. Evaluate model

# Later: Domain specific edits and LoRA Adapters
# Writing outline
# 1.

In [19]:
# Imports

import os
import json
import requests
import zipfile
from pathlib import Path
from PIL import Image
import pandas as pd
from collections import Counter

In [22]:
class MagicBrushFashionLoader:
    def __init__(self, base_dir="./magicbrush_data"):
        """
        Initialize the MagicBrush loader

        Args:
            base_dir: Directory to store downloaded data
        """
        self.base_dir = Path(base_dir)
        self.base_dir.mkdir(exist_ok=True)

        # Fashion-related keywords for filtering
        self.fashion_keywords = [
            'dress', 'shirt', 'pants', 'jeans', 'coat', 'jacket',
            'skirt', 'blouse', 'sweater', 'hoodie', 'clothing',
            'fashion', 'wear', 'outfit', 'suit', 'tie', 'shoes',
            'boots', 'sneakers', 'hat', 'scarf', 'gloves', 'socks',
            'shorts', 'tank top', 't-shirt', 'blazer', 'cardigan',
            'uniform', 'robe', 'gown', 'vest', 'sleeve', 'collar',
            'button', 'zipper', 'pocket', 'fabric', 'textile',
            'cotton', 'silk', 'denim', 'leather', 'wool'
        ]

    def download_dataset(self):
        """
        Download MagicBrush dataset from HuggingFace
        Note: You may need to manually download if this doesn't work
        """
        print("Attempting to download MagicBrush dataset...")
        print("If this fails, please manually download from:")
        print("https://huggingface.co/datasets/osunlp/MagicBrush")

        # Using HuggingFace datasets library
        try:
            from datasets import load_dataset
            dataset = load_dataset("osunlp/MagicBrush")
            print("‚úì Dataset downloaded successfully!")
            return dataset
        except Exception as e:
            print(f"‚úó Download failed: {e}")
            print("Please install: pip install datasets")
            return None

    def load_from_huggingface(self):
        """
        Load MagicBrush directly from HuggingFace (recommended)
        """
        try:
            from datasets import load_dataset
            print("Loading MagicBrush from HuggingFace...")
            dataset = load_dataset("osunlp/MagicBrush")
            print(f"‚úì Loaded {len(dataset['train'])} training samples")
            return dataset
        except ImportError:
            print("Please install: pip install datasets")
            return None
        except Exception as e:
            print(f"Error loading dataset: {e}")
            return None

    def is_fashion_related(self, text):
        """
        Check if instruction text is fashion-related

        Args:
            text: Instruction text to check

        Returns:
            bool: True if fashion-related
        """
        text_lower = text.lower()
        return any(keyword in text_lower for keyword in self.fashion_keywords)

    def filter_fashion_samples(self, dataset):
        """
        Filter dataset for fashion-related samples

        Args:
            dataset: HuggingFace dataset object

        Returns:
            list: Filtered fashion samples with metadata
        """
        print("\nFiltering for fashion-related images...")
        fashion_samples = []

        # Process train split
        for idx, sample in enumerate(dataset['train']):
            instruction = sample['instruction']

            if self.is_fashion_related(instruction):
                fashion_samples.append({
                    'id': f"train_{idx}",
                    'source_img': sample['source_img'],
                    'target_img': sample['target_img'],
                    'instruction': instruction,
                    'turn_index': sample.get('turn_index', 0),
                    'split': 'train'
                })

        # Process test split if exists
        if 'test' in dataset:
            for idx, sample in enumerate(dataset['test']):
                instruction = sample['instruction']

                if self.is_fashion_related(instruction):
                    fashion_samples.append({
                        'id': f"test_{idx}",
                        'source_img': sample['source_img'],
                        'target_img': sample['target_img'],
                        'instruction': instruction,
                        'turn_index': sample.get('turn_index', 0),
                        'split': 'test'
                    })

        print(f"‚úì Found {len(fashion_samples)} fashion-related samples")
        return fashion_samples

    def analyze_fashion_data(self, fashion_samples):
        """
        Analyze the filtered fashion dataset

        Args:
            fashion_samples: List of fashion samples

        Returns:
            dict: Statistics about the dataset
        """
        print("\n" + "="*60)
        print("FASHION DATASET ANALYSIS")
        print("="*60)

        # Basic statistics
        total_samples = len(fashion_samples)
        print(f"\nüìä Total Fashion Samples: {total_samples}")

        # Split distribution
        splits = Counter([s['split'] for s in fashion_samples])
        print(f"\nüìÇ Split Distribution:")
        for split, count in splits.items():
            print(f"   {split}: {count} samples ({count/total_samples*100:.1f}%)")

        # Instruction analysis
        instructions = [s['instruction'] for s in fashion_samples]
        avg_length = sum(len(inst.split()) for inst in instructions) / len(instructions)
        print(f"\nüìù Instruction Statistics:")
        print(f"   Average length: {avg_length:.1f} words")

        # Keyword frequency
        keyword_counts = Counter()
        for sample in fashion_samples:
            text = sample['instruction'].lower()
            for keyword in self.fashion_keywords:
                if keyword in text:
                    keyword_counts[keyword] += 1

        print(f"\nüè∑Ô∏è  Top 10 Fashion Keywords:")
        for keyword, count in keyword_counts.most_common(10):
            print(f"   {keyword}: {count} occurrences")

        # Sample instructions
        print(f"\nüí¨ Sample Instructions:")
        for i, sample in enumerate(fashion_samples[:5]):
            print(f"   {i+1}. \"{sample['instruction']}\"")

        # Storage estimation
        print(f"\nüíæ Storage Estimation:")
        # Assuming ~500KB per image pair (conservative estimate)
        estimated_size_mb = (total_samples * 2 * 0.5)  # 2 images per sample
        print(f"   Estimated size: ~{estimated_size_mb:.1f} MB")
        print(f"   Estimated size: ~{estimated_size_mb/1024:.2f} GB")

        stats = {
            'total_samples': total_samples,
            'splits': dict(splits),
            'avg_instruction_length': avg_length,
            'top_keywords': keyword_counts.most_common(10),
            'estimated_size_mb': estimated_size_mb
        }

        return stats

    def save_fashion_metadata(self, fashion_samples, output_path="fashion_samples.json"):
        """
        Save filtered fashion samples metadata to JSON

        Args:
            fashion_samples: List of fashion samples
            output_path: Path to save JSON file
        """
        output_file = self.base_dir / output_path

        # Convert PIL images to paths/info (can't serialize images directly)
        serializable_samples = []
        for sample in fashion_samples:
            serializable_samples.append({
                'id': sample['id'],
                'instruction': sample['instruction'],
                'turn_index': sample['turn_index'],
                'split': sample['split'],
                'has_source_img': sample['source_img'] is not None,
                'has_target_img': sample['target_img'] is not None
            })

        with open(output_file, 'w') as f:
            json.dump(serializable_samples, f, indent=2)

        print(f"\n‚úì Saved metadata to {output_file}")
        return output_file

    def create_sample_visualization(self, fashion_samples, num_samples=3):
        """
        Create a visualization of sample fashion edits

        Args:
            fashion_samples: List of fashion samples
            num_samples: Number of samples to visualize
        """
        print(f"\nüñºÔ∏è  Creating visualization of {num_samples} samples...")

        try:
            import matplotlib.pyplot as plt

            fig, axes = plt.subplots(num_samples, 2, figsize=(10, 5*num_samples))
            if num_samples == 1:
                axes = axes.reshape(1, -1)

            for i in range(min(num_samples, len(fashion_samples))):
                sample = fashion_samples[i]

                # Source image
                axes[i, 0].imshow(sample['source_img'])
                axes[i, 0].set_title(f"Source Image {i+1}")
                axes[i, 0].axis('off')

                # Target image
                axes[i, 1].imshow(sample['target_img'])
                axes[i, 1].set_title(f"Target Image {i+1}\n\"{sample['instruction']}\"",
                                    fontsize=9)
                axes[i, 1].axis('off')

            plt.tight_layout()
            viz_path = self.base_dir / "sample_visualization.png"
            plt.savefig(viz_path, dpi=150, bbox_inches='tight')
            print(f"‚úì Saved visualization to {viz_path}")
            plt.show()

        except ImportError:
            print("‚ö†Ô∏è  matplotlib not installed, skipping visualization")
            print("   Install with: pip install matplotlib")

In [23]:
def main():
    """
    Main execution function
    """
    print("="*60)
    print("MAGICBRUSH FASHION DATASET LOADER")
    print("="*60)

    # Use Colab's temporary storage (no Drive needed!)
    base_dir = "/content/magicbrush_data"
    print(f"‚úì Using temporary storage: {base_dir}")
    print("‚ö†Ô∏è  Note: Data will be deleted when runtime disconnects")

    # Initialize loader
    loader = MagicBrushFashionLoader(base_dir=base_dir)

    # Load dataset
    dataset = loader.load_from_huggingface()

    if dataset is None:
        print("\n‚ùå Failed to load dataset. Please check your internet connection")
        print("   and ensure 'datasets' library is installed.")
        return

    # Filter for fashion samples
    fashion_samples = loader.filter_fashion_samples(dataset)

    if len(fashion_samples) == 0:
        print("\n‚ö†Ô∏è  No fashion samples found. Check your keywords.")
        return

    # Analyze the data
    stats = loader.analyze_fashion_data(fashion_samples)

    # Save metadata
    loader.save_fashion_metadata(fashion_samples)

    # Create visualization
    loader.create_sample_visualization(fashion_samples, num_samples=3)

    # Recommendations
    print("\n" + "="*60)
    print("RECOMMENDATIONS")
    print("="*60)

    if stats['estimated_size_mb'] < 500:  # Less than 500MB
        print("‚úì Dataset is small enough for Google Drive")
        print("  No need for AWS S3 for this project")
    else:
        print("‚ö†Ô∏è  Dataset is fairly large")
        print("  Consider AWS S3 if you have storage issues")

    print(f"\nüìà For your milestone, you have {stats['total_samples']} fashion samples")
    if stats['total_samples'] < 100:
        print("   ‚ö†Ô∏è  This might be too few. Consider:")
        print("   1. Expanding keywords")
        print("   2. Using full MagicBrush + augmenting with DeepFashion2")
    elif stats['total_samples'] < 500:
        print("   ‚úì Good for preliminary results")
        print("   ‚úì May want more data for final project")
    else:
        print("   ‚úì Excellent amount for training!")


if __name__ == "__main__":
    # Required installations
    print("Required packages:")
    print("  pip install datasets pillow pandas matplotlib")
    print("\n")

    main()

Required packages:
  pip install datasets pillow pandas matplotlib


MAGICBRUSH FASHION DATASET LOADER
‚úì Using temporary storage: /content/magicbrush_data
‚ö†Ô∏è  Note: Data will be deleted when runtime disconnects
Loading MagicBrush from HuggingFace...
Error loading dataset: [Errno 107] Transport endpoint is not connected: 'osunlp/MagicBrush/state.json'

‚ùå Failed to load dataset. Please check your internet connection
   and ensure 'datasets' library is installed.


In [None]:
# 2. Prepare / Store Data (David)

#test
class FashionEditDataset(Dataset):
    def __init__(self, magicbrush_path, deepfashion_path=None):
        # Load MagicBrush data
        self.data = self.load_magicbrush(magicbrush_path)

        # Optionally augment with DeepFashion2
        if deepfashion_path:
            self.data.extend(self.load_deepfashion(deepfashion_path))

    def load_magicbrush(self, path):
        # MagicBrush format is already (source, instruction, target)
        samples = []
        with open(os.path.join(path, 'annotations.json')) as f:
            data = json.load(f)
            for item in data:
                if self.is_fashion_related(item['instruction']):
                    samples.append({
                        'original': item['source_img'],
                        'edited': item['target_img'],
                        'prompt': item['instruction']
                    })
        return samples