### EMD for Shape Similarity

In [5]:
import numpy as np
from sklearn.decomposition import PCA
from scipy.spatial.distance import euclidean
from svgpathtools import svg2paths, Path
import matplotlib.pyplot as plt

import cairosvg
import io


import os
from openpyxl import Workbook
from openpyxl.drawing.image import Image
from openpyxl.utils import get_column_letter

from PIL import Image as ImagePil


# SVG to Paths

def join_svg_paths(svg_file):
    """
    Reads an SVG file and combines all paths into a single path.
    
    Parameters:
        svg_file (str): Path to the SVG file.
    
    Returns:
        Path: A single path object combining all paths in the SVG.
    """
    paths, _ = svg2paths(svg_file)
    combined_path = Path()
    
    for path in paths:
        combined_path.extend(path)
    
    return combined_path

def parse_svg(svg_file, num_samples=250):
    """
    Samples points uniformly along a combined SVG path.
    
    Parameters:
        path (Path): A combined path object.
        num_samples (int): Number of points to sample along the path.
    
    Returns:
        np.ndarray: A 2D array of sampled points of shape (num_samples, 2).
    """
    path = join_svg_paths(svg_file)
    total_length = path.length()
    sample_distances = np.linspace(0, total_length, num_samples)
    sampled_points = []
    
    for distance in sample_distances:
        point = path.point(distance / total_length)
        sampled_points.append((point.real, point.imag))
    
    return np.array(sampled_points)


# Normalization Functions
def center_shape(points):
    centroid = np.mean(points, axis=0)
    return points - centroid

def scale_to_unit_size(points):
    min_coords = np.min(points, axis=0)
    max_coords = np.max(points, axis=0)
    width, height = max_coords - min_coords
    scale_factor = 1 / max(width, height)
    return points * scale_factor

def align_orientation(points):
    pca = PCA(n_components=2)
    pca.fit(points)
    principal_axis = pca.components_[0]
    angle = np.arctan2(principal_axis[1], principal_axis[0])
    rotation_matrix = np.array([
        [np.cos(-angle), -np.sin(-angle)],
        [np.sin(-angle), np.cos(-angle)]
    ])
    return points @ rotation_matrix.T

def normalize_shape(points):
    points = center_shape(points)
    # points = scale_to_unit_size(points) # scale together - 
    points = align_orientation(points) # rotate first - check
    points = scale_to_unit_size(points)
    return points

def build_feature_store_normalized_points(svg_store_directory):
    feature_store = {}
    for file_name in os.listdir(svg_store_directory):
        if file_name.endswith(".svg"):
            file_path = os.path.join(svg_store_directory, file_name)
            points = parse_svg(file_path)  # Parse SVG to points
            normalized_points = normalize_shape(points)
            feature_store[file_name] = normalized_points  # Store normalized contour points
    return feature_store




In [6]:
from scipy.spatial import distance_matrix
from scipy.optimize import linear_sum_assignment

def compute_emd_similarity(shape1, shape2):
    """
    Compute the Earth Mover's Distance (EMD) similarity between two point sets.
    shape1, shape2: Numpy arrays of shape (N, 2) representing contour points.
    """
    # Compute pairwise Euclidean distance matrix between points in shape1 and shape2
    cost_matrix = distance_matrix(shape1, shape2)

    # Solve the optimal transport problem using the Hungarian algorithm
    row_ind, col_ind = linear_sum_assignment(cost_matrix)

    # Compute total transport cost
    emd_distance = cost_matrix[row_ind, col_ind].sum() / len(row_ind)

    # Convert to similarity score (lower distance means higher similarity)
    similarity = 1 / (1 + emd_distance)
    return similarity

def rank_by_similarity_emd(normalized_query, feature_store):
    """
    Rank stored shapes based on similarity to the query shape using EMD.
    """
    similarity_scores = {}
    for file_name, normalized_points in feature_store.items():
        score = compute_emd_similarity(normalized_query, normalized_points)
        similarity_scores[file_name] = score
    
    # Sort by similarity (higher score is more similar)
    ranked_files = sorted(similarity_scores.items(), key=lambda x: x[1], reverse=True)
    return ranked_files


In [None]:
# Creating visualizations and excel sheets

def convert_svg_to_png(svg_path):
    """
    Convert SVG to PNG for visualization.
    
    Parameters:
        svg_path (str): Path to the SVG file.
    
    Returns:
        PIL.Image: Converted PNG image.
    """
    try:
        png_data = cairosvg.svg2png(url=svg_path)
        return ImagePil.open(io.BytesIO(png_data))
    except Exception as e:
        print(f"Error converting {svg_path} to PNG: {e}")
        return None


