In [7]:
!pip install -q lpips scikit-image opencv-python Pillow
!pip install -q transformers torch torchvision torchcodec

In [8]:
!pip install torchcodec



In [13]:
import torch
import numpy as np
from PIL import Image
from typing import Dict, List, Tuple, Optional
import warnings
import lpips
import torchcodec
from datasets import load_dataset
warnings.filterwarnings('ignore')

#переиспользуем классы созданные в ноутбуке reference_evaluator.ipynb
class LightweightLPIPS:
    def __init__(self, net='alex', use_gpu=None):


        if use_gpu is None:
            use_gpu = torch.cuda.is_available()

        self.device = 'cuda' if use_gpu else 'cpu'

        print(f"Loading LPIPS ({net})...")


        self.model = lpips.LPIPS(net=net, verbose=False)
        self.model = self.model.to(self.device)
        self.model.eval()


        for param in self.model.parameters():
            param.requires_grad = False


        if self.device == 'cuda':
            allocated = torch.cuda.memory_allocated() / 1024**2
            print(f"   GPU Memory: {allocated:.1f} MB")

    def compute(self, img1: Image.Image, img2: Image.Image) -> float:

        tensor1 = self._preprocess(img1)
        tensor2 = self._preprocess(img2)

        with torch.no_grad():
            distance = self.model(tensor1, tensor2)

        similarity = 1.0 - distance.item()


        if self.device == 'cuda':
            torch.cuda.empty_cache()

        return max(0.0, min(1.0, similarity))

    def _preprocess(self, img: Image.Image) -> torch.Tensor:

        max_size = 512
        if max(img.size) > max_size:
            ratio = max_size / max(img.size)
            new_size = tuple(int(dim * ratio) for dim in img.size)
            img = img.resize(new_size, Image.LANCZOS)

        img_array = np.array(img).astype(np.float32) / 255.0
        tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0)


        tensor = tensor * 2 - 1

        return tensor.to(self.device)

    def __del__(self):

        if hasattr(self, 'device') and self.device == 'cuda':
            torch.cuda.empty_cache()


