This notebook implements a submission with Gemma 2 9B IT model with some helper code to ensure the generated SVGs conform to the submission requirements. (See the [Evaluation](https://www.kaggle.com/competitions/drawing-with-llms/overview/evaluation) page for details on the submission requirements.)

To use this notebook interactively, you'll need to install some dependencies. First, *turn on* the Internet under **Session options** to the right. Then select the **Add-ons->Install Dependencies** menu above and click *Run*. A console should pop up with a running `pip` command. Wait for the dependencies to finish installing and then *turn off* the Internet before submitting.

In [1]:
#| default_exp core

In [2]:
#| export
class PromptsMixin:
    def __init__(self):
        self.planner_template = """\
You are an expert SVG writer. You will be given a textual representation of an image.
You will make a plan to generate an SVG that best represents the text, while respecting constraints below. 
Your plan will be used by downstream agents to generate the actual SVG.

<constraints>
* **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs`
* **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity`
</constraints>

Output MAX 10 bullet points on the elements you plan to use. 
DO NOT use markdown formatting except for the bullet points. NO inline code, NO emphasis, NO italics, etc
Your plan should include elements for the description AND their specific SVG constructs.

<text>{}</text>
"""

        self.writer_template = """\
You are an expert SVG coder. You will be given: 
1. A textual representation of an image
2. A plan to draw that image

Use both to write SVG while respecting the following constraints.

<constraints>
* **Allowed Elements:** `svg`, `path`, `circle`, `rect`, `ellipse`, `line`, `polyline`, `polygon`, `g`, `linearGradient`, `radialGradient`, `stop`, `defs`
* **Allowed Attributes:** `viewBox`, `width`, `height`, `fill`, `stroke`, `stroke-width`, `d`, `cx`, `cy`, `r`, `x`, `y`, `rx`, `ry`, `x1`, `y1`, `x2`, `y2`, `points`, `transform`, `opacity`
</constraints>

You are encouraged to utilize comments to structure your output.
Focus on a clear and concise representation of the input description within the given limitations. 
Always give the complete SVG code with nothing omitted. Never use an ellipsis.

## Output Format:
```svg
<svg viewBox="0 0 256 256" width="256" height="256">
  <circle cx="50" cy="50" r="40" fill="red"/>
  <rect x="30" y="30" width="40" height="40" fill="blue"/>
</svg>
```

<text>{text}</text>
<plan>{plan}</plan>
"""

        self.reflection_user_prompt = """\
Don't write SVG yet. Review and create actionable recommendations to improve the outputted SVG \
while respecting the constraints.

## Bad Examples:
### Not specific
- The SVG has an invalid element.
- The SVG is illegal and cannot render.

### Not actionable
- The SVG does not make sense.

Prioritize! Output MAX 10 bullet points in order of importance. 
"""

In [3]:
#| export
import concurrent
import io
import logging
import re
import re2

import cairosvg
import kagglehub
import torch
from lxml import etree
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from transformers import pipeline

svg_constraints = kagglehub.package_import('metric/svg-constraints')

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEMPERATURE = 0.6
TOP_P = 0.9

class Model(PromptsMixin):
    def __init__(self):
        super().__init__()
        MODEL_NAME = "google/gemma-2/Transformers/gemma-2-9b-it/2"
        model_path = kagglehub.model_download(MODEL_NAME)
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        quantization_config_4_bit = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
            bnb_4bit_compute_dtype=torch.float16,
        )
        
        llm = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            # torch_dtype=torch.float16,  # half precision
            quantization_config=quantization_config_4_bit,  # uncomment for quantization
            max_memory={0: "15GiB", 1: "15GiB"},  # Maximize each GPU usage
            offload_folder="offload",  # For any CPU offloading
        )
        self.pipe = pipe = pipeline("text-generation", model=llm, tokenizer=tokenizer, trust_remote_code=True, max_new_tokens=200)
        
        self.default_svg = """<svg width="256" height="256" viewBox="0 0 256 256"><circle cx="50" cy="50" r="40" fill="red" /></svg>"""
        self.constraints = svg_constraints.SVGConstraints()
        self.timeout_seconds = 480

    def planner_block(self, text) -> str:
        messages = [
            {"role": "user", "content": self.planner_template.format(text)}
        ]
            
        outputs = self.pipe(
            messages,
            max_new_tokens=512, 
            do_sample=True,
            temperature=TEMPERATURE,
            top_p=TOP_P
        )
        assistant_response = outputs[0]["generated_text"][-1]["content"].strip()
        logging.debug(f"Planner: {assistant_response}")
        return assistant_response

    def writer_block(self,
                     text: str,
                     plan_text: str, 
                     revision_turns=0
    ) -> str:
        messages = [
            {"role": "user", "content": self.writer_template.format(text=text, plan=plan_text)},
        ]
            
        outputs = self.pipe(
            messages,
            max_new_tokens=2048, 
            do_sample=True,
            temperature=TEMPERATURE,
            top_p=TOP_P
        )
        assistant_response = outputs[0]["generated_text"][-1]["content"].strip()
        logging.debug(f"Writer: {assistant_response}")
    
        if revision_turns:
            revision_idx = 0
            while revision_idx < revision_turns:
                logging.debug(f"Revision turn: {revision_idx+1}")
                revision_messages = [
                    {"role": "assistant", "content": assistant_response},
                    {"role": "user", "content": self.reflection_user_prompt}
                ]
                messages.extend(revision_messages)        
                outputs = self.pipe(
                    messages,
                    max_new_tokens=1024, 
                    do_sample=True,
                    temperature=TEMPERATURE,
                    top_p=TOP_P
                )
                critic_response = outputs[0]["generated_text"][-1]["content"].strip()
                logging.debug(f"Critic: {critic_response}")
    
                writer_messages = [
                    {"role": "assistant", "content": critic_response},
                    {"role": "user", "content": "Looks good. Rewrite the SVG now please."}
                ]
                messages.extend(writer_messages)  
                outputs = self.pipe(
                    messages,
                    max_new_tokens=1024, 
                    do_sample=True,
                    temperature=TEMPERATURE,
                    top_p=TOP_P
                )
                assistant_response = outputs[0]["generated_text"][-1]["content"].strip()
                logging.debug(f"Writer: {assistant_response}")
                
                revision_idx +=1
    
        return assistant_response

    # You could try increasing `max_new_tokens`
    def predict(self, description: str) -> str:
        def generate_svg():
            try:
                plan = self.planner_block(description)
                svg_response = self.writer_block(
                    description, 
                    plan, 
                    revision_turns=1
                )

                matches = re.findall(r"<svg.*?</svg>", svg_response, re.DOTALL | re.IGNORECASE)
                if not matches:
                    return self.default_svg

                svg = matches[-1]
                
                svg = self.enforce_constraints(svg)
                logging.debug('Processed SVG: %s', svg)
                # Ensure the generated code can be converted by cairosvg
                cairosvg.svg2png(bytestring=svg.encode('utf-8'))
                return svg
            except Exception as e:
                logging.error('Exception during SVG generation: %s', e)
                return self.default_svg

        # Execute SVG generation in a new thread to enforce time constraints
        with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
            future = executor.submit(generate_svg)
            try:
                return future.result(timeout=self.timeout_seconds)
            except concurrent.futures.TimeoutError:
                logging.warning("Prediction timed out after %s seconds.", self.timeout_seconds)
                return self.default_svg
            except Exception as e:
                logging.error(f"An unexpected error occurred: {e}")
                return self.default_svg

    def enforce_constraints(self, svg_string: str) -> str:
        """Enforces constraints on an SVG string, removing disallowed elements
        and attributes.

        Parameters
        ----------
        svg_string : str
            The SVG string to process.

        Returns
        -------
        str
            The processed SVG string, or the default SVG if constraints
            cannot be satisfied.
        """
        logging.info('Sanitizing SVG...')

        try:
            parser = etree.XMLParser(remove_blank_text=True, remove_comments=True)
            root = etree.fromstring(svg_string, parser=parser)
        except etree.ParseError as e:
            logging.error('SVG Parse Error: %s. Returning default SVG.', e)
            return self.default_svg
    
        elements_to_remove = []
        for element in root.iter():
            tag_name = etree.QName(element.tag).localname
    
            # Remove disallowed elements
            if tag_name not in self.constraints.allowed_elements:
                elements_to_remove.append(element)
                continue  # Skip attribute checks for removed elements
    
            # Remove disallowed attributes
            attrs_to_remove = []
            for attr in element.attrib:
                attr_name = etree.QName(attr).localname
                if (
                    attr_name
                    not in self.constraints.allowed_elements[tag_name]
                    and attr_name
                    not in self.constraints.allowed_elements['common']
                ):
                    attrs_to_remove.append(attr)
    
            for attr in attrs_to_remove:
                logging.debug(
                    'Attribute "%s" for element "%s" not allowed. Removing.',
                    attr,
                    tag_name,
                )
                del element.attrib[attr]
    
            # Check and remove invalid href attributes
            for attr, value in element.attrib.items():
                 if etree.QName(attr).localname == 'href' and not value.startswith('#'):
                    logging.debug(
                        'Removing invalid href attribute in element "%s".', tag_name
                    )
                    del element.attrib[attr]

            # Validate path elements to help ensure SVG conversion
            if tag_name == 'path':
                d_attribute = element.get('d')
                if not d_attribute:
                    logging.warning('Path element is missing "d" attribute. Removing path.')
                    elements_to_remove.append(element)
                    continue # Skip further checks for this removed element
                # Use regex to validate 'd' attribute format
                path_regex = re2.compile(
                    r'^'  # Start of string
                    r'(?:'  # Non-capturing group for each command + numbers block
                    r'[MmZzLlHhVvCcSsQqTtAa]'  # Valid SVG path commands (adjusted to exclude extra letters)
                    r'\s*'  # Optional whitespace after command
                    r'(?:'  # Non-capturing group for optional numbers
                    r'-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?'  # First number
                    r'(?:[\s,]+-?\d+(?:\.\d+)?(?:[Ee][+-]?\d+)?)*'  # Subsequent numbers with mandatory separator(s)
                    r')?'  # Numbers are optional (e.g. for Z command)
                    r'\s*'  # Optional whitespace after numbers/command block
                    r')+'  # One or more command blocks
                    r'\s*'  # Optional trailing whitespace
                    r'$'  # End of string
                )
                if not path_regex.match(d_attribute):
                    logging.warning(
                        'Path element has malformed "d" attribute format. Removing path.'
                    )
                    elements_to_remove.append(element)
                    continue
                logging.debug('Path element "d" attribute validated (regex check).')
        
        # Remove elements marked for removal
        for element in elements_to_remove:
            if element.getparent() is not None:
                element.getparent().remove(element)
                logging.debug('Removed element: %s', element.tag)

        try:
            cleaned_svg_string = etree.tostring(root, encoding='unicode')
            return cleaned_svg_string
        except ValueError as e:
            logging.error(
                'SVG could not be sanitized to meet constraints: %s', e
            )
            return self.default_svg

