In [None]:
import gradio as gr
import logging
import math
import numpy as np
import ollama
import time
import torch

from diffusers import AutoPipelineForInpainting
from PIL import Image
from segment_anything import SamPredictor, sam_model_registry

In [1]:
from huggingface_hub import hf_hub_download
# Загрузка Segment Anything Model (лёгкая модель)
chkpt_path = hf_hub_download(
    repo_id="ybelkada/segment-anything",
    filename="checkpoints/sam_vit_b_01ec64.pth"
)
print(f"Checkpoint downloaded to: {chkpt_path}")

Checkpoint downloaded to: C:\Users\ADMIN\.cache\huggingface\hub\models--ybelkada--segment-anything\snapshots\7790786db131bcdc639f24a915d9f2c331d843ee\checkpoints\sam_vit_b_01ec64.pth


In [2]:
from huggingface_hub import login
login(token=hf_token)

In [2]:
# Настройка логирования
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler(),  # Вывод в консоль
        logging.FileHandler("app.log")  # Сохранение в файл
    ]
)
logger = logging.getLogger(__name__)

In [3]:
# Загрузка моделей
def load_models():
    logger.info("Loading models...")
    start_time = time.time()
    
    # Stable Diffusion XL Inpainting
    inpaint_pipe = AutoPipelineForInpainting.from_pretrained(
        "diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
        torch_dtype=torch.float16
    ).to("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Inpainting model loaded in {time.time() - start_time:.2f} seconds")
    
    # Segment Anything Model (SAM)
    sam_checkpoint = "./checkpoints/sam_vit_b_01ec64.pth"
    start_time = time.time()
    sam = sam_model_registry["vit_b"](checkpoint=sam_checkpoint)
    sam_predictor = SamPredictor(sam)
    logger.info(f"SAM model loaded in {time.time() - start_time:.2f} seconds")
    
    return inpaint_pipe, sam_predictor

inpaint_pipe, sam_predictor = load_models()

2025-07-09 13:06:45,435 - INFO - Loading models...


Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

The config attributes {'decay': 0.9999, 'inv_gamma': 1.0, 'min_decay': 0.0, 'optimization_step': 37000, 'power': 0.6666666666666666, 'update_after_step': 0, 'use_ema_warmup': False} were passed to UNet2DConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.
2025-07-09 13:06:55,083 - INFO - Inpainting model loaded in 9.65 seconds
2025-07-09 13:06:55,591 - INFO - SAM model loaded in 0.51 seconds


In [9]:
# Функция для получения размера, кратного 64
def get_compatible_size(width, height):
    max_size = 2048  # Ограничение для 12 ГБ VRAM
    target_size = min(max(width, height), max_size)
    target_size = math.ceil(target_size / 64) * 64
    return target_size, target_size

# Функция для улучшения промпта с помощью gemma-3n-e4b-it через Ollama
def enhance_prompt(user_prompt, temperature, max_tokens):
    logger.info(f"Enhancing prompt: {user_prompt}")
    start_time = time.time()
    
    torch.cuda.empty_cache()
    instruction = (
        f"""You are an expert in generating detailed prompts for image editing, specializing in aircraft liveries. 
        Take the following user prompt and enhance it by adding specific details like texture, style, and aviation context. 
        Make it concise, realistic, and suitable for Stable Diffusion inpainting.
        Also get me negative prompt because it's important for inpainting.
        User prompt: {user_prompt}"""
    )

    try:
        response = ollama.generate(
            model="gemma3n",
            prompt=instruction,
            options={
                "temperature": temperature,
                "max_tokens": max_tokens
            }
        )
        enhanced_prompt = response["response"].strip()
        
        logger.info(f"Enhanced prompt: {enhanced_prompt}")
        logger.info(f"Prompt enhancement completed in {time.time() - start_time:.2f} seconds")
        return enhanced_prompt
    except Exception as e:
        logger.error(f"Error in Ollama prompt enhancement: {str(e)}")
        raise gr.Error(f"Failed to enhance prompt with Ollama: {str(e)}")

