In [None]:
import os
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
from sklearn.model_selection import train_test_split

class ImageHistogramDataset(Dataset):
    """Датасет для изображений и соответствующих гистограмм"""

    def __init__(self, images_dir, histograms_dir, csv_path=None, transform=None, img_size=512):
        """
        Args:
            images_dir (str): Путь к папке с изображениями
            histograms_dir (str): Путь к папке с гистограммами
            csv_path (str, optional): Путь к CSV файлу с точками белого
            transform (callable, optional): Трансформы для изображений
            img_size (int): Размер изображения для resize
        """
        self.images_dir = Path(images_dir)
        self.histograms_dir = Path(histograms_dir)
        self.transform = transform
        self.img_size = img_size

        # Проверяем существование папок
        """if not self.images_dir.exists():
            raise ValueError(f"Images directory not found: {images_dir}")
        if not self.histograms_dir.exists():
            raise ValueError(f"Histograms directory not found: {histograms_dir}")"""

        # Получаем список файлов
        self.image_files = self._get_image_files()
        self.histogram_files = self._get_histogram_files()

        # Загружаем white points из CSV если есть
        self.white_points = self._load_white_points(csv_path)

        # Проверяем соответствие файлов
        self._validate_files()

        print(f"Dataset created: {len(self.image_files)} images, {len(self.histogram_files)} histograms")

    def _get_image_files(self):
        """Получаем список изображений"""
        image_extensions = ['.png', '.jpg', '.jpeg', '.tiff', '.tif', '.bmp']
        image_files = []

        for ext in image_extensions:
            image_files.extend(list(self.images_dir.glob(f'*{ext}')))
            image_files.extend(list(self.images_dir.glob(f'*{ext.upper()}')))

        # Сортируем по имени для consistency
        image_files.sort(key=lambda x: x.name)
        return image_files

    def _get_histogram_files(self):
        """Получаем список гистограмм"""
        histogram_extensions = ['.png', '.jpg', '.jpeg', '.npy']
        histogram_files = []

        for ext in histogram_extensions:
            histogram_files.extend(list(self.histograms_dir.glob(f'*{ext}')))
            histogram_files.extend(list(self.histograms_dir.glob(f'*{ext.upper()}')))

        histogram_files.sort(key=lambda x: x.name)
        return histogram_files

    def _load_white_points(self, csv_path):
        """Загружаем точки белого из CSV"""
        white_points = {}

        if csv_path and os.path.exists(csv_path):
            try:
                df = pd.read_csv(csv_path)
                print(f"Loaded white points from CSV: {len(df)} entries")

                # Определяем колонки
                image_col = None
                wp_cols = ['wp_r', 'wp_g', 'wp_b']

                for col in df.columns:
                    if 'image' in col.lower() or 'path' in col.lower() or 'name' in col.lower():
                        image_col = col
                        break

                if image_col is None:
                    image_col = df.columns[0]

                # Создаем словарь {filename: [wp_r, wp_g, wp_b]}
                for _, row in df.iterrows():
                    filename = Path(row[image_col]).name
                    wp_values = [row.get(col, 32768.0) for col in wp_cols]
                    white_points[filename] = wp_values

            except Exception as e:
                print(f"Error loading CSV: {e}")

        return white_points

    def _validate_files(self):
        """Проверяем соответствие изображений и гистограмм"""
        # Получаем имена файлов без расширений
        image_names = {f.stem for f in self.image_files}
        histogram_names = {f.stem for f in self.histogram_files}

        # Общие файлы
        common_names = image_names & histogram_names
        print(f"Common files: {len(common_names)}")

        # Фильтруем только соответствующие файлы
        self.image_files = [f for f in self.image_files if f.stem in common_names]
        self.histogram_files = [f for f in self.histogram_files if f.stem in common_names]

        # Сортируем чтобы порядок совпадал
        self.image_files.sort(key=lambda x: x.stem)
        self.histogram_files.sort(key=lambda x: x.stem)

    def _read_image(self, image_path):
        """Чтение и обработка изображения"""
        try:
            # Пробуем разные способы чтения
            img = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
            if img is None:
                img = cv2.imread(str(image_path), cv2.IMREAD_COLOR)

            if img is None:
                raise ValueError(f"Cannot read image: {image_path}")

            # Конвертируем в float и нормализуем
            if img.dtype == np.uint16:
                img = img.astype(np.float32) / 65535.0
            elif img.dtype == np.uint8:
                img = img.astype(np.float32) / 255.0

            # BGR to RGB
            if img.shape[-1] == 3:
                img = img[..., ::-1]

            # Resize
            if img.shape[0] != self.img_size or img.shape[1] != self.img_size:
                img = cv2.resize(img, (self.img_size, self.img_size))

            return img

        except Exception as e:
            print(f"Error reading image {image_path}: {e}")
            return np.zeros((self.img_size, self.img_size, 3), dtype=np.float32)

    def _read_histogram(self, histogram_path):
        """Чтение гистограммы"""
        try:
            if histogram_path.suffix.lower() in ['.npy']:
                hist = np.load(histogram_path)
            else:
                hist = cv2.imread(str(histogram_path), cv2.IMREAD_GRAYSCALE)
                if hist is None:
                    hist = np.zeros((128, 128), dtype=np.float32)
                else:
                    hist = hist.astype(np.float32) / 255.0

            return hist

        except Exception as e:
            print(f"Error reading histogram {histogram_path}: {e}")
            return np.zeros((128, 128), dtype=np.float32)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        try:
            image_path = self.image_files[idx]
            histogram_path = self.histogram_files[idx]

            # Чтение данных
            image = self._read_image(image_path)
            histogram = self._read_histogram(histogram_path)

            # White point
            filename = image_path.stem
            if filename in self.white_points:
                wp_r, wp_g, wp_b = self.white_points[filename]
            else:
                wp_r, wp_g, wp_b = 32768.0, 32768.0, 32768.0

            # Нормализация white point к [0, 1]
            white_point = np.array([wp_r, wp_g, wp_b], dtype=np.float32) / 65535.0

            # Преобразование в tensor
            image_tensor = torch.from_numpy(image.transpose(2, 0, 1)).float()
            hist_tensor = torch.from_numpy(histogram).float().unsqueeze(0)  # Добавляем channel dimension
            wp_tensor = torch.from_numpy(white_point).float()

            # Применяем трансформы
            if self.transform:
                image_tensor = self.transform(image_tensor)

            return image_tensor, hist_tensor, wp_tensor

        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # Возвращаем dummy данные
            dummy_image = torch.randn(3, self.img_size, self.img_size)
            dummy_hist = torch.zeros(1, 128, 128)
            dummy_wp = torch.tensor([0.5, 0.5, 0.5])
            return dummy_image, dummy_hist, dummy_wp

def create_data_loaders(images_dir, histograms_dir, csv_path=None,
                       batch_size=8, val_size=0.2, img_size=512, transform=None):
    """Создает DataLoader'ы для обучения"""

    # Создаем полный датасет
    dataset = ImageHistogramDataset(
        images_dir=images_dir,
        histograms_dir=histograms_dir,
        csv_path=csv_path,
        transform=transform,
        img_size=img_size
    )

    if len(dataset) == 0:
        raise ValueError("No valid samples found in dataset!")

    # Разделение на train/val
    indices = list(range(len(dataset)))
    train_indices, val_indices = train_test_split(
        indices, test_size=val_size, random_state=42, shuffle=True
    )

    # Создаем подвыборки
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    print(f"Train samples: {len(train_indices)}, Val samples: {len(val_indices)}")

    # DataLoaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=min(batch_size, len(train_dataset)),
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=min(batch_size, len(val_dataset)),
        shuffle=False,
        num_workers=1,
        pin_memory=True
    )

    return train_loader, val_loader

# Пример использования
def example_usage():
    """Пример использования датасета"""

    # Пути к вашим папкам
    IMAGES_DIR = "/content/train_imgs"
    HISTOGRAMS_DIR = "/content/train_histograms"
    CSV_PATH = "/content/train.csv"

    try:
        # Создаем DataLoader'ы
        train_loader, val_loader = create_data_loaders(
            images_dir=IMAGES_DIR,
            histograms_dir=HISTOGRAMS_DIR,
            csv_path=CSV_PATH,
            batch_size=8,
            val_size=0.2,
            img_size=512
        )

        print(f"Successfully created data loaders!")
        print(f"Train batches: {len(train_loader)}")
        print(f"Val batches: {len(val_loader)}")

        # Проверяем первый батч
        for images, hists, white_points in train_loader:
            print(f"Images shape: {images.shape}")
            print(f"Histograms shape: {hists.shape}")
            print(f"White points shape: {white_points.shape}")
            print(f"White points range: {white_points.min():.3f} - {white_points.max():.3f}")
            break

        return train_loader, val_loader

    except Exception as e:
        print(f"Error: {e}")
        return None, None

# Утилиты для проверки
def check_folder_structure(images_dir, histograms_dir):
    """Проверяет структуру папок"""
    print("🔍 Checking folder structure...")

    images_dir = Path(images_dir)
    histograms_dir = Path(histograms_dir)

    print(f"Images directory: {images_dir}")
    if images_dir.exists():
        image_files = list(images_dir.glob('*.*'))
        print(f"  Found {len(image_files)} files")
        if image_files:
            print(f"  Example: {image_files[0].name}")
    else:
        print("  ❌ Does not exist!")

    print(f"Histograms directory: {histograms_dir}")
    if histograms_dir.exists():
        hist_files = list(histograms_dir.glob('*.*'))
        print(f"  Found {len(hist_files)} files")
        if hist_files:
            print(f"  Example: {hist_files[0].name}")
    else:
        print("  ❌ Does not exist!")

# Быстрый старт
def quick_start():
    """Быстрое создание датасета"""
    # Замените на ваши пути
    images_dir = input("Enter images directory path: ").strip() or "/content/train_imgs"
    histograms_dir = input("Enter histograms directory path: ").strip() or "/content/train_histograms"
    csv_path = input("Enter CSV path (optional): ").strip() or None

    check_folder_structure(images_dir, histograms_dir)

    return create_data_loaders(images_dir, histograms_dir, csv_path)

