In [None]:
import pandas as pd
import numpy as np
import pyarrow.parquet as pq
import os
import matplotlib.pyplot as plt
from PIL import Image

def load_parquet_data(sample_size=None):
    """
    Load the merged parquet file and return the data
    
    Parameters:
    sample_size (int, optional): If provided, only load this many samples randomly
    
    Returns:
    Tuple of (images, artists, styles, titles)
    """
    current_path = '/home/work/workspace_ai/Artificlass/data_process'
    parquet_file = os.path.join(current_path, 'data', 'top6_styles_merged.parquet')
    
    print(f"Loading data from {parquet_file}")
    
    # Read the parquet file
    table = pq.read_table(parquet_file)
    df = table.to_pandas()
    
    # If sample_size is specified, take a random sample
    if sample_size is not None and sample_size < len(df):
        df = df.sample(sample_size, random_state=42)
        print(f"Sampled {sample_size} examples from dataset")
    
    # Extract data
    images = np.array([np.array(img) for img in df['image']])
    artists = np.array(df['artist'])
    styles = np.array(df['style'])
    titles = np.array(df['title'])
    
    print(f"Loaded {len(images)} images")
    print(f"Unique styles: {np.unique(styles)}")
    print(f"Images shape: {images.shape}")
    
    return images, artists, styles, titles

def display_sample_images(images, styles, num_samples=5):
    """Display some sample images with their style labels"""
    indices = np.random.choice(len(images), num_samples, replace=False)
    
    plt.figure(figsize=(15, 3*num_samples))
    for i, idx in enumerate(indices):
        img = images[idx].transpose(1, 2, 0)  # Convert from (C,H,W) to (H,W,C)
        plt.subplot(num_samples, 1, i+1)
        plt.imshow(img)
        plt.title(f"Style: {styles[idx]}")
        plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
# Load a small sample for demonstration
    images, artists, styles, titles = load_parquet_data(sample_size=20)
    
    # Display some sample images
    display_sample_images(images, styles, num_samples=5)
    
    # Print unique styles
    style_counts = {style: np.sum(styles == style) for style in np.unique(styles)}
    print("\nStyle distribution in sample:")
    for style, count in style_counts.items():
        print(f"{style}: {count} images")