# Функция для автоматической сегментации (SAM) с несколькими точками
def segment_aircraft(image, click_points):
    logger.info(f"Starting segmentation with points: {click_points}")
    start_time = time.time()
    
    if not click_points:
        logger.error("No points provided for SAM")
        raise gr.Error("Please click on the image to select at least one point")
    
    image_np = np.array(image)
    sam_predictor.set_image(image_np)
    
    # Подготовка координат и меток
    point_coords = np.array(click_points)
    point_labels = np.ones(len(click_points))  # Все точки положительные
    
    masks, _, _ = sam_predictor.predict(
        point_coords=point_coords,
        point_labels=point_labels,
        multimask_output=False
    )
    
    # Комбинируем маски с помощью логического OR
    combined_mask = np.any(masks, axis=0).astype(np.uint8) * 255
    mask_image = Image.fromarray(combined_mask)
    
    logger.info(f"Segmentation completed in {time.time() - start_time:.2f} seconds")
    return mask_image

# Функция для обработки клика по изображению
def get_click_coordinates(image, evt: gr.SelectData, click_points):
    if image is None:
        logger.error("No image provided")
        return None, [], "Please upload an image"
    
    x, y = evt.index[0], evt.index[1]
    orig_width, orig_height = image.size
    
    # Добавляем новую точку в список
    click_points = click_points or []
    click_points.append([x, y])
    
    logger.info(f"Image clicked at coordinates ({x}, {y}) on size ({orig_width}, {orig_height})")
    logger.info(f"Current click points: {click_points}")
    
    return (orig_width, orig_height), click_points, f"Selected points: {click_points}"

# Функция для очистки списка точек
def clear_click_points():
    logger.info("Clearing click points")
    return [], "No points selected"

# Функция для предпросмотра маски в режиме SAM
def preview_mask(image, click_points, orig_size):
    logger.info(f"Previewing mask with points: {click_points}")
    if image is None:
        logger.error("No image provided")
        raise gr.Error("Please upload an image")
    if not click_points:
        logger.error("No points provided for SAM")
        raise gr.Error("Please click on the image to select at least one point")
    
    orig_width, orig_height = image.size if orig_size is None else orig_size
    proc_width, proc_height = get_compatible_size(orig_width, orig_height)
    proc_image = image.resize((proc_width, proc_height))
    
    scale_x = proc_width / orig_width
    scale_y = proc_height / orig_height
    adjusted_points = [[x * scale_x, y * scale_y] for x, y in click_points]
    logger.info(f"Adjusted coordinates for SAM preview: {adjusted_points}")
    
    mask = segment_aircraft(proc_image, adjusted_points)
    mask = mask.resize((orig_width, orig_height))
    return mask

# Функция для извлечения маски
def extract_mask_from_composite(background, composite):
    logger.info("Extracting mask from composite and background")
    start_time = time.time()
    
    # Преобразуем numpy.ndarray в PIL.Image, если нужно
    if isinstance(background, np.ndarray):
        background = Image.fromarray(background).convert("RGB")
    if isinstance(composite, np.ndarray):
        composite = Image.fromarray(composite).convert("RGB")
    
    # Преобразуем в numpy массивы для вычисления разницы
    bg_array = np.array(background.convert("RGB"))
    comp_array = np.array(composite.convert("RGB"))
    
    # Проверяем, совпадают ли размеры
    if bg_array.shape != comp_array.shape:
        logger.error(f"Shape mismatch: background {bg_array.shape}, composite {comp_array.shape}")
        raise gr.Error("Background and composite images have different sizes.")
    
    # Вычисляем разницу между composite и background
    diff = np.abs(comp_array - bg_array)
    # Суммируем разницу по каналам RGB и создаём бинарную маску
    diff_sum = np.sum(diff, axis=2)
    mask_array = (diff_sum > 0).astype(np.uint8) * 255
    
    # Проверяем, содержит ли маска ненулевые пиксели
    if mask_array.max() == 0:
        logger.error("Extracted mask is empty (no differences between composite and background)")
        raise gr.Error("The drawn mask is empty. Please draw a white or colored area on the image.")
    
    mask = Image.fromarray(mask_array).convert("L")
    logger.info(f"Mask extraction completed in {time.time() - start_time:.2f} seconds")
    logger.info(f"Extracted mask shape: {mask_array.shape}, max value: {mask_array.max()}")
    return mask

