In [2]:
import os
from dotenv import load_dotenv


load_dotenv()

GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")

In [3]:
from google import genai
from google.genai import types

client = genai.Client(api_key=GOOGLE_API_KEY)

In [4]:
MODEL_ID = "gemini-2.5-flash-image-preview"

from IPython.display import display, Markdown, Image
import pathlib

# Loop over all parts and display them either as text or images
def display_response(response):
    for part in response.parts:
        if part.text:
            display(Markdown(part.text))
        elif image:= part.as_image():
            display(image)
            # image.show() if not in a notebook

# Save the image
# If there are multiple ones, only the last one will be saved
def save_image(response, path):
    for part in response.parts:
        if image:= part.as_image():
            image.save(path)


In [44]:
from PIL import Image

from PIL import Image
import numpy as np
from sklearn.cluster import KMeans


def get_k_representatives(pixels: np.array, k: int = 5, sort: bool = True) -> list:
    """Given an array (1, N*(r,g,b)), compute the k cluster and then return the centroids
    which are representatives of the pixels group given."""
    k_means_cluster = KMeans(n_clusters=k, n_init=5).fit(pixels)
    centroids = k_means_cluster.cluster_centers_

    # Sort the centroids by their proximity to the origin
    if sort:
        centroids_distances = np.linalg.norm(centroids, axis=1)
        sorted_centroids_indices = np.argsort(centroids_distances)
        return centroids[sorted_centroids_indices]

    return centroids


def image_to_flat_rgb_array(img: Image.Image):
    # Ensure RGB mode
    if img.mode != "RGB":
        img = img.convert("RGB")

    # Convert to NumPy array (shape: H x W x 3)
    arr = np.array(img)

    # Reshape to flat list of (r,g,b) tuples
    return arr.reshape(-1, 3)


def get_base_img(height: int, width: int) -> Image.Image:
    return Image.new("RGB", (width, height), color="black")


def build_color_palette_img(dims: tuple, colors: np.array):
    slice_width = dims[1] // colors.shape[0]
    base_img = get_base_img(height=dims[0], width=(colors.shape[0] * slice_width))


    for idx, color in enumerate(colors):
        pixel = tuple(map(int, color))

        for y in range(0, dims[0]):
            start = idx*slice_width
            for x in range(start, start + slice_width):
                base_img.putpixel(xy=(x, y), value=pixel)

    return base_img


def generate_color_palette(reference: Image.Image, target: Image.Image):
    flatten_img = image_to_flat_rgb_array(reference)
    k_representative_colors = get_k_representatives(flatten_img, k=8, sort=True)

    color_palette_img = build_color_palette_img(
        (target.height, target.width), colors=k_representative_colors
    )
    return color_palette_img

In [None]:
from pathlib import Path

base_path = Path("images")
img = Image.open(base_path / 'matrix.jpg')
target = Image.open(base_path / 'matrix.jpg')

print(img.height, img.width)

#c = generate_color_palette(x, x.height, x.width)
display(generate_color_palette(img, target))

In [None]:
img = Image.open(base_path / 'output.png')
target = Image.open(base_path / 'matrix.jpg')

print(img.height, img.width)

#c = generate_color_palette(x, x.height, x.width)
display(generate_color_palette(img, target))

In [None]:
text_prompt = (
    "Transfer the color style into the target image using the provided color palette image"
    " Ensure the recoloring is consistent and harmonious across the entire image, preserving details"
)

text_prompt = (
    "Recolor the target image using the provided color palette as reference." 
    " Apply the palette consistently across the entire image." 
    " Preserve details, textures, and natural shading while maintaining overall harmony and balance."
)

reference = Image.open(base_path / 'matrix.jpg')
target = Image.open(base_path / 'target2.jpeg')
palette_img = generate_color_palette(reference, target)

print("color palette extracted...")


response = client.models.generate_content(
    model=MODEL_ID,
    contents=[
        text_prompt,
        # the orden matters (idk why...)
        target,
        palette_img,
    ],
    config=types.GenerateContentConfig(
        response_modalities=['Text', 'Image'],
        top_p=0.7,
    ),
)

display_response(response)