### Centroidal Distance and Angle Features

In [None]:
import os
import numpy as np
from svgpathtools import svg2paths, Path, Line, CubicBezier
from scipy.spatial.distance import euclidean
from scipy.integrate import quad
import matplotlib.pyplot as plt
import matplotlib.image as mpimg

import cairosvg
import io
from openpyxl import Workbook
from openpyxl.drawing.image import Image
from openpyxl.utils import get_column_letter

from PIL import Image as ImagePil
# from io import BytesIO


def join_svg_paths(svg_file):
    paths, _ = svg2paths(svg_file)
    # print("Paths of ",svg_file,": ",paths,  "\n OF LENGTH:", len(paths))
    combined_path = Path()
    for path in paths:
        combined_path.extend(path)
    return combined_path

def parse_svg(svg_file, num_samples=200):
    path = join_svg_paths(svg_file)
    # print("FInal JOined Path of ",svg_file,":\n OF LENGTH:", len(path))
    total_length = path.length()
    sample_distances = np.linspace(0, total_length, num_samples)
    sampled_points = []
    
    for distance in sample_distances:
        point = path.point(distance / total_length)
        sampled_points.append((point.real, point.imag))
    
    return np.array(sampled_points)

# Function to calculate centroid
def calculate_centroid(points):  # calculate centroid using library
    """Calculate the centroid of the points."""
    x_coords, y_coords = points[:, 0], points[:, 1]
    return np.mean(x_coords), np.mean(y_coords)

# Function to calculate features r and θ
def calculate_features(points, centroid):
    """Calculate r and θ features."""
    gx, gy = centroid
    r = np.sqrt((points[:, 0] - gx) ** 2 + (points[:, 1] - gy) ** 2)
    theta = []
    for i, (x, y) in enumerate(points):
        tangent = (points[(i + 1) % len(points)] - points[i - 1])  # Tangent vector
        dx, dy = x - gx, y - gy  # Vector to centroid
        angle = abs(np.arctan2(dy, dx) - np.arctan2(tangent[1], tangent[0]))
        theta.append(min(angle, np.pi / 2))
    return r / np.max(r), np.array(theta)   # r gets scaled

# Function to calculate similarity
def calculate_similarity(r1, theta1, r2, theta2):
    """Calculate similarity between two objects."""
    #distance = np.sqrt(np.sum((np.array(r1) - np.array(r2)) ** 2))
    #r_similarity = 1 / (1 + euclidean(r1, r2))  #this gives low results for euc similarity
    r_similarity = 1 - euclidean(r1, r2) / np.sqrt(len(r1))
    theta_similarity = np.dot(theta1, theta2) / (np.linalg.norm(theta1) * np.linalg.norm(theta2))

    # print("R_similarity: ",r_similarity)
    # print("Theta_similarity: ",theta_similarity)
    return r_similarity + theta_similarity

# Function to handle clockwise/anticlockwise issue
def ensure_consistent_direction(points):
    """Ensure consistent contour direction using Eq.(5)."""
    s = sum(
        points[i, 0] * points[(i + 1) % len(points), 1] - points[(i + 1) % len(points), 0] * points[i, 1]
        for i in range(len(points))
    ) # Shoelace formula shows clockwise or not, % len(points) is to wrap around when i is 200
    if s < 0:
        # print("points reversed due to s < 0")
        points = points[::-1]  # Reverse the order of points
    return points

# Function to calculate similarity with starting point alignment
def calculate_max_similarity(r1, theta1, r2, theta2):
    """Calculate maximum similarity by cycling through starting points."""
    n = len(r1)
    max_similarity = 0
    for shift in range(n):
        r2_shifted = np.roll(r2, shift)
        theta2_shifted = np.roll(theta2, shift)
        similarity = calculate_similarity(r1, theta1, r2_shifted, theta2_shifted)
        
        # print("Total Simialrity: ",similarity)
        max_similarity = max(max_similarity, similarity)
    return max_similarity

# Main Function
def svg_shape_similarity(query_svg, svg_store_path):
    """
    Calculate similarity between a query SVG and a store of SVGs.
    
    Parameters:
        query_svg (str): Path to the query SVG file.
        svg_store_path (str): Path to the directory containing SVG files.
    
    Returns:
        list: Ranked list of SVG files with their similarity scores.
    """
    # Parse the query SVG
    query_points = parse_svg(query_svg)
    query_points = ensure_consistent_direction(query_points)
    query_centroid = calculate_centroid(query_points)
    query_r, query_theta = calculate_features(query_points, query_centroid)

    # Iterate through the SVG store and calculate similarities
    similarity_scores = []
    for svg_file in os.listdir(svg_store_path):
        if svg_file.endswith('.svg'):
            svg_path = os.path.join(svg_store_path, svg_file)
            # print(svg_path)
            target_points = parse_svg(svg_path)
            target_points = ensure_consistent_direction(target_points)
            target_centroid = calculate_centroid(target_points)
            target_r, target_theta = calculate_features(target_points, target_centroid)
            similarity = calculate_max_similarity(query_r, query_theta, target_r, target_theta)
            similarity_scores.append((svg_file, similarity))

    # Rank SVGs by similarity
    similarity_scores.sort(key=lambda x: x[1], reverse=True)

    return similarity_scores