# Функция для редактирования изображения
def edit_aircraft(
        image, 
        click_points, 
        orig_size, 
        manual_mask, mask_mode, 
        user_prompt, negative_prompt, 
        enhanced_prompt, 
        temperature, max_tokens, num_inference_steps, guidance_scale, blur_factor
    ):
    logger.info(f"Starting image editing with user prompt: \
                {user_prompt}, enhanced prompt: \
                {enhanced_prompt}, mask mode: {mask_mode}")
    logger.info(f"Click points: {click_points}, \
                Manual mask from ImageEditor: {manual_mask}")
    logger.info(f"Hyperparameters: temperature={temperature}, \
                max_tokens={max_tokens}, num_inference_steps={num_inference_steps}, \
                guidance_scale={guidance_scale}, blur_factor={blur_factor}")
    start_time = time.time()
    
    if image is None:
        logger.error("No image provided")
        raise ValueError("Please upload an image")
    
    orig_width, orig_height = image.size if orig_size is None else orig_size
    logger.info(f"Original image size: {orig_width}x{orig_height}")
    
    # Используем enhanced_prompt, если он не пустой, иначе user_prompt
    final_prompt = enhanced_prompt if enhanced_prompt else user_prompt
    if not final_prompt:
        logger.warning("No prompt provided, using default")
        final_prompt = "realistic aircraft livery, high detail"
    logger.info(f"Using prompt for inpainting: {final_prompt}")
    
    proc_width, proc_height = get_compatible_size(orig_width, orig_height)
    logger.info(f"Processing image at size: {proc_width}x{proc_height}")
    
    proc_image = image.resize((proc_width, proc_height))
    
    mask_start = time.time()
    if mask_mode == "Point-based (SAM)":
        if not click_points:
            logger.error("No points selected for SAM")
            raise gr.Error("Please click on the image to select at least one point")
        
        scale_x = proc_width / orig_width
        scale_y = proc_height / orig_height
        adjusted_points = [[x * scale_x, y * scale_y] for x, y in click_points]
        logger.info(f"Adjusted coordinates for SAM: {adjusted_points}")
        
        mask = segment_aircraft(proc_image, adjusted_points)
        display_mask = mask  # Маска для отображения
    else:  # Manual (Sketch)
        # Извлекаем маску из composite и background
        mask = extract_mask_from_composite(manual_mask["background"], manual_mask["composite"])
        display_mask = mask  # Для отображения
        mask = mask.resize((proc_width, proc_height))  # Для inpainting
        logger.info("Using manually drawn mask")
    
    # Применяем размытие к маске для более плавных переходов
    mask = inpaint_pipe.mask_processor.blur(mask, blur_factor=blur_factor)
    logger.info(f"Applied blur to mask with blur_factor={blur_factor}")

    logger.info(f"Mask generated in {time.time() - mask_start:.2f} seconds")
    
    logger.info("Starting inpainting...")
    inpaint_start = time.time()
    torch.cuda.empty_cache()  # Очистка памяти GPU

    def progress_callback(step, timestep, latents):
        logger.info(f"Inpainting step {step}/{num_inference_steps}") 

    edited_image = inpaint_pipe(
        prompt=enhanced_prompt,
        image=proc_image,
        mask_image=mask,
        guidance_scale=guidance_scale, 
        num_inference_steps=num_inference_steps,
        negative_prompt=negative_prompt if negative_prompt \
            else "blurry, low quality, distorted, unrealistic, cartoonish, low detail",
        callback=progress_callback,
        callback_steps=10 # Количество шагов для прогресс-бара
    ).images[0]
    logger.info(f"Inpainting completed in {time.time() - inpaint_start:.2f} seconds")
    
    # Масштабирование результата обратно к оригинальному размеру
    edited_image = edited_image.resize((orig_width, orig_height))
    display_mask = display_mask.resize((orig_width, orig_height))
    logger.info(f"Resized output and mask to original size: {orig_width}x{orig_height}")

    logger.info(f"Total editing time: {time.time() - start_time:.2f} seconds")
    return edited_image, display_mask, enhanced_prompt

