In [1]:
!pip uninstall -y transformers trl peft accelerate
!pip install -U "transformers==4.45.2" "trl==0.9.4" "peft==0.12.0" "accelerate==0.34.2" "datasets>=2.20.0" safetensors einops lxml defusedxml cairosvg pillow scikit-image


Found existing installation: transformers 4.45.2
Uninstalling transformers-4.45.2:
  Successfully uninstalled transformers-4.45.2
Found existing installation: trl 0.9.4
Uninstalling trl-0.9.4:
  Successfully uninstalled trl-0.9.4
Found existing installation: peft 0.12.0
Uninstalling peft-0.12.0:
  Successfully uninstalled peft-0.12.0
Found existing installation: accelerate 0.34.2
Uninstalling accelerate-0.34.2:
  Successfully uninstalled accelerate-0.34.2
Collecting transformers==4.45.2
  Using cached transformers-4.45.2-py3-none-any.whl.metadata (44 kB)
Collecting trl==0.9.4
  Using cached trl-0.9.4-py3-none-any.whl.metadata (11 kB)
Collecting peft==0.12.0
  Using cached peft-0.12.0-py3-none-any.whl.metadata (13 kB)
Collecting accelerate==0.34.2
  Using cached accelerate-0.34.2-py3-none-any.whl.metadata (19 kB)
