In [None]:
from PIL import Image
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
from starvector.data.util import process_and_rasterize_svg
import torch
 
model_name = "starvector/starvector-8b-im2svg"

In [None]:
device = "cuda:1"
starvector = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True).to(device)
processor = starvector.model.processor
tokenizer = starvector.model.svg_transformer.tokenizer
 
starvector.eval()
 
image_pil = Image.open('assets/examples/sample-18.png')
image = processor(image_pil, return_tensors="pt")['pixel_values'].to(device)
if not image.shape[0] == 1:
    image = image.squeeze(0)


In [None]:

batch = {"image": image, 'caption': ["a drawing of a dog wearing a red hat and sunglasses"]}
raw_svg = starvector.model.generate_text2svg(batch, max_length=4000)
raw_svg

In [None]:

batch = {'caption': ["a drawing of a dog wearing a red hat and sunglasses"]}
raw_svg = starvector.model.generate_text2svg(batch, max_length=4000)
raw_svg

In [None]:
raw_svg.shape
process_and_rasterize_svg(raw_svg)

In [None]:
raw_svg_0 = starvector.generate_im2svg(batch, max_length=4000)[0]
svg_0, raster_image = process_and_rasterize_svg(raw_svg_0)
raw_svg_0

In [None]:
# text2svg.py ---------------------------------------------------------------
import torch, os, textwrap
from transformers import AutoModelForCausalLM

MODEL_ID = "starvector/starvector-8b-im2svg"
prompt   = "Design a flat-style SVG badge of a rocket circling planet Earth."

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) load StarVector (outer wrapper)
sv = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
        trust_remote_code=True,
    ).to(device).eval()

# 2) grab the *real* StarCoder-for-CausalLM + tokenizer
lm        = sv.model.svg_transformer.transformer       # ← has .generate()
tokenizer = sv.model.svg_transformer.tokenizer

# 3) encode prompt  ➜  generate SVG tokens  ➜  decode
inputs = tokenizer(prompt, return_tensors="pt").to(device)

with torch.no_grad():
    out_ids = lm.generate(
        **inputs,
        max_new_tokens=1024,          # bump if your SVGs are long
        temperature=0.8,
        eos_token_id=tokenizer.eos_token_id,
    )

svg_code = tokenizer.decode(out_ids[0], skip_special_tokens=True)

# 4) save / peek
out_file = "rocket_badge.svg"
with open(out_file, "w") as f:
    f.write(svg_code)

print(textwrap.shorten(svg_code, width=150))
print(f"✓ saved to {os.path.abspath(out_file)}")

In [None]:

# Process all PNGs in examples directory
import os
from IPython.display import display, HTML
import base64
from io import BytesIO

def image_to_base64(img):
    buffered = BytesIO()
    img.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

# Clear CUDA cache before starting
torch.cuda.empty_cache()
gc.collect()

# Get all PNG files
example_dir = 'assets/examples'
png_files = sorted([f for f in os.listdir(example_dir) if f.endswith('.png')])
print(f"Found {len(png_files)} PNG files to process")

# Create HTML table
html = """
<table style="border-collapse: collapse;">
    <tr>
        <th style="padding: 10px; border: 1px solid black;">Original PNG</th>
        <th style="padding: 10px; border: 1px solid black;">Generated SVG (Rasterized)</th>
        <th style="padding: 10px; border: 1px solid black;">SVG Code</th>
    </tr>
"""

# Process each PNG
for i, png_file in enumerate(png_files, 1):
    print(f"\nProcessing {png_file} ({i}/{len(png_files)})...")
    
    try:
        # Load and process image
        image_path = os.path.join(example_dir, png_file)
        image_pil = Image.open(image_path)
        
        # Generate SVG
        print("  Generating SVG...")
        image = processor(image_pil, return_tensors="pt")['pixel_values'].cuda()
        if not image.shape[0] == 1:
            image = image.squeeze(0)
        batch = {"image": image}
        
        raw_svg = starvector.generate_im2svg(batch, max_length=4000)[0]
        print("  Rasterizing SVG...")
        svg, raster_image = process_and_rasterize_svg(raw_svg)
        
        # Convert images to base64
        orig_b64 = image_to_base64(image_pil)
        raster_b64 = image_to_base64(raster_image)
        
        # Add to HTML table
        html += f"""
        <tr>
            <td style="padding: 10px; border: 1px solid black;">
                <img src="data:image/png;base64,{orig_b64}" style="max-width: 300px;">
                <br>
                <small>{png_file}</small>
            </td>
            <td style="padding: 10px; border: 1px solid black;">
                <img src="data:image/png;base64,{raster_b64}" style="max-width: 300px;">
            </td>
            <td style="padding: 10px; border: 1px solid black;">
                <pre style="text-align: left; max-height: 200px; overflow: auto; white-space: pre-wrap; word-wrap: break-word;">
                    {svg.replace('<', '&lt;').replace('>', '&gt;')}
                </pre>
            </td>
        </tr>
        """
        
        # Clear CUDA cache after each image
        torch.cuda.empty_cache()
        gc.collect()
        
    except Exception as e:
        print(f"Error processing {png_file}: {str(e)}")
        html += f"""
        <tr>
            <td colspan="3" style="padding: 10px; border: 1px solid black;">
                Error processing {png_file}: {str(e)}
            </td>
        </tr>
        """

html += "</table>"

# Display the table
display(HTML(html))

In [None]:
starvector = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, trust_remote_code=True)
processor = starvector.model.processor
tokenizer = starvector.model.svg_transformer.tokenizer
 
starvector.cuda()
starvector.eval()
 
image_pil = Image.open('assets/examples/sample-18.png')
 
image = processor(image_pil, return_tensors="pt")['pixel_values'].cuda()
if not image.shape[0] == 1:
    image = image.squeeze(0)
batch = {"image": image}
 
raw_svg = starvector.generate_im2svg(batch, max_length=4000)[0]
svg, raster_image = process_and_rasterize_svg(raw_svg)

In [None]:
dir(starvector)