In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir("..")

In [3]:
from deepsvg.svglib.svg import SVG

from deepsvg import utils
from deepsvg.difflib.tensor import SVGTensor
from deepsvg.svglib.utils import to_gif
from deepsvg.svglib.geom import Bbox
from deepsvg.svgtensor_dataset import SVGTensorDataset, load_dataset
from deepsvg.utils.utils import batchify, linear

import torch
import numpy as np

# DeepSVG latent space operations

In [4]:
device = torch.device("cuda:0"if torch.cuda.is_available() else "cpu") 

Load the pretrained model and dataset

In [5]:
pretrained_path = "./pretrained/hierarchical_ordered.pth.tar"
from configs.deepsvg.hierarchical_ordered import Config

try:
    state_dict = torch.load(pretrained_path, map_location=torch.device('cpu'))  # Use 'cuda' if you have GPU
    print("File loaded successfully.")
    print(state_dict.keys())  # Prints the keys to verify the content
except Exception as e:
    print(f"Error loading the file: {e}")

cfg = Config()
model = cfg.make_model().to(device)
utils.load_model(pretrained_path, model)
model.eval();

File loaded successfully.
dict_keys(['model'])


In [6]:
dataset = load_dataset(cfg)

In [7]:
def load_svg(filename):
    svg = SVG.load_svg(filename)
    svg.canonicalize()
    svg.normalize()
    svg.zoom(0.9)
    svg = svg.simplify_heuristic()
    svg =svg.numericalize(256)
    return svg

In [8]:
def encode(data):
    model_args = batchify((data[key] for key in cfg.model_args), device)
    with torch.no_grad():
        z = model(*model_args, encode_mode=True)
        return z
    
def encode_svg(svg):
    data = dataset.get(svg=svg)
    return encode(data)


In [9]:
# Flask and CORS
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS

# Standard libraries
import os
import uuid
import base64
import io

# Image processing
from PIL import Image as ImagePil
from IPython.display import display
import cairosvg

# Math & Data
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from scipy.spatial import distance_matrix
from scipy.spatial.distance import cosine, euclidean, directed_hausdorff
from scipy.optimize import linear_sum_assignment
from scipy.spatial import procrustes

# SVG tools
from svgpathtools import svg2paths, Path


## DeepSVG- Lookup

Milvus setup

In [10]:
import os
from pymilvus import connections
from pymilvus import FieldSchema, DataType, CollectionSchema, Collection

# Connect to Zilliz Cloud
ENDPOINT = "https://in03-754f3454a65e40f.serverless.gcp-us-west1.cloud.zilliz.com"
TOKEN = "2b830a69fb087e580f904877ff816ff1477e67a38c091fc6b8c9c75d3992a458cc2deb681d3ae18dd91900855e1c538013080bf5"
connections.connect(uri=ENDPOINT, token=TOKEN)
# print("Connected to Zilliz Cloud!")

# Load existing collection
collection_name = "fyp_project"
collection = Collection(name=collection_name)

import torch  # Assuming PyTorch for tensor operations


Mongodb Setup

In [11]:
import pymongo
# MongoDB setup
try:
    mongo_client = pymongo.MongoClient("mongodb://localhost:27017/", serverSelectionTimeoutMS=3000)
    mongo_client.server_info()  # Force connection check
    mongo_db = mongo_client["logoDB"]
    mongo_collection = mongo_db["logos"]
    print("✅ MongoDB connection established successfully.")
except pymongo.errors.ServerSelectionTimeoutError as err:
    print(f"❌ Failed to connect to MongoDB: {err}")


✅ MongoDB connection established successfully.


Milvus schema setup

In [None]:
# from pymilvus import MilvusClient, DataType
# import numpy as np

# client = MilvusClient(uri="https://in03-754f3454a65e40f.serverless.gcp-us-west1.cloud.zilliz.com", token="2b830a69fb087e580f904877ff816ff1477e67a38c091fc6b8c9c75d3992a458cc2deb681d3ae18dd91900855e1c538013080bf5")