In [None]:
# Gradio интерфейс
with gr.Blocks() as demo:
    gr.Markdown("# Aircraft Paint Editor")
    gr.Markdown("Upload an aircraft image, choose a mask mode (point-based or manual), \
                click or draw on the image, and describe the new paint or livery.")
    gr.Markdown("Example prompts: 'Glossy black fuselage with gold horizontal stripes, \
                realistic aviation paint' or 'Military green camouflage with gray accents, matte finish'")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Aircraft Image", type="pil", interactive=True)
            mask_mode = gr.Radio(
                choices=["Point-based (SAM)", "Manual (Sketch)"],
                label="Mask Selection Mode",
                value="Manual (Sketch)"
            )
            click_coords = gr.Textbox(label="Selected Point Coordinates (for SAM)", value="No points selected")
            clear_points_button = gr.Button("Clear Points")
            preview_mask_button = gr.Button("Preview Mask", visible=False)
            manual_mask_input = gr.ImageEditor(label="Draw Mask", interactive=True, visible=True)
            user_prompt = gr.Textbox(label="Describe new paint/livery (e.g., 'blue fuselage with red stripes')")
            negative_prompt = gr.Textbox(label="Negative Prompt (e.g., 'blurry, low quality, unrealistic')")
            enhanced_prompt = gr.Textbox(label="Enhanced Prompt (editable)", interactive=True)
            with gr.Group():
                gr.Markdown("### Prompt Enhancement Parameters")
                temperature = gr.Slider(label="Temperature (Gemma creativity)", 
                                        minimum=0.1, maximum=1.5, step=0.1, value=0.7)
                max_tokens = gr.Slider(label="Max Tokens (Gemma prompt length)", 
                                       minimum=50, maximum=500, step=10, value=100)
            with gr.Group():
                gr.Markdown("### Inpainting Parameters")
                num_inference_steps = gr.Slider(label="Inference Steps (Diffusion quality)", 
                                                minimum=20, maximum=100, step=1, value=30)
                guidance_scale = gr.Slider(label="Guidance Scale (Prompt adherence)", 
                                           minimum=5, maximum=20, step=1, value=10)
                blur_factor = gr.Slider(label="Mask Blur Factor (Edge smoothness)", 
                                        minimum=0, maximum=100, step=1, value=50)
            enhance_button = gr.Button("Enhance Prompt")
            edit_button = gr.Button("Edit")
        with gr.Column():
            mask_output = gr.Image(label="Generated Mask")
            edited_output = gr.Image(label="Edited Aircraft")
    
    # Состояния для хранения данных
    click_points_state = gr.State(value=[])  # Список координат
    orig_size_state = gr.State()
    
    # Обновление видимости manual_mask_input и preview_mask_button
    def update_mask_input_visibility(mode):
        is_manual = mode == "Manual (Sketch)"
        return gr.update(visible=is_manual), gr.update(visible=not is_manual)
    
    mask_mode.change(
        fn=update_mask_input_visibility,
        inputs=mask_mode,
        outputs=[manual_mask_input, preview_mask_button]
    )
    
    # Обработка клика по изображению
    image_input.select(
        fn=get_click_coordinates,
        inputs=[image_input, click_points_state],
        outputs=[orig_size_state, click_points_state, click_coords]
    )
    
    # Очистка точек
    clear_points_button.click(
        fn=clear_click_points,
        inputs=None,
        outputs=[click_points_state, click_coords]
    )
    
    # Предпросмотр маски
    preview_mask_button.click(
        fn=preview_mask,
        inputs=[image_input, click_points_state, orig_size_state],
        outputs=mask_output
    )

    # Улучшение промпта
    enhance_button.click(
        fn=enhance_prompt,
        inputs=user_prompt,
        outputs=enhanced_prompt
    )
    
    # Запуск редактирования
    edit_button.click(
        fn=edit_aircraft,
        inputs=[
            image_input, click_points_state, orig_size_state, 
            manual_mask_input, mask_mode, 
            user_prompt, negative_prompt, enhanced_prompt,
            temperature, max_tokens,
            num_inference_steps, guidance_scale, blur_factor
        ],
        outputs=[edited_output, mask_output, enhanced_prompt]
    )

