In [None]:
!pip install -q torch torchvision torchaudio transformers

In [1]:
!pip install torchcodec

Collecting torchcodec
  Downloading torchcodec-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (11 kB)
Downloading torchcodec-0.9.1-cp312-cp312-manylinux_2_28_x86_64.whl (2.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchcodec
Successfully installed torchcodec-0.9.1


In [None]:
!pip install google-genai



In [2]:
import torch
import numpy as np
from PIL import Image
from typing import Dict, List, Tuple, Optional
import warnings

import tempfile
import cv2
from transformers import CLIPModel, CLIPProcessor
import json
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from PIL import Image
import cv2
import torch
from transformers import CLIPModel, CLIPProcessor
from transformers import AutoProcessor, LlavaForConditionalGeneration
import librosa

import io
import os
import time
import json
from pathlib import Path
from typing import Tuple
from PIL import Image
from tqdm import tqdm
import pandas as pd

from datasets import load_dataset
import torchcodec

warnings.filterwarnings('ignore')

In [2]:
import io
import json
import time
from typing import Tuple

import pandas as pd
from PIL import Image
from tqdm import tqdm
import soundfile as sf

from datasets import load_dataset
from google import genai
from google.genai import types


EVAL_PROMPT = """
You are a strict but nuanced evaluator of an image editing system.

INPUT:
1. AUDIO: Spoken instruction in Russian describing an edit
2. IMAGE BEFORE: Original image
3. IMAGE AFTER: Edited image

TASK:
Evaluate how well the edit matches the instruction.

IMPORTANT RULES:
- Do NOT treat this as a binary yes/no task
- Most real edits are imperfect
- Scores of exactly 0.0 or 1.0 should be VERY RARE
- Prefer values like 0.55, 0.68, 0.82, etc.

STEP 1: Understand the instruction
- Transcribe the audio in Russian
- Describe what change was requested

STEP 2: Compare BEFORE vs AFTER
- What changes actually happened?
- What parts of the image stayed the same?
- Are there artifacts or quality issues?

STEP 3: Score multiple aspects independently

Use the following scale for EACH sub-score:
- 0.0–0.2: Not present or completely wrong
- 0.3–0.4: Weak or incorrect
- 0.5–0.6: Partially correct
- 0.7–0.8: Mostly correct with issues
- 0.9–1.0: Almost perfect (rare)

SUB-SCORES:
- instruction_match: How well the edit matches the requested change
- visual_correctness: Is the change visually correct and localized
- preservation: Were unrelated regions preserved
- artifact_free: Lack of artifacts, realism, natural look

STEP 4: Compute final score
- alignment_score = average of the four sub-scores
- Round to 2 decimal places

RESPOND ONLY IN JSON:
{
  "transcription": "",
  "instruction_understood": "",
  "instruction_execution": "not_attempted | partial | complete",
  "scores": {
    "instruction_match": 0.0-1.0,
    "visual_correctness": 0.0-1.0,
    "preservation": 0.0-1.0,
    "artifact_free": 0.0-1.0
  },
  "alignment_score": 0.0-1.0,
  "changes_detected": "",
  "quality_assessment": "poor | acceptable | good",
  "issues": [],
  "explanation": ""
}
"""


class GeminiEvaluator:
    def __init__(
        self,
        api_key: str,
        model: str = "gemini-2.5-flash",
        delay_between_requests: float = 10.0,
    ):
        self.client = genai.Client(api_key=api_key)
        self.model = model
        self.delay = delay_between_requests

    @staticmethod
    def _image_to_bytes(img: Image.Image) -> bytes:
        if img.mode != "RGB":
            img = img.convert("RGB")
        buf = io.BytesIO()
        img.save(buf, format="JPEG", quality=85)
        return buf.getvalue()

    def evaluate(
        self,
        audio_bytes: bytes,
        image_before: Image.Image,
        image_after: Image.Image,
        audio_mime: str,
    ) -> Tuple[float, dict]:

        parts = [
            types.Part.from_bytes(
                data=audio_bytes,
                mime_type=audio_mime,
            ),
            types.Part.from_bytes(
                data=self._image_to_bytes(image_before),
                mime_type="image/jpeg",
            ),
            types.Part.from_bytes(
                data=self._image_to_bytes(image_after),
                mime_type="image/jpeg",
            ),
            EVAL_PROMPT,
        ]

        response = self.client.models.generate_content(
            model=self.model,
            contents=parts,
        )

        text = (
            response.text.strip()
            .removeprefix("```json")
            .removesuffix("```")
            .strip()
        )

        result = json.loads(text)
        score = float(result.get("alignment_score", 0.5))

        time.sleep(self.delay)
        return score, result




def extract_audio(audio) -> Tuple[bytes, str]:
    """
    Универсально извлекает аудио из HF Audio / torchcodec AudioDecoder / bytes
    """
    if isinstance(audio, bytes):
        return audio, "audio/mpeg"

    if isinstance(audio, dict):
        if audio.get("bytes"):
            path = audio.get("path", "")
            mime = "audio/mpeg" if path.endswith(".mp3") else "audio/wav"
            return audio["bytes"], mime

        if "array" in audio:
            buf = io.BytesIO()
            sf.write(
                buf,
                audio["array"],
                audio.get("sampling_rate", 16000),
                format="WAV",
            )
            return buf.getvalue(), "audio/wav"

        if audio.get("path"):
            with open(audio["path"], "rb") as f:
                return f.read(), "audio/wav"

    if hasattr(audio, "get_all_samples"):
        samples = audio.get_all_samples()
        buf = io.BytesIO()
        sf.write(buf, samples.data.numpy().flatten(), samples.sample_rate, format="WAV")
        return buf.getvalue(), "audio/wav"

    raise ValueError(f"Unknown audio format: {type(audio)}")



def to_pil(img) -> Image.Image:
    if isinstance(img, Image.Image):
        return img
    if isinstance(img, dict) and "bytes" in img:
        return Image.open(io.BytesIO(img["bytes"]))
    if isinstance(img, bytes):
        return Image.open(io.BytesIO(img))
    raise ValueError(f"Unknown image format: {type(img)}")


def process_sample(
    evaluator: GeminiEvaluator,
    image_id: str,
    source_item: dict,
    result_item: dict,
) -> dict:

    image_before = to_pil(
        source_item.get("INPUT_IMG")
        or source_item.get("input_img")
        or source_item.get("image")
    )

    image_after = to_pil(
        result_item.get("result_image")
        or result_item.get("output_image")
        or result_item.get("image")
    )

    audio_bytes, audio_mime = extract_audio(source_item["audio"])

    score, details = evaluator.evaluate(
        audio_bytes,
        image_before,
        image_after,
        audio_mime,
    )

    return {
        "IMAGE_ID": image_id,
        "score": score,
        **details,
    }


def evaluate_subset(
    source_dataset: str,
    result_dataset: str,
    gemini_api_key: str,
    num_samples: int = 10,
    save_path: str = "evaluation_results.csv",
):

    print("Streaming output dataset...")
    result_stream = load_dataset(
        result_dataset,
        split="train",
        streaming=True,
    )

    target_ids = []
    result_subset = []

    for item in result_stream:
        if "IMAGE_ID" in item:
            target_ids.append(item["IMAGE_ID"])
            result_subset.append(item)
        if len(target_ids) >= num_samples:
            break

    print(f"Collected {len(target_ids)} IMAGE_IDs")

    print("Streaming source dataset...")
    source_stream = load_dataset(
        source_dataset,
        split="train",
        streaming=True,
    )

    source_by_id = {}

    for item in source_stream:
        image_id = item.get("IMAGE_ID")
        if image_id in target_ids:
            source_by_id[image_id] = item
        if len(source_by_id) >= len(target_ids):
            break

    evaluator = GeminiEvaluator(api_key=gemini_api_key)

    rows = []

    for item in tqdm(result_subset, desc="Evaluating"):
        image_id = item["IMAGE_ID"]

        try:
            row = process_sample(
                evaluator,
                image_id,
                source_by_id[image_id],
                item,
            )
            rows.append(row)

        except Exception as e:
            rows.append({
                "IMAGE_ID": image_id,
                "score": 0.0,
                "error": str(e),
            })

    df = pd.DataFrame(rows)
    df.to_csv(save_path, index=False)
    print(f"Saved results to {save_path}")

    return df


if __name__ == "__main__":

    GEMINI_API_KEY = ""

    df = evaluate_subset(
        source_dataset="arood0/mmm_project_with_audio_ru_final",
        result_dataset="gab1k/mmm_project_gigaam",
        gemini_api_key=GEMINI_API_KEY,
        num_samples=10,
        save_path="mmm_project_gigaam.csv",
    )

    print(df[["IMAGE_ID", "score"]])


Streaming output dataset...
Collected 10 IMAGE_IDs
Streaming source dataset...


Evaluating: 100%|██████████| 10/10 [02:52<00:00, 17.25s/it]

Saved results to mmm_project_gigaam.csv
       IMAGE_ID  score
0   -bH_SxERgTA   0.93
1  000000072902   0.10
2  000000074945   0.88
3  000000063958   0.50
4  000000064314   0.41
5  000000065170   0.05
6  000000065057   0.50
7  000000065209   0.23
8  000000067663   0.00
9  000000067854   0.00





In [3]:
if __name__ == "__main__":

    GEMINI_API_KEY = ""

    df = evaluate_subset(
        source_dataset="arood0/mmm_project_with_audio_ru_final",
        result_dataset="gab1k/mmm_project_parakeet",
        gemini_api_key=GEMINI_API_KEY,
        num_samples=10,
        save_path="mmm_project_parakeet.csv",
    )

    print(df[["IMAGE_ID", "score"]])

Streaming output dataset...


README.md:   0%|          | 0.00/324 [00:00<?, ?B/s]

Collected 10 IMAGE_IDs
Streaming source dataset...


Evaluating: 100%|██████████| 10/10 [03:01<00:00, 18.10s/it]

Saved results to mmm_project_parakeet.csv
       IMAGE_ID  score
0   -bH_SxERgTA   0.00
1  000000072902   0.50
2  000000074945   0.50
3  000000063958   0.50
4  000000064314   0.00
5  000000065170   0.00
6  000000065057   0.28
7  000000065209   0.64
8  000000067663   0.50
9  000000067854   0.20





In [4]:
if __name__ == "__main__":

    GEMINI_API_KEY = ""

    df = evaluate_subset(
        source_dataset="arood0/mmm_project_with_audio_ru_final",
        result_dataset="gab1k/mmm_project_tone",
        gemini_api_key=GEMINI_API_KEY,
        num_samples=10,
        save_path="mmm_project_tone.csv",
    )

    print(df[["IMAGE_ID", "score"]])

Streaming output dataset...


README.md:   0%|          | 0.00/324 [00:00<?, ?B/s]

Collected 10 IMAGE_IDs
Streaming source dataset...


Evaluating: 100%|██████████| 10/10 [00:27<00:00,  2.74s/it]

Saved results to mmm_project_tone.csv
       IMAGE_ID  score
0   -bH_SxERgTA   0.00
1  000000072902   0.03
2  000000074945   0.00
3  000000063958   0.00
4  000000064314   0.00
5  000000065170   0.00
6  000000065057   0.00
7  000000065209   0.00
8  000000067663   0.00
9  000000067854   0.00



