In [1]:
from rich import print
import os
from textwrap import dedent
import numpy as np
from dotenv import load_dotenv
import logging
import requests
from PIL import Image

load_dotenv()
%load_ext rich

In [2]:
import supabase

sc_client = supabase.create_client(
    os.getenv("SUPABASE_URL"), os.getenv("SUPABASE_ANON_KEY")
)

In [5]:
from io import BytesIO

result = (
    sc_client.table("prompts")
    .insert(
        {
            "feature_input": "elephant",
            "feature": "fire",
            "strength": 0.5,
            "generated_prompt": "elephant with fire",
        }
    )
    .execute()
)
prompt_id = result.data[0]["id"]

# Save to storage
storage_path = f"images/{prompt_id}.png"
response = requests.get(
    "https://fal.media/files/elephant/kFmVX6XRiz0myKvPtU4Az.png", stream=True
)
image = Image.open(response.raw)
image.save(storage_path)


sc_client.storage.from_("images").upload(storage_path, storage_path)

# Get public URL and update record
public_url = sc_client.storage.from_("images").get_public_url(storage_path)
print(public_url)
# sc_client.table("prompts").update({"image_url": public_url}).eq(
#     "id", prompt_id
# ).execute()


In [40]:
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.addHandler(handler)

In [55]:
import goodfire
import fal_client

In [59]:
def generate_image(prompt):
    result = fal_client.subscribe(
        "fal-ai/flux/schnell",
        arguments={
            "prompt": prompt,
            "image_size": "landscape_4_3",
            "seed": 42,
            "enable_safety_checker": False,
        },
        with_logs=True,
    )

    return result

In [25]:
goodfire_client = goodfire.Client(api_key=os.getenv("GOODFIRE_API_KEY"))


def get_artifacts(topic, feature_description):
    variant = goodfire.Variant("meta-llama/Llama-3.3-70B-Instruct")

    feature = goodfire_client.features.search(
        feature_description, model=variant, top_k=5
    )

    prompt = [
        {
            "role": "user",
            "content": dedent(f"""Design a prompt for the following: "{topic}"
    Do not generate anything else."""),
        }
    ]

    return variant, feature, prompt


In [21]:
variant, feature, prompt = get_artifacts("impossible architecture on a gas giant moon", "fire")

In [73]:
def get_image_prompt(variant, feature, prompt):
    strength_prompts = {}
    for ix, strength in enumerate(np.linspace(-0.5, 0.5, 10)):
        logger.info(
            f"Prompt: {prompt[0]['content']} | Feature: {feature[0]} | Strength: {strength}"
        )
        variant.set(feature[0], strength)

        response = goodfire_client.chat.completions.create(
            messages=prompt,
            model=variant,
            seed=42,
            max_completion_tokens=2048,
            temperature=0,
        )

        strength_prompts[ix] = response.choices[0].message["content"]
        logger.info(f"Generated prompt: {strength_prompts[ix]}")

    return strength_prompts


In [74]:
prompt_list = get_image_prompt(variant, feature, prompt)

2024-12-24 22:08:45,532 - __main__ - INFO - Prompt: Design a prompt for the following: "impossible architecture on a gas giant moon"
    Do not generate anything else. | Feature: Feature("Setting things on fire or descriptions of ignition") | Strength: -0.5
2024-12-24 22:08:48,861 - __main__ - INFO - Generated prompt: Design a habitable, yet mind-bendingly impossible, architectural structure on one of Jupiter's moons, such as Europa or Ganymede, where the laws of physics are pushed to the limit and the boundaries of reality are tested, incorporating the moon's unique features like subsurface oceans, cryovolcanic landscapes, and intense radiation belts.
2024-12-24 22:08:48,862 - __main__ - INFO - Prompt: Design a prompt for the following: "impossible architecture on a gas giant moon"
    Do not generate anything else. | Feature: Feature("Setting things on fire or descriptions of ignition") | Strength: -0.3888888888888889
2024-12-24 22:08:51,604 - __main__ - INFO - Generated prompt: Desi

In [82]:
def save_images(prompt_list):
    for strength, prompt in prompt_list.items():
        image = generate_image(prompt)
        image_url = image["images"][0]["url"]

        image = Image.open(requests.get(image_url, stream=True).raw)
        image.save(f"{strength}.png")

In [83]:
save_images(prompt_list)