# Запуск интерфейса
logger.info("Launching Gradio interface...")
demo.launch()

2025-07-09 14:21:51,097 - INFO - Launching Gradio interface...
2025-07-09 14:21:51,325 - INFO - HTTP Request: GET http://127.0.0.1:7865/gradio_api/startup-events "HTTP/1.1 200 OK"
2025-07-09 14:21:51,357 - INFO - HTTP Request: HEAD http://127.0.0.1:7865/ "HTTP/1.1 200 OK"


* Running on local URL:  http://127.0.0.1:7865
* To create a public link, set `share=True` in `launch()`.




2025-07-09 14:21:52,252 - INFO - HTTP Request: GET https://api.gradio.app/pkg-version "HTTP/1.1 200 OK"
2025-07-09 14:23:09,941 - INFO - Image clicked at coordinates (381, 279) on size (928, 448)
2025-07-09 14:23:09,943 - INFO - Current click points: [[381, 279]]
2025-07-09 14:23:12,431 - INFO - Previewing mask with points: [[381, 279]]
2025-07-09 14:23:12,447 - INFO - Adjusted coordinates for SAM preview: [[394.1379310344828, 597.8571428571429]]
2025-07-09 14:23:12,450 - INFO - Starting segmentation with points: [[394.1379310344828, 597.8571428571429]]
2025-07-09 14:23:23,416 - INFO - Segmentation completed in 10.96 seconds
2025-07-09 14:23:29,790 - INFO - Image clicked at coordinates (233, 176) on size (928, 448)
2025-07-09 14:23:29,791 - INFO - Current click points: [[381, 279], [233, 176]]
2025-07-09 14:24:02,617 - INFO - Enhancing prompt: Раскрась мне фюзеляж вертолёта
2025-07-09 14:25:08,320 - INFO - HTTP Request: POST http://127.0.0.1:11434/api/generate "HTTP/1.1 200 OK"
2025-07

  0%|          | 0/29 [00:00<?, ?it/s]

2025-07-09 14:26:29,266 - INFO - Inpainting step 0/30
2025-07-09 14:26:36,278 - INFO - Inpainting step 10/30
2025-07-09 14:26:43,306 - INFO - Inpainting step 20/30
2025-07-09 14:26:50,784 - INFO - Inpainting completed in 23.70 seconds
2025-07-09 14:26:50,812 - INFO - Resized output and mask to original size: 928x448
2025-07-09 14:26:50,813 - INFO - Total editing time: 34.76 seconds
2025-07-09 14:42:30,434 - INFO - Starting image editing with user prompt:                 Раскрась мне фюзеляж вертолёта, enhanced prompt:                 Detailed helicopter fuselage painting, modern military transport helicopter, camouflage scheme (digital camouflage pattern, woodland variant), highly realistic, weathered texture, subtle wear and tear, metallic paint sheen, intricate panel lines, sharp focus, dramatic lighting, volumetric light, photorealistic, 8k, octane render, detailed rivets,  high-resolution,  aircraft livery,  detailed engine intakes,  visible landing gear attachment points, mask mod

  0%|          | 0/29 [00:00<?, ?it/s]

2025-07-09 14:42:43,698 - INFO - Inpainting step 0/30
2025-07-09 14:42:50,951 - INFO - Inpainting step 10/30
2025-07-09 14:42:58,183 - INFO - Inpainting step 20/30
2025-07-09 14:43:05,887 - INFO - Inpainting completed in 24.38 seconds
2025-07-09 14:43:05,915 - INFO - Resized output and mask to original size: 928x448
2025-07-09 14:43:05,915 - INFO - Total editing time: 35.48 seconds