# schema = client.create_schema(enable_dynamic_field=True, description="")
# schema.add_field(field_name="Auto_id", datatype=DataType.INT64, description="The Primary Key", is_primary=True, auto_id=True)
# schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=256)

# index_params = client.prepare_index_params()
# index_params.add_index(field_name="vector", metric_type="COSINE", index_type="AUTOINDEX")


# client.create_collection(collection_name="fyp_project", schema=schema, index_params=index_params)


## Register Logo- Print the data passed from the front end

In [12]:
app = Flask(__name__)
CORS(app)

DATASET_DIR = "./dataset/Dataset_simplified"
PNG_OUTPUT_DIR = "./dataset/rendered_pngs"
TEMP_DIR = "./temp"
os.makedirs(DATASET_DIR, exist_ok=True)
os.makedirs(PNG_OUTPUT_DIR, exist_ok=True)
os.makedirs(TEMP_DIR, exist_ok=True)

# MongoDB setup
mongo_client = pymongo.MongoClient("mongodb://localhost:27017/")
mongo_db = mongo_client["logoDB"]
mongo_collection = mongo_db["logos"]

def load_and_encode(svg_path):
    try:
        svg = load_svg(svg_path)
        vector = encode_svg(svg)
        embedding_array = vector.flatten().numpy()
        if embedding_array.shape[0] != 256:
            raise ValueError(f"Embedding dimension is {embedding_array.shape[0]}, expected 256")
        return embedding_array.tolist()
    except Exception as e:
        print(f"Encoding failed for {svg_path}: {e}")
        return None

@app.route('/api/register-logo', methods=['POST'])
def register_logo():
    if 'logo' not in request.files:
        return jsonify({'error': 'No file uploaded'}), 400

    file = request.files['logo']
    if not file or file.filename == '' or not file.filename.endswith('.svg'):
        return jsonify({'error': 'Invalid or missing SVG file'}), 400

    # Unique ID and paths
    logo_id = str(uuid.uuid4())
    svg_filename = f"{logo_id}.svg"
    png_filename = f"{logo_id}.png"
    svg_path = os.path.join(DATASET_DIR, svg_filename)
    png_path = os.path.join(PNG_OUTPUT_DIR, png_filename)

    # Save SVG
    file.save(svg_path)

    # Encode SVG
    embedding = load_and_encode(svg_path)
    if embedding is None:
        return jsonify({'error': 'Failed to encode SVG'}), 500

    # Insert into Milvus
    mr = collection.insert([ [embedding] ])
    print(mr)
    milvus_id = mr.primary_keys[0]

    # Convert to PNG
    # cairosvg.svg2png(file_obj=open(svg_path, "rb"), write_to=png_path)

    # Read SVG file content
    with open(svg_path, 'r', encoding='utf-8') as svg_file:
        svg_content = svg_file.read()

    # Save metadata and SVG content in MongoDB
    mongo_record = {
        "logo_id": logo_id,
        "svg_content": svg_content,
        "milvus_id": milvus_id,
        "file_name": file.filename
    }

    mongo_collection.insert_one(mongo_record)

    return jsonify({
        "message": "Logo registered successfully",
        "logo_id": logo_id,
        "milvus_id": milvus_id
    })

# Serve static PNGs
@app.route('/static/<filename>')
def serve_png(filename):
    return send_from_directory(PNG_OUTPUT_DIR, filename)

app.run(port=5000, debug=True, use_reloader=False)


 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [21/Apr/2025 01:25:55] "POST /api/register-logo HTTP/1.1" 200 -


(insert count: 1, delete count: 0, upsert count: 0, timestamp: 457488192041385987, success count: 1, err count: 0)


127.0.0.1 - - [21/Apr/2025 01:26:10] "POST /api/lookup-logo HTTP/1.1" 404 -
127.0.0.1 - - [21/Apr/2025 01:26:19] "POST /api/lookup-logo HTTP/1.1" 404 -