if __name__ == "__main__":
    # Запуск примера
    train_loader, val_loader = example_usage()

    # Или быстрый старт
    # train_loader, val_loader = quick_start()

Loaded white points from CSV: 570 entries
Common files: 0
Dataset created: 0 images, 0 histograms
Error: No valid samples found in dataset!


In [None]:
import os
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from torch.utils.data import Dataset
import torch
from tqdm import tqdm
from sklearn.model_selection import train_test_split

class ImageHistogramDataset(Dataset):
    """Датасет для изображений и соответствующих гистограмм"""

    def __init__(self, images_dir, histograms_dir, csv_path=None, transform=None, img_size=512):
        """
        Args:
            images_dir (str): Путь к папке с изображениями
            histograms_dir (str): Путь к папке с гистограммами
            csv_path (str, optional): Путь к CSV файлу с точками белого
            transform (callable, optional): Трансформы для изображений
            img_size (int): Размер изображения для resize
        """
        self.images_dir = Path(images_dir)
        self.histograms_dir = Path(histograms_dir)
        self.transform = transform
        self.img_size = img_size

        # Проверяем существование папок
        if not self.images_dir.exists():
            raise ValueError(f"Images directory not found: {images_dir}")
        if not self.histograms_dir.exists():
            raise ValueError(f"Histograms directory not found: {histograms_dir}")

        # Получаем список файлов
        self.image_files = self._get_image_files()
        self.histogram_files = self._get_histogram_files()

        # Загружаем white points из CSV если есть
        self.white_points = self._load_white_points(csv_path)

        # Проверяем соответствие файлов
        self._validate_files()

        print(f"Dataset created: {len(self.image_files)} images, {len(self.histogram_files)} histograms")

    def _get_image_files(self):
        """Получаем список изображений"""
        image_extensions = ['.png', '.jpg', '.jpeg', '.tiff', '.tif', '.bmp']
        image_files = []

        for ext in image_extensions:
            image_files.extend(list(self.images_dir.glob(f'*{ext}')))
            image_files.extend(list(self.images_dir.glob(f'*{ext.upper()}')))

        # Сортируем по имени для consistency
        image_files.sort(key=lambda x: x.name)
        return image_files

    def _get_histogram_files(self):
        """Получаем список гистограмм"""
        histogram_extensions = ['.png', '.jpg', '.jpeg', '.npy']
        histogram_files = []

        for ext in histogram_extensions:
            histogram_files.extend(list(self.histograms_dir.glob(f'*{ext}')))
            histogram_files.extend(list(self.histograms_dir.glob(f'*{ext.upper()}')))

        histogram_files.sort(key=lambda x: x.name)
        return histogram_files

    def _load_white_points(self, csv_path):
        """Загружаем точки белого из CSV"""
        white_points = {}

        if csv_path and os.path.exists(csv_path):
            try:
                df = pd.read_csv(csv_path)
                print(f"Loaded white points from CSV: {len(df)} entries")

                # Определяем колонки
                image_col = None
                wp_cols = ['wp_r', 'wp_g', 'wp_b']

                for col in df.columns:
                    if 'image' in col.lower() or 'path' in col.lower() or 'name' in col.lower():
                        image_col = col
                        break

                if image_col is None:
                    image_col = df.columns[0]

                # Создаем словарь {filename: [wp_r, wp_g, wp_b]}
                for _, row in df.iterrows():
                    filename = Path(row[image_col]).name
                    wp_values = [row.get(col, 32768.0) for col in wp_cols]
                    white_points[filename] = wp_values

            except Exception as e:
                print(f"Error loading CSV: {e}")

        return white_points

    def _validate_files(self):
        """Проверяем соответствие изображений и гистограмм"""
        # Получаем имена файлов без расширений
        image_names = {f.stem for f in self.image_files}
        histogram_names = {f.stem for f in self.histogram_files}

        # Общие файлы
        common_names = image_names & histogram_names
        print(f"Common files: {len(common_names)}")

        # Фильтруем только соответствующие файлы
        self.image_files = [f for f in self.image_files if f.stem in common_names]
        self.histogram_files = [f for f in self.histogram_files if f.stem in common_names]

        # Сортируем чтобы порядок совпадал
        self.image_files.sort(key=lambda x: x.stem)
        self.histogram_files.sort(key=lambda x: x.stem)

    def _read_image(self, image_path):
        """Чтение и обработка изображения"""
        try:
            # Пробуем разные способы чтения
            img = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
            if img is None:
                img = cv2.imread(str(image_path), cv2.IMREAD_COLOR)

            if img is None:
                raise ValueError(f"Cannot read image: {image_path}")

            # Конвертируем в float и нормализуем
            if img.dtype == np.uint16:
                img = img.astype(np.float32) / 65535.0
            elif img.dtype == np.uint8:
                img = img.astype(np.float32) / 255.0

            # BGR to RGB
            if img.shape[-1] == 3:
                img = img[..., ::-1]

            # Resize
            if img.shape[0] != self.img_size or img.shape[1] != self.img_size:
                img = cv2.resize(img, (self.img_size, self.img_size))

            return img

        except Exception as e:
            print(f"Error reading image {image_path}: {e}")
            return np.zeros((self.img_size, self.img_size, 3), dtype=np.float32)

    def _read_histogram(self, histogram_path):
        """Чтение гистограммы"""
        try:
            if histogram_path.suffix.lower() in ['.npy']:
                hist = np.load(histogram_path)
            else:
                hist = cv2.imread(str(histogram_path), cv2.IMREAD_GRAYSCALE)
                if hist is None:
                    hist = np.zeros((128, 128), dtype=np.float32)
                else:
                    hist = hist.astype(np.float32) / 255.0

            return hist

        except Exception as e:
            print(f"Error reading histogram {histogram_path}: {e}")
            return np.zeros((128, 128), dtype=np.float32)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        try:
            image_path = self.image_files[idx]
            histogram_path = self.histogram_files[idx]

            # Чтение данных
            image = self._read_image(image_path)
            histogram = self._read_histogram(histogram_path)

            # White point
            filename = image_path.stem
            if filename in self.white_points:
                wp_r, wp_g, wp_b = self.white_points[filename]
            else:
                wp_r, wp_g, wp_b = 32768.0, 32768.0, 32768.0

            # Нормализация white point к [0, 1]
            white_point = np.array([wp_r, wp_g, wp_b], dtype=np.float32) / 65535.0

            # Преобразование в tensor
            image_tensor = torch.from_numpy(image.transpose(2, 0, 1)).float()
            hist_tensor = torch.from_numpy(histogram).float().unsqueeze(0)  # Добавляем channel dimension
            wp_tensor = torch.from_numpy(white_point).float()

            # Применяем трансформы
            if self.transform:
                image_tensor = self.transform(image_tensor)

            return image_tensor, hist_tensor, wp_tensor

        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # Возвращаем dummy данные
            dummy_image = torch.randn(3, self.img_size, self.img_size)
            dummy_hist = torch.zeros(1, 128, 128)
            dummy_wp = torch.tensor([0.5, 0.5, 0.5])
            return dummy_image, dummy_hist, dummy_wp

def create_data_loaders(images_dir, histograms_dir, csv_path=None,
                       batch_size=8, val_size=0.2, img_size=512, transform=None):
    """Создает DataLoader'ы для обучения"""

    # Создаем полный датасет
    dataset = ImageHistogramDataset(
        images_dir=images_dir,
        histograms_dir=histograms_dir,
        csv_path=csv_path,
        transform=transform,
        img_size=img_size
    )

    if len(dataset) == 0:
        raise ValueError("No valid samples found in dataset!")

    # Разделение на train/val
    indices = list(range(len(dataset)))
    train_indices, val_indices = train_test_split(
        indices, test_size=val_size, random_state=42, shuffle=True
    )

    # Создаем подвыборки
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)

    print(f"Train samples: {len(train_indices)}, Val samples: {len(val_indices)}")

    # DataLoaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=min(batch_size, len(train_dataset)),
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=min(batch_size, len(val_dataset)),
        shuffle=False,
        num_workers=1,
        pin_memory=True
    )

    return train_loader, val_loader

# Пример использования
def example_usage():
    """Пример использования датасета"""

    # Пути к вашим папкам
    IMAGES_DIR = "/content/train_imgs"
    HISTOGRAMS_DIR = "/content/екфшт_ршыещпкфьы2"
    CSV_PATH = "/content/train.csv"

    try:
        # Создаем DataLoader'ы
        train_loader, val_loader = create_data_loaders(
            images_dir='/content/train_imgs2',
            histograms_dir='/content/train_histograms',
            csv_path="/content/train.csv",
            batch_size=8,
            val_size=0.2,
            img_size=512
        )

        print(f"Successfully created data loaders!")
        print(f"Train batches: {len(train_loader)}")
        print(f"Val batches: {len(val_loader)}")

        # Проверяем первый батч
        for images, hists, white_points in train_loader:
            print(f"Images shape: {images.shape}")
            print(f"Histograms shape: {hists.shape}")
            print(f"White points shape: {white_points.shape}")
            print(f"White points range: {white_points.min():.3f} - {white_points.max():.3f}")
            break

        return train_loader, val_loader

    except Exception as e:
        print(f"Error: {e}")
        return None, None

# Утилиты для проверки
def check_folder_structure(images_dir, histograms_dir):
    """Проверяет структуру папок"""
    print("🔍 Checking folder structure...")

    images_dir = Path(images_dir)
    histograms_dir = Path(histograms_dir)

    print(f"Images directory: {images_dir}")
    if images_dir.exists():
        image_files = list(images_dir.glob('*.*'))
        print(f"  Found {len(image_files)} files")
        if image_files:
            print(f"  Example: {image_files[0].name}")
    else:
        print("  ❌ Does not exist!")

    print(f"Histograms directory: {histograms_dir}")
    if histograms_dir.exists():
        hist_files = list(histograms_dir.glob('*.*'))
        print(f"  Found {len(hist_files)} files")
        if hist_files:
            print(f"  Example: {hist_files[0].name}")
    else:
        print("  ❌ Does not exist!")