# Helper Functions for Visualization and Excel Sheet Creation

def convert_svg_to_png(svg_path):
    """
    Convert SVG to PNG for visualization.
    
    Parameters:
        svg_path (str): Path to the SVG file.
    
    Returns:
        PIL.Image: Converted PNG image.
    """
    try:
        png_data = cairosvg.svg2png(url=svg_path)
        return ImagePil.open(io.BytesIO(png_data))
    except Exception as e:
        print(f"Error converting {svg_path} to PNG: {e}")
        return None

def save_top_5_results(query_svg, results, svg_store_path, output_folder_param):
    """
    Display the top 10 retrieved SVG images along with their similarity scores.

    Parameters:
        query_svg (str): Path to the query SVG file.
        results (list): Ranked list of SVG files with their similarity scores.
        svg_store_path (str): Path to the directory containing SVG files.
    """
  

    top_5 = results[:5]
    
    # Create output folder using query name
    query_name = os.path.splitext(os.path.basename(query_svg))[0]
    output_folder = output_folder_param
    os.makedirs(output_folder, exist_ok=True)

    # Plot retrieved images
    fig, axes = plt.subplots(1, 5, figsize=(15, 5))
    #fig.suptitle(f"Top 5 Retrieved Logos for {os.path.basename(query_svg)}", fontsize=14)

    for i, (svg, score) in enumerate(top_5):
        svg_path = os.path.normpath(os.path.join(svg_store_path, svg))  # Ensure correct path format

        # Convert SVG to PNG before displaying
        img = convert_svg_to_png(svg_path)
        if img:
            axes[i].imshow(img)
            axes[i].set_title(f"{svg}\nScore: {score:.4f}")
        else:
            axes[i].text(0.5, 0.5, "Error Loading", fontsize=12, ha='center', va='center')

        axes[i].axis("off")

    # Hide unused axes if fewer than 5 results are returned
    for j in range(len(top_5), 5):
        axes[j].axis("off")


    plt.tight_layout()
    save_path = os.path.join(output_folder, f"{query_name}.png")
    plt.savefig(save_path,bbox_inches='tight')
    plt.close()
    print(f"Saved plot to {save_path}")

def save_results_to_excel(svg_directory, visualization_dir, excel_output):

    # Ensure output directory exists
    os.makedirs(visualization_dir, exist_ok=True)
    # Create a new workbook and sheet
    wb = Workbook()
    ws = wb.active
    ws.title = "Similarity Rankings"
    
    # Add headers to the Excel sheet
    ws.append(["Filename", "Image", "SVG XML Content", "Visualization Resulted Similarity"])

    # Row number to start inserting images and data
    row_number = 2  # Start at row 2 since row 1 is for headers

    # Iterate over all SVG files
    for svg_file in os.listdir(svg_directory):
        if svg_file.endswith(".svg"):
            svg_path = os.path.join(svg_directory, svg_file)

            # Read SVG content
            with open(svg_path, 'r') as file:
                svg_content = file.read()

            # Convert SVG to PNG
            png_data = cairosvg.svg2png(bytestring=svg_content.encode('utf-8'))
            png_image = io.BytesIO(png_data)
            img = Image(png_image)

            # Resize the image
            img.width = 80
            img.height = 80

            # Insert original SVG image
            image_cell = f"B{row_number}"
            ws.add_image(img, image_cell)
            ws.row_dimensions[row_number].height = 80
            ws.column_dimensions[get_column_letter(2)].width = 15

            # **Load Pre-Generated Visualization PNG**
            visualization_filename = svg_file.replace(".svg", ".png")  # Ensure correct naming convention
            visualization_path = os.path.join(visualization_dir, visualization_filename)

            if os.path.exists(visualization_path):
                visualization_img = Image(visualization_path)
                visualization_img.width = 350
                visualization_img.height = 80
                visualization_cell = f"D{row_number}"
                ws.add_image(visualization_img, visualization_cell)
                visualization_status = f"Visualization in {visualization_cell}"
            else:
                visualization_status = "No visualization found"

            # Add the row to Excel
            row = [svg_file, f"Image in {image_cell}", svg_content, visualization_status]
            ws.append(row)

            # Move to the next row
            row_number += 1

    # Save the Excel file
    wb.save(excel_output)

    print(f"Excel file saved to {excel_output}")





In [None]:

if __name__ == "__main__":  
    query_store_directory = "../dataset/Test_Dataset_simplified_2"
    svg_store_directory = "../dataset/train_Dataset_simplified_2"
    result_visualizations_directory = "../dataset/visualizations_test_dataset_simplified_2"
    output_excel_file = "similarity_rankings_test_dataset_simplified_svgs.xlsx"
    
    for svg_file in os.listdir(query_store_directory):
        if svg_file.endswith('.svg'):
            query_svg_path = os.path.join(query_store_directory, svg_file)
            ranked_similarities = svg_shape_similarity(query_svg_path, svg_store_directory)

            save_top_5_results(query_svg_path, ranked_similarities, svg_store_directory,result_visualizations_directory)

    save_results_to_excel(query_store_directory , result_visualizations_directory, output_excel_file)