In [1]:
import os

os.environ["DISABLE_TQDM"] = "1"

In [2]:
import numpy as np
import uuid
import pandas as pd
import numpy as np
from diffusers import DiffusionPipeline
import torch
from PIL import Image
import cv2
import xml.etree.ElementTree as ET

  warn("The installed version of bitsandbytes was compiled without GPU support. "


'NoneType' object has no attribute 'cadam32bit_grad_fp32'


In [3]:
os.environ["ACCELERATE_DISABLE_PROGRESS_BAR"] = "1"

In [4]:
from diffusers.utils import logging
logging.set_verbosity_error()

SVG Validation:

In [5]:
# SVG validation function
MAX_SVG_SIZE = 10000
ALLOWED_TAGS = {
    "svg", "g", "defs", "symbol", "use", "marker", "pattern", "linearGradient", "radialGradient",
    "stop", "filter", "feBlend", "feColorMatrix", "feComposite", "feFlood", "feGaussianBlur",
    "feMerge", "feMergeNode", "feOffset", "feTurbulence", "path", "rect", "circle", "ellipse",
    "line", "polyline", "polygon"
}
DISALLOWED_ATTRS = ["data:", ";base64", "http://", "https://"]

def is_valid_svg(svg_code):
    try:
        if len(svg_code.encode('utf-8')) > MAX_SVG_SIZE:
            print('Exceeds max SVG string length. Current length:', len(svg_code.encode('utf-8')))
            return False
        root = ET.fromstring(svg_code)
        if root.tag.split('}')[-1] != "svg":
            print('SVG tag not specified.')
            return False
        for elem in root.iter():
            tag = elem.tag.split('}')[-1]
            if tag not in ALLOWED_TAGS:
                print('Tag not allowed.')
                return False
            for val in elem.attrib.values():
                if any(x in val.lower() for x in DISALLOWED_ATTRS):
                    print('Attribute not allowed.')
                    return False
        #print('This is a valid SVG string.')
        return True
    except Exception:
        return False

Compress SVG:

In [6]:
def compress_hex_color(hex_color):
    """Convert hex color to shortest possible representation"""
    r, g, b = int(hex_color[1:3], 16), int(hex_color[3:5], 16), int(hex_color[5:7], 16)
    if r % 17 == 0 and g % 17 == 0 and b % 17 == 0:
        return f'#{r//17:x}{g//17:x}{b//17:x}'
    return hex_color

def extract_features_by_scale(img_np, num_colors=16):
    """
    Extract image features hierarchically by scale

    Args:
        img_np (np.ndarray): Input image
        num_colors (int): Number of colors to quantize
    Returns:
        list: Hierarchical features sorted by importance
    """
    # Convert to RGB if needed
    if len(img_np.shape) == 3 and img_np.shape[2] > 1:
        img_rgb = img_np
    else:
        img_rgb = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)

    # Convert to grayscale for processing
    gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
    height, width = gray.shape

    # Perform color quantization
    pixels = img_rgb.reshape(-1, 3).astype(np.float32)
    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 100, 0.2)
    _, labels, centers = cv2.kmeans(pixels, num_colors, None, criteria, 10, cv2.KMEANS_RANDOM_CENTERS)
    # Quantized image
    palette = centers.astype(np.uint8)
    quantized = palette[labels.flatten()].reshape(img_rgb.shape)

    # Hierarchical feature extraction
    hierarchical_features = []

    # Sort colors by frequency
    unique_labels, counts = np.unique(labels, return_counts=True)
    sorted_indices = np.argsort(-counts)
    sorted_colors = [palette[i] for i in sorted_indices]

    # Center point for importance calculations
    center_x, center_y = width/2, height/2
    for color in sorted_colors:
        # Create color mask
        color_mask = cv2.inRange(quantized, color, color)

        # Find contours
        contours, _ = cv2.findContours(color_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        # Sort contours by area (largest first)
        contours = sorted(contours, key=cv2.contourArea, reverse=True)

        # Convert RGB to compressed hex
        hex_color = compress_hex_color(f'#{color[0]:02x}{color[1]:02x}{color[2]:02x}')
        color_features = []
        for contour in contours:
            # Skip tiny contours
            area = cv2.contourArea(contour)
            if area < 20:
                continue

            # Calculate contour center
            m = cv2.moments(contour)
            if m["m00"] == 0:
                continue

            cx = int(m["m10"] / m["m00"])
            cy = int(m["m01"] / m["m00"])

            # Distance from image center (normalized)
            dist_from_center = np.sqrt(((cx - center_x) / width)**2 + ((cy - center_y) / height)**2)

            # Simplify contour
            epsilon = 0.02 * cv2.arcLength(contour, True)
            approx = cv2.approxPolyDP(contour, epsilon, True)
            # Generate points string
            points = " ".join([f"{pt[0][0]:.1f},{pt[0][1]:.1f}" for pt in approx])

            # Calculate importance (area, proximity to center, complexity)
            importance = (
                area *
                (1 - dist_from_center) *
                (1 / (len(approx) + 1))
            )

            color_features.append({
                'points': points,
                'color': hex_color,
                'area': area,
                'importance': importance,
                'point_count': len(approx),
                'original_contour': approx  # Store original contour for adaptive simplification
            })
            # Sort features by importance within this color
        color_features.sort(key=lambda x: x['importance'], reverse=True)
        hierarchical_features.extend(color_features)

    # Final sorting by overall importance
    hierarchical_features.sort(key=lambda x: x['importance'], reverse=True)

    return hierarchical_features

def simplify_polygon(points_str, simplification_level):
    """
    Simplify a polygon by reducing coordinate precision or number of points

    Args:
        points_str (str): Space-separated "x,y" coordinates
        simplification_level (int): Level of simplification (0-3)

    Returns:
    str: Simplified points string
    """
    if simplification_level == 0:
        return points_str

    points = points_str.split()

    # Level 1: Round to 1 decimal place
    if simplification_level == 1:
        return " ".join([f"{float(p.split(',')[0]):.1f},{float(p.split(',')[1]):.1f}" for p in points])

    # Level 2: Round to integer
    if simplification_level == 2:
        return " ".join([f"{float(p.split(',')[0]):.0f},{float(p.split(',')[1]):.0f}" for p in points])

    # Level 3: Reduce number of points (keep every other point, but ensure at least 3 points)
    if simplification_level == 3:
        if len(points) <= 4:
            # If 4 or fewer points, just round to integer
            return " ".join([f"{float(p.split(',')[0]):.0f},{float(p.split(',')[1]):.0f}" for p in points])
        else:
            # Keep approximately half the points, but maintain at least 3
            step = min(2, len(points) // 3)
            reduced_points = [points[i] for i in range(0, len(points), step)]
            # Ensure we keep at least 3 points and the last point
            if len(reduced_points) < 3:
                reduced_points = points[:3]
            if points[-1] not in reduced_points:
                reduced_points.append(points[-1])
            return " ".join([f"{float(p.split(',')[0]):.0f},{float(p.split(',')[1]):.0f}" for p in reduced_points])

    return points_str
def bitmap_to_svg_layered(image, max_size_bytes=10000, resize=True, target_size=(384, 384),
                         adaptive_fill=True, num_colors=None):
    """
    Convert bitmap to SVG using layered feature extraction with optimized space usage

    Args:
        image: Input image (PIL.Image)
        max_size_bytes (int): Maximum SVG size
        resize (bool): Whether to resize the image before processing
        target_size (tuple): Target size for resizing (width, height)
        adaptive_fill (bool): Whether to adaptively fill available space
        num_colors (int): Number of colors to quantize, if None uses adaptive selection

    Returns:
        str: SVG representation
        """
    # Adaptive color selection based on image complexity
    if num_colors is None:
        # Simple heuristic: more colors for complex images
        if resize:
            pixel_count = target_size[0] * target_size[1]
        else:
            pixel_count = image.size[0] * image.size[1]

        if pixel_count < 65536:  # 256x256
            num_colors = 8
        elif pixel_count < 262144:  # 512x512
            num_colors = 12
        else:
            num_colors = 16
    # Resize the image if requested
    if resize:
        original_size = image.size
        image = image.resize(target_size, Image.LANCZOS)
    else:
        original_size = image.size

    # Convert to numpy array
    img_np = np.array(image)

    # Get image dimensions
    height, width = img_np.shape[:2]

    # Calculate average background color
    if len(img_np.shape) == 3 and img_np.shape[2] == 3:
        avg_bg_color = np.mean(img_np, axis=(0,1)).astype(int)
        bg_hex_color = compress_hex_color(f'#{avg_bg_color[0]:02x}{avg_bg_color[1]:02x}{avg_bg_color[2]:02x}')
    else:
        bg_hex_color = '#fff'
    # Start building SVG
    # Use original dimensions in viewBox for proper scaling when displayed
    orig_width, orig_height = original_size
    svg_header = f'<svg xmlns="http://www.w3.org/2000/svg" width="{orig_width}" height="{orig_height}" viewBox="0 0 {width} {height}">\n'
    svg_bg = f'<rect width="{width}" height="{height}" fill="{bg_hex_color}"/>\n'
    svg_base = svg_header + svg_bg
    svg_footer = '</svg>'

    # Calculate base size
    base_size = len((svg_base + svg_footer).encode('utf-8'))
    available_bytes = max_size_bytes - base_size

    # Extract hierarchical features
    features = extract_features_by_scale(img_np, num_colors=num_colors)
    # If not using adaptive fill, just add features until we hit the limit
    if not adaptive_fill:
        svg = svg_base
        for feature in features:
            # Try adding the feature
            feature_svg = f'<polygon points="{feature["points"]}" fill="{feature["color"]}" />\n'

            # Check if adding this feature exceeds size limit
            if len((svg + feature_svg + svg_footer).encode('utf-8')) > max_size_bytes:
                break

            # Add the feature
            svg += feature_svg

        # Close SVG
        svg += svg_footer
        return svg
    # For adaptive fill, use binary search to find optimal simplification level

    # First attempt: calculate size of all features at different simplification levels
    feature_sizes = []
    for feature in features:
        feature_sizes.append({
            'original': len(f'<polygon points="{feature["points"]}" fill="{feature["color"]}" />\n'.encode('utf-8')),
            'level1': len(f'<polygon points="{simplify_polygon(feature["points"], 1)}" fill="{feature["color"]}" />\n'.encode('utf-8')),
            'level2': len(f'<polygon points="{simplify_polygon(feature["points"], 2)}" fill="{feature["color"]}" />\n'.encode('utf-8')),
            'level3': len(f'<polygon points="{simplify_polygon(feature["points"], 3)}" fill="{feature["color"]}" />\n'.encode('utf-8'))
        })
    # Two-pass approach: first add most important features, then fill remaining space
    svg = svg_base
    bytes_used = base_size
    added_features = set()

    # Pass 1: Add most important features at original quality
    for i, feature in enumerate(features):
        feature_svg = f'<polygon points="{feature["points"]}" fill="{feature["color"]}" />\n'
        feature_size = feature_sizes[i]['original']

        if bytes_used + feature_size <= max_size_bytes:
            svg += feature_svg
            bytes_used += feature_size
            added_features.add(i)
    # Pass 2: Try to add remaining features with progressive simplification
    for level in range(1, 4):  # Try simplification levels 1-3
        for i, feature in enumerate(features):
            if i in added_features:
                continue

            feature_size = feature_sizes[i][f'level{level}']
            if bytes_used + feature_size <= max_size_bytes:
                feature_svg = f'<polygon points="{simplify_polygon(feature["points"], level)}" fill="{feature["color"]}" />\n'
                svg += feature_svg
                bytes_used += feature_size
                added_features.add(i)

    # Finalize SVG
    svg += svg_footer
    # Double check we didn't exceed limit
    final_size = len(svg.encode('utf-8'))
    if final_size > max_size_bytes:
        # If we somehow went over, return basic SVG
        return f'<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {width} {height}"><rect width="{width}" height="{height}" fill="{bg_hex_color}"/></svg>'

    # Calculate space utilization
    utilization = (final_size / max_size_bytes) * 100

    # Return the SVG with efficient space utilization
    return svg

Generate and compress SVG:

In [7]:
aug_train = pd.read_csv("/Users/jianglimeng/Desktop/projects/augmented_train.csv")

In [8]:
pipe = DiffusionPipeline.from_pretrained(
    "stabilityai/sdxl-turbo",
    torch_dtype=torch.float16,   # Leave this as float16
    variant="fp16",
    disable_progress_bar=True
)

# Use GPU if available, otherwise fallback to CPU
device = "mps" if torch.backends.mps.is_available() else "cpu"
pipe.to(device)

pipe.set_progress_bar_config(disable=True)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [10]:
for idx, row in aug_train.iloc[4026:6001].iterrows():
    prompt = row["prompt"]
    styled_prompt = f"{prompt}, minimal geometric illustration, flat design, simple color, no details, 2D vector style"

    image = pipe(prompt=styled_prompt, height=256, width=256).images[0]
    svg_content = bitmap_to_svg_layered(image)
    if is_valid_svg(svg_content):
        with open(f"/Users/jianglimeng/Desktop/projects/outputs/svg6000/{idx}.svg", "w") as f:
            f.write(svg_content)

    if idx % 200 == 0:
        print(f"This is the {idx}th compressed image.")

This is the 4200th compressed image.
This is the 4400th compressed image.
This is the 4600th compressed image.
This is the 4800th compressed image.
This is the 5000th compressed image.
This is the 5200th compressed image.
This is the 5400th compressed image.
This is the 5600th compressed image.
This is the 5800th compressed image.
This is the 6000th compressed image.


Read the svg strings and save them as a csv file:

In [19]:
svg_df = pd.DataFrame(data={'id': [], 'svg': []})

In [20]:
directory_path = "/Users/jianglimeng/Desktop/projects/outputs/svg10000"
for root, dirs, files in os.walk(directory_path):
    for fname in sorted(files):
        # if current file is not an svg, skip
        if not fname.endswith('.svg'):
            continue
        full_file_path = os.path.join(root, fname)
        with open(full_file_path, "r", encoding="utf-8") as f:
            svg_string = f.read()
        svg_df.loc[len(svg_df)] = [fname[:-4], svg_string]

In [21]:
print(len(svg_df))

1999


In [22]:
svg_df.to_csv("/Users/jianglimeng/Desktop/projects/outputs/svg10000.csv")