This notebook is based on the following notebook.  
https://www.kaggle.com/code/takusid/score-0-82-starter-notebook-use-flux1

In [1]:
#| default_exp core

In [2]:
#| export
import sys
import traceback
import torch
import numpy as np
from diffusers import FluxPipeline
import kagglehub
kcomfyui_paht = kagglehub.dataset_download('takusid/kcomfyui')
kcomfyui_paht = kcomfyui_paht + "/KComfyUI"
sys.path.append(kcomfyui_paht)
from lxml import etree
from tqdm import tqdm

from custom_nodes.ToSVG.svgnode import ConvertRasterToVectorColor

class Model:
    def __init__(self):
        flux_path = kagglehub.model_download('jwadow/flux.1-dev/other/default/5')
        self.pipe = FluxPipeline.from_pretrained(
            flux_path,
            torch_dtype=torch.bfloat16,
            device_map="balanced",
            max_memory={0:"15GiB", 1:"15GiB", "cpu":"20GiB"}
        )

        self.default_svg = '<svg width="100" height="100" viewBox="0 0 100 100"><circle cx="50" cy="50" r="40" fill="red" /></svg>'

        svg_constraints = kagglehub.package_import('metric/svg-constraints')
        self.constraints = svg_constraints.SVGConstraints()
      
    def predict(self, prompt: str) -> str:
        prompt += ". Icon Style. Logo"
        # prompt += (
        #     "Simple flat vector illustration. Minimalist icon style. "
        #     "Limited color palette. Clean lines and shapes. "
        #     "No complex details. No shading or gradients. "
        #     "High contrast. Recognizable silhouette. "
        #     "Suitable for SVG rendering."
        #  )
        try:
            image = self.pipe(
                prompt,
                height=152,
                width=152,
                guidance_scale=3.5,
                num_inference_steps=7,
                max_sequence_length=256,
                generator=torch.Generator("cpu").manual_seed(0)
            )

            svg_conv = ConvertRasterToVectorColor()
            svg_output = svg_conv.convert_to_svg(
                image=torch.from_numpy(np.array(image.images[0]) / 255).unsqueeze(0),
                hierarchical="cutout",
                mode="polygon",
                filter_speckle=4,  # More aggressive speckle removal
                color_precision=8,  # Fewer colors
                layer_difference=64,  # Larger difference between layers
                corner_threshold=60,  # More aggressive corner simplification
                length_threshold=4,  # Remove smaller segments
                max_iterations=3,  # More refinement
                splice_threshold=180,  # More aggressive splicing
                path_precision=2,  # Lower precision for simpler paths
            )

            # svg_output = svg_conv.convert_to_svg(
            #     image=torch.from_numpy(np.array(image.images[0]) / 255).unsqueeze(0),
            #     hierarchical="cutout",
            #     mode="polygon",
            #     filter_speckle=8,
            #     color_precision=16,
            #     layer_difference=128,
            #     corner_threshold=90,
            #     length_threshold=8,
            #     max_iterations=20,
            #     splice_threshold=120,
            #     path_precision=3,
            # )

            svg = svg_output[0][0]
            
            svg = self.enforce_constraints(svg)

            if len(svg) > 10000:
                svg = self.default_svg
        
        except:
            traceback.print_exc()
            svg = self.default_svg
        
        return svg


    def enforce_constraints(self, svg_string: str) -> str:

        try:
            parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
            root = etree.fromstring(svg_string.encode('utf-8'), parser=parser)
        except etree.ParseError as e:
            return self.default_svg
    
        elements_to_remove = []
        for element in root.iter():
            tag_name = etree.QName(element.tag).localname
    
            # Remove disallowed elements
            if tag_name not in self.constraints.allowed_elements:
                elements_to_remove.append(element)
                continue  # Skip attribute checks for removed elements
    
            # Remove disallowed attributes
            attrs_to_remove = []
            for attr in tqdm(element.attrib):
                attr_name = etree.QName(attr).localname
                if (
                    attr_name
                    not in self.constraints.allowed_elements[tag_name]
                    and attr_name
                    not in self.constraints.allowed_elements['common']
                ):
                    attrs_to_remove.append(attr)

            for attr in tqdm(attrs_to_remove):
                print(f'Attribute "{attr}" for element "{tag_name}" not allowed. Removing.')
                del element.attrib[attr]
        
        # Remove elements marked for removal
        for element in tqdm(elements_to_remove):
            if element.getparent() is not None:
                element.getparent().remove(element)

        try:
            cleaned_svg_string = etree.tostring(root, encoding='unicode')
            return cleaned_svg_string
        except ValueError as e:
            return self.default_svg

In [3]:
%%time
import kaggle_evaluation

kaggle_evaluation.test(Model)

Creating Model instance...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers


Running inference tests...
Wrote test submission file to "/tmp/kaggle-evaluation-submission-ve44kx2q.csv".
Success!
CPU times: user 7.43 s, sys: 4.73 s, total: 12.2 s
Wall time: 2min
