# 💫 StarVector: Generating SVG Code from Images and Text

<div align="center">
  <img src="https://raw.githubusercontent.com/joanrod/star-vector/main/assets/starvector-xyz.png" alt="starvector" style="width: 600px; display: block; margin-left: auto; margin-right: auto;"/>
</div>

This notebook demonstrates [StarVector](https://github.com/joanrod/star-vector), a multimodal vision-language model for Scalable Vector Graphics (SVG) generation. It can be used to perform image2SVG (vectorization) and text2SVG generation.

## Key Features
- **Image to SVG**: Convert raster images to SVG code
- **Text to SVG**: Generate SVG code from text descriptions
- **Semantic Understanding**: Understands image semantics to produce accurate SVG primitives

[![Paper](https://img.shields.io/badge/arXiv-StarVector-red?logo=arxiv)](https://arxiv.org/abs/2312.11556)
[![Website](https://img.shields.io/badge/🌎_Website-starvector.github.io-blue.svg)](https://starvector.github.io/)
[![HF Model-1B](https://img.shields.io/badge/%F0%9F%A4%97%20_Model-StarVector--1B-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/starvector/starvector-1b-im2svg)
[![HF Model-8B](https://img.shields.io/badge/%F0%9F%A4%97%20_Model-StarVector--8B-ffc107?color=ffc107&logoColor=white)](https://huggingface.co/starvector/starvector-8b-im2svg)

## 1. Setup Environment

First, we'll install the necessary packages. This includes cloning the StarVector repository and installing its dependencies.

In [None]:
# Check if running in Google Colab
import sys
IN_COLAB = 'google.colab' in sys.modules
print(f"Running in Google Colab: {IN_COLAB}")

# Install dependencies
!pip install torch==2.5.1 torchvision==0.20.1 transformers==4.49.0 tokenizers==0.21.1 sentencepiece==0.2.0 accelerate pydantic==2.10 markdown2 numpy scikit-learn==1.2.2 gradio==3.36.1 gradio_client==0.2.9 requests httpx==0.24.0 uvicorn fastapi svgpathtools==1.6.1 seaborn==0.12.2 cairosvg beautifulsoup4 webcolors tqdm omegaconf open-clip-torch noise datasets scikit-image lxml Pillow protobuf

In [None]:
# Clone StarVector repository if in Colab
if IN_COLAB:
    !git clone https://github.com/joanrod/star-vector.git
    %cd star-vector
    !pip install -e .

## 2. Load StarVector Models

StarVector comes in two model sizes:
- **StarVector-1B**: Smaller model, faster inference
- **StarVector-8B**: Larger model, higher quality results

You can choose the model size based on your needs.

In [None]:
import torch
from PIL import Image
from IPython.display import SVG, display, HTML
import matplotlib.pyplot as plt
import os
import time

# Uncomment to load the 8B model instead
model_name = "starvector/starvector-1b-im2svg"
# model_name = "starvector/starvector-8b-im2svg"

print(f"Loading {model_name}...")
start_time = time.time()

# For CUDA memory efficiency, use half precision
# Note: not required for CPUs with enough memory
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
print(f"Using device: {device} with dtype: {dtype}")

# Load StarVector
from starvector.model.starvector_arch import StarVectorForCausalLM
from starvector.data.util import process_and_rasterize_svg

starvector = StarVectorForCausalLM.from_pretrained(model_name, torch_dtype=dtype)
starvector.to(device)
starvector.eval()

print(f"Model loaded in {time.time() - start_time:.2f} seconds")

## 3. Image-to-SVG (Vectorization)

Vectorization is the process of converting raster images (like PNG, JPEG) to SVG format. Let's start with a sample image from the StarVector repository.

In [None]:
# Function to download a sample image from the repository if not already available
def get_sample_image(sample_num=0):
    # Try to access local examples first
    local_path = f"assets/examples/sample-{sample_num}.png"
    if os.path.exists(local_path):
        return Image.open(local_path)
    
    # If in Colab and directory structure is different
    if IN_COLAB:
        # If we're in the cloned repo directory
        if os.path.exists(f"./star-vector/assets/examples/sample-{sample_num}.png"):
            return Image.open(f"./star-vector/assets/examples/sample-{sample_num}.png")
        # Download the sample image from GitHub
        import requests
        from io import BytesIO
        url = f"https://raw.githubusercontent.com/joanrod/star-vector/main/assets/examples/sample-{sample_num}.png"
        response = requests.get(url)
        return Image.open(BytesIO(response.content))
    
    return None

# Load sample image
sample_num = 18 # Try different numbers (0, 1, 4, 6, 7, 15, 16, 17, 18)
image_pil = get_sample_image(sample_num)
if image_pil is None:
    print("Sample image not found. Please upload your own image.")
else:
    # Display the image
    plt.figure(figsize=(5, 5))
    plt.imshow(image_pil)
    plt.axis('off')
    plt.title("Input Image")
    plt.show()

In [None]:
# Upload your own image (alternative to using sample images)
if IN_COLAB:
    from google.colab import files
    print("Upload an image for vectorization:")
    uploaded = files.upload()
    if uploaded:
        image_name = list(uploaded.keys())[0]
        image_pil = Image.open(image_name).convert('RGB')
        
        # Display the uploaded image
        plt.figure(figsize=(5, 5))
        plt.imshow(image_pil)
        plt.axis('off')
        plt.title("Uploaded Image")
        plt.show()

In [None]:
# Process the image and generate SVG
if image_pil is not None:
    # Ensure the image is RGB
    image_pil = image_pil.convert('RGB')
    
    # Process the image
    print("Processing image...")
    image = starvector.process_images([image_pil])[0].to(device)
    batch = {"image": image}
    
    # Generate SVG
    print("Generating SVG code...")
    start_time = time.time()
    raw_svg = starvector.generate_im2svg(
        batch, 
        max_length=4000, 
        temperature=1.0,          # Controls randomness: lower is more deterministic
        length_penalty=-1,        # Penalizes short outputs
        repetition_penalty=3.0    # Penalizes repetition
    )[0]
    print(f"SVG generated in {time.time() - start_time:.2f} seconds")
    
    # Process and rasterize SVG
    svg, raster_image = process_and_rasterize_svg(raw_svg)

In [None]:
# Display results: Input Image, SVG, and Rasterized SVG
if 'svg' in locals():
    # Create a 1x3 subplot for comparison
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Display original image
    axes[0].imshow(image_pil)
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    # Display generated SVG (rasterized for matplotlib)
    axes[1].imshow(raster_image)
    axes[1].set_title("Generated SVG (Rasterized)")
    axes[1].axis('off')
    
    # Create a blank area for the SVG display note
    axes[2].text(0.5, 0.5, "SVG displayed below the plot\n(interactive in the next cell)", 
                ha='center', va='center', fontsize=12)
    axes[2].set_title("SVG Code")
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Display the SVG code and interactive SVG
    print("SVG Code Preview (first 500 characters):")
    print(svg[:500] + ("..." if len(svg) > 500 else ""))
    print(f"\nTotal SVG length: {len(svg)} characters")
    
    # Display interactive SVG
    display(HTML(f"<h3>Interactive SVG Result:</h3>"))
    display(SVG(svg))

In [None]:
# Save the SVG file
if 'svg' in locals() and IN_COLAB:
    # Save SVG to a file
    output_filename = "starvector_output.svg"
    with open(output_filename, "w") as f:
        f.write(svg)
    print(f"SVG saved to {output_filename}")
    
    # Download the SVG file
    from google.colab import files
    files.download(output_filename)

## 4. Text-to-SVG Generation (Experimental)

Generate SVG code from text descriptions. Note that this feature is still experimental.

In [None]:
# Try loading a text2svg model if available
try:
    from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
    
    # The model for text2svg might be different or not yet publicly available
    text2svg_model_name = "starvector/starvector-8b-im2svg"  # Use the same model for now
    
    # Use the HuggingFace interface for loading
    starvector_text2svg = AutoModelForCausalLM.from_pretrained(
        text2svg_model_name, 
        torch_dtype=dtype,
        trust_remote_code=True
    )
    processor = starvector_text2svg.model.processor
    tokenizer = starvector_text2svg.model.svg_transformer.tokenizer
    
    starvector_text2svg.to(device)
    starvector_text2svg.eval()
    
    text2svg_available = True
    print(f"Text2SVG model loaded: {text2svg_model_name}")
except Exception as e:
    text2svg_available = False
    print(f"Text2SVG model could not be loaded: {e}")
    print("Text2SVG functionality may not be available yet in the public models.")

In [None]:
# Text-to-SVG generation (if available)
if text2svg_available:
    # Example prompts for text2svg
    text_prompts = [
        "A simple red circle",
        "A blue square inside a yellow circle",
        "An icon of a house with a chimney",
        "A line chart showing increasing trend",
        "A minimalist logo with three horizontal lines"  
    ]
    
    # Let user choose or enter custom prompt
    print("Choose a sample prompt or enter your own:")
    for i, prompt in enumerate(text_prompts):
        print(f"{i+1}. {prompt}")
    
    if IN_COLAB:
        # Custom prompt input
        custom_prompt = input("Enter a number 1-5 to use a sample prompt, or enter your own text prompt: ")
        
        try:
            prompt_idx = int(custom_prompt) - 1
            if 0 <= prompt_idx < len(text_prompts):
                text_prompt = text_prompts[prompt_idx]
            else:
                text_prompt = custom_prompt
        except ValueError:
            text_prompt = custom_prompt
        
        print(f"\nGenerating SVG for: \"{text_prompt}\"")
        
        try:
            # Generate SVG from text
            # Note: This is experimental and might need adjustments based on the actual API
            raw_svg = starvector_text2svg.generate_text2svg(
                text_prompt, 
                max_length=4000,
                temperature=1.0,
                length_penalty=-1,
                repetition_penalty=3.0
            )[0]
            
            # Process and display SVG
            svg, raster_image = process_and_rasterize_svg(raw_svg)
            
            plt.figure(figsize=(5, 5))
            plt.imshow(raster_image)
            plt.title(f"Generated SVG for: {text_prompt}")
            plt.axis('off')
            plt.show()
            
            # Display interactive SVG
            display(HTML(f"<h3>Interactive SVG Result:</h3>"))
            display(SVG(svg))
            
            # Save and download SVG
            output_filename = "starvector_text2svg_output.svg"
            with open(output_filename, "w") as f:
                f.write(svg)
            print(f"SVG saved to {output_filename}")
            files.download(output_filename)
            
        except Exception as e:
            print(f"Error generating SVG from text: {e}")
            print("The text2svg functionality may still be experimental.")
else:
    print("Text2SVG functionality is not available with the current model.")

## 5. Create a Simple Gradio Interface (Optional)

A simple web interface to try the model interactively.

In [None]:
# Create a Gradio interface if in Colab
if IN_COLAB:
    try:
        import gradio as gr
        from io import BytesIO
        import base64
        
        def image_to_svg(input_image):
            if input_image is None:
                return None, None, "Please upload an image"
                
            # Process the image
            input_image = input_image.convert('RGB')
            image = starvector.process_images([input_image])[0].to(device)
            batch = {"image": image}
            
            # Generate SVG
            raw_svg = starvector.generate_im2svg(
                batch, 
                max_length=4000, 
                temperature=1.0,
                length_penalty=-1,
                repetition_penalty=3.0
            )[0]
            
            # Process and rasterize SVG
            svg, raster_image = process_and_rasterize_svg(raw_svg)
            
            # Convert SVG to HTML for display
            svg_html = f"<div style='width:100%;height:100%;'>{svg}</div>"
            
            return raster_image, svg, "SVG generated successfully"
        
        # Create the interface
        with gr.Blocks() as demo:
            gr.Markdown("# 💫 StarVector: Image to SVG Converter")
            
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(label="Upload Image", type="pil")
                    submit_btn = gr.Button("Convert to SVG")
                    
                with gr.Column():
                    output_image = gr.Image(label="Rasterized SVG")
                    output_svg = gr.Code(language="svg", label="SVG Code")
                    output_msg = gr.Textbox(label="Status")
            
            submit_btn.click(
                fn=image_to_svg,
                inputs=[input_image],
                outputs=[output_image, output_svg, output_msg]
            )
            
            gr.Markdown(
                """### Notes:
                - StarVector works best with logos, icons, diagrams, and clean graphics
                - Not designed for photographs or complex illustrations
                - The model converts the image to vector graphics (SVG code)
                """
            )
        
        # Launch the interface
        demo.launch(share=True, debug=False)
        
    except Exception as e:
        print(f"Error creating Gradio interface: {e}")

## 6. Clean Up Resources

Free up GPU memory when done.

In [None]:
# Clean up resources
if device == "cuda":
    if 'starvector' in locals():
        del starvector
    if 'starvector_text2svg' in locals():
        del starvector_text2svg
    torch.cuda.empty_cache()
    print("GPU memory cleared")

## About StarVector

StarVector is a multimodal vision-language model that generates SVG code from images and text. It was developed by researchers at [Element AI / ServiceNow](https://www.servicenow.com/research/).

If you find this project useful for your research or applications, please cite:

```
@misc{rodriguez2024starvector,
      title={StarVector: Generating Scalable Vector Graphics Code from Images and Text}, 
      author={Juan A. Rodriguez and Abhay Puri and Shubham Agarwal and Issam H. Laradji and Pau Rodriguez and Sai Rajeswar and David Vazquez and Christopher Pal and Marco Pedersoli},
      year={2024},
      eprint={2312.11556},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2312.11556}, 
}
```

### Resources
- [GitHub Repository](https://github.com/joanrod/star-vector)
- [Project Website](https://starvector.github.io/)
- [Paper](https://arxiv.org/abs/2312.11556)