In [12]:
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
import os, uuid
import cairosvg
import pymongo
from pymilvus import Collection, connections

app = Flask(__name__)
CORS(app)

PNG_OUTPUT_DIR = "./dataset/rendered_pngs"
TEMP_DIR = "./temp"
os.makedirs(PNG_OUTPUT_DIR, exist_ok=True)
os.makedirs(TEMP_DIR, exist_ok=True)


def load_and_encode(svg_path):
    try:
        svg = load_svg(svg_path)
        vector = encode_svg(svg)
        embedding_array = vector.flatten().numpy()
        if embedding_array.shape[0] != 256:
            raise ValueError(f"Embedding dimension is {embedding_array.shape[0]}, expected 256")
        return embedding_array.tolist()
    except Exception as e:
        print(f"Encoding failed for {svg_path}: {e}")
        return None

@app.route('/api/lookup-logo', methods=['POST'])
def lookup_logo():
    if 'logo' not in request.files:
        return jsonify({'error': 'No file uploaded'}), 400

    file = request.files['logo']
    print("delete1")
    if not file or file.filename == '' or not file.filename.endswith('.svg'):
        return jsonify({'error': 'Invalid file'}), 400

    # Save uploaded file temporarily
    temp_id = str(uuid.uuid4())
    temp_path = f"./temp/{temp_id}.svg"
    file.save(temp_path)

    target_vector = load_and_encode(temp_path)
    print("delete11")
    os.remove(temp_path)

    if target_vector is None:
        return jsonify({'error': 'SVG encoding failed'}), 500

    # Milvus search
    print("delete111")
    results = collection.search(
        data=[target_vector],
        anns_field="vector",
        param={"metric_type": "COSINE"},
        limit=5,
        output_fields=["milvus_id"]
    )

    print("Search results (raw):", results)

    matches = []
    for hit in results[0]:
        try:
            milvus_id = hit.id  # or hit.entity.get("milvus_id") depending on how it's stored
            mongo_doc = mongo_collection.find_one({"milvus_id": milvus_id})

            if not mongo_doc:
                print(f"⚠️ No MongoDB entry for milvus_id {milvus_id}")
                continue

            # Extract SVG content
            svg_content = mongo_doc.get("svg_content")
            if not svg_content:
                continue

            # Generate unique PNG filename
            png_name = f"{milvus_id}.png"
            png_path = os.path.join(PNG_OUTPUT_DIR, png_name)

            # Save converted PNG if not exists
            if not os.path.exists(png_path):
                with open(png_path, "wb") as f:
                    cairosvg.svg2png(bytestring=svg_content.encode('utf-8'), write_to=f)

            matches.append({
                "logoUrl": f"http://localhost:5000/static/{png_name}",
                "companyUrl": f"https://example.com/brand/{mongo_doc.get('logo_id')}",
                "score": round(1 - hit.distance, 4)
            })

        except Exception as e:
            print(f"Error processing result: {e}")

    return jsonify({'matches': matches})


@app.route('/static/<filename>')
def serve_png(filename):
    return send_from_directory(PNG_OUTPUT_DIR, filename)


app.run(port=5000, debug=True, use_reloader=False)


 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit


delete1
delete11
delete111


127.0.0.1 - - [21/Apr/2025 01:30:24] "POST /api/lookup-logo HTTP/1.1" 200 -


Search results (raw): ["['id: 456575839311535064, distance: 0.5963233709335327, entity: {}', 'id: 456575839304388171, distance: 0.5925104022026062, entity: {}', 'id: 456575839309966379, distance: 0.4008810222148895, entity: {}', 'id: 456575839311510024, distance: 0.3864792585372925, entity: {}', 'id: 456575839309966488, distance: 0.168972447514534, entity: {}']"]


