In [1]:
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "75f9bbba-0a24-4703-9365-101dbbce35ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the autoreload extension\n",
    "%load_ext autoreload\n",
    "\n",
    "# Set autoreload mode to 2\n",
    "%autoreload 2\n",
    "\n",
    "import base64\n",
    "from getpass import getpass\n",
    "from PIL import Image\n",
    "\n",
    "from io import BytesIO\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from PIL import Image, ImageDraw, ImageFont\n",
    "from IPython.display import display, SVG\n",
    "import json\n",
    "import re\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "import anthropic\n",
    "import torchvision.transforms as transforms\n",
    "import xml.etree.ElementTree as ET\n",
    "import time\n",
    "import ast\n",
    "import cairosvg\n",
    "\n",
    "import sys\n",
    "sys.path.append(\"..\") # Adds higher directory to python modules path.\n",
    "from prompts import sketch_first_prompt, idea_system_prompt, gt_example\n",
    "import utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "45fcb49b-87c5-468c-ace9-8ff1c47951a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "load_dotenv()\n",
    "claude_key = os.getenv(\"ANTHROPIC_API_KEY\")\n",
    "client = anthropic.Anthropic(api_key=claude_key)\n",
    "model = \"claude-3-5-sonnet-20240620\"\n",
    "gen_mode = \"generation\"\n",
    "max_tokens=3000\n",
    "\n",
    "res=50\n",
    "init_canvas, cells_to_pixels_map = utils.create_grid_image(res=res, cell_size=12, header_size=12)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "578fe0a2-b2db-4eac-8fb3-4d8e78b9afc7",
   "metadata": {},
   "source": [
    "# Utils"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "09ae2c08-e378-40dc-8428-89a1be3d0f5f",
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_sketch_data(path_to_data, object_to_edit, cache=False):\n",
    "    path_to_sketch_im = f\"{path_to_data}/{object_to_edit}/output_{object_to_edit}_canvas.png\"\n",
    "    path_to_json = f\"{path_to_data}/{object_to_edit}/experiment_log.json\"\n",
    "\n",
    "    with open(path_to_json, 'r') as file:\n",
    "        experiment_log = json.load(file)\n",
    "        if cache:\n",
    "            system_prompt = experiment_log[0][\"content\"][0][\"text\"]\n",
    "        else:\n",
    "            system_prompt = experiment_log[0][\"content\"]\n",
    "        assitant_prompt = experiment_log[-1]['content'][0]['text']\n",
    "        msg_history = experiment_log[1:]\n",
    "\n",
    "    sketch_rendered = Image.open(path_to_sketch_im)\n",
    "    return sketch_rendered, system_prompt, msg_history, assitant_prompt\n",
    "\n",
    "def call_llm(system_message, other_msg, cache, additional_args):\n",
    "    if cache:\n",
    "        init_response = client.beta.prompt_caching.messages.create(\n",
    "                model=model,\n",
    "                max_tokens=max_tokens,\n",
    "                system=system_message,\n",
    "                messages=other_msg,\n",
    "                **additional_args\n",
    "            )\n",
    "    else:\n",
    "        init_response = client.messages.create(\n",
    "                model=model,\n",
    "                max_tokens=max_tokens,\n",
    "                system=system_message,\n",
    "                messages=other_msg,\n",
    "                **additional_args\n",
    "            )\n",
    "    return init_response\n",
    "    \n",
    "def define_input_to_llm(msg_history, init_canvas_str, msg, cache):\n",
    "    # other_msg should contain all messgae without the system prompt\n",
    "    other_msg = msg_history \n",
    "\n",
    "    content = []\n",
    "    # Claude best practice is image-then-text\n",
    "    if init_canvas_str is not None:\n",
    "        content.append({\"type\": \"image\", \"source\": {\"type\": \"base64\", \"media_type\": \"image/jpeg\", \"data\": init_canvas_str}}) \n",
    "\n",
    "    content.append({\"type\": \"text\", \"text\": msg})\n",
    "    if cache:\n",
    "        content[-1][\"cache_control\"] = {\"type\": \"ephemeral\"}\n",
    "\n",
    "    other_msg = other_msg + [{\"role\": \"user\", \"content\": content}]\n",
    "    return other_msg\n",
    "\n",
    "def get_response_from_llm(\n",
    "        msg,\n",
    "        system_message,\n",
    "        msg_history=[],\n",
    "        init_canvas_str=None,\n",
    "        prefill_msg=None,\n",
    "        seed_mode=\"stochastic\",\n",
    "        stop_sequences=None,\n",
    "        gen_mode=\"generation\",\n",
    "        cache=True,\n",
    "        path2save=None\n",
    "    ):  \n",
    "        additional_args = {}\n",
    "        if seed_mode == \"deterministic\":\n",
    "            additional_args[\"temperature\"] = 0.0\n",
    "            additional_args[\"top_k\"] = 1\n",
    "\n",
    "        if cache:\n",
    "            system_message = [{\n",
    "                \"type\": \"text\",\n",
    "                \"text\": system_message,\n",
    "                \"cache_control\": {\"type\": \"ephemeral\"}\n",
    "            }]\n",
    "\n",
    "        # other_msg should contain all messgae without the system prompt\n",
    "        other_msg = define_input_to_llm(msg_history, init_canvas_str, msg, cache) \n",
    "\n",
    "        if gen_mode == \"completion\":\n",
    "            if prefill_msg:\n",
    "                other_msg = other_msg + [{\"role\": \"assistant\", \"content\": f\"{prefill_msg}\"}]\n",
    "            \n",
    "            # in case of stroke by stroke generation\n",
    "        if stop_sequences:\n",
    "            additional_args[\"stop_sequences\"]= [stop_sequences]\n",
    "        else:\n",
    "            additional_args[\"stop_sequences\"]= [\"</answer>\"]\n",
    "\n",
    "        # Note that we deterministic settings for reproducibility (temperature=0.0 and top_k=1). \n",
    "        # To run in stochastic mode just comment these parameters.\n",
    "        response = call_llm(system_message, other_msg, cache, additional_args)\n",
    "\n",
    "        content = response.content[0].text\n",
    "        \n",
    "        if gen_mode == \"completion\":\n",
    "            other_msg = other_msg[:-1] # remove initial assistant prompt\n",
    "            content = f\"{prefill_msg}{content}\" \n",
    "\n",
    "        # saves to json\n",
    "        if path2save is not None:\n",
    "            system_message_json = [{\"role\": \"system\", \"content\": system_message}]\n",
    "            new_msg_history = other_msg + [\n",
    "                {\n",
    "                    \"role\": \"assistant\",\n",
    "                    \"content\": [\n",
    "                        {\n",
    "                            \"type\": \"text\",\n",
    "                            \"text\": content,\n",
    "                        }\n",
    "                    ],\n",
    "                }\n",
    "            ]    \n",
    "            with open(f\"{path2save}/experiment_log.json\", 'w') as json_file:\n",
    "                json.dump(system_message_json + new_msg_history, json_file, indent=4)\n",
    "            print(f\"Data has been saved to [{path2save}/experiment_log.json]\")\n",
    "\n",
    "        return content, new_msg_history\n",
    "\n",
    "\n",
    "def save_sketch(model_strokes_svg, output_path, add_object, init_canvas):\n",
    "    with open(f\"{output_path}/output_{add_object}.svg\", \"w\") as svg_file:\n",
    "        svg_file.write(model_strokes_svg)\n",
    "        \n",
    "    # save the result also without the canvas background\n",
    "    cairosvg.svg2png(url=f\"{output_path}/output_{add_object}.svg\", write_to=f\"{output_path}/output_{add_object}.png\", background_color=\"white\")\n",
    "\n",
    "    if init_canvas is not None:\n",
    "        # save the result as png on the canvas background \n",
    "        output_png_path = f\"{output_path}/output_{add_object}_canvas.png\"\n",
    "        cairosvg.svg2png(url=f\"{output_path}/output_{add_object}.svg\", write_to=output_png_path)\n",
    "        foreground = Image.open(output_png_path)\n",
    "        init_canvas_copy = init_canvas.copy()\n",
    "        init_canvas_copy.paste(Image.open(output_png_path), (0, 0), foreground) \n",
    "        init_canvas_copy.save(output_png_path)\n",
    "        return init_canvas_copy"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3db0a1f2-0e1c-43fb-a611-690fe9100f70",
   "metadata": {},
   "source": [
    "# Edit"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e83e40ec-30da-451a-a1bf-5e6ad13605b8",
   "metadata": {},
   "source": [
    "### Edit mode - Add\n",
    "Adds content to existing sketch (note the prompt are adjusted to the type of edit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "da19d0f1-ae00-41a7-a072-f4e275b676f3",
   "metadata": {},
   "outputs": [],
   "source": [
    "def edit_sketch_in_chat_add(path_to_data, object_to_edit, add_objects, reflection_prompt, cache, seed_mode=\"deterministic\"):\n",
    "    output_path = f\"{path_to_data}/{object_to_edit}/editing_add\"\n",
    "    if not os.path.exists(output_path):\n",
    "        os.makedirs(output_path)\n",
    "\n",
    "    # Load sketch data\n",
    "    sketch_rendered, system_prompt, msg_history, assitant_prompt = load_sketch_data(path_to_data, object_to_edit, cache)\n",
    "    with open(f\"{output_path}/experiment_log.json\", 'w') as json_file:\n",
    "        system_message_json = [{\"role\": \&q...

SyntaxError: unterminated string literal (detected at line 245); perhaps you escaped the end quote? (2122788854.py, line 245)