class ReferenceBasedEvaluator:

    def __init__(
        self,
        use_lpips: bool = True,
        use_ssim: bool = True,
        use_clip: bool = False,  # По умолчанию выключен для экономии памяти
        lpips_net: str = 'alex',
        device: str = None
    ):

        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = device


        print("Initializing Reference-Based Evaluator")
        print(f"Device: {self.device}")
        print(f"LPIPS: {use_lpips}")
        print(f"SSIM: {use_ssim}")
        print(f"CLIP: {use_clip}")
        print()

        self.lpips_model = None
        self.clip_model = None
        self.use_ssim = use_ssim


        if use_lpips:
            try:
                self.lpips_model = LightweightLPIPS(
                    net=lpips_net,
                    use_gpu=(self.device == 'cuda')
                )
            except Exception as e:
                print(f" LPIPS loading failed: {e}")

        if use_clip and self.device == 'cuda':
            try:

                from transformers import CLIPModel, CLIPProcessor

                self.clip_model = CLIPModel.from_pretrained(
                    "openai/clip-vit-base-patch32"
                ).to(self.device)
                self.clip_processor = CLIPProcessor.from_pretrained(
                    "openai/clip-vit-base-patch32"
                )
                self.clip_model.eval()


            except Exception as e:
                print(f"CLIP loading failed: {e}")
                self.clip_model = None

        print(" Evaluator initialized")


    def compute_lpips(
        self,
        generated: Image.Image,
        ground_truth: Image.Image
    ) -> float:

        if self.lpips_model is None:
            return 0.0

        try:
            return self.lpips_model.compute(generated, ground_truth)
        except Exception as e:
            print(f" LPIPS computation failed: {e}")
            return 0.0


    def compute_ssim(
        self,
        generated: Image.Image,
        ground_truth: Image.Image
    ) -> float:

        if not self.use_ssim:
            return 0.0

        try:
            from skimage.metrics import structural_similarity as ssim


            gen_array = np.array(generated.convert('RGB'))
            gt_array = np.array(ground_truth.convert('RGB'))


            if gen_array.shape != gt_array.shape:
                from skimage.transform import resize
                gen_array = resize(gen_array, gt_array.shape, anti_aliasing=True)
                gen_array = (gen_array * 255).astype(np.uint8)


            score = ssim(
                gen_array,
                gt_array,
                channel_axis=2,
                data_range=255
            )

            return float(score)
        except Exception as e:
            print(f"SSIM computation failed: {e}")
            return 0.0

    def compute_psnr(
        self,
        generated: Image.Image,
        ground_truth: Image.Image
    ) -> float:

        try:
            gen_array = np.array(generated.convert('RGB')).astype(np.float64)
            gt_array = np.array(ground_truth.convert('RGB')).astype(np.float64)


            if gen_array.shape != gt_array.shape:
                from skimage.transform import resize
                gen_array = resize(gen_array, gt_array.shape, anti_aliasing=True)
                gen_array = gen_array * 255

            mse = np.mean((gen_array - gt_array) ** 2)

            if mse == 0:
                return 100.0

            max_pixel = 255.0
            psnr = 20 * np.log10(max_pixel / np.sqrt(mse))

            return float(psnr)

        except Exception as e:
            print(f" PSNR computation failed: {e}")
            return 0.0


    def compute_clip_similarity(
        self,
        generated: Image.Image,
        ground_truth: Image.Image
    ) -> float:

        if self.clip_model is None:
            return 0.0

        try:
            inputs = self.clip_processor(
                images=[generated, ground_truth],
                return_tensors="pt"
            ).to(self.device)

            with torch.no_grad():
                image_features = self.clip_model.get_image_features(**inputs)

                # Normalize
                image_features = image_features / image_features.norm(dim=-1, keepdim=True)

                # Cosine similarity
                similarity = torch.cosine_similarity(
                    image_features[0:1],
                    image_features[1:2]
                ).item()

            # Cleanup
            if self.device == 'cuda':
                torch.cuda.empty_cache()

            return float(similarity)

        except Exception as e:
            print(f" CLIP computation failed: {e}")
            return 0.0


    def compute_color_similarity(
        self,
        generated: Image.Image,
        ground_truth: Image.Image
    ) -> float:

        try:
            import cv2

            gen_array = np.array(generated.convert('RGB'))
            gt_array = np.array(ground_truth.convert('RGB'))


            if gen_array.shape != gt_array.shape:
                gen_array = cv2.resize(gen_array, (gt_array.shape[1], gt_array.shape[0]))


            hist_gen = cv2.calcHist(
                [gen_array], [0, 1, 2], None,
                [8, 8, 8], [0, 256, 0, 256, 0, 256]
            )
            hist_gt = cv2.calcHist(
                [gt_array], [0, 1, 2], None,
                [8, 8, 8], [0, 256, 0, 256, 0, 256]
            )


            hist_gen = cv2.normalize(hist_gen, hist_gen).flatten()
            hist_gt = cv2.normalize(hist_gt, hist_gt).flatten()


            correlation = cv2.compareHist(
                hist_gen.reshape(-1, 1),
                hist_gt.reshape(-1, 1),
                cv2.HISTCMP_CORREL
            )

            return float(correlation)

        except Exception as e:
            print(f"Color similarity computation failed: {e}")
            return 0.0



    def evaluate_single(
        self,
        generated: Image.Image,
        ground_truth: Image.Image,
        weights: Optional[Dict[str, float]] = None,
        verbose: bool = False
    ) -> Dict[str, float]:

        if verbose:
            print("Evaluating...", end=" ")

        results = {}


        if weights is None:
            weights = {
                'lpips': 0.50,
                'ssim': 0.30,
                'clip': 0.0,
                'color': 0.20
            }

        results['lpips_similarity'] = self.compute_lpips(generated, ground_truth)


        results['ssim'] = self.compute_ssim(generated, ground_truth)

        results['psnr'] = self.compute_psnr(generated, ground_truth)


        if self.clip_model is not None:
            results['clip_similarity'] = self.compute_clip_similarity(
                generated, ground_truth
            )



        results['color_similarity'] = self.compute_color_similarity(
            generated, ground_truth
        )


        combined = 0.0
        total_weight = 0.0

        for metric, weight in weights.items():
            metric_key = f'{metric}_similarity' if metric != 'ssim' else metric

            if metric_key in results:
                value = results[metric_key]


                if metric == 'psnr':
                    value = min(1.0, max(0.0, (value - 20) / 20))

                combined += value * weight
                total_weight += weight

        results['combined_score'] = combined / total_weight if total_weight > 0 else 0.0
        return results