def save_top_5_results(query_svg, results, svg_store_path, output_folder_param):
    top_5 = results[:5]
    query_name = os.path.splitext(os.path.basename(query_svg))[0]
    os.makedirs(output_folder_param, exist_ok=True)
    
    fig, axes = plt.subplots(1, 5, figsize=(15, 5))
    for i, (svg, score) in enumerate(top_5):
        svg_path = os.path.normpath(os.path.join(svg_store_path, svg))
        img = convert_svg_to_png(svg_path)
        if img:
            axes[i].imshow(img)
            axes[i].set_title(f"Score: {score:.4f}", fontsize=26.5)#(f"{svg}\nScore: {score:.4f}")
        else:
            axes[i].text(0.5, 0.5, "Error Loading", fontsize=12, ha='center', va='center')
        axes[i].axis("off")
    
    for j in range(len(top_5), 5):
        axes[j].axis("off")
    
    plt.tight_layout()
    save_path = os.path.join(output_folder_param, f"{query_name}.png")
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()
    print(f"Saved plot to {save_path}")

def save_results_to_excel(svg_directory, visualization_dir, excel_output):
    # Ensure output directory exists
    os.makedirs(visualization_dir, exist_ok=True)
    # Create a new workbook and sheet
    wb = Workbook()
    ws = wb.active
    ws.title = "Similarity Rankings"
    
    # Add headers to the Excel sheet
    ws.append(["Filename", "Image", "SVG XML Content", "Visualization Resulted Similarity"])

    # Row number to start inserting images and data
    row_number = 2  # Start at row 2 since row 1 is for headers

    # Iterate over all SVG files
    for svg_file in os.listdir(svg_directory):
        if svg_file.endswith(".svg"):
            svg_path = os.path.join(svg_directory, svg_file)

            # Read SVG content
            with open(svg_path, 'r') as file:
                svg_content = file.read()

            # Convert SVG to PNG
            png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
            png_image = io.BytesIO(png_data)
            img = Image(png_image)

            # Resize the image
            img.width = 80
            img.height = 80

            # Insert original SVG image
            image_cell = f"B{row_number}"
            ws.add_image(img, image_cell)
            ws.row_dimensions[row_number].height = 80
            ws.column_dimensions[get_column_letter(2)].width = 15

            # **Load Pre-Generated Visualization PNG**
            visualization_filename = svg_file.replace(".svg", ".png")  # Ensure correct naming convention
            visualization_path = os.path.join(visualization_dir, visualization_filename)

            if os.path.exists(visualization_path):
                visualization_img = Image(visualization_path)
                visualization_img.width = 350
                visualization_img.height = 80
                visualization_cell = f"D{row_number}"
                ws.add_image(visualization_img, visualization_cell)
                visualization_status = f"Visualization in {visualization_cell}"
            else:
                visualization_status = "No visualization found"

            # Add the row to Excel
            row = [svg_file, f"Image in {image_cell}", svg_content, visualization_status]
            ws.append(row)

            # Move to the next row
            row_number += 1

    # Save the Excel file
    wb.save(excel_output)

    print(f"Excel file saved to {excel_output}")

: 

In [None]:
if __name__ == "__main__":
    query_store_directory =  "../dataset/rotated_queries"#"../dataset/query_classified/correct_10"#"../dataset/query_Dataset_simplified"
    svg_store_directory = "../dataset/final-collection" #"../dataset/registered_Dataset_simplified"
    result_visualizations_directory = "../dataset/visualizations/emd_rotated_fc"
    output_excel_file = "../dataset/excel/emd_rotated_fc.xlsx"
    
    feature_store = build_feature_store_normalized_points(svg_store_directory)
    
    for svg_file in os.listdir(query_store_directory):
        if svg_file.endswith('.svg'):
            query_svg_path = os.path.join(query_store_directory, svg_file)
            query_points = parse_svg(query_svg_path)
            normalized_query = normalize_shape(query_points)
            ranked_similarities = rank_by_similarity_emd(normalized_query, feature_store)
            save_top_5_results(query_svg_path, ranked_similarities, svg_store_directory, result_visualizations_directory)
    
    save_results_to_excel(query_store_directory , result_visualizations_directory, output_excel_file)

Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_0.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_10.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_180.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_20.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_270.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_30.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_40.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_50.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_60.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\airbnb_scaled_rotated_90.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\apple_scaled_rotated_0.png
Saved plot to ../dataset/visualizations/emd_rotated_fc\apple_scaled_rotated_1