Collecting lxml
  Downloading lxml-6.0.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Using cached transformers-4.45.2-py3-none-any.whl (9.9 MB

In [5]:
import os
from huggingface_hub import login

# === paths ===
CSV_PATH   = "/content/data10000.csv"       # your CSV
OUT_DIR    = "/content/svg_run"             # base output dir
DATA_DIR   = f"{OUT_DIR}/data"
MODEL_DIR  = f"{OUT_DIR}/qwen7b-svg-lora-a100"   # LoRA adapter
MERGED_DIR = f"{OUT_DIR}/qwen7b-svg-merged"      # optional merged full model

os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(MERGED_DIR, exist_ok=True)

# Login (replace token or leave blank to prompt)
login("token")

print("Will read:", CSV_PATH)


Will read: /content/data10000.csv


In [6]:
import pandas as pd
import csv

def load_csv_robust(path):
    try:
        return pd.read_csv(path)
    except Exception as e1:
        return pd.read_csv(
            path,
            engine="python",
            quoting=csv.QUOTE_MINIMAL,
            quotechar='"',
            escapechar="\\",
            on_bad_lines="warn"
        )

df_raw = load_csv_robust(CSV_PATH)
df_raw = df_raw.dropna(subset=["prompt", "svg"])
print("Loaded rows:", len(df_raw))
df_raw.head(2)


Loaded rows: 10014


Unnamed: 0,prompt,svg
0,turquoise tiles with mother-of-pearl inlays,"<svg xmlns=""http://www.w3.org/2000/svg"" width=..."
1,crimson spheres suspended from a silver ring,"<svg xmlns=""http://www.w3.org/2000/svg"" width=..."


In [7]:
import re
from lxml import etree

def canonicalize_svg(svg_text: str) -> str:
    svg_text = re.sub(r"<!--.*?-->", "", str(svg_text), flags=re.S).strip()
    parser = etree.XMLParser(remove_blank_text=True, recover=True)
    root = etree.fromstring(svg_text.encode("utf-8"), parser=parser)
    for el in root.iter():
        if el.attrib:
            el.attrib.update(dict(sorted(el.attrib.items(), key=lambda kv: kv[0])))
    out = etree.tostring(root, encoding="utf-8", xml_declaration=False, pretty_print=False).decode("utf-8")
    return re.sub(r">\s+<", "><", out).strip()

def is_valid_svg(svg_text: str) -> bool:
    try:
        etree.fromstring(svg_text.encode("utf-8"))
        return True
    except Exception:
        return False

clean_rows = []
for _, r in df_raw.iterrows():
    try:
        s = canonicalize_svg(r["svg"])
        if is_valid_svg(s):
            clean_rows.append({"prompt": str(r["prompt"]).strip(), "svg": s})
    except Exception:
        pass

df = pd.DataFrame(clean_rows)
print("After cleaning:", df.shape)


After cleaning: (10014, 2)


In [8]:
from sklearn.model_selection import train_test_split

bins = pd.qcut(df["svg"].str.len(), q=min(10, max(2, len(df)//100)), duplicates="drop")
df["len_bin"] = bins

train_df, tmp_df = train_test_split(df, test_size=0.10, random_state=42, stratify=df["len_bin"])
val_df,   test_df = train_test_split(tmp_df, test_size=0.50, random_state=42, stratify=tmp_df["len_bin"])

for name, d in [("train",train_df), ("val",val_df), ("test",test_df)]:
    print(name, len(d))

for d in (train_df, val_df, test_df):
    d.drop(columns=["len_bin"], inplace=True, errors="ignore")


train 9012
val 501
test 501


In [9]:
import json

SYSTEM = "You translate a natural-language prompt into a valid, minimal SVG. Output only a single <svg> root."

def format_example(prompt, svg):
    user = f"PROMPT:\n{prompt}\n\nREQUIREMENTS:\n- One <svg> root.\n- Valid XML; no comments.\n- Minimal elements."
    return f"<|system|>{SYSTEM}</s><|user|>\n{user}\n</s><|assistant|>\n{svg}"

def to_jsonl(df_in, path):
    with open(path, "w", encoding="utf-8") as f:
        for _, r in df_in.iterrows():
            f.write(json.dumps({"text": format_example(r['prompt'], r['svg'])})+"\n")

TR_PATH = f"{DATA_DIR}/train.jsonl"
VA_PATH = f"{DATA_DIR}/val.jsonl"
TE_PATH = f"{DATA_DIR}/test.jsonl"

to_jsonl(train_df, TR_PATH)
to_jsonl(val_df, VA_PATH)
to_jsonl(test_df, TE_PATH)

print("Saved:", TR_PATH, VA_PATH, TE_PATH)


Saved: /content/svg_run/data/train.jsonl /content/svg_run/data/val.jsonl /content/svg_run/data/test.jsonl


In [10]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

MODEL_ID = "Qwen/Qwen2.5-Coder-7B-Instruct"

ds = load_dataset("json", data_files={"train": TR_PATH, "eval": VA_PATH})
print(ds)

tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
tok.pad_token = tok.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    low_cpu_mem_usage=True  # no flash-attn to avoid ImportError
)
model.gradient_checkpointing_enable()


Generating train split: 0 examples [00:00, ? examples/s]

Generating eval split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 9012
    })
    eval: Dataset({
        features: ['text'],
        num_rows: 501
    })
})


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [12]:
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

lora = LoraConfig(
    r=16, lora_alpha=32, lora_dropout=0.05, bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]
)

args = SFTConfig(
    output_dir=MODEL_DIR,
    bf16=True,
    learning_rate=2e-4,
    num_train_epochs=1.5,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=32,
    max_seq_length=4096,
    packing=True,                  # keep packing
    warmup_ratio=0.03,
    logging_steps=50,
    eval_strategy="steps", eval_steps=400,
    save_steps=400, save_total_limit=2,
    max_grad_norm=0.3,
    optim="adamw_torch",
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tok,
    peft_config=lora,
    train_dataset=ds["train"],
    eval_dataset=ds["eval"],
    args=args,
    dataset_text_field="text",     # ← tell TRL which column to use
)



Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [14]:
import os
os.environ["WANDB_MODE"] = "disabled"


In [15]:
trainer.train()
trainer.save_model(MODEL_DIR)
tok.save_pretrained(MODEL_DIR)
print("LoRA adapter saved to:", MODEL_DIR)


Step,Training Loss,Validation Loss


