In [2]:
import os

os.mkdir("/content/tokenizer_svg_extended")
os.mkdir("/content/codellama_svg_qlora")

In [3]:
!pip uninstall -y transformers trl peft accelerate
!pip install -U "transformers==4.45.2" "trl==0.9.4" "peft==0.12.0" "accelerate==0.34.2" "datasets>=2.20.0" safetensors einops lxml defusedxml cairosvg pillow scikit-image

Found existing installation: transformers 4.55.2
Uninstalling transformers-4.55.2:
  Successfully uninstalled transformers-4.55.2
[0mFound existing installation: peft 0.17.0
Uninstalling peft-0.17.0:
  Successfully uninstalled peft-0.17.0
Found existing installation: accelerate 1.10.0
Uninstalling accelerate-1.10.0:
  Successfully uninstalled accelerate-1.10.0
Collecting transformers==4.45.2
  Downloading transformers-4.45.2-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m4.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting trl==0.9.4
  Downloading trl-0.9.4-py3-none-any.whl.metadata (11 kB)
Collecting peft==0.12.0
  Downloading peft-0.12.0-py3-none-any.whl.metadata (13 kB)
Collecting accelerate==0.34.2
  Downloading accelerate-0.34.2-py3-none-any.whl.metadata (19 kB)
Collecting lxml
  Downloading lxml-6.0.1-cp312-cp312-manylinux_2_26_x86_64.manylinux_2_28_x86_64.whl.metadata (3.8 kB)
Collecting cairosvg
  Downloadi

In [5]:
import trl

print(trl.__version__)

0.9.4


In [8]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel

In [11]:
# Load base model and tokenizer
base_model = AutoModelForCausalLM.from_pretrained("meta-llama/CodeLlama-7b-Instruct-hf", torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("/content/tokenizer_svg_extended")

# adjust the base model
base_model.resize_token_embeddings(len(tokenizer))

# Load LoRA weights
lora_model = PeftModel.from_pretrained(base_model, "/content/codellama_svg_qlora")

# Merge LoRA into the base model
merged_model = lora_model.merge_and_unload()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [12]:
import math
from datasets import Dataset
from tqdm import tqdm

In [13]:
merged_model.eval().to("cuda" if torch.cuda.is_available() else "cpu")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32026, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (no

In [14]:
import pandas as pd

In [39]:
# load evaluation data
dataset = pd.read_csv("/content/data10k.csv")
dataset = dataset[['description', 'svg']]

# Create the text column
dataset["text"] = dataset.apply(lambda row: f"Given the following description: {row['description']}, generate the corresponding SVG string.\n{row['svg']}",
                                axis=1)
SEED = 42
ds_all = Dataset.from_pandas(dataset[["text"]], preserve_index=False)
training, temp = ds_all.train_test_split(test_size=0.2, seed=SEED).values()
validation, testing = temp.train_test_split(test_size=0.5, seed=SEED).values()
print(f"training: {len(training)}, validation: {len(validation)}, testing: {len(testing)}")

training: 8011, validation: 1001, testing: 1002


In [40]:
# Tokenize the evaluation data
def tokenize(example):
    tokenized = tokenizer(
        example["text"],
        return_tensors="pt",
        padding="max_length",
        truncation=True,
        max_length=1024,
    )
    example["input_ids"] = tokenized["input_ids"][0].tolist()
    example["attention_mask"] = tokenized["attention_mask"][0].tolist()
    return example

tokenized_dataset = testing.map(tokenize)

Map:   0%|          | 0/1002 [00:00<?, ? examples/s]

In [18]:
# Compute perplexity
losses = []

for example in tqdm(tokenized_dataset):
    input_ids = torch.tensor(example["input_ids"], dtype=torch.long).unsqueeze(0).to(merged_model.device)
    attention_mask = torch.tensor(example["attention_mask"], dtype=torch.long).unsqueeze(0).to(merged_model.device)
    labels = input_ids.clone()

    with torch.no_grad():
        # get metric values
        outputs = merged_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        losses.append(loss.item())


avg_loss = sum(losses) / len(losses)
perplexity = math.exp(avg_loss)

print(f"\nAverage Loss: {avg_loss:.4f}")
print(f"Perplexity:   {perplexity:.4f}")

100%|██████████| 1002/1002 [01:25<00:00, 11.73it/s]


Average Loss: 1.2130
Perplexity:   3.3636





In [41]:
for example in tqdm(tokenized_dataset.select(range(50))):
    input_ids = torch.tensor(example["input_ids"], dtype=torch.long).unsqueeze(0).to(merged_model.device)
    attention_mask = torch.tensor(example["attention_mask"], dtype=torch.long).unsqueeze(0).to(merged_model.device)
    labels = input_ids.clone()

    with torch.no_grad():
        generated_tokens = merged_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=6400,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
        )
        prompt_len = input_ids.shape[1]
        output_tokens = generated_tokens[0, prompt_len:]
        generated_svg_string = tokenizer.decode(output_tokens, skip_special_tokens=True)
        print(generated_svg_string)
        break

  0%|          | 0/50 [03:38<?, ?it/s]

91" />
<polygonpoints="0.0,252.0 307.0,274.0"  #f6e9dc" />
<polygon>0.0,262.0 108.0,252.0"#d08f91" />
<polygon>0.0,295.0 255.0,318.0 108.0,299.0 188.0,252.0" #d08f91" />
<polygon>0.0,368.0 301.0,368.0 186.0,368.0 109.0,367.0" #f6e9dc" />
<polygon>53.0,0.0 53.0,104.0" #d08f91" />
<polygon>0.0,383.0 53.0,397.0" #f6e9dc" />
<polygon>111.0,414.0 249.0,449.0 412.0,416.0 452.0,441.0 532.0,441.0" #1f9891" />
<polygon>247.0,468.0 307.0,372.0 450.0,372.0 510.0,361.0 540.0,361.0" #f6e9dc" />
<polygon>451.0,518.0 527.0,552.0 529.0,330.0 549.0,330.0 496.0,496.0" #f6e9dc" />
<polygon>242.0,0.0 249.0,66.0" #d08f91" />
<polygon>0.0,356.0 85.0,356.0 140.0,413.0 110.0,158.0 108.0,252.0 158.0,252.0 161.0,356.0 197.0,356.0" #d08f91" />
<polygon>383.0,397.0 307.0,397.0 453.0,453.0 350.0,450.0 306.0,306.0 251.0,452.0" #1f9891" />
<polygon>250.0,0.0 252.0,50.0" #d08f91" />
<polygon>0.0,396.0 106.0,392.0 108.0,344.0 331.0,318.0 330.0,412.0" #f6e9dc" />
<polygon>85.0,96.0 85.0,251.0" #1f9891" />
<polygon>359.




In [None]:
# format prompt
description = "a cat on a beach"
prompt = f"Given the following description: {description}, generate the corresponding SVG string.\n"

# tokenize prompt
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(merged_model.device)
attention_mask = tokenizer(prompt, return_tensors="pt").attention_mask.to(merged_model.device)

with torch.no_grad():
  generated_tokens = merged_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=6400,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            pad_token_id=tokenizer.eos_token_id,
        )
  prompt_len = input_ids.shape[1]
  output_tokens = generated_tokens[0, prompt_len:]
  generated_svg_string = tokenizer.decode(output_tokens, skip_special_tokens=True)

print(generated_svg_string)



\begin{code}
cat_on_beach() {
  cat <<EOF
  <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 100 100">
    <ellipse cx="50" cy="50" rx="40" ry="35" fill="#d38e4d"/>
    <path d="M50 40L62 52 70 60" stroke="#c08006" stroke-width="3"/>
    <path d="M50 40L78 52 70 60" stroke="#c08006" stroke-width="3"/>
    <path d="M50 50L67 70 70 60" stroke="#c08006" stroke-width="3"/>
    <path d="M50 50L81 70 70 60" stroke="#c08006" stroke-width="3"/>
    <path d="M50 60L67 78 70 60" stroke="#c08006" stroke-width="3"/>
    <path d="M50 60L81 78 70 60" stroke="#c08006" stroke-width="3"/>
  </svg>
EOF
}

cat_on_beach | xclip -selection clipboard
\end{code}

The only real drawback to this solution is that xclip copies to the clipboard the text string of the SVG. This might be a problem if the SVG is large or if it needs to be shared using another medium such as file transfer or email. An alternative to xclip, especially in cases of large data, is the "clip" command.

\begin{code}
cat_on_beach() {