In [16]:
#Допишем класс для инференса датасета
class LightweightReferenceEvaluator:
    def __init__(
        self,
        use_lpips: bool = True,
        use_ssim: bool = True,
        use_psnr: bool = True,
        lpips_net: str = 'alex',
        device: str = None
    ):
        if device is None:
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.device = device


        self.use_ssim = use_ssim
        self.use_psnr = use_psnr
        self.lpips_model = None

        if use_lpips:
            try:
                self.lpips_model = lpips.LPIPS(net=lpips_net, verbose=False)
                self.lpips_model = self.lpips_model.to(self.device)
                self.lpips_model.eval()

                for param in self.lpips_model.parameters():
                    param.requires_grad = False

                if self.device == 'cuda':
                    mem = torch.cuda.memory_allocated() / 1024**2
                    print(f"  GPU Memory: {mem:.1f} MB")
            except Exception as e:
                print(f"LPIPS loading failed: {e}")
                self.lpips_model = None


    def _to_tensor(self, img: Image.Image, max_size: int = 512) -> torch.Tensor:

        if max(img.size) > max_size:
            ratio = max_size / max(img.size)
            new_size = tuple(int(dim * ratio) for dim in img.size)
            img = img.resize(new_size, Image.LANCZOS)

        if img.mode != 'RGB':
            img = img.convert('RGB')

        img_array = np.array(img).astype(np.float32) / 255.0
        tensor = torch.from_numpy(img_array).permute(2, 0, 1).unsqueeze(0)
        tensor = tensor * 2 - 1  # Normalize to [-1, 1]

        return tensor.to(self.device)

    def compute_lpips(self, generated: Image.Image, ground_truth: Image.Image) -> float:
        if self.lpips_model is None:
            return 0.0

        try:
            tensor1 = self._to_tensor(generated)
            tensor2 = self._to_tensor(ground_truth)

            if tensor1.shape != tensor2.shape:
                tensor1 = torch.nn.functional.interpolate(
                    tensor1, size=tensor2.shape[2:], mode='bilinear', align_corners=False
                )

            with torch.no_grad():
                distance = self.lpips_model(tensor1, tensor2)

            similarity = 1.0 - distance.item()

            if self.device == 'cuda':
                torch.cuda.empty_cache()

            return max(0.0, min(1.0, similarity))

        except Exception as e:
            print(f" LPIPS error: {e}")
            return 0.0

    def compute_ssim(self, generated: Image.Image, ground_truth: Image.Image) -> float:
        if not self.use_ssim:
            return 0.0

        try:
            from skimage.metrics import structural_similarity as ssim
            from skimage.transform import resize

            gen_array = np.array(generated.convert('RGB'))
            gt_array = np.array(ground_truth.convert('RGB'))


            if gen_array.shape != gt_array.shape:
                gen_array = resize(
                    gen_array, gt_array.shape,
                    anti_aliasing=True, preserve_range=True
                ).astype(np.uint8)

            score = ssim(gen_array, gt_array, channel_axis=2, data_range=255)
            return float(score)

        except Exception as e:
            print(f" SSIM error: {e}")
            return 0.0

    def compute_psnr(self, generated: Image.Image, ground_truth: Image.Image) -> float:

        if not self.use_psnr:
            return 0.0

        try:
            from skimage.transform import resize

            gen_array = np.array(generated.convert('RGB')).astype(np.float64)
            gt_array = np.array(ground_truth.convert('RGB')).astype(np.float64)


            if gen_array.shape != gt_array.shape:
                gen_array = resize(
                    gen_array, gt_array.shape,
                    anti_aliasing=True, preserve_range=True
                )

            mse = np.mean((gen_array - gt_array) ** 2)

            if mse == 0:
                return 100.0

            psnr = 20 * np.log10(255.0 / np.sqrt(mse))
            return float(psnr)

        except Exception as e:
            print(f" {e}")
            return 0.0

    def compute_mse(self, generated: Image.Image, ground_truth: Image.Image) -> float:
        try:
            from skimage.transform import resize

            gen_array = np.array(generated.convert('RGB')).astype(np.float64)
            gt_array = np.array(ground_truth.convert('RGB')).astype(np.float64)

            if gen_array.shape != gt_array.shape:
                gen_array = resize(
                    gen_array, gt_array.shape,
                    anti_aliasing=True, preserve_range=True
                )

            mse = np.mean((gen_array - gt_array) ** 2)
            normalized_mse = mse / (255.0 ** 2)

            return float(normalized_mse)

        except Exception as e:
            print(f"MSE error: {e}")
            return 1.0

    def evaluate_single(
        self,
        generated: Image.Image,
        ground_truth: Image.Image,
        verbose: bool = False
    ) -> Dict[str, float]:

        results = {}

        results['lpips_similarity'] = self.compute_lpips(generated, ground_truth)
        results['ssim'] = self.compute_ssim(generated, ground_truth)

        results['psnr'] = self.compute_psnr(generated, ground_truth)

        results['mse'] = self.compute_mse(generated, ground_truth)
        psnr_normalized = min(1.0, max(0.0, (results['psnr'] - 15) / 25))

        results['combined_score'] = (
            results['lpips_similarity'] * 0.4 +
            results['ssim'] * 0.4 +
            psnr_normalized * 0.2
        )


        return results

    def cleanup(self):
        if self.lpips_model is not None:
            del self.lpips_model
            self.lpips_model = None

        if self.device == 'cuda':
            torch.cuda.empty_cache()

        gc.collect()