LoRA adapter saved to: /content/svg_run/qwen7b-svg-lora-a100


In [16]:
from peft import PeftModel

base = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, torch_dtype=torch.bfloat16, device_map="cpu", low_cpu_mem_usage=True
)
lora_model = PeftModel.from_pretrained(base, MODEL_DIR)
merged = lora_model.merge_and_unload()

merged.save_pretrained(MERGED_DIR, safe_serialization=True)
tok.save_pretrained(MERGED_DIR)
print("Merged model saved to:", MERGED_DIR)


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Merged model saved to: /content/svg_run/qwen7b-svg-merged


In [20]:
import torch, time

# Make sure we're on GPU
def where(m):
    try: return next(m.parameters()).device
    except StopIteration: return torch.device("cpu")

print("model device:", where(model))
if 'merged' in globals():
    print("merged device:", where(merged))

# Speed helpers
model.eval()
if 'merged' in globals(): merged.eval()
torch.set_grad_enabled(False)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True


model device: cuda:0
merged device: cpu


In [21]:
from tqdm.auto import tqdm
import numpy as np
import io, time
import cairosvg
from PIL import Image
from skimage.metrics import structural_similarity as ssim

def rasterize_svg(svg_code, size=(256, 256)):
    png_bytes = cairosvg.svg2png(bytestring=svg_code.encode('utf-8'),
                                 output_width=size[0], output_height=size[1])
    img = Image.open(io.BytesIO(png_bytes)).convert("L").resize(size)
    return np.array(img, dtype=np.float32) / 255.0

# lighter generation for speed check
MAX_NEW = 800         # ↓ from 2500
TEMP    = 0.3
TOP_P   = 0.9
USE_MERGED = True     # set False if you didn't run the merge cell

# shallow copy 20 rows to test
subset = test_df.head(20).copy()

valid_count = 0
ssim_scores = []
times = []

for _, row in tqdm(subset.iterrows(), total=len(subset)):
    t0 = time.perf_counter()
    prompt = row["prompt"]
    gt_svg = row["svg"]

    # one-pass generation (no retry for speed)
    m = merged if USE_MERGED and 'merged' in globals() else model
    sys_msg = "Translate the prompt into a single valid SVG that follows the constraints. Output ONLY the <svg>…</svg>."
    rules = "- One <svg> root\n- Only allowed elements/attributes; no external href; no comments\n- Keep minimal and under size limit"
    prefix = f"<|system|>{sys_msg}\n{rules}</s><|user|>\n{prompt}\n</s><|assistant|>\n"
    inputs = tok(prefix, return_tensors="pt").to(next(m.parameters()).device)

    with torch.inference_mode():
        out = m.generate(**inputs, do_sample=True, temperature=TEMP, top_p=TOP_P, max_new_tokens=MAX_NEW)
    pred = tok.decode(out[0], skip_special_tokens=True)
    # slice to first complete SVG
    if "<svg" in pred: pred = "<svg" + pred.split("<svg",1)[1]
    if "</svg>" in pred: pred = pred.split("</svg>",1)[0] + "</svg>"

    # constraints (no retry for speed)
    try:
        validator.validate_svg(pred)
        valid_count += 1
    except Exception:
        pass

    # SSIM (optional but still quick)
    try:
        gt_img = rasterize_svg(gt_svg)
        pr_img = rasterize_svg(pred)
        ssim_scores.append(ssim(gt_img, pr_img, data_range=1.0))
    except Exception:
        ssim_scores.append(0.0)

    times.append(time.perf_counter() - t0)

print(f"\nAvg sec/sample: {np.mean(times):.2f}  |  p90: {np.percentile(times,90):.2f}  |  fastest: {np.min(times):.2f}")
print(f"Valid SVG rate (mini): {valid_count/len(subset):.2%}")
print(f"Mean SSIM (mini): {np.mean(ssim_scores):.4f}")


  0%|          | 0/20 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
