In [1]:
!pip install transformers torch datasets accelerate scikit-learn

import xml.etree.ElementTree as ET
from tqdm import tqdm
import json
import torch
import numpy as np
import pandas as pd
from datasets import load_dataset, Dataset
from transformers import (
    Trainer,
    TrainingArguments,
    AutoTokenizer,
    AutoModelForCausalLM,
    DataCollatorForLanguageModeling,
)

Defaulting to user installation because normal site-packages is not writeable


In [2]:
dataset = load_dataset("starvector/text2svg-stack")
df = pd.DataFrame(dataset['train'])

df["svg_len"] = df["Svg"].apply(lambda x: len(x))
df = df.sort_values("svg_len", ascending=False).head(10000).reset_index(drop=True)
df.drop(columns=["svg_len"], inplace=True)

In [3]:
df = df.sample(frac=1, random_state=42).reset_index(drop=True)
train_df = df.iloc[:8000]
val_df = df.iloc[8000:]

In [4]:
model_name = "Qwen/Qwen2.5-Coder-3B"  # change this model name for Qwen/Qwen2.5-Coder-7B and Qwen/Qwen2.5-Coder-14B
tokenizer = AutoTokenizer.from_pretrained(model_name)

def normalize_svg(svg_string):
    try:
        return ET.tostring(ET.fromstring(svg_string), encoding="unicode")
    except ET.ParseError:
        return svg_string.strip()

def create_prompt(description):
    return f"""Convert this SVG description to valid SVG XML code. Follow these rules:
1. Use proper XML syntax with self-closing tags where appropriate
2. Include xmlns="http://www.w3.org/2000/svg" in your code
3. Add viewBox when necessary
4. Close all tags properly

Description: {description}"""

def tokenize(row):
    prompt = create_prompt(row["caption_cogvlm"])
    messages = [
        {"role": "system", "content": "You are an expert SVG coder."},
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": row["Svg"]},
    ]
    full_text = tokenizer.apply_chat_template(messages, tokenize=False)

    tokenized = tokenizer(
        full_text,
        truncation=True,
        padding="max_length",
        max_length=2048,
    )
    tokenized["labels"] = tokenized["input_ids"].copy()
    return tokenized

In [5]:
train_hf_ds = Dataset.from_pandas(train_df)
val_hf_ds = Dataset.from_pandas(val_df)
train_ds = train_hf_ds.map(tokenize, remove_columns=train_hf_ds.column_names, num_proc=4)
val_ds = val_hf_ds.map(tokenize, remove_columns=val_hf_ds.column_names, num_proc=4)

Map (num_proc=4):   0%|          | 0/9000 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/1000 [00:00<?, ? examples/s]

In [None]:
# Load model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

# Load training
training_args = TrainingArguments(
    output_dir="./results_svg_qwen",
    num_train_epochs=2,
    max_steps=500,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    evaluation_strategy="no",
    logging_dir="./logs_svg_qwen",
    logging_steps=1,
    report_to="none",
    bf16=True,
    learning_rate=5e-5,
)

#Load trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

# Train the model
train_result = trainer.train()

avg_train_loss = train_result.training_loss
print(f"\nAverage Training Loss: {avg_train_loss:.4f}")


In [None]:
# Evaluate accuracy
results = []

model.eval()
correct = 0
total = len(val_df)

for example in tqdm(val_df[:100].to_dict("records"), desc="Evaluating"):
    description = example["caption_cogvlm"]
    true_svg = example["Svg"]

    messages = [
        {"role": "system", 
         "content": "You are an expert SVG coder. Your task is to create an SVG code following the following prompt."},
        {"role": "user", "content": create_prompt(description)}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt").to(model.device)

    with torch.no_grad():
        output_ids = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.3,
            top_p=0.95,
            repetition_penalty=1.1,
            do_sample=False
        )

    decoded_svg = tokenizer.batch_decode(
        output_ids[:, inputs["input_ids"].shape[1]:],
        skip_special_tokens=True
    )[0]

    is_correct = normalize_svg(decoded_svg) == normalize_svg(true_svg)
    if is_correct:
        correct += 1

    results.append({
        "description": description,
        "prediction": decoded_svg,
        "ground_truth": true_svg,
        "match": is_correct
    })

    del inputs, output_ids
    torch.cuda.empty_cache()

# Output to JSON file
with open("svg_predictions.json", "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

accuracy = correct / total
print(f"\nValidation Accuracy: {accuracy:.4f}")