The following code tests the above model in a local mock-up of this competition's evaluation pipeline. It runs the model on a sample of 15 instances defined in the `test.csv` file in the `kaggle_evaluation` package folder.

Alternatively, you could use the code below to run the model over `train.csv` and see some generated images along with some debugging info. Feel free to turn down the logging level to `INFO` if you just want to see the images.

In [4]:
import polars as pl
pl.Config.set_fmt_str_lengths(100)
train = pl.read_csv('/kaggle/input/drawing-with-llms/train.csv')
train

id,description
str,str
"""04c411""","""a starlit night over snow-covered peaks"""
"""215136""","""black and white checkered pants"""
"""3e2bc6""","""crimson rectangles forming a chaotic grid"""
"""61d7a8""","""burgundy corduroy pants with patch pockets and silver buttons"""
"""6f2ca7""","""orange corduroy overalls"""
…,…
"""bf3306""","""magenta trapezoids layered on a transluscent silver sheet"""
"""e2240f""","""gray wool coat with a faux fur collar"""
"""f02e39""","""a purple forest at dusk"""
"""f6790a""","""purple pyramids spiraling around a bronze cone"""


In [5]:
model = Model()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0


In [6]:
import kaggle_evaluation

logging.basicConfig(level=logging.INFO, force=True)
kaggle_evaluation.test(Model)

Creating Model instance...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0


Running inference tests...


The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...
INFO:root:Sanitizing SVG...


Wrote test submission file to "/tmp/kaggle-evaluation-submission-6lx437_7.csv".
Success!


In [7]:
N_SAMPLES = 3

def generate():
    import polars as pl
    from IPython.display import SVG
    import time  # Import the time module
    
    logging.basicConfig(level=logging.DEBUG, force=True)
    
    train = pl.read_csv('/kaggle/input/drawing-with-llms/train.csv')[:N_SAMPLES]
    display(train.head())
    svgs = []
    for desc in train.get_column('description'):
        start_time = time.time()  # Record start time
        svg = model.predict(desc)
        end_time = time.time()    # Record end time
        elapsed_time = end_time - start_time # Calculate elapsed time
        print(f"Prediction time for description '{desc[:20]}...': {elapsed_time:.4f} seconds") # Print time
    
        try:
            display(SVG(svg))
        except Exception as e:
            print(e)
            continue

# Uncomment and run the line below to see some generated images
# generate()