127.0.0.1 - - [21/Apr/2025 01:30:25] "GET /static/456575839311535064.png HTTP/1.1" 200 -
127.0.0.1 - - [21/Apr/2025 01:30:25] "GET /static/456575839304388171.png HTTP/1.1" 304 -
127.0.0.1 - - [21/Apr/2025 01:30:25] "GET /static/456575839309966379.png HTTP/1.1" 304 -
127.0.0.1 - - [21/Apr/2025 01:30:25] "GET /static/456575839311510024.png HTTP/1.1" 304 -
127.0.0.1 - - [21/Apr/2025 01:30:25] "GET /static/456575839309966488.png HTTP/1.1" 304 -


In [None]:
# app = Flask(__name__)
# CORS(app) 

# DATASET_DIR = "./dataset/Registered_Dataset_simplified"
# PNG_OUTPUT_DIR = "./dataset/rendered_pngs"

# def load_and_encode(svg_path):
#     try:
#         svg = load_svg(svg_path)
#         vector = encode_svg(svg)
#         embedding_array = vector.flatten().numpy()  # Convert to NumPy array
#         if embedding_array.shape[0] != 256:
#             raise ValueError(f"Embedding dimension is {embedding_array.shape[0]}, expected 256")
#         return embedding_array.tolist()  # Convert to Python list
#     except Exception as e:
#         print(f"Encoding failed for {svg_path}: {e}")
#         return None

# @app.route('/api/lookup-logo', methods=['POST'])
# def lookup_logo():
#     if 'logo' not in request.files:
#         return jsonify({'error': 'No file uploaded'}), 400

#     file = request.files['logo']
#     if not file or file.filename == '':
#         return jsonify({'error': 'Invalid file'}), 400

#     if not file.filename.endswith('.svg'):
#         return jsonify({'error': 'Only SVG files are allowed'}), 400

#     # Save uploaded SVG temporarily
#     temp_id = str(uuid.uuid4())
#     temp_path = f"./temp/{temp_id}.svg"
#     os.makedirs("./temp", exist_ok=True)
#     file.save(temp_path)

#     target_vector = load_and_encode(temp_path)
#     if target_vector is None:
#         return jsonify({'error': 'SVG encoding failed'}), 500

#      # Search for visually similar images
#     results = collection.search(
#         data=[target_vector],
#         anns_field="vector",  # field name for the vector
#         param={"metric_type": "COSINE"},  # Specify the metric type
#         limit=5, 
#         output_fields=["image_path"],  # Return image paths
#     )
#     print("Search results (raw):", results)

    
#     matches = []
#     for hit in results[0]:  # results[0] is the list of hits for our one query
#         try:
#             image_path = hit.entity.get('image_path')
#             logo_filename = os.path.basename(image_path)
#             png_name = logo_filename.replace('.svg', '.png')
#             png_path = os.path.join(PNG_OUTPUT_DIR, png_name)

#             # Convert SVG to PNG if it doesn't exist
#             if not os.path.exists(png_path):
#                 os.makedirs(PNG_OUTPUT_DIR, exist_ok=True)
#                 cairosvg.svg2png(file_obj=open(image_path, "rb"), write_to=png_path)

#             match = {
#                 "logoUrl": f"http://localhost:5000/static/{png_name}",
#                 "companyUrl": f"https://example.com/brand/{logo_filename.replace('.svg', '')}",
#                 "score": round(1 - hit.distance, 4)  # convert cosine distance to similarity
#             }
#             matches.append(match)
#             print("Parsed match:", match)
#         except Exception as e:
#             print(f"Error processing hit: {hit}, error: {e}")


#     os.remove(temp_path)  # Clean up
#     print(matches)
#     return jsonify({'matches': matches})

# # Serve static images
# @app.route('/static/<filename>')
# def serve_png(filename):
#     return send_from_directory(PNG_OUTPUT_DIR, filename)

# app.run(port=5000, debug=True, use_reloader=False)