def evaluate_with_reference(
    result_dataset: str,
    ground_truth_dataset: str,
    result_image_col: str = "result_image",
    ground_truth_col: str = "OUTPUT_IMG",
    id_col: str = "IMAGE_ID",
    load_samples: Optional[int] = None,
    max_samples: Optional[int] = None,
    save_path: Optional[str] = None,
    save_every: int = 10,
    resume_from: Optional[str] = None,
    use_lpips: bool = True,
    use_ssim: bool = True,
    use_psnr: bool = True
) -> pd.DataFrame:

    evaluator = LightweightReferenceEvaluator(
        use_lpips=use_lpips,
        use_ssim=use_ssim,
        use_psnr=use_psnr
    )


    processed_ids = set()
    results = []

    if resume_from and Path(resume_from).exists():
        existing_df = pd.read_csv(resume_from)
        results = existing_df.to_dict('records')
        processed_ids = set(existing_df['IMAGE_ID'].tolist())

    gt_stream = load_dataset(ground_truth_dataset, split='train', streaming=True)

    gt_dict = {}
    for idx, item in enumerate(gt_stream):
        image_id = (item.get(id_col) or item.get('IMAGE_ID') or
                    item.get('image_id') or item.get('id'))
        if image_id is not None:
            gt_dict[image_id] = item
        if load_samples and idx >= load_samples - 1:
            break

    if gt_dict:
        sample_gt = list(gt_dict.values())[0]

    result_stream = load_dataset(result_dataset, split='train', streaming=True)


    processed_count = 0
    skipped_count = 0

    for idx, result_item in enumerate(tqdm(result_stream, desc="Evaluating")):

        if load_samples and idx >= load_samples:
            break


        if max_samples and processed_count >= max_samples:
            break


        image_id = (result_item.get(id_col) or result_item.get('IMAGE_ID') or
                    result_item.get('image_id') or result_item.get('id'))

        if image_id is None:
            skipped_count += 1
            continue

        if image_id in processed_ids:
            continue

        if image_id not in gt_dict:
            skipped_count += 1
            continue

        gt_item = gt_dict[image_id]

        try:
            result_img = _get_image(result_item, result_image_col)
            gt_img = _get_image(gt_item, ground_truth_col)

            if result_img is None or gt_img is None:
                skipped_count += 1
                continue

            scores = evaluator.evaluate_single(result_img, gt_img)

            result_entry = {
                'IMAGE_ID': image_id,
                'lpips_similarity': scores['lpips_similarity'],
                'ssim': scores['ssim'],
                'psnr': scores['psnr'],
                'mse': scores['mse'],
                'combined_score': scores['combined_score'],
                'error': ''
            }
            results.append(result_entry)
            processed_count += 1

            if save_path and processed_count % save_every == 0:
                pd.DataFrame(results).to_csv(save_path, index=False)


            if processed_count % 20 == 0:
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

        except Exception as e:
            tqdm.write(f"   {image_id}: {e}")
            results.append({
                'IMAGE_ID': image_id,
                'lpips_similarity': 0.0,
                'ssim': 0.0,
                'psnr': 0.0,
                'mse': 1.0,
                'combined_score': 0.0,
                'error': str(e)
            })


    evaluator.cleanup()

    results_df = pd.DataFrame(results)

    # Сохраняем
    if save_path:
        results_df.to_csv(save_path, index=False)
        print(f"{save_path}")


    return results_df


def _get_image(item: dict, col_name: str) -> Optional[Image.Image]:
    img = (item.get(col_name) or item.get(col_name.lower()) or
           item.get('image') or item.get('img'))

    if img is None:
        return None

    try:
        if isinstance(img, Image.Image):
            return img.convert('RGB')

        if isinstance(img, dict):
            if 'bytes' in img:
                return Image.open(io.BytesIO(img['bytes'])).convert('RGB')
            if 'path' in img:
                return Image.open(img['path']).convert('RGB')

        if isinstance(img, bytes):
            return Image.open(io.BytesIO(img)).convert('RGB')

        return None

    except Exception as e:
        print(f"mage loading error: {e}")
        return None


def compare_models_with_reference(
    result_datasets: Dict[str, str],
    ground_truth_dataset: str,
    ground_truth_col: str = "OUTPUT_IMG",
    **kwargs
) -> pd.DataFrame:

    all_results = []

    for model_name, result_dataset in result_datasets.items():

        df = evaluate_with_reference(
            result_dataset=result_dataset,
            ground_truth_dataset=ground_truth_dataset,
            ground_truth_col=ground_truth_col,
            save_path=f"ref_results_{model_name}.csv",
            **kwargs
        )

        df['model'] = model_name
        all_results.append(df)

        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    combined_df = pd.concat(all_results, ignore_index=True)

    summary = combined_df.groupby('model').agg({
        'lpips_similarity': ['mean', 'std'],
        'ssim': ['mean', 'std'],
        'psnr': ['mean', 'std'],
        'combined_score': ['mean', 'std']
    }).round(4)

    print(summary)

    return combined_df



In [17]:
import torchcodec
results = evaluate_with_reference(
  result_dataset="gab1k/mmm_project_with_ru_intermdata",
  ground_truth_dataset="arood0/mmm_project_with_audio_ru_final",
  result_image_col="result_image",
  ground_truth_col="OUTPUT_IMG",
  load_samples=300,
  max_samples=10,
  save_path="reference_eval_results.csv",
  save_every=5
)

print("\nResults preview:")
print(results[['IMAGE_ID', 'lpips_similarity', 'ssim', 'psnr', 'combined_score']].head(10))


Evaluating: 10it [00:20,  2.06s/it]


reference_eval_results.csv

Results preview:
      IMAGE_ID  lpips_similarity      ssim       psnr  combined_score
0  BEwrVP6o0yQ          0.753962  0.461809  17.756203        0.508358
1  BdH23PovaTA          0.752320  0.823275  14.035430        0.630238
2  bFdWsiXblVw          0.871372  0.775359  22.478411        0.718520
3  bfF-9S0ktP8          0.801408  0.848465  23.605118        0.728790
4  bfWJOx132As          0.906186  0.872151  21.662621        0.764636
5  BfO20utCi0I          0.549404  0.497361  14.688648        0.418706
6  Bfp634LE8Cc          0.731500  0.583868  19.044910        0.558506
7  Bg0Geue-cY8          0.832323  0.809618  17.144528        0.673933
8  Bgae-sqbe_g          0.571017  0.608394  11.402190        0.471764
9  bgDSZ-w2gIM          0.676759  0.748714  11.863627        0.570189


In [20]:
print("lpips_similarity", results['lpips_similarity'].mean())
print("ssim", results['ssim'].mean())
print("psnr", results['psnr'].mean())
print("combined_score", results['combined_score'].mean())

lpips_similarity 0.7446252033114433
ssim 0.702901249933613
psnr 17.36816850752275
combined_score 0.6043640141488593


Хочется сделать какой-то вывод, но пока что рано, тк нет других моделей с которыми мы можем сравниться и сделать вывод кто лучше)