from tqdm.auto import tqdm
import numpy as np, time

MAX_NEW = 1200   # balanced; raise only if outputs truncate a lot
TEMP    = 0.3
TOP_P   = 0.9
USE_MERGED = True

valid = 0
ssim_scores = []
start = time.perf_counter()
per_item = []

for i, (_, row) in enumerate(tqdm(test_df.iterrows(), total=len(test_df))):
    t0 = time.perf_counter()
    prompt, gt_svg = row["prompt"], row["svg"]

    m = merged if USE_MERGED and 'merged' in globals() else model
    sys_msg = "Translate the prompt into a single valid SVG that follows the constraints. Output ONLY the <svg>…</svg>."
    rules = "- One <svg> root\n- Only allowed elements/attributes; no external href; no comments\n- Keep minimal and under size limit"
    prefix = f"<|system|>{sys_msg}\n{rules}</s><|user|>\n{prompt}\n</s><|assistant|>\n"
    inputs = tok(prefix, return_tensors="pt").to(next(m.parameters()).device)

    with torch.inference_mode():
        out = m.generate(**inputs, do_sample=True, temperature=TEMP, top_p=TOP_P, max_new_tokens=MAX_NEW)
    pred = tok.decode(out[0], skip_special_tokens=True)
    if "<svg" in pred: pred = "<svg" + pred.split("<svg",1)[1]
    if "</svg>" in pred: pred = pred.split("</svg>",1)[0] + "</svg>"

    try:
        validator.validate_svg(pred)
        valid += 1
    except Exception:
        pass

    try:
        gt_img = rasterize_svg(gt_svg)
        pr_img = rasterize_svg(pred)
        ssim_scores.append(ssim(gt_img, pr_img, data_range=1.0))
    except Exception:
        ssim_scores.append(0.0)

    dt = time.perf_counter() - t0
    per_item.append(dt)

    # simple ETA print every 10 items
    if (i+1) % 10 == 0:
        avg = np.mean(per_item)
        rem = (len(test_df) - (i+1)) * avg
        print(f"Processed {i+1}/{len(test_df)} | avg {avg:.2f}s | ETA ~ {rem/60:.1f} min")

print(f"\nValid SVG rate: {valid/len(test_df):.2%}")
print(f"Mean SSIM: {np.mean(ssim_scores):.4f} | Median SSIM: {np.median(ssim_scores):.4f}")
print(f"Total time: {(time.perf_counter()-start)/60:.1f} min")


In [22]:
from google.colab import files

# Zip everything in OUT_DIR
!zip -r svg_run_all.zip "{OUT_DIR}"
files.download("svg_run_all.zip")


  adding: content/svg_run/ (stored 0%)
  adding: content/svg_run/data/ (stored 0%)
  adding: content/svg_run/data/test.jsonl (deflated 78%)
  adding: content/svg_run/data/val.jsonl (deflated 78%)
  adding: content/svg_run/data/train.jsonl (deflated 78%)
  adding: content/svg_run/qwen7b-svg-merged/ (stored 0%)
  adding: content/svg_run/qwen7b-svg-merged/model.safetensors.index.json (deflated 95%)
  adding: content/svg_run/qwen7b-svg-merged/model-00004-of-00004.safetensors (deflated 21%)
  adding: content/svg_run/qwen7b-svg-merged/config.json (deflated 47%)
  adding: content/svg_run/qwen7b-svg-merged/vocab.json (deflated 61%)
  adding: content/svg_run/qwen7b-svg-merged/model-00003-of-00004.safetensors (deflated 21%)
  adding: content/svg_run/qwen7b-svg-merged/generation_config.json (deflated 39%)
  adding: content/svg_run/qwen7b-svg-merged/tokenizer.json (deflated 81%)
  adding: content/svg_run/qwen7b-svg-merged/model-00001-of-00004.safetensors (deflated 21%)
  adding: content/svg_run/qw

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>