# Быстрый старт
def quick_start():
    """Быстрое создание датасета"""
    # Замените на ваши пути
    images_dir = input("Enter images directory path: ").strip() or "/content/train_imgs"
    histograms_dir = input("Enter histograms directory path: ").strip() or "/content/train_histograms"
    csv_path = input("Enter CSV path (optional): ").strip() or None

    check_folder_structure(images_dir, histograms_dir)

    return create_data_loaders(images_dir, histograms_dir, csv_path)

if __name__ == "__main__":
    # Запуск примера
    train_loader, val_loader = example_usage()

    # Или быстрый старт
    # train_loader, val_loader = quick_start()

Loaded white points from CSV: 570 entries
Common files: 84
Dataset created: 84 images, 84 histograms
Train samples: 67, Val samples: 17
Successfully created data loaders!
Train batches: 9
Val batches: 3




Images shape: torch.Size([8, 3, 512, 512])
Histograms shape: torch.Size([8, 1, 128, 128])
White points shape: torch.Size([8, 3])
White points range: 0.500 - 0.500


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import os
import json
import matplotlib.pyplot as plt
from datetime import datetime

class Dist2HistLoss(nn.Module):
    """Функция потерь, соответствующая метрике Dist2Hist"""

    def __init__(self, alpha=0.6, beta=0.3, gamma=0.1, eps=1e-7):
        """
        Args:
            alpha: вес угловой ошибки (основной компонент Dist2Hist)
            beta: вес хроматического расстояния
            gamma: вес MSE для стабилизации
            eps: маленькое значение для численной стабильности
        """
        super(Dist2HistLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma
        self.eps = eps

    def angular_loss(self, pred, target):
        """Угловая ошибка между векторами"""
        # Нормализуем векторы
        pred_norm = pred / (torch.norm(pred, dim=1, keepdim=True) + self.eps)
        target_norm = target / (torch.norm(target, dim=1, keepdim=True) + self.eps)

        # Косинусная схожесть
        cosine_sim = torch.sum(pred_norm * target_norm, dim=1)
        cosine_sim = torch.clamp(cosine_sim, -1 + self.eps, 1 - self.eps)

        # Угол в радианах
        angle = torch.acos(cosine_sim)
        return torch.mean(angle)

    def chromatic_loss(self, pred, target):
        """Расстояние в хроматическом пространстве"""
        # Преобразование RGB в хроматические координаты (α, β)
        pred_alpha = pred[:, 0] / (torch.sum(pred, dim=1) + self.eps)
        pred_beta = pred[:, 1] / (torch.sum(pred, dim=1) + self.eps)

        target_alpha = target[:, 0] / (torch.sum(target, dim=1) + self.eps)
        target_beta = target[:, 1] / (torch.sum(target, dim=1) + self.eps)

        # Евклидово расстояние в хроматическом пространстве
        alpha_diff = pred_alpha - target_alpha
        beta_diff = pred_beta - target_beta
        chroma_dist = torch.sqrt(alpha_diff**2 + beta_diff**2 + self.eps)

        return torch.mean(chroma_dist)

    def mse_loss(self, pred, target):
        """MSE для стабилизации обучения"""
        return nn.MSELoss()(pred, target)

    def forward(self, pred, target):
        """
        Вычисление общей потери

        Args:
            pred: предсказанные точки белого [batch_size, 3] в диапазоне [0, 1]
            target: истинные точки белого [batch_size, 3] в диапазоне [0, 1]

        Returns:
            Комбинированная потеря, соответствующая метрике Dist2Hist
        """
        angular = self.angular_loss(pred, target)
        chroma = self.chromatic_loss(pred, target)
        mse = self.mse_loss(pred, target)

        # Комбинируем с весами, соответствующими метрике Dist2Hist
        total_loss = (self.alpha * angular +
                     self.beta * chroma +
                     self.gamma * mse)

        return total_loss

class WhiteBalanceModel(nn.Module):
    """Модель для предсказания точки белого"""

    def __init__(self, pretrained=False):
        super(WhiteBalanceModel, self).__init__()

        # Используем EfficientNet как энкодер
        if pretrained:
            from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
            self.backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        else:
            from torchvision.models import efficientnet_b0
            self.backbone = efficientnet_b0(weights=None)

        # Заменяем классификатор на регрессионную голову
        in_features = self.backbone.classifier[1].in_features

        self.backbone.classifier = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 3),  # 3 выхода для RGB
            nn.Sigmoid()  # Выход в диапазоне [0, 1]
        )

    def forward(self, x):
        return self.backbone(x)

def setup_device():
    """Настройка устройства (GPU/CPU)"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    return device

def train_epoch(model, train_loader, criterion, optimizer, device, epoch):
    """Одна эпоха обучения"""
    model.train()
    total_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch} Training")

    for batch_idx, (images, _, white_points) in enumerate(progress_bar):
        images = images.to(device)
        white_points = white_points.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, white_points)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Обновление progress bar
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.6f}',
            'Avg Loss': f'{total_loss/(batch_idx+1):.6f}'
        })

    return total_loss / len(train_loader)

def validate_epoch(model, val_loader, criterion, device, epoch):
    """Валидация"""
    model.eval()
    total_loss = 0.0
    progress_bar = tqdm(val_loader, desc=f"Epoch {epoch} Validation")

    with torch.no_grad():
        for batch_idx, (images, _, white_points) in enumerate(progress_bar):
            images = images.to(device)
            white_points = white_points.to(device)

            outputs = model(images)
            loss = criterion(outputs, white_points)

            total_loss += loss.item()

            progress_bar.set_postfix({
                'Val Loss': f'{loss.item():.6f}',
                'Avg Val Loss': f'{total_loss/(batch_idx+1):.6f}'
            })

    return total_loss / len(val_loader)

def save_checkpoint(model, optimizer, epoch, loss, path):
    """Сохранение чекпоинта"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
        'timestamp': datetime.now().isoformat()
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved: {path}")

def load_checkpoint(model, optimizer, checkpoint_path, device):
    """Загрузка чекпоинта"""
    if os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print(f"Checkpoint loaded from epoch {checkpoint['epoch']}")
        return checkpoint['epoch'], checkpoint['loss']
    return 0, float('inf')

def create_dummy_dataloaders(batch_size=8, img_size=224, num_samples=100):
    """Создание dummy dataloaders для тестирования"""
    from torch.utils.data import TensorDataset, DataLoader

    # Создаем dummy данные
    dummy_images = torch.randn(num_samples, 3, img_size, img_size)
    dummy_hists = torch.randn(num_samples, 1, 128, 128)  # Гистограммы
    dummy_white_points = torch.rand(num_samples, 3)  # Точки белого в [0, 1]

    # Разделение на train/val
    split = int(0.8 * num_samples)

    train_dataset = TensorDataset(dummy_images[:split], dummy_hists[:split], dummy_white_points[:split])
    val_dataset = TensorDataset(dummy_images[split:], dummy_hists[split:], dummy_white_points[split:])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    print(f"Dummy data: {num_samples} samples")
    return train_loader, val_loader

def main():
    """Основная функция обучения"""
    # Конфигурация
    config = {
        'batch_size': 1,
        'learning_rate': 1e-4,
        'num_epochs': 33,
        'val_size': 0.2,
        'weight_decay': 1e-5,
        'pretrained': True,
        'checkpoint_path': 'best_model.pth',
        'results_dir': 'results',
        'img_size': 224
    }

    # Создание директории для результатов
    os.makedirs(config['results_dir'], exist_ok=True)

    # Сохранение конфигурации
    with open(f"{config['results_dir']}/config.json", 'w') as f:
        json.dump(config, f, indent=4)

    # Настройка устройства
    device = setup_device()

    # Создание DataLoaders
    print("Creating data loaders...")
    try:
        # Замените на вашу функцию создания датасета


        train_loader, val_loader = create_data_loaders(
            images_dir="/content/train_imgs2",
            histograms_dir="/content/train_histograms",
            csv_path="/content/train.csv",
            batch_size=config['batch_size'],
            val_size=config['val_size'],
            img_size=config['img_size']
        )

    except Exception as e:
        print(f"Error creating data loaders: {e}")
        print("Using dummy data for testing...")
        train_loader, val_loader = create_dummy_dataloaders(
            batch_size=config['batch_size'],
            img_size=config['img_size']
        )

    print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

    # Инициализация модели
    print("Initializing model...")
    model = WhiteBalanceModel(pretrained=config['pretrained']).to(device)

    # Функция потерь - ОСНОВНОЕ ИЗМЕНЕНИЕ
    criterion = Dist2HistLoss(alpha=0.6, beta=0.3, gamma=0.1)

    # Оптимизатор
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )

    # Scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5
    )

    # Загрузка чекпоинта если существует
    start_epoch, best_loss = load_checkpoint(
        model, optimizer, config['checkpoint_path'], device
    )

    # Обучение
    train_losses = []
    val_losses = []
    learning_rates = []

    print("Starting training with Dist2Hist loss...")
    for epoch in range(start_epoch, config['num_epochs']):
        print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
        print("-" * 50)

        # Обучение
        train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch+1)
        train_losses.append(train_loss)

        # Валидация
        val_loss = validate_epoch(model, val_loader, criterion, device, epoch+1)
        val_losses.append(val_loss)

        # Обновление scheduler
        scheduler.step(val_loss)
        learning_rates.append(optimizer.param_groups[0]['lr'])

        # Сохранение лучшей модели
        if val_loss < best_loss:
            best_loss = val_loss
            save_checkpoint(
                model, optimizer, epoch, val_loss,
                f"{config['results_dir']}/{config['checkpoint_path']}"
            )

        # Сохранение истории обучения
        history = {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'learning_rates': learning_rates,
            'best_val_loss': best_loss
        }

        with open(f"{config['results_dir']}/training_history.json", 'w') as f:
            json.dump(history, f, indent=4)

        # Визуализация прогресса
        if epoch % 5 == 0 or epoch == config['num_epochs'] - 1:
            plt.figure(figsize=(15, 5))

            plt.subplot(1, 3, 1)
            plt.plot(train_losses, label='Train Loss', marker='o')
            plt.plot(val_losses, label='Validation Loss', marker='s')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Training Progress')
            plt.legend()
            plt.grid(True)

            plt.subplot(1, 3, 2)
            plt.plot(learning_rates, label='Learning Rate', marker='^', color='red')
            plt.xlabel('Epoch')
            plt.ylabel('Learning Rate')
            plt.title('Learning Rate Schedule')
            plt.legend()
            plt.grid(True)
            plt.yscale('log')

            plt.subplot(1, 3, 3)
            # Показываем компоненты потерь
            plt.bar(['Angular', 'Chromatic', 'MSE'],
                   [criterion.alpha, criterion.beta, criterion.gamma],
                   color=['blue', 'green', 'orange'])
            plt.title('Loss Components Weights')
            plt.ylabel('Weight')

            plt.tight_layout()
            plt.savefig(f"{config['results_dir']}/training_progress_epoch_{epoch+1}.png")
            plt.close()

        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}, LR = {optimizer.param_groups[0]['lr']:.2e}")

    print("Training completed!")
    print(f"Best validation loss: {best_loss:.6f}")

    # Финальное сохранение модели
    torch.save(model.state_dict(), f"{config['results_dir']}/final_model.pth")
    print(f"Final model saved to {config['results_dir']}/final_model.pth")

    return model, train_losses, val_losses

# Дополнительные утилиты
def test_loss_function():
    """Тестирование функции потерь"""
    criterion = Dist2HistLoss()

    # Тестовые данные
    batch_size = 4
    pred = torch.rand(batch_size, 3)  # Предсказания в [0, 1]
    target = torch.rand(batch_size, 3)  # Цели в [0, 1]

    loss = criterion(pred, target)
    print(f"Test loss: {loss.item():.6f}")

    # Компоненты потерь
    angular = criterion.angular_loss(pred, target)
    chroma = criterion.chromatic_loss(pred, target)
    mse = criterion.mse_loss(pred, target)

    print(f"Angular: {angular.item():.6f}")
    print(f"Chromatic: {chroma.item():.6f}")
    print(f"MSE: {mse.item():.6f}")

if __name__ == "__main__":
    # Тестирование функции потерь
    print("Testing loss function...")
    test_loss_function()

    print("\n" + "="*60)
    # Запуск обучения
    model, train_losses, val_losses = main()

Testing loss function...
Test loss: 0.550855
Angular: 0.711987
Chromatic: 0.345134
MSE: 0.201232

Using device: cpu
Creating data loaders...
Loaded white points from CSV: 570 entries
Common files: 84
Dataset created: 84 images, 84 histograms
Train samples: 67, Val samples: 17
Train batches: 67, Val batches: 17
Initializing model...
Starting training with Dist2Hist loss...

Epoch 1/33
--------------------------------------------------


Epoch 1 Training: 100%|██████████| 67/67 [00:22<00:00,  3.02it/s, Loss=0.001672, Avg Loss=0.003866]
Epoch 1 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.64it/s, Val Loss=0.025495, Avg Val Loss=0.010755]


Checkpoint saved: results/best_model.pth
Epoch 1: Train Loss = 0.003866, Val Loss = 0.010755, LR = 1.00e-04

Epoch 2/33
--------------------------------------------------


Epoch 2 Training: 100%|██████████| 67/67 [00:22<00:00,  2.99it/s, Loss=0.002720, Avg Loss=0.003218]
Epoch 2 Validation: 100%|██████████| 17/17 [00:01<00:00, 10.07it/s, Val Loss=0.006048, Avg Val Loss=0.005724]


Checkpoint saved: results/best_model.pth
Epoch 2: Train Loss = 0.003218, Val Loss = 0.005724, LR = 1.00e-04

Epoch 3/33
--------------------------------------------------


Epoch 3 Training: 100%|██████████| 67/67 [00:22<00:00,  3.04it/s, Loss=0.001049, Avg Loss=0.002210]
Epoch 3 Validation: 100%|██████████| 17/17 [00:01<00:00, 10.03it/s, Val Loss=0.006263, Avg Val Loss=0.005767]


Epoch 3: Train Loss = 0.002210, Val Loss = 0.005767, LR = 1.00e-04

Epoch 4/33
--------------------------------------------------


Epoch 4 Training: 100%|██████████| 67/67 [00:21<00:00,  3.05it/s, Loss=0.000611, Avg Loss=0.001804]
Epoch 4 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.31it/s, Val Loss=0.005970, Avg Val Loss=0.005604]


Checkpoint saved: results/best_model.pth
Epoch 4: Train Loss = 0.001804, Val Loss = 0.005604, LR = 1.00e-04

Epoch 5/33
--------------------------------------------------


Epoch 5 Training: 100%|██████████| 67/67 [00:21<00:00,  3.14it/s, Loss=0.002091, Avg Loss=0.001616]
Epoch 5 Validation: 100%|██████████| 17/17 [00:02<00:00,  6.25it/s, Val Loss=0.003224, Avg Val Loss=0.005333]


Checkpoint saved: results/best_model.pth
Epoch 5: Train Loss = 0.001616, Val Loss = 0.005333, LR = 1.00e-04

Epoch 6/33
--------------------------------------------------


Epoch 6 Training: 100%|██████████| 67/67 [00:21<00:00,  3.15it/s, Loss=0.000569, Avg Loss=0.001282]
Epoch 6 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.82it/s, Val Loss=0.000667, Avg Val Loss=0.003719]


Checkpoint saved: results/best_model.pth
Epoch 6: Train Loss = 0.001282, Val Loss = 0.003719, LR = 1.00e-04

Epoch 7/33
--------------------------------------------------


Epoch 7 Training: 100%|██████████| 67/67 [00:22<00:00,  2.97it/s, Loss=0.001374, Avg Loss=0.001093]
Epoch 7 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.62it/s, Val Loss=0.001377, Avg Val Loss=0.003056]


Checkpoint saved: results/best_model.pth
Epoch 7: Train Loss = 0.001093, Val Loss = 0.003056, LR = 1.00e-04

Epoch 8/33
--------------------------------------------------


Epoch 8 Training: 100%|██████████| 67/67 [00:22<00:00,  2.95it/s, Loss=0.000668, Avg Loss=0.001147]
Epoch 8 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.65it/s, Val Loss=0.003698, Avg Val Loss=0.002809]


Checkpoint saved: results/best_model.pth
Epoch 8: Train Loss = 0.001147, Val Loss = 0.002809, LR = 1.00e-04

Epoch 9/33
--------------------------------------------------


Epoch 9 Training: 100%|██████████| 67/67 [00:22<00:00,  2.94it/s, Loss=0.000821, Avg Loss=0.001031]
Epoch 9 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.64it/s, Val Loss=0.001262, Avg Val Loss=0.002690]


Checkpoint saved: results/best_model.pth
Epoch 9: Train Loss = 0.001031, Val Loss = 0.002690, LR = 1.00e-04

Epoch 10/33
--------------------------------------------------


Epoch 10 Training: 100%|██████████| 67/67 [00:25<00:00,  2.65it/s, Loss=0.000795, Avg Loss=0.000936]
Epoch 10 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.42it/s, Val Loss=0.001244, Avg Val Loss=0.002800]


Epoch 10: Train Loss = 0.000936, Val Loss = 0.002800, LR = 1.00e-04

Epoch 11/33
--------------------------------------------------


Epoch 11 Training: 100%|██████████| 67/67 [00:22<00:00,  2.92it/s, Loss=0.000515, Avg Loss=0.000906]
Epoch 11 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.61it/s, Val Loss=0.000915, Avg Val Loss=0.001725]


Checkpoint saved: results/best_model.pth
Epoch 11: Train Loss = 0.000906, Val Loss = 0.001725, LR = 1.00e-04

Epoch 12/33
--------------------------------------------------


Epoch 12 Training: 100%|██████████| 67/67 [00:22<00:00,  2.95it/s, Loss=0.000832, Avg Loss=0.000785]
Epoch 12 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.59it/s, Val Loss=0.002603, Avg Val Loss=0.001811]


Epoch 12: Train Loss = 0.000785, Val Loss = 0.001811, LR = 1.00e-04

Epoch 13/33
--------------------------------------------------


Epoch 13 Training: 100%|██████████| 67/67 [00:22<00:00,  2.99it/s, Loss=0.000611, Avg Loss=0.000742]
Epoch 13 Validation: 100%|██████████| 17/17 [00:02<00:00,  7.54it/s, Val Loss=0.003691, Avg Val Loss=0.002326]


Epoch 13: Train Loss = 0.000742, Val Loss = 0.002326, LR = 1.00e-04

Epoch 14/33
--------------------------------------------------


Epoch 14 Training: 100%|██████████| 67/67 [00:21<00:00,  3.09it/s, Loss=0.000704, Avg Loss=0.000704]
Epoch 14 Validation: 100%|██████████| 17/17 [00:02<00:00,  6.79it/s, Val Loss=0.008058, Avg Val Loss=0.002152]


Epoch 14: Train Loss = 0.000704, Val Loss = 0.002152, LR = 1.00e-04

Epoch 15/33
--------------------------------------------------


Epoch 15 Training: 100%|██████████| 67/67 [00:22<00:00,  3.00it/s, Loss=0.000701, Avg Loss=0.000718]
Epoch 15 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.61it/s, Val Loss=0.001208, Avg Val Loss=0.001826]


Epoch 15: Train Loss = 0.000718, Val Loss = 0.001826, LR = 1.00e-04

Epoch 16/33
--------------------------------------------------


Epoch 16 Training: 100%|██████████| 67/67 [00:22<00:00,  2.94it/s, Loss=0.000510, Avg Loss=0.000698]
Epoch 16 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.76it/s, Val Loss=0.003908, Avg Val Loss=0.001812]


Epoch 16: Train Loss = 0.000698, Val Loss = 0.001812, LR = 1.00e-04

Epoch 17/33
--------------------------------------------------


Epoch 17 Training: 100%|██████████| 67/67 [00:22<00:00,  2.93it/s, Loss=0.000675, Avg Loss=0.000641]
Epoch 17 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.79it/s, Val Loss=0.004914, Avg Val Loss=0.001272]


Checkpoint saved: results/best_model.pth
Epoch 17: Train Loss = 0.000641, Val Loss = 0.001272, LR = 1.00e-04

Epoch 18/33
--------------------------------------------------


Epoch 18 Training: 100%|██████████| 67/67 [00:22<00:00,  2.95it/s, Loss=0.000558, Avg Loss=0.000636]
Epoch 18 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.63it/s, Val Loss=0.004085, Avg Val Loss=0.001756]


Epoch 18: Train Loss = 0.000636, Val Loss = 0.001756, LR = 1.00e-04

Epoch 19/33
--------------------------------------------------


Epoch 19 Training: 100%|██████████| 67/67 [00:22<00:00,  2.94it/s, Loss=0.000732, Avg Loss=0.000651]
Epoch 19 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.60it/s, Val Loss=0.005045, Avg Val Loss=0.001554]


Epoch 19: Train Loss = 0.000651, Val Loss = 0.001554, LR = 1.00e-04

Epoch 20/33
--------------------------------------------------


Epoch 20 Training: 100%|██████████| 67/67 [00:22<00:00,  3.02it/s, Loss=0.000574, Avg Loss=0.000623]
Epoch 20 Validation: 100%|██████████| 17/17 [00:02<00:00,  7.59it/s, Val Loss=0.003757, Avg Val Loss=0.001651]


Epoch 20: Train Loss = 0.000623, Val Loss = 0.001651, LR = 1.00e-04

Epoch 21/33
--------------------------------------------------


Epoch 21 Training: 100%|██████████| 67/67 [00:21<00:00,  3.09it/s, Loss=0.000657, Avg Loss=0.000615]
Epoch 21 Validation: 100%|██████████| 17/17 [00:02<00:00,  7.03it/s, Val Loss=0.003861, Avg Val Loss=0.001658]


Epoch 21: Train Loss = 0.000615, Val Loss = 0.001658, LR = 1.00e-04

Epoch 22/33
--------------------------------------------------


Epoch 22 Training: 100%|██████████| 67/67 [00:21<00:00,  3.07it/s, Loss=0.000684, Avg Loss=0.000613]
Epoch 22 Validation: 100%|██████████| 17/17 [00:01<00:00,  9.28it/s, Val Loss=0.001738, Avg Val Loss=0.001375]


Epoch 22: Train Loss = 0.000613, Val Loss = 0.001375, LR = 1.00e-04

Epoch 23/33
--------------------------------------------------


Epoch 23 Training:  76%|███████▌  | 51/67 [00:18<00:05,  2.77it/s, Loss=0.000622, Avg Loss=0.000612]


KeyboardInterrupt: 

In [None]:
torch.save(model.state_dict(), "final_model.pth")

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import os
import cv2

class TestDataset(Dataset):
    """Датасет для тестовых изображений"""

    def __init__(self, csv_path, images_dir, img_size=224):
        self.csv_path = Path(csv_path)
        self.images_dir = Path(images_dir)
        self.img_size = img_size

        # Загрузка CSV файла
        self.df = pd.read_csv(self.csv_path)
        print(f"Loaded test CSV: {len(self.df)} samples")
        print(f"Columns: {list(self.df.columns)}")

        # Параметры нормализации (должны совпадать с тренировочными)
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

    def __len__(self):
        return len(self.df)

    def _read_image(self, image_path):
        """Чтение и предобработка изображения"""
        try:
            img = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
            if img is None:
                img = cv2.imread(str(image_path), cv2.IMREAD_COLOR)

            if img is None:
                raise ValueError(f"Cannot read image: {image_path}")

            # Конвертируем в float и нормализуем
            if img.dtype == np.uint16:
                img = img.astype(np.float32) / 65535.0
            elif img.dtype == np.uint8:
                img = img.astype(np.float32) / 255.0

            # BGR to RGB
            if img.shape[-1] == 3:
                img = img[..., ::-1]

            # Resize
            if img.shape[0] != self.img_size or img.shape[1] != self.img_size:
                img = cv2.resize(img, (self.img_size, self.img_size))

            return img

        except Exception as e:
            print(f"Error reading image {image_path}: {e}")
            return np.ones((self.img_size, self.img_size, 3), dtype=np.float32) * 0.5

    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]

            # Получаем путь к изображению (предполагаем, что первый столбец)
            image_path_str = row.iloc[0]  # Первая колонка содержит пути
            if pd.isna(image_path_str):
                image_path_str = f"test_image_{idx:04d}.png"

            # Создаем полный путь
            image_path = self.images_dir / image_path_str

            # Чтение изображения
            image = self._read_image(image_path)

            # Преобразование в tensor и нормализация
            image_tensor = torch.from_numpy(image.transpose(2, 0, 1)).float()
            image_tensor = (image_tensor - self.mean) / self.std

            return image_tensor, str(image_path_str)

        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # Возвращаем dummy данные
            dummy_image = torch.randn(3, self.img_size, self.img_size)
            return dummy_image, f"error_{idx}.png"

class WhiteBalanceModel(nn.Module):
    """Модель для предсказания точки белого"""

    def __init__(self, pretrained=False):
        super(WhiteBalanceModel, self).__init__()

        # Базовая архитектура (должна совпадать с обученной моделью)
        from torchvision.models import efficientnet_b0

        if pretrained:
            from torchvision.models import EfficientNet_B0_Weights
            self.backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        else:
            self.backbone = efficientnet_b0(weights=None)

        # Заменяем классификатор
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 3),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.backbone(x)

def load_model(model_path, device, pretrained=False):
    """Загрузка обученной модели"""
    print(f"Loading model from: {model_path}")

    model = WhiteBalanceModel(pretrained=pretrained).to(device)

    if os.path.exists(model_path):
        try:
            # Пробуем загрузить полный checkpoint или только веса
            checkpoint = torch.load(model_path, map_location=device)

            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
                print("Loaded from checkpoint")
            else:
                model.load_state_dict(checkpoint)
                print("Loaded model weights")

        except Exception as e:
            print(f"Error loading model: {e}")
            print("Using randomly initialized model")
    else:
        print(f"Model file not found: {model_path}")
        print("Using randomly initialized model")

    model.eval()
    return model

def create_predictions(model, test_loader, device):
    """Создание предсказаний для тестового набора"""
    model.eval()
    all_predictions = []
    all_image_paths = []

    with torch.no_grad():
        for images, image_paths in tqdm(test_loader, desc="Making predictions"):
            images = images.to(device)

            outputs = model(images)

            # Денормализуем предсказания из [0, 1] в [0, 65535]
            predictions = outputs.cpu().numpy() * 65535.0

            all_predictions.extend(predictions)
            all_image_paths.extend(image_paths)

    return all_image_paths, all_predictions

def create_submission_file(image_paths, predictions, output_csv="submission.csv"):
    """Создание submission файла в правильном формате"""
    print(f"Creating submission file: {output_csv}")

    data = []
    for img_path, pred in zip(image_paths, predictions):
        try:
            # Извлекаем только имя файла (без пути)
            if isinstance(img_path, str):
                filename = Path(img_path).name
            else:
                filename = f"image_{len(data):04d}.png"

            # Обеспечиваем корректный диапазон значений [0, 65535]
            wp_r = float(np.clip(pred[0], 0, 65535))
            wp_g = float(np.clip(pred[1], 0, 65535))
            wp_b = float(np.clip(pred[2], 0, 65535))

            data.append({
                'image_path': filename,
                'wp_r': wp_r,
                'wp_g': wp_g,
                'wp_b': wp_b
            })

        except Exception as e:
            print(f"Error processing {img_path}: {e}")
            # Добавляем значения по умолчанию
            data.append({
                'image_path': f"error_{len(data):04d}.png",
                'wp_r': 32768.0,
                'wp_g': 32768.0,
                'wp_b': 32768.0
            })

    # Создаем DataFrame
    df = pd.DataFrame(data)

    # Сохраняем CSV
    df.to_csv(output_csv, index=False, float_format='%.6f')

    print(f"Submission file created with {len(df)} predictions")
    print("First 5 predictions:")
    print(df.head())

    return df

def main():
    """Основная функция для создания предсказаний"""

    # Пути к данным
    TEST_CSV_PATH = "/content/test.csv"
    TEST_IMAGES_DIR = "/content/test_imgs2"
    MODEL_PATH = "/content/final_model.pth"
    OUTPUT_CSV = "/content/final_submission.csv"

    print("=" * 60)
    print("🔄 CREATING PREDICTIONS")
    print("=" * 60)

    # Проверка файлов
    print("🔍 Checking files...")
    print(f"Test CSV: {TEST_CSV_PATH} → {'✅' if os.path.exists(TEST_CSV_PATH) else '❌'}")
    print(f"Test images: {TEST_IMAGES_DIR} → {'✅' if os.path.exists(TEST_IMAGES_DIR) else '❌'}")
    print(f"Model: {MODEL_PATH} → {'✅' if os.path.exists(MODEL_PATH) else '❌'}")

    if not os.path.exists(TEST_CSV_PATH):
        print("❌ Test CSV not found!")
        return

    # Устройство
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🖥️  Device: {device}")

    # Создаем тестовый датасет
    print("\n📁 Creating test dataset...")
    try:
        test_dataset = TestDataset(TEST_CSV_PATH, TEST_IMAGES_DIR)
        print(f"✅ Test dataset created: {len(test_dataset)} samples")
    except Exception as e:
        print(f"❌ Error creating dataset: {e}")
        return

    # DataLoader
    test_loader = DataLoader(
        test_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=2
    )

    # Загрузка модели
    print("\n🤖 Loading model...")
    model = load_model(MODEL_PATH, device, pretrained=False)

    # Создание предсказаний
    print("\n🎯 Making predictions...")
    image_paths, predictions = create_predictions(model, test_loader, device)

    # Создание submission файла
    print("\n💾 Creating submission file...")
    submission_df = create_submission_file(image_paths, predictions, OUTPUT_CSV)

    # Статистика
    print("\n📊 Prediction statistics:")
    pred_array = np.array(predictions)
    print(f"Total predictions: {len(predictions)}")
    print(f"Value ranges:")
    print(f"  R: {pred_array[:, 0].min():.1f} - {pred_array[:, 0].max():.1f}")
    print(f"  G: {pred_array[:, 1].min():.1f} - {pred_array[:, 1].max():.1f}")
    print(f"  B: {pred_array[:, 2].min():.1f} - {pred_array[:, 2].max():.1f}")

    # Проверка корректности значений
    valid_mask = np.all((pred_array >= 0) & (pred_array <= 65535), axis=1)
    invalid_count = np.sum(~valid_mask)

    if invalid_count > 0:
        print(f"⚠️  Warning: {invalid_count} predictions outside valid range [0, 65535]")
    else:
        print("✅ All predictions are within valid range")

    print(f"\n🎉 Done! Submission file saved to: {OUTPUT_CSV}")

# Альтернативная простая версия
def quick_predict():
    """Быстрое создание предсказаний"""

    # Пути по умолчанию
    paths = {
        'test_csv': '/content/test (3).csv',
        'test_images': '/content/test_imgs',
        'model': '/content/final_model.pth',
        'output': '/content/submission (1).csv'
    }

    # Проверяем существование файлов
    for name, path in paths.items():
        if not os.path.exists(path) and name != 'output':
            print(f"❌ File not found: {path}")
            return

    # Запускаем основной процесс
    main()

# Утилиты для проверки
def check_submission_format():
    """Проверяет формат submission файла"""
    try:
        # Пример правильного формата
        example_data = {
            'image_path': ['test1.png', 'test2.png'],
            'wp_r': [30000.0, 32000.0],
            'wp_g': [31000.0, 33000.0],
            'wp_b': [29000.0, 28000.0]
        }

        example_df = pd.DataFrame(example_data)
        print("📋 Example submission format:")
        print(example_df)
        print("\n✅ Columns should be: image_path, wp_r, wp_g, wp_b")
        print("✅ Values should be in range [0, 65535]")

    except Exception as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    # Показываем пример формата
    check_submission_format()

    print("\n" + "="*60)
    # Запускаем создание предсказаний
    main()

    # Или быстрый запуск
    # quick_predict()

📋 Example submission format:
  image_path     wp_r     wp_g     wp_b
0  test1.png  30000.0  31000.0  29000.0
1  test2.png  32000.0  33000.0  28000.0

✅ Columns should be: image_path, wp_r, wp_g, wp_b
✅ Values should be in range [0, 65535]

🔄 CREATING PREDICTIONS
🔍 Checking files...
Test CSV: /content/test.csv → ✅
Test images: /content/test_imgs2 → ❌
Model: /content/final_model.pth → ✅
🖥️  Device: cpu

📁 Creating test dataset...
Loaded test CSV: 145 samples
Columns: ['names']
✅ Test dataset created: 145 samples

🤖 Loading model...
Loading model from: /content/final_model.pth
Loaded model weights

🎯 Making predictions...


Making predictions:   0%|          | 0/19 [00:00<?, ?it/s]

Error reading image /content/test_imgs2/test_imgs/0001.png: Cannot read image: /content/test_imgs2/test_imgs/0001.png
Error reading image /content/test_imgs2/test_imgs/0015.png: Cannot read image: /content/test_imgs2/test_imgs/0015.png
Error reading image /content/test_imgs2/test_imgs/0034.png: Cannot read image: /content/test_imgs2/test_imgs/0034.png
Error reading image /content/test_imgs2/test_imgs/0018.png: Cannot read image: /content/test_imgs2/test_imgs/0018.png
Error reading image /content/test_imgs2/test_imgs/0035.png: Cannot read image: /content/test_imgs2/test_imgs/0035.png
Error reading image /content/test_imgs2/test_imgs/0019.png: Cannot read image: /content/test_imgs2/test_imgs/0019.png
Error reading image /content/test_imgs2/test_imgs/0042.png: Cannot read image: /content/test_imgs2/test_imgs/0042.png
Error reading image /content/test_imgs2/test_imgs/0020.png: Cannot read image: /content/test_imgs2/test_imgs/0020.png
Error reading image /content/test_imgs2/test_imgs/0049.p

Making predictions:   5%|▌         | 1/19 [00:01<00:19,  1.06s/it]

Error reading image /content/test_imgs2/test_imgs/0238.png: Cannot read image: /content/test_imgs2/test_imgs/0238.png
Error reading image /content/test_imgs2/test_imgs/0239.png: Cannot read image: /content/test_imgs2/test_imgs/0239.png
Error reading image /content/test_imgs2/test_imgs/0241.png: Cannot read image: /content/test_imgs2/test_imgs/0241.png
Error reading image /content/test_imgs2/test_imgs/0242.png: Cannot read image: /content/test_imgs2/test_imgs/0242.png
Error reading image /content/test_imgs2/test_imgs/0243.png: Cannot read image: /content/test_imgs2/test_imgs/0243.png
Error reading image /content/test_imgs2/test_imgs/0245.png: Cannot read image: /content/test_imgs2/test_imgs/0245.png
Error reading image /content/test_imgs2/test_imgs/0264.png: Cannot read image: /content/test_imgs2/test_imgs/0264.png
Error reading image /content/test_imgs2/test_imgs/0291.png: Cannot read image: /content/test_imgs2/test_imgs/0291.png


Making predictions:  11%|█         | 2/19 [00:01<00:15,  1.12it/s]

Error reading image /content/test_imgs2/test_imgs/0297.png: Cannot read image: /content/test_imgs2/test_imgs/0297.png
Error reading image /content/test_imgs2/test_imgs/0312.png: Cannot read image: /content/test_imgs2/test_imgs/0312.png
Error reading image /content/test_imgs2/test_imgs/0313.png: Cannot read image: /content/test_imgs2/test_imgs/0313.png
Error reading image /content/test_imgs2/test_imgs/0317.png: Cannot read image: /content/test_imgs2/test_imgs/0317.png
Error reading image /content/test_imgs2/test_imgs/0319.png: Cannot read image: /content/test_imgs2/test_imgs/0319.png
Error reading image /content/test_imgs2/test_imgs/0334.png: Cannot read image: /content/test_imgs2/test_imgs/0334.png
Error reading image /content/test_imgs2/test_imgs/0341.png: Cannot read image: /content/test_imgs2/test_imgs/0341.png
Error reading image /content/test_imgs2/test_imgs/0348.png: Cannot read image: /content/test_imgs2/test_imgs/0348.png


Making predictions:  16%|█▌        | 3/19 [00:02<00:13,  1.19it/s]

Error reading image /content/test_imgs2/test_imgs/0351.png: Cannot read image: /content/test_imgs2/test_imgs/0351.png
Error reading image /content/test_imgs2/test_imgs/0355.png: Cannot read image: /content/test_imgs2/test_imgs/0355.png
Error reading image /content/test_imgs2/test_imgs/0357.png: Cannot read image: /content/test_imgs2/test_imgs/0357.png
Error reading image /content/test_imgs2/test_imgs/0358.png: Cannot read image: /content/test_imgs2/test_imgs/0358.png
Error reading image /content/test_imgs2/test_imgs/0359.png: Cannot read image: /content/test_imgs2/test_imgs/0359.png
Error reading image /content/test_imgs2/test_imgs/0373.png: Cannot read image: /content/test_imgs2/test_imgs/0373.png
Error reading image /content/test_imgs2/test_imgs/0376.png: Cannot read image: /content/test_imgs2/test_imgs/0376.png
Error reading image /content/test_imgs2/test_imgs/0377.png: Cannot read image: /content/test_imgs2/test_imgs/0377.png


Making predictions:  21%|██        | 4/19 [00:03<00:12,  1.20it/s]

Error reading image /content/test_imgs2/test_imgs/0379.png: Cannot read image: /content/test_imgs2/test_imgs/0379.png
Error reading image /content/test_imgs2/test_imgs/0380.png: Cannot read image: /content/test_imgs2/test_imgs/0380.png
Error reading image /content/test_imgs2/test_imgs/0385.png: Cannot read image: /content/test_imgs2/test_imgs/0385.png
Error reading image /content/test_imgs2/test_imgs/0396.png: Cannot read image: /content/test_imgs2/test_imgs/0396.png
Error reading image /content/test_imgs2/test_imgs/0401.png: Cannot read image: /content/test_imgs2/test_imgs/0401.png
Error reading image /content/test_imgs2/test_imgs/0407.png: Cannot read image: /content/test_imgs2/test_imgs/0407.png
Error reading image /content/test_imgs2/test_imgs/0411.png: Cannot read image: /content/test_imgs2/test_imgs/0411.png
Error reading image /content/test_imgs2/test_imgs/0416.png: Cannot read image: /content/test_imgs2/test_imgs/0416.png


Making predictions:  26%|██▋       | 5/19 [00:04<00:11,  1.23it/s]

Error reading image /content/test_imgs2/test_imgs/0443.png: Cannot read image: /content/test_imgs2/test_imgs/0443.png
Error reading image /content/test_imgs2/test_imgs/0456.png: Cannot read image: /content/test_imgs2/test_imgs/0456.png
Error reading image /content/test_imgs2/test_imgs/0464.png: Cannot read image: /content/test_imgs2/test_imgs/0464.png
Error reading image /content/test_imgs2/test_imgs/0470.png: Cannot read image: /content/test_imgs2/test_imgs/0470.png
Error reading image /content/test_imgs2/test_imgs/0479.png: Cannot read image: /content/test_imgs2/test_imgs/0479.png
Error reading image /content/test_imgs2/test_imgs/0486.png: Cannot read image: /content/test_imgs2/test_imgs/0486.png
Error reading image /content/test_imgs2/test_imgs/0496.png: Cannot read image: /content/test_imgs2/test_imgs/0496.png
Error reading image /content/test_imgs2/test_imgs/0500.png: Cannot read image: /content/test_imgs2/test_imgs/0500.png


Making predictions:  32%|███▏      | 6/19 [00:04<00:09,  1.34it/s]

Error reading image /content/test_imgs2/test_imgs/0513.png: Cannot read image: /content/test_imgs2/test_imgs/0513.png
Error reading image /content/test_imgs2/test_imgs/0528.png: Cannot read image: /content/test_imgs2/test_imgs/0528.png
Error reading image /content/test_imgs2/test_imgs/0534.png: Cannot read image: /content/test_imgs2/test_imgs/0534.png
Error reading image /content/test_imgs2/test_imgs/0558.png: Cannot read image: /content/test_imgs2/test_imgs/0558.png
Error reading image /content/test_imgs2/test_imgs/0570.png: Cannot read image: /content/test_imgs2/test_imgs/0570.png
Error reading image /content/test_imgs2/test_imgs/0574.png: Cannot read image: /content/test_imgs2/test_imgs/0574.png
Error reading image /content/test_imgs2/test_imgs/0589.png: Cannot read image: /content/test_imgs2/test_imgs/0589.png
Error reading image /content/test_imgs2/test_imgs/0592.png: Cannot read image: /content/test_imgs2/test_imgs/0592.png


Making predictions:  37%|███▋      | 7/19 [00:05<00:08,  1.46it/s]

Error reading image /content/test_imgs2/test_imgs/0596.png: Cannot read image: /content/test_imgs2/test_imgs/0596.png
Error reading image /content/test_imgs2/test_imgs/0600.png: Cannot read image: /content/test_imgs2/test_imgs/0600.png
Error reading image /content/test_imgs2/test_imgs/0601.png: Cannot read image: /content/test_imgs2/test_imgs/0601.png
Error reading image /content/test_imgs2/test_imgs/0604.png: Cannot read image: /content/test_imgs2/test_imgs/0604.png
Error reading image /content/test_imgs2/test_imgs/0609.png: Cannot read image: /content/test_imgs2/test_imgs/0609.png
Error reading image /content/test_imgs2/test_imgs/0614.png: Cannot read image: /content/test_imgs2/test_imgs/0614.png
Error reading image /content/test_imgs2/test_imgs/0616.png: Cannot read image: /content/test_imgs2/test_imgs/0616.png
Error reading image /content/test_imgs2/test_imgs/0627.png: Cannot read image: /content/test_imgs2/test_imgs/0627.png


Making predictions:  42%|████▏     | 8/19 [00:05<00:07,  1.53it/s]

Error reading image /content/test_imgs2/test_imgs/0634.png: Cannot read image: /content/test_imgs2/test_imgs/0634.png
Error reading image /content/test_imgs2/test_imgs/0644.png: Cannot read image: /content/test_imgs2/test_imgs/0644.png
Error reading image /content/test_imgs2/test_imgs/0647.png: Cannot read image: /content/test_imgs2/test_imgs/0647.png
Error reading image /content/test_imgs2/test_imgs/0649.png: Cannot read image: /content/test_imgs2/test_imgs/0649.png
Error reading image /content/test_imgs2/test_imgs/0661.png: Cannot read image: /content/test_imgs2/test_imgs/0661.png
Error reading image /content/test_imgs2/test_imgs/0665.png: Cannot read image: /content/test_imgs2/test_imgs/0665.png
Error reading image /content/test_imgs2/test_imgs/0669.png: Cannot read image: /content/test_imgs2/test_imgs/0669.png
Error reading image /content/test_imgs2/test_imgs/0677.png: Cannot read image: /content/test_imgs2/test_imgs/0677.png


Making predictions:  47%|████▋     | 9/19 [00:06<00:06,  1.60it/s]

Error reading image /content/test_imgs2/test_imgs/0679.png: Cannot read image: /content/test_imgs2/test_imgs/0679.png
Error reading image /content/test_imgs2/test_imgs/0682.png: Cannot read image: /content/test_imgs2/test_imgs/0682.png
Error reading image /content/test_imgs2/test_imgs/0687.png: Cannot read image: /content/test_imgs2/test_imgs/0687.png
Error reading image /content/test_imgs2/test_imgs/0693.png: Cannot read image: /content/test_imgs2/test_imgs/0693.png
Error reading image /content/test_imgs2/test_imgs/0700.png: Cannot read image: /content/test_imgs2/test_imgs/0700.png
Error reading image /content/test_imgs2/test_imgs/0709.png: Cannot read image: /content/test_imgs2/test_imgs/0709.png
Error reading image /content/test_imgs2/test_imgs/0726.png: Cannot read image: /content/test_imgs2/test_imgs/0726.png
Error reading image /content/test_imgs2/test_imgs/0727.png: Cannot read image: /content/test_imgs2/test_imgs/0727.png


Making predictions:  53%|█████▎    | 10/19 [00:07<00:05,  1.63it/s]

Error reading image /content/test_imgs2/test_imgs/0752.png: Cannot read image: /content/test_imgs2/test_imgs/0752.png
Error reading image /content/test_imgs2/test_imgs/0757.png: Cannot read image: /content/test_imgs2/test_imgs/0757.png
Error reading image /content/test_imgs2/test_imgs/0760.png: Cannot read image: /content/test_imgs2/test_imgs/0760.png
Error reading image /content/test_imgs2/test_imgs/0763.png: Cannot read image: /content/test_imgs2/test_imgs/0763.png
Error reading image /content/test_imgs2/test_imgs/0768.png: Cannot read image: /content/test_imgs2/test_imgs/0768.png
Error reading image /content/test_imgs2/test_imgs/0770.png: Cannot read image: /content/test_imgs2/test_imgs/0770.png
Error reading image /content/test_imgs2/test_imgs/0774.png: Cannot read image: /content/test_imgs2/test_imgs/0774.png
Error reading image /content/test_imgs2/test_imgs/0778.png: Cannot read image: /content/test_imgs2/test_imgs/0778.png


Making predictions:  58%|█████▊    | 11/19 [00:07<00:04,  1.66it/s]

Error reading image /content/test_imgs2/test_imgs/0791.png: Cannot read image: /content/test_imgs2/test_imgs/0791.png
Error reading image /content/test_imgs2/test_imgs/0792.png: Cannot read image: /content/test_imgs2/test_imgs/0792.png
Error reading image /content/test_imgs2/test_imgs/0799.png: Cannot read image: /content/test_imgs2/test_imgs/0799.png
Error reading image /content/test_imgs2/test_imgs/0800.png: Cannot read image: /content/test_imgs2/test_imgs/0800.png
Error reading image /content/test_imgs2/test_imgs/0804.png: Cannot read image: /content/test_imgs2/test_imgs/0804.png
Error reading image /content/test_imgs2/test_imgs/0806.png: Cannot read image: /content/test_imgs2/test_imgs/0806.png
Error reading image /content/test_imgs2/test_imgs/0816.png: Cannot read image: /content/test_imgs2/test_imgs/0816.png
Error reading image /content/test_imgs2/test_imgs/0818.png: Cannot read image: /content/test_imgs2/test_imgs/0818.png


Making predictions:  63%|██████▎   | 12/19 [00:08<00:04,  1.68it/s]

Error reading image /content/test_imgs2/test_imgs/0821.png: Cannot read image: /content/test_imgs2/test_imgs/0821.png
Error reading image /content/test_imgs2/test_imgs/0823.png: Cannot read image: /content/test_imgs2/test_imgs/0823.png
Error reading image /content/test_imgs2/test_imgs/0842.png: Cannot read image: /content/test_imgs2/test_imgs/0842.png
Error reading image /content/test_imgs2/test_imgs/0852.png: Cannot read image: /content/test_imgs2/test_imgs/0852.png
Error reading image /content/test_imgs2/test_imgs/0855.png: Cannot read image: /content/test_imgs2/test_imgs/0855.png
Error reading image /content/test_imgs2/test_imgs/0868.png: Cannot read image: /content/test_imgs2/test_imgs/0868.png
Error reading image /content/test_imgs2/test_imgs/0877.png: Cannot read image: /content/test_imgs2/test_imgs/0877.png
Error reading image /content/test_imgs2/test_imgs/0883.png: Cannot read image: /content/test_imgs2/test_imgs/0883.png


Making predictions:  68%|██████▊   | 13/19 [00:08<00:03,  1.69it/s]

Error reading image /content/test_imgs2/test_imgs/0890.png: Cannot read image: /content/test_imgs2/test_imgs/0890.png
Error reading image /content/test_imgs2/test_imgs/0895.png: Cannot read image: /content/test_imgs2/test_imgs/0895.png
Error reading image /content/test_imgs2/test_imgs/0896.png: Cannot read image: /content/test_imgs2/test_imgs/0896.png
Error reading image /content/test_imgs2/test_imgs/0909.png: Cannot read image: /content/test_imgs2/test_imgs/0909.png
Error reading image /content/test_imgs2/test_imgs/0912.png: Cannot read image: /content/test_imgs2/test_imgs/0912.png
Error reading image /content/test_imgs2/test_imgs/0917.png: Cannot read image: /content/test_imgs2/test_imgs/0917.png
Error reading image /content/test_imgs2/test_imgs/0924.png: Cannot read image: /content/test_imgs2/test_imgs/0924.png
Error reading image /content/test_imgs2/test_imgs/0925.png: Cannot read image: /content/test_imgs2/test_imgs/0925.png


Making predictions:  74%|███████▎  | 14/19 [00:09<00:02,  1.69it/s]

Error reading image /content/test_imgs2/test_imgs/0926.png: Cannot read image: /content/test_imgs2/test_imgs/0926.png


Making predictions: 100%|██████████| 19/19 [00:11<00:00,  1.61it/s]


💾 Creating submission file...
Creating submission file: /content/final_submission.csv
Submission file created with 145 predictions
First 5 predictions:
  image_path     wp_r  wp_g  wp_b
0   0001.png  65535.0   0.0   0.0
1   0015.png  65535.0   0.0   0.0
2   0018.png  65535.0   0.0   0.0
3   0019.png  65535.0   0.0   0.0
4   0020.png  65535.0   0.0   0.0

📊 Prediction statistics:
Total predictions: 145
Value ranges:
  R: 65535.0 - 65535.0
  G: 0.0 - 0.0
  B: 0.0 - 0.0
✅ All predictions are within valid range

🎉 Done! Submission file saved to: /content/final_submission.csv





In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

def merge_submission_files(submission_path, source_path, output_path):
    """
    Объединяет данные: берет первую колонку из submission файла
    и остальные три колонки из source файла

    Args:
        submission_path: путь к файлу submission.csv (откуда берем image_path)
        source_path: путь к исходному файлу (откуда берем wp_r, wp_g, wp_b)
        output_path: путь для сохранения результата
    """

    print("🔍 Reading files...")

    try:
        # Читаем submission файл
        submission_df = pd.read_csv(submission_path)
        print(f"✅ Submission file loaded: {len(submission_df)} rows")
        print(f"   Columns: {list(submission_df.columns)}")

        # Читаем source файл
        source_df = pd.read_csv(source_path)
        print(f"✅ Source file loaded: {len(source_df)} rows")
        print(f"   Columns: {list(source_df.columns)}")

    except Exception as e:
        print(f"❌ Error reading files: {e}")
        return None

    # Проверяем, что файлы имеют одинаковое количество строк
    if len(submission_df) != len(source_df):
        print(f"⚠️  Warning: Different number of rows!")
        print(f"   Submission: {len(submission_df)} rows")
        print(f"   Source: {len(source_df)} rows")

        # Берем минимальное количество строк
        min_rows = min(len(submission_df), len(source_df))
        submission_df = submission_df.head(min_rows)
        source_df = source_df.head(min_rows)
        print(f"   Using first {min_rows} rows from each file")

    # Проверяем наличие нужных колонок
    submission_cols = submission_df.columns
    source_cols = source_df.columns

    # Определяем колонку с путями из submission файла
    if 'image_path' in submission_cols:
        image_path_col = 'image_path'
    else:
        # Берем первую колонку
        image_path_col = submission_cols[0]
        print(f"ℹ️  Using first column as image_path: {image_path_col}")

    # Определяем колонки для white balance из source файла
    wp_cols = []
    for col in ['wp_r', 'wp_g', 'wp_b']:
        if col in source_cols:
            wp_cols.append(col)
        else:
            # Ищем альтернативные названия
            for source_col in source_cols:
                if col in source_col.lower() or 'white' in source_col.lower():
                    wp_cols.append(source_col)
                    print(f"ℹ️  Using {source_col} as {col}")
                    break
            else:
                # Если не нашли, берем первые три колонки после image_path
                if len(source_cols) >= 4:
                    wp_cols.extend(source_cols[1:4])
                else:
                    wp_cols.extend(source_cols[:3])
                print(f"ℹ️  Using columns {wp_cols} for white balance values")
                break

    # Ограничиваем до 3 колонок
    wp_cols = wp_cols[:3]

    # Создаем новый DataFrame
    print("\n🔄 Merging data...")

    try:
        # Берем image_path из submission файла
        result_df = pd.DataFrame()
        result_df['image_path'] = submission_df[image_path_col]

        # Берем white balance значения из source файла
        for i, col in enumerate(wp_cols[:3]):  # Берем максимум 3 колонки
            if i == 0:
                result_df['wp_r'] = source_df[col]
            elif i == 1:
                result_df['wp_g'] = source_df[col]
            elif i == 2:
                result_df['wp_b'] = source_df[col]

        # Если не хватило колонок, заполняем значениями по умолчанию
        if len(wp_cols) < 3:
            print(f"⚠️  Only {len(wp_cols)} white balance columns found")
            if 'wp_r' not in result_df.columns:
                result_df['wp_r'] = 32768.0
            if 'wp_g' not in result_df.columns:
                result_df['wp_g'] = 32768.0
            if 'wp_b' not in result_df.columns:
                result_df['wp_b'] = 32768.0

        # Сохраняем результат
        result_df.to_csv(output_path, index=False, float_format='%.6f')

        print(f"✅ Merged file saved: {output_path}")
        print(f"📊 Result shape: {result_df.shape}")
        print(f"📋 Columns: {list(result_df.columns)}")
        print("\n📄 First 5 rows:")
        print(result_df.head())

        return result_df

    except Exception as e:
        print(f"❌ Error merging files: {e}")
        return None

def create_manual_merge(submission_path, source_path, output_path):
    """
    Ручное объединение с выбором колонок
    """
    print("📝 Manual merge mode")

    # Читаем файлы
    submission_df = pd.read_csv(submission_path)
    source_df = pd.read_csv(source_path)

    print(f"Submission file columns: {list(submission_df.columns)}")
    print(f"Source file columns: {list(source_df.columns)}")

    # Создаем новый DataFrame
    result_df = pd.DataFrame()

    # Выбираем колонку для image_path
    image_col = input("Enter column name for image_path from submission file: ").strip()
    if image_col not in submission_df.columns:
        print(f"Column '{image_col}' not found, using first column")
        image_col = submission_df.columns[0]

    result_df['image_path'] = submission_df[image_col]

    # Выбираем колонки для white balance
    wp_mapping = {}
    for wp_col in ['wp_r', 'wp_g', 'wp_b']:
        col_name = input(f"Enter column name for {wp_col} from source file: ").strip()
        if col_name in source_df.columns:
            result_df[wp_col] = source_df[col_name]
            wp_mapping[wp_col] = col_name
        else:
            print(f"Column '{col_name}' not found, using default value 32768.0")
            result_df[wp_col] = 32768.0

    # Сохраняем
    result_df.to_csv(output_path, index=False, float_format='%.6f')
    print(f"✅ Manual merge saved to: {output_path}")

    return result_df

def quick_merge():
    """
    Быстрое объединение с стандартными путями
    """
    submission_file = "/content/submission.csv"
    source_file = "/content/submission.csv"  # или ваш файл с white balance значениями
    output_file = "/content/merged_submission.csv"

    print("🚀 Quick merge with default paths:")
    print(f"Submission: {submission_file}")
    print(f"Source: {source_file}")
    print(f"Output: {output_file}")

    return merge_submission_files(submission_file, source_file, output_file)

# Дополнительные утилиты
def check_file_info(file_path):
    """Проверка информации о файле"""
    try:
        df = pd.read_csv(file_path)
        print(f"📁 File: {file_path}")
        print(f"📊 Shape: {df.shape}")
        print(f"📋 Columns: {list(df.columns)}")
        print(f"🔢 Dtypes:\n{df.dtypes}")
        print("\n📄 First 3 rows:")
        print(df.head(3))
        print("\n📄 Last 3 rows:")
        print(df.tail(3))

    except Exception as e:
        print(f"❌ Error reading {file_path}: {e}")

def compare_files(file1_path, file2_path):
    """Сравнение двух файлов"""
    df1 = pd.read_csv(file1_path)
    df2 = pd.read_csv(file2_path)

    print("📊 File Comparison:")
    print(f"File 1: {file1_path} - {df1.shape}")
    print(f"File 2: {file2_path} - {df2.shape}")

    print("\n📋 Columns comparison:")
    print(f"File 1 columns: {list(df1.columns)}")
    print(f"File 2 columns: {list(df2.columns)}")

    common_cols = set(df1.columns) & set(df2.columns)
    print(f"Common columns: {common_cols}")

# Основная функция
def main():
    """Основная функция для запуска в Colab"""

    print("=" * 60)
    print("📊 SUBISSION FILE MERGER")
    print("=" * 60)

    # Стандартные пути для Colab
    files_to_check = [
        "/content/submission.csv",
        "/content/you1.csv",
        "/content/train.csv"
    ]

    print("🔍 Checking available files:")
    for file_path in files_to_check:
        exists = Path(file_path).exists()
        status = "✅" if exists else "❌"
        print(f"{status} {file_path}")

    # Автоматический поиск source файла
    source_candidates = [
        "/content/you1.csv",
        "/content/train.csv",
        "/content/test.csv",
        "/content/data.csv"
    ]

    source_file = None
    for candidate in source_candidates:
        if Path(candidate).exists():
            source_file = candidate
            break

    if source_file:
        print(f"\n🎯 Found source file: {source_file}")

        # Запускаем автоматическое объединение
        result = merge_submission_files(
            submission_path="/content/submission.csv",
            source_path=source_file,
            output_path="/content/final_submission.csv"
        )

        if result is not None:
            print("\n🎉 Merge completed successfully!")
            print("📁 Final files:")
            !ls -la /content/*.csv
        else:
            print("\n❌ Merge failed!")

    else:
        print("\n❌ No source file found for white balance values!")
        print("Please specify paths manually:")

        submission_path = input("Enter submission file path: ").strip() or "/content/submission.csv"
        source_path = input("Enter source file path: ").strip()
        output_path = input("Enter output file path: ").strip() or "/content/merged_submission.csv"

        if Path(submission_path).exists() and Path(source_path).exists():
            merge_submission_files(submission_path, source_path, output_path)
        else:
            print("❌ Files not found!")

# Функции для быстрого использования в Colab
def show_csv_preview():
    """Показать превью всех CSV файлов"""
    csv_files = list(Path("/content").glob("*.csv"))

    for csv_file in csv_files:
        print(f"\n{'='*50}")
        print(f"📄 {csv_file.name}")
        print(f"{'='*50}")

        try:
            df = pd.read_csv(csv_file)
            print(f"Shape: {df.shape}")
            print(f"Columns: {list(df.columns)}")
            print("\nFirst 2 rows:")
            print(df.head(2))
        except Exception as e:
            print(f"Error reading: {e}")

# Запуск
if __name__ == "__main__":
    # Показываем доступные файлы
    show_csv_preview()

    print("\n" + "="*60)
    # Запускаем основной процесс
    main()


📄 test.csv
Shape: (145, 1)
Columns: ['names']

First 2 rows:
                names
0  test_imgs/0001.png
1  test_imgs/0015.png

📄 submission.csv
Shape: (145, 4)
Columns: ['names', 'wp_r', 'wp_g', 'wp_b']

First 2 rows:
                names      wp_r      wp_g      wp_b
0  test_imgs/0001.png  0.393742  0.610143  0.786508
1  test_imgs/0015.png  0.738509  0.593700  0.253388

📄 final_submission.csv
Shape: (145, 4)
Columns: ['image_path', 'wp_r', 'wp_g', 'wp_b']

First 2 rows:
  image_path     wp_r  wp_g  wp_b
0   0001.png  65535.0   0.0   0.0
1   0015.png  65535.0   0.0   0.0

📄 train.csv
Shape: (570, 4)
Columns: ['names', 'wp_r', 'wp_g', 'wp_b']

First 2 rows:
                 names      wp_r      wp_g      wp_b
0  train_imgs/0000.png  0.173683  0.508642  0.215429
1  train_imgs/0002.png  0.266894  0.956725  0.577948

📊 SUBISSION FILE MERGER
🔍 Checking available files:
✅ /content/submission.csv
❌ /content/you1.csv
✅ /content/train.csv

🎯 Found source file: /content/train.csv
🔍 Reading fi