In [None]:
import cv2
import numpy  as np
from pathlib import Path

from math import sqrt

TARGET_COLUMN_NAMES = ['wp_r', 'wp_g', 'wp_b']

HIST_BINS = [116, 100]
HIST_RANGE_FLAT = [-sqrt(3), sqrt(3), -1, 2] #value ranges for beta and alpha
HIST_TARGET_SIZE = [128, 128]
HIST_VERT_PADD, HIST_HORR_PADD = 6, 14

WHITE_LEVEL = 2**12 - 1 - 256

def read_hist(path2hist: Path):
    hist = cv2.imread(str(path2hist), cv2.IMREAD_UNCHANGED)
    hist = hist[HIST_HORR_PADD:-HIST_HORR_PADD,
                HIST_VERT_PADD:-HIST_VERT_PADD]
    hist = hist.astype(np.float32) / 255
    return hist.astype(np.float32)

In [None]:
from typing import Literal
from pathlib import Path

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import cv2
import pandas as pd
import numpy as np
from ast import literal_eval
from tqdm import tqdm

this_dir = Path('/content/sasd').parent



def read_image(path2img: Path | str,
           white_level_corr: bool = True):
    img = cv2.imread(str(path2img), cv2.IMREAD_UNCHANGED)
    if white_level_corr:
        img = img / WHITE_LEVEL
    if img.shape[-1] == 3:
        img = img[..., ::-1]
    return img


def get_hists_dir(part: Literal['train', 'test'] = 'train'):
    return this_dir / f'{part}_histograms'


def get_imgs_dir(part: Literal['train', 'test'] = 'train'):
    return this_dir / f'{part}_imgs'


def get_historgam_by_name(img_name: str,
                          part: Literal['train', 'test'] = 'train'):
    hists_dir = get_hists_dir(part)
    hist = read_hist((hists_dir / img_name).with_suffix('.png'))
    return hist


class IllumDataset:
    def __init__(self,
                 part: Literal['train', 'test'] = 'train'):
        self.part = part
        self._init_paths()
        self._init_white_points()

    def _init_paths(self):
        self.imgs_paths = sorted(get_imgs_dir(self.part).glob('*.png'))
        self.hists_paths = sorted(get_hists_dir(self.part).glob('*.png'))

    def _init_white_points(self):
        df = pd.read_csv(this_dir / f'{self.part}.csv', converters={'white_points': literal_eval})
        self.white_points = df[['wp_r', 'wp_g', 'wp_b']].values

    def __getitem__(self, idx):
        return read_image(self.imgs_paths[idx]), read_hist(self.hists_paths[idx]), self.white_points[idx]

    def __len__(self):
        return len(self.imgs_paths)


class AWBDataset(Dataset):
    def __init__(self,
                 dataset: IllumDataset,
                 ids: list,
                 require_transform: bool = False,
                 preload: bool = False):
        self.imgs_paths   = [elem for i, elem in enumerate(dataset.imgs_paths) if i in ids]
        self.hists_paths  = [elem for i, elem in enumerate(dataset.hists_paths) if i in ids]
        self.white_points = [elem for i, elem in enumerate(dataset.white_points) if i in ids]
        self.preload = preload
        if require_transform:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.RandomCrop(256)
            ])
        else:
            self.transform = transforms.Compose([
                transforms.ToTensor()
            ])
        if self.preload:
            self.imgs = [read_image(elem) for elem in tqdm(self.imgs_paths, desc="Preloading images")]
            self.hists = [read_hist(elem) for elem in tqdm(self.imgs_paths, desc="Preloading histograms")]


    def __getitem__(self, idx):
        if self.preload:
            return (
                self.transform(self.imgs[idx]),
                self.hists[idx],
                np.array(self.white_points[idx], dtype=np.float32)
            )
        else:
            return (
                self.transform((read_image(self.imgs_paths[idx])).astype(np.float32)),
                read_hist(self.hists_paths[idx]),
                np.array(self.white_points[idx], dtype=np.float32)
            )



    def __len__(self):
        return len(self.imgs_paths)


class DataModule:
    def __init__(self,
                 dataset: IllumDataset,
                 val_size: float = 0.2,
                 batch_size: int = 32,
                 preload: bool = False):
        self.dataset = dataset
        self.batch_size = batch_size
        self.preload = preload
        self._make_train_val_split(val_size=val_size)
        self._init_datasets()
        self._init_dataloaders()

    def _make_train_val_split(self, val_size: float = 0.2):
        n = len(self.dataset)
        permuted_ids = np.random.permutation(np.arange(n))
        self.val_ids = permuted_ids[:int(n * val_size)]
        self.train_ids = permuted_ids[int(n * val_size):]

    def _init_datasets(self):
        self.train_dataset = AWBDataset(self.dataset, self.train_ids, True, self.preload)
        self.val_dataset   = AWBDataset(self.dataset, self.val_ids, False, self.preload)

    def _init_dataloaders(self):
        self.train_dataloader = DataLoader(self.train_dataset,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=10)
        self.val_dataloader = DataLoader(self.val_dataset,
                                         batch_size=self.batch_size,
                                         shuffle=False,
                                         num_workers=2)

    def get_train_dataloader(self) -> DataLoader:
        return self.train_dataloader

    def get_val_dataloader(self) -> DataLoader:
        return self.val_dataloader

    def get_train_dataset(self) -> AWBDataset:
        return self.train_dataset

    def get_val_dataset(self) -> AWBDataset:
        return self.val_dataset


def create_datamodule(mode: Literal['train', 'test'] = 'train',
                      val_size=0.2, batch_size=32) -> DataModule:
    illum_dataset = IllumDataset(mode)
    return DataModule(illum_dataset, val_size=val_size, batch_size=batch_size)

In [None]:
import cv2
import numpy as np
from pathlib import Path
from math import sqrt
from typing import Literal
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pandas as pd
from ast import literal_eval
from tqdm import tqdm

# Константы из первого скрипта
TARGET_COLUMN_NAMES = ['wp_r', 'wp_g', 'wp_b']
HIST_BINS = [116, 100]
HIST_RANGE_FLAT = [-sqrt(3), sqrt(3), -1, 2]  # value ranges for beta and alpha
HIST_TARGET_SIZE = [128, 128]
HIST_VERT_PADD, HIST_HORR_PADD = 6, 14
WHITE_LEVEL = 2**12 - 1 - 256

def read_hist(path2hist: Path):
    """Чтение гистограммы из файла"""
    hist = cv2.imread(str(path2hist), cv2.IMREAD_UNCHANGED)
    hist = hist[HIST_HORR_PADD:-HIST_HORR_PADD,
                HIST_VERT_PADD:-HIST_VERT_PADD]
    hist = hist.astype(np.float32) / 255
    return hist.astype(np.float32)

def read_image(path2img: Path | str, white_level_corr: bool = True):
    """Чтение изображения из файла"""
    img = cv2.imread(str(path2img), cv2.IMREAD_UNCHANGED)
    if white_level_corr:
        img = img / WHITE_LEVEL
    if img.shape[-1] == 3:
        img = img[..., ::-1]  # Convert BGR to RGB
    return img

def get_hists_dir(part: Literal['train', 'test'] = 'train'):
    """Получение пути к директории с гистограммами"""
    return Path('/content/sasd').parent / f'{part}_histograms'

def get_imgs_dir(part: Literal['train', 'test'] = 'train'):
    """Получение пути к директории с изображениями"""
    return Path('/content/sasd').parent / f'{part}_imgs'

def get_histogram_by_name(img_name: str, part: Literal['train', 'test'] = 'train'):
    """Получение гистограммы по имени изображения"""
    hists_dir = get_hists_dir(part)
    hist = read_hist((hists_dir / img_name).with_suffix('.png'))
    return hist

class IllumDataset(Dataset):
    """Основной датасет для работы с изображениями и гистограммами"""

    def __init__(self, part: Literal['train', 'test'] = 'train'):
        self.part = part
        self._init_paths()
        self._init_white_points()

    def _init_paths(self):
        """Инициализация путей к файлам"""
        self.imgs_dir = get_imgs_dir(self.part)
        self.hists_dir = get_hists_dir(self.part)

        # Получаем список всех изображений
        self.imgs_paths = sorted(self.imgs_dir.glob('*.png'))
        self.hists_paths = sorted(self.hists_dir.glob('*.png'))

        # Проверяем соответствие количества файлов
        if len(self.imgs_paths) != len(self.hists_paths):
            print(f"Warning: Different number of images ({len(self.imgs_paths)}) and histograms ({len(self.hists_paths)})")

    def _init_white_points(self):
        """Инициализация белых точек из CSV файла"""
        csv_path = Path("/content/sasd").parent / f'{self.part}.csv'
        if csv_path.exists():
            df = pd.read_csv(csv_path, converters={'white_points': literal_eval})
            self.white_points = df[['wp_r', 'wp_g', 'wp_b']].values
        else:
            print(f"Warning: CSV file {csv_path} not found")
            self.white_points = np.zeros((len(self.imgs_paths), 3))

    def __getitem__(self, idx):
        """Получение одного элемента датасета"""
        # Чтение изображения
        image = read_image(self.imgs_paths[idx])

        # Чтение гистограммы
        histogram = read_hist(self.hists_paths[idx])

        # Получение белых точек
        white_point = self.white_points[idx] if hasattr(self, 'white_points') else np.zeros(3)

        return {
            'image': image.astype(np.float32),
            'histogram': histogram.astype(np.float32),
            'white_point': white_point.astype(np.float32),
            'image_path': str(self.imgs_paths[idx]),
            'histogram_path': str(self.hists_paths[idx])
        }

    def __len__(self):
        """Количество элементов в датасете"""
        return len(self.imgs_paths)

class AWBDataset(Dataset):
    """Расширенный датасет с дополнительными возможностями"""

    def __init__(self, base_dataset: IllumDataset, indices: list = None,
                 transform=None, preload: bool = False):
        self.base_dataset = base_dataset
        self.indices = indices if indices is not None else list(range(len(base_dataset)))
        self.transform = transform
        self.preload = preload

        if self.preload:
            self._preload_data()

    def _preload_data(self):
        """Предзагрузка данных в память"""
        self.images = []
        self.histograms = []
        self.white_points = []

        for idx in tqdm(self.indices, desc="Preloading data"):
            item = self.base_dataset[idx]
            self.images.append(item['image'])
            self.histograms.append(item['histogram'])
            self.white_points.append(item['white_point'])

    def __getitem__(self, idx):
        """Получение одного элемента датасета"""
        actual_idx = self.indices[idx]

        if self.preload:
            image = self.images[idx]
            histogram = self.histograms[idx]
            white_point = self.white_points[idx]
        else:
            item = self.base_dataset[actual_idx]
            image = item['image']
            histogram = item['histogram']
            white_point = item['white_point']

        # Применение трансформаций
        if self.transform:
            image = self.transform(image)

        return {
            'image': image,
            'histogram': histogram,
            'white_point': white_point
        }

    def __len__(self):
        """Количество элементов в датасете"""
        return len(self.indices)

class DataModule:
    """Модуль данных для обучения и валидации"""

    def __init__(self, part: Literal['train', 'test'] = 'train',
                 val_size: float = 0.2, batch_size: int = 32,
                 preload: bool = False, transform=None):
        self.part = part
        self.val_size = val_size
        self.batch_size = batch_size
        self.preload = preload

        # Создаем базовый датасет
        self.base_dataset = IllumDataset(part)

        # Разделяем на train/val
        self._make_train_val_split()

        # Трансформации
        self.transform = transform if transform else transforms.Compose([
            transforms.ToTensor()
        ])

        # Создаем датасеты
        self._init_datasets()

        # Создаем dataloaders
        self._init_dataloaders()

    def _make_train_val_split(self):
        """Разделение данных на тренировочную и валидационную части"""
        n = len(self.base_dataset)
        all_indices = np.arange(n)
        np.random.shuffle(all_indices)

        val_count = int(n * self.val_size)
        self.train_indices = all_indices[val_count:]
        self.val_indices = all_indices[:val_count]

    def _init_datasets(self):
        """Инициализация датасетов"""
        self.train_dataset = AWBDataset(
            self.base_dataset, self.train_indices,
            transform=self.transform, preload=self.preload
        )

        self.val_dataset = AWBDataset(
            self.base_dataset, self.val_indices,
            transform=self.transform, preload=self.preload
        )

    def _init_dataloaders(self):
        """Инициализация dataloaders"""
        self.train_dataloader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True
        )

        self.val_dataloader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )

    def get_dataloaders(self):
        """Получение train и val dataloaders"""
        return self.train_dataloader, self.val_dataloader

    def get_datasets(self):
        """Получение train и val datasets"""
        return self.train_dataset, self.val_dataset

# Пример использования
if __name__ == "__main__":
    # Создаем базовый датасет
    dataset = IllumDataset('train')

    print(f"Dataset size: {len(dataset)}")

    # Получаем первый элемент
    sample = dataset[0]
    print(f"Image shape: {sample['image'].shape}")
    print(f"Histogram shape: {sample['histogram'].shape}")
    print(f"White point: {sample['white_point']}")

    # Создаем DataModule для обучения
    datamodule = DataModule('train', val_size=0.2, batch_size=16, preload=False)

    train_loader, val_loader = datamodule.get_dataloaders()

    print(f"Train batches: {len(train_loader)}")
    print(f"Val batches: {len(val_loader)}")

    # Пример итерации по данным
    for batch in train_loader:
        images = batch['image']
        histograms = batch['histogram']
        white_points = batch['white_point']

        print(f"Batch images shape: {images.shape}")
        print(f"Batch histograms shape: {histograms.shape}")
        print(f"Batch white points shape: {white_points.shape}")
        break

Dataset size: 1
Image shape: (768, 1024, 3)
Histogram shape: (100, 116)
White point: [0.17368333 0.5086421  0.21542864]
Train batches: 1
Val batches: 0




Batch images shape: torch.Size([1, 3, 768, 1024])
Batch histograms shape: torch.Size([1, 100, 116])
Batch white points shape: torch.Size([1, 3])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from math import acos, pi

class AWB_DualStream(nn.Module):
    def __init__(self, pretrained=True):
        super(AWB_DualStream, self).__init__()

        # Image branch - ResNet18
        resnet = models.resnet18(pretrained=pretrained)
        self.image_branch = nn.Sequential(*list(resnet.children())[:-2])
        self.image_pool = nn.AdaptiveAvgPool2d(1)

        # Histogram branch
        self.hist_branch = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )

        # Fusion and regression
        self.fusion = nn.Sequential(
            nn.Linear(512 + 128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),

            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),

            nn.Linear(128, 3)
        )

    def forward(self, image, hist):
        # Image features
        img_feat = self.image_branch(image)
        img_feat = self.image_pool(img_feat)
        img_feat = img_feat.view(img_feat.size(0), -1)

        # Histogram features (add channel dimension if needed)
        if hist.dim() == 3:
            hist = hist.unsqueeze(1)
        hist_feat = self.hist_branch(hist)
        hist_feat = hist_feat.view(hist_feat.size(0), -1)

        # Fusion
        combined = torch.cat([img_feat, hist_feat], dim=1)
        wp_pred = self.fusion(combined)

        return wp_pred

In [None]:
def angular_loss(y_pred, y_true, eps=1e-8):
    """Angular loss between predicted and true white points"""
    y_pred_norm = y_pred / (y_pred.norm(dim=1, keepdim=True) + eps)
    y_true_norm = y_true / (y_true.norm(dim=1, keepdim=True) + eps)

    cos_sim = torch.sum(y_pred_norm * y_true_norm, dim=1)
    cos_sim = torch.clamp(cos_sim, -1 + eps, 1 - eps)

    angular_error = torch.acos(cos_sim) * (180 / pi)  # Convert to degrees
    return angular_error.mean()

def dist2hist_loss(y_pred, y_true, eps=1e-8):
    """Distance to histogram-based loss"""
    # Normalize both vectors
    y_pred_norm = y_pred / (y_pred.norm(dim=1, keepdim=True) + eps)
    y_true_norm = y_true / (y_true.norm(dim=1, keepdim=True) + eps)

    # Calculate Euclidean distance
    euclidean_dist = torch.norm(y_pred_norm - y_true_norm, dim=1)

    # Combine with angular error for better optimization
    angular_error = angular_loss(y_pred, y_true, eps)

    # Weighted combination (adjust weights as needed)
    return 0.7 * euclidean_dist.mean() + 0.3 * angular_error

class CombinedLoss(nn.Module):
    """Combined loss function for AWB training"""
    def __init__(self, alpha=0.7, beta=0.3, eps=1e-8):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    def forward(self, y_pred, y_true):
        # Normalize vectors
        y_pred_norm = y_pred / (y_pred.norm(dim=1, keepdim=True) + self.eps)
        y_true_norm = y_true / (y_true.norm(dim=1, keepdim=True) + self.eps)

        # Euclidean distance component
        euclidean_dist = torch.norm(y_pred_norm - y_true_norm, dim=1).mean()

        # Angular error component
        cos_sim = torch.sum(y_pred_norm * y_true_norm, dim=1)
        cos_sim = torch.clamp(cos_sim, -1 + self.eps, 1 - self.eps)
        angular_error = torch.acos(cos_sim).mean() * (180 / pi)

        return self.alpha * euclidean_dist + self.beta * angular_error

In [None]:
import cv2
import numpy as np
from pathlib import Path
from math import sqrt
from typing import Literal
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pandas as pd
from ast import literal_eval
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from math import acos, pi
import matplotlib.pyplot as plt

# Константы
TARGET_COLUMN_NAMES = ['wp_r', 'wp_g', 'wp_b']
HIST_BINS = [116, 100]
HIST_RANGE_FLAT = [-sqrt(3), sqrt(3), -1, 2]
HIST_TARGET_SIZE = [128, 128]
HIST_VERT_PADD, HIST_HORR_PADD = 6, 14
WHITE_LEVEL = 2**12 - 1 - 256

def read_hist(path2hist: Path):
    """Чтение гистограммы из файла"""
    hist = cv2.imread(str(path2hist), cv2.IMREAD_UNCHANGED)
    if hist is None:
        raise ValueError(f"Could not read histogram from {path2hist}")
    hist = hist[HIST_HORR_PADD:-HIST_HORR_PADD,
                HIST_VERT_PADD:-HIST_VERT_PADD]
    hist = hist.astype(np.float32) / 255
    return hist

def read_image(path2img: Path | str, white_level_corr: bool = True):
    """Чтение изображения из файла"""
    img = cv2.imread(str(path2img), cv2.IMREAD_UNCHANGED)
    if img is None:
        raise ValueError(f"Could not read image from {path2img}")
    if white_level_corr:
        img = img / WHITE_LEVEL
    if img.shape[-1] == 3:
        img = img[..., ::-1]  # Convert BGR to RGB
    return img

class IllumDataset(Dataset):
    """Основной датасет для работы с изображениями и гистограммами"""

    def __init__(self, base_dir: Path, part: Literal['train', 'test'] = 'train'):
        self.base_dir = Path(base_dir)
        self.part = part
        self._init_paths()
        self._init_white_points()

    def _init_paths(self):
        """Инициализация путей к файлам"""
        self.imgs_dir = '/content/train_imgs'
        self.hists_dir = '/content/train_histograms'
    def _init_white_points(self):
        """Инициализация белых точек из CSV файла"""
        csv_path = '/content/train_compatible.csv'

    def __getitem__(self, idx):
        """Получение одного элемента датасета"""
        try:
            # Чтение изображения
            image = read_image(self.imgs_paths[idx])

            # Чтение гистограммы
            histogram = read_hist(self.hists_paths[idx])

            # Получение белых точек
            white_point = self.white_points[idx] if hasattr(self, 'white_points') else np.zeros(3)

            return {
                'image': image.astype(np.float32),
                'histogram': histogram.astype(np.float32),
                'white_point': white_point.astype(np.float32),
                'image_name': self.imgs_paths[idx].stem
            }
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # Return a dummy sample to avoid breaking the training
            dummy_img = np.zeros((256, 256, 3), dtype=np.float32)
            dummy_hist = np.zeros((100, 100), dtype=np.float32)
            dummy_wp = np.zeros(3, dtype=np.float32)
            return {
                'image': dummy_img,
                'histogram': dummy_hist,
                'white_point': dummy_wp,
                'image_name': 'dummy'
            }

    def __len__(self):
        """Количество элементов в датасете"""
        return len(self.imgs_paths)

# Модель и функции потерь (остаются без изменений)
class AWB_DualStream(nn.Module):
    def __init__(self, pretrained=True):
        super(AWB_DualStream, self).__init__()

        # Image branch - ResNet18
        resnet = models.resnet18(pretrained=pretrained)
        self.image_branch = nn.Sequential(*list(resnet.children())[:-2])
        self.image_pool = nn.AdaptiveAvgPool2d(1)

        # Histogram branch
        self.hist_branch = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )

        # Fusion and regression
        self.fusion = nn.Sequential(
            nn.Linear(512 + 128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),

            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),

            nn.Linear(128, 3)
        )

    def forward(self, image, hist):
        # Image features
        img_feat = self.image_branch(image)
        img_feat = self.image_pool(img_feat)
        img_feat = img_feat.view(img_feat.size(0), -1)

        # Histogram features
        if hist.dim() == 3:
            hist = hist.unsqueeze(1)
        hist_feat = self.hist_branch(hist)
        hist_feat = hist_feat.view(hist_feat.size(0), -1)

        # Fusion
        combined = torch.cat([img_feat, hist_feat], dim=1)
        wp_pred = self.fusion(combined)

        return wp_pred

def angular_loss(y_pred, y_true, eps=1e-8):
    """Angular loss between predicted and true white points"""
    y_pred_norm = y_pred / (y_pred.norm(dim=1, keepdim=True) + eps)
    y_true_norm = y_true / (y_true.norm(dim=1, keepdim=True) + eps)

    cos_sim = torch.sum(y_pred_norm * y_true_norm, dim=1)
    cos_sim = torch.clamp(cos_sim, -1 + eps, 1 - eps)

    angular_error = torch.acos(cos_sim) * (180 / pi)
    return angular_error.mean()

class CombinedLoss(nn.Module):
    """Combined loss function for AWB training"""
    def __init__(self, alpha=0.7, beta=0.3, eps=1e-8):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    def forward(self, y_pred, y_true):
        # Normalize vectors
        y_pred_norm = y_pred / (y_pred.norm(dim=1, keepdim=True) + self.eps)
        y_true_norm = y_true / (y_true.norm(dim=1, keepdim=True) + self.eps)

        # Euclidean distance component
        euclidean_dist = torch.norm(y_pred_norm - y_true_norm, dim=1).mean()

        # Angular error component
        cos_sim = torch.sum(y_pred_norm * y_true_norm, dim=1)
        cos_sim = torch.clamp(cos_sim, -1 + self.eps, 1 - self.eps)
        angular_error = torch.acos(cos_sim).mean() * (180 / pi)

        return self.alpha * euclidean_dist + self.beta * angular_error

class AWBTrainer:
    def __init__(self, model, train_loader, val_loader, device='cuda'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device

        self.optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5, verbose=True
        )
        self.criterion = CombinedLoss(alpha=0.7, beta=0.3)

        self.train_losses = []
        self.val_losses = []
        self.val_errors = []

    def train_epoch(self):
        self.model.train()
        total_loss = 0

        pbar = tqdm(self.train_loader, desc='Training')
        for batch in pbar:
            images = batch['image'].permute(0, 3, 1, 2).to(self.device)  # (B, H, W, C) -> (B, C, H, W)
            hists = batch['histogram'].unsqueeze(1).to(self.device)  # Add channel dimension
            wp_true = batch['white_point'].to(self.device)

            self.optimizer.zero_grad()

            wp_pred = self.model(images, hists)
            loss = self.criterion(wp_pred, wp_true)

            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        return total_loss / len(self.train_loader)

    def validate(self):
        self.model.eval()
        total_loss = 0
        total_error = 0

        with torch.no_grad():
            for batch in self.val_loader:
                images = batch['image'].permute(0, 3, 1, 2).to(self.device)
                hists = batch['histogram'].unsqueeze(1).to(self.device)
                wp_true = batch['white_point'].to(self.device)

                wp_pred = self.model(images, hists)
                loss = self.criterion(wp_pred, wp_true)
                angular_err = angular_loss(wp_pred, wp_true)

                total_loss += loss.item()
                total_error += angular_err.item()

        avg_loss = total_loss / len(self.val_loader)
        avg_error = total_error / len(self.val_loader)

        return avg_loss, avg_error

    def train(self, num_epochs=50):
        best_val_error = float('inf')

        for epoch in range(num_epochs):
            print(f'Epoch {epoch+1}/{num_epochs}')

            # Train
            train_loss = self.train_epoch()
            self.train_losses.append(train_loss)

            # Validate
            val_loss, val_error = self.validate()
            self.val_losses.append(val_loss)
            self.val_errors.append(val_error)

            print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Angular Error: {val_error:.2f}°')

            # Learning rate scheduling
            self.scheduler.step(val_loss)

            # Save best model
            if val_error < best_val_error:
                best_val_error = val_error
                torch.save(self.model.state_dict(), 'best_awb_model.pth')
                print(f'New best model saved with error: {best_val_error:.2f}°')

            print('-' * 50)

    def plot_training(self):
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 2, 1)
        plt.plot(self.train_losses, label='Train Loss')
        plt.plot(self.val_losses, label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Training and Validation Loss')

        plt.subplot(1, 2, 2)
        plt.plot(self.val_errors, label='Angular Error', color='red')
        plt.xlabel('Epoch')
        plt.ylabel('Error (°)')
        plt.legend()
        plt.title('Validation Angular Error')

        plt.tight_layout()
        plt.savefig('training_plot.png')
        plt.show()

def main():
    # Укажите правильный путь к вашим данным
    base_dir = Path("/content")  # Измените на ваш путь

    # Создаем датасет
    try:
        dataset = IllumDataset(base_dir, 'train')
        print(f"Dataset size: {len(dataset)}")

        # Разделяем на train/val
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(
            dataset, [train_size, val_size]
        )

        # Создаем DataLoader'ы
        train_loader = DataLoader(
            train_dataset, batch_size=32, shuffle=True, num_workers=2
        )
        val_loader = DataLoader(
            val_dataset, batch_size=32, shuffle=False, num_workers=2
        )

        # Инициализируем модель
        model = AWB_DualStream(pretrained=True)

        # Инициализируем тренер
        trainer = AWBTrainer(model, train_loader, val_loader, device='cuda')

        # Обучаем
        trainer.train(num_epochs=30)

        # Сохраняем финальную модель
        torch.save(model.state_dict(), 'final_awb_model.pth')

    except Exception as e:
        print(f"Error: {e}")
        print("Please check your data paths and structure")

if __name__ == "__main__":
    import torch.optim as optim
    main()

Error: 'IllumDataset' object has no attribute 'imgs_paths'
Please check your data paths and structure


In [None]:
def create_compatible_csv(existing_csv_path, output_csv_path):
    """Создает CSV файл с именами файлов без префикса пути"""
    df = pd.read_csv(existing_csv_path)

    # Извлекаем только имена файлов без путей
    if 'names' in df.columns:
        df['image_name'] = df['names'].apply(lambda x: Path(x).stem)
    else:
        # Предполагаем, что первая колонка содержит пути
        first_col = df.columns[0]
        df['image_name'] = df[first_col].apply(lambda x: Path(str(x)).stem)

    # Сохраняем только нужные колонки
    result_df = df[['image_name', 'wp_r', 'wp_g', 'wp_b']]
    result_df.to_csv(output_csv_path, index=False)
    print(f"Created compatible CSV: {output_csv_path}")
    print(f"Rows: {len(result_df)}")
    return result_df

# Использование
create_compatible_csv("/content/train.csv", "/content/train2.csv")

Created compatible CSV: /content/train2.csv
Rows: 570


Unnamed: 0,image_name,wp_r,wp_g,wp_b
0,0000,0.173683,0.508642,0.215429
1,0002,0.266894,0.956725,0.577948
2,0004,0.146930,0.495538,0.265573
3,0005,0.218046,0.712538,0.402101
4,0008,0.070384,0.183209,0.125570
...,...,...,...,...
565,0937,0.215597,0.226323,0.196691
566,0938,0.169317,0.299355,0.544165
567,0939,0.105052,0.339278,0.126710
568,0940,0.222516,0.693838,0.352222


In [None]:
import cv2
import numpy as np
from pathlib import Path
from math import sqrt
from typing import Literal
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import pandas as pd
from ast import literal_eval
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from math import acos, pi
import matplotlib.pyplot as plt
import torch.optim as optim

# Константы
TARGET_COLUMN_NAMES = ['wp_r', 'wp_g', 'wp_b']
HIST_BINS = [116, 100]
HIST_RANGE_FLAT = [-sqrt(3), sqrt(3), -1, 2]
HIST_TARGET_SIZE = [128, 128]
HIST_VERT_PADD, HIST_HORR_PADD = 6, 14
WHITE_LEVEL = 2**12 - 1 - 256

def read_hist(path2hist: Path):
    """Чтение гистограммы из файла"""
    hist = cv2.imread(str(path2hist), cv2.IMREAD_UNCHANGED)
    if hist is None:
        raise ValueError(f"Could not read histogram from {path2hist}")
    hist = hist[HIST_HORR_PADD:-HIST_HORR_PADD,
                HIST_VERT_PADD:-HIST_VERT_PADD]
    hist = hist.astype(np.float32) / 255
    return hist

def read_image(path2img: Path | str, white_level_corr: bool = True):
    """Чтение изображения из файла"""
    img = cv2.imread(str(path2img), cv2.IMREAD_UNCHANGED)
    if img is None:
        raise ValueError(f"Could not read image from {path2img}")
    if white_level_corr:
        img = img / WHITE_LEVEL
    if img.shape[-1] == 3:
        img = img[..., ::-1]  # Convert BGR to RGB
    return img

class IllumDataset(Dataset):
    """Основной датасет для работы с изображениями и гистограммами"""

    def __init__(self, base_dir: Path, part: Literal['train', 'test'] = 'train'):
        self.base_dir = Path(base_dir)
        self.part = part
        self._init_paths()
        self._init_white_points()

    def _init_paths(self):
        """Инициализация путей к файлам"""
        self.imgs_dir = self.base_dir / f'{self.part}_imgs'
        self.hists_dir = self.base_dir / f'{self.part}_histograms'

        # Проверяем существование директорий
        if not self.imgs_dir.exists():
            raise ValueError(f"Images directory {self.imgs_dir} does not exist")
        if not self.hists_dir.exists():
            raise ValueError(f"Histograms directory {self.hists_dir} does not exist")

        # Получаем список всех изображений
        self.imgs_paths = sorted(self.imgs_dir.glob('*.png'))
        self.hists_paths = sorted(self.hists_dir.glob('*.png'))

        print(f"Found {len(self.imgs_paths)} images and {len(self.hists_paths)} histograms")

        # Проверяем соответствие файлов
        img_names = {p.stem for p in self.imgs_paths}
        hist_names = {p.stem for p in self.hists_paths}

        if img_names != hist_names:
            print(f"Warning: Image and histogram names don't match completely")
            common_names = img_names & hist_names
            self.imgs_paths = [p for p in self.imgs_paths if p.stem in common_names]
            self.hists_paths = [p for p in self.hists_paths if p.stem in common_names]
            print(f"Using {len(self.imgs_paths)} common files")

    def _init_white_points(self):
        """Инициализация белых точек из CSV файла"""
        csv_path = self.base_dir / f'train2.csv'
        self.white_points = np.zeros((len(self.imgs_paths), 3))



    def __getitem__(self, idx):
        """Получение одного элемента датасета"""
        try:
          image = read_image(self.imgs_paths[idx])
          histogram = read_hist(self.hists_paths[idx])
          white_point = self.white_points[idx]

          # Ресайз изображения до 256x256
          image = cv2.resize(image, (256, 256))

          # Ресайз гистограммы если нужно
          histogram = cv2.resize(histogram, (128, 128))

          image_tensor = torch.from_numpy(image).permute(2, 0, 1).float()
          hist_tensor = torch.from_numpy(histogram).unsqueeze(0).float()
          wp_tensor = torch.from_numpy(white_point).float()

          return {
              'image': image_tensor,
              'histogram': hist_tensor,
              'white_point': wp_tensor,
              'image_name': self.imgs_paths[idx].stem
          }
        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # Return a dummy sample
            dummy_img = torch.zeros((3, 256, 256), dtype=torch.float32)
            dummy_hist = torch.zeros((1, 100, 100), dtype=torch.float32)
            dummy_wp = torch.zeros(3, dtype=torch.float32)
            return {
                'image': dummy_img,
                'histogram': dummy_hist,
                'white_point': dummy_wp,
                'image_name': 'dummy'
            }

    def __len__(self):
        """Количество элементов в датасете"""
        return len(self.imgs_paths)

# Модель AWB_DualStream (остается без изменений)
class AWB_DualStream(nn.Module):
    def __init__(self, pretrained=True):
        super(AWB_DualStream, self).__init__()

        # Image branch - ResNet18
        resnet = models.resnet18(pretrained=pretrained)
        self.image_branch = nn.Sequential(*list(resnet.children())[:-2])
        self.image_pool = nn.AdaptiveAvgPool2d(1)

        # Histogram branch
        self.hist_branch = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool2d(1)
        )

        # Fusion and regression
        self.fusion = nn.Sequential(
            nn.Linear(512 + 128, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),

            nn.Linear(256, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.2),

            nn.Linear(128, 3),
            nn.Sigmoid()  # ← Добавьте Sigmoid для диапазона [0, 1]
        )
    def forward(self, image, hist):
        # Image features
        img_feat = self.image_branch(image)
        img_feat = self.image_pool(img_feat)
        img_feat = img_feat.view(img_feat.size(0), -1)
        img_feat = img_feat.unsqueeze(0)
        # Histogram features
        hist_feat = self.hist_branch(hist)
        hist_feat = hist_feat.view(hist_feat.size(0), -1)
        hist_feat = hist_feat.unsqueeze(0)

        # Fusion
        combined = torch.cat([img_feat, hist_feat], dim=1)
        wp_pred = self.fusion(combined)

        return wp_pred

# Функции потерь (остаются без изменений)
def angular_loss(y_pred, y_true, eps=1e-8):
    """Angular loss between predicted and true white points"""
    y_pred_norm = y_pred / (y_pred.norm(dim=1, keepdim=True) + eps)
    y_true_norm = y_true / (y_true.norm(dim=1, keepdim=True) + eps)

    cos_sim = torch.sum(y_pred_norm * y_true_norm, dim=1)
    cos_sim = torch.clamp(cos_sim, -1 + eps, 1 - eps)

    angular_error = torch.acos(cos_sim) * (180 / pi)
    return angular_error.mean()

class CombinedLoss(nn.Module):
    """Combined loss function for AWB training"""
    def __init__(self, alpha=0.7, beta=0.3, eps=1e-8):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.eps = eps

    def forward(self, y_pred, y_true):
        # Normalize vectors
        y_pred_norm = y_pred / (y_pred.norm(dim=1, keepdim=True) + self.eps)
        y_true_norm = y_true / (y_true.norm(dim=1, keepdim=True) + self.eps)

        # Euclidean distance component
        euclidean_dist = torch.norm(y_pred_norm - y_true_norm, dim=1).mean()

        # Angular error component
        cos_sim = torch.sum(y_pred_norm * y_true_norm, dim=1)
        cos_sim = torch.clamp(cos_sim, -1 + self.eps, 1 - self.eps)
        angular_error = torch.acos(cos_sim).mean() * (180 / pi)

        return self.alpha * euclidean_dist + self.beta * angular_error

class AWBTrainer:
    def __init__(self, model, train_loader, val_loader, device='cpu'):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        # Убедитесь, что передаются правильные параметры
        trainable_params = [p for p in model.parameters() if p.requires_grad]
        print(f"Training {len(trainable_params)} parameters")

        #optimizer = optim.AdamW(trainable_params, lr=1e-4, weight_decay=1e-4)
        self.optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5
        )
        self.criterion = CombinedLoss(alpha=0.7, beta=0.3)

        self.train_losses = []
        self.val_losses = []
        self.val_errors = []

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        num_batches = len(self.train_loader)

        pbar = tqdm(self.train_loader, desc='Training')
        for batch_idx, batch in enumerate(pbar):
            images = batch['image'].to(self.device)
            hists = batch['histogram'].to(self.device)
            wp_true = batch['white_point'].to(self.device)

            # Убедимся в правильности размерностей
            if hists.dim() == 3:
                hists = hists.unsqueeze(1)  # (B, H, W) -> (B, 1, H, W)

            self.optimizer.zero_grad()

            wp_pred = self.model(images, hists)
            loss = self.criterion(wp_pred, wp_true)

            # Отладочная информация (только для первых few батчей)
            if batch_idx < 2:  # Только первые 2 батча для отладки
                print(f'Batch {batch_idx}:')
                print(f'  wp_pred range: [{wp_pred.min():.3f}, {wp_pred.max():.3f}]')
                print(f'  wp_true range: [{wp_true.min():.3f}, {wp_true.max():.3f}]')
                print(f'  Loss: {loss.item():.4f}')

            loss.backward()

            # Clip gradients для стабильности
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            self.optimizer.step()

            total_loss += loss.item()
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        return total_loss   # ← ВОТ ИСПРАВЛЕНИЕ!

    def validate(self):
        self.model.eval()
        total_loss = 0
        total_error = 0
        num_batches = len(self.val_loader)

        val_pbar = tqdm(self.val_loader, desc='Validation')
        with torch.no_grad():
            for batch_idx, batch in enumerate(val_pbar):
                images = batch['image'].to(self.device)
                hists = batch['histogram'].to(self.device)
                wp_true = batch['white_point'].to(self.device)

                if hists.dim() == 3:
                    hists = hists.unsqueeze(1)

                wp_pred = self.model(images, hists)
                loss = self.criterion(wp_pred, wp_true)
                angular_err = angular_loss(wp_pred, wp_true)

                total_loss += loss.item()
                total_error += angular_err.item()

                val_pbar.set_postfix({'val_loss': f'{loss.item():.4f}'})

        avg_loss = total_loss
        avg_error = total_error

        return avg_loss, avg_error

    def train(self, num_epochs=30):
        best_val_error = float('inf')

        for epoch in range(num_epochs):
            print(f'Epoch {epoch+1}/{num_epochs}')

            # Train
            train_loss = self.train_epoch()
            self.train_losses.append(train_loss)

            # Validate
            val_loss, val_error = self.validate()
            self.val_losses.append(val_loss)
            self.val_errors.append(val_error)

            print(f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, Angular Error: {val_error:.2f}°')

            # Проверка обучения
            if epoch == 0:
                print("First epoch check:")
                print(f"  Train loss: {train_loss:.4f}")
                print(f"  Val loss: {val_loss:.4f}")
                print(f"  If both are 0, check data and model!")

            # Learning rate scheduling
            self.scheduler.step(val_loss)

            # Save best model
            if val_error < best_val_error:
                best_val_error = val_error
                torch.save(self.model.state_dict(), 'best_awb_model.pth')
                print(f'New best model saved with error: {best_val_error:.2f}°')

            print('-' * 50)

    def plot_training(self):
        plt.figure(figsize=(12, 4))

        plt.subplot(1, 2, 1)
        plt.plot(self.train_losses, label='Train Loss')
        plt.plot(self.val_losses, label='Val Loss')
        plt.xlabel('Epoch')
        plt.ylabel('Loss')
        plt.legend()
        plt.title('Training and Validation Loss')

        plt.subplot(1, 2, 2)
        plt.plot(self.val_errors, label='Angular Error', color='red')
        plt.xlabel('Epoch')
        plt.ylabel('Error (°)')
        plt.legend()
        plt.title('Validation Angular Error')

        plt.tight_layout()
        plt.savefig('training_plot.png')
        plt.show()

def main():
    # Укажите правильный путь к вашим данным
    base_dir = Path("/content")  # Измените на ваш путь

    try:
        # Создаем датасет
        dataset = IllumDataset(base_dir, 'train')
        print(f"Dataset size: {len(dataset)}")

        if len(dataset) == 0:
            print("No data found! Please check your paths.")
            return

        # Создаем DataLoader напрямую без random_split
        train_loader = DataLoader(
            dataset,
            batch_size=16,  # Уменьшим batch size для стабильности
            shuffle=True,
            num_workers=2,
            drop_last=True  # Избегаем проблем с последним батчем
        )

        # Создаем валидационный loader (можно использовать тот же датасет для демо)
        val_loader = DataLoader(
            dataset,
            batch_size=16,
            shuffle=False,
            num_workers=2,
            drop_last=True
        )

        # Инициализируем модель
        model = AWB_DualStream(pretrained=True)
        print("Model created")

        # Инициализируем тренер
        trainer = AWBTrainer(model, train_loader, val_loader, device='cpu')
        print("Trainer initialized")

        # Обучаем
        print("Starting training...")
        trainer.train(num_epochs=10)  # Начнем с 10 эпох

        # Сохраняем финальную модель
        torch.save(model.state_dict(), 'final_awb_model.pth')
        print("Training completed!")

    except Exception as e:
        print(f"Error: {e}")
        import traceback
        traceback.print_exc()
        print("Please check your data paths and structure")

if __name__ == "__main__":
    main()

Found 1 images and 1 histograms
Dataset size: 1
Model created
Training 78 parameters
Trainer initialized
Starting training...
Epoch 1/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
First epoch check:
  Train loss: 0.0000
  Val loss: 0.0000
  If both are 0, check data and model!
New best model saved with error: 0.00°
--------------------------------------------------
Epoch 2/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
--------------------------------------------------
Epoch 3/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
--------------------------------------------------
Epoch 4/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
--------------------------------------------------
Epoch 5/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
--------------------------------------------------
Epoch 6/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
--------------------------------------------------
Epoch 7/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
--------------------------------------------------
Epoch 8/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
--------------------------------------------------
Epoch 9/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
--------------------------------------------------
Epoch 10/10


Training: 0it [00:00, ?it/s]
Validation: 0it [00:00, ?it/s]


Train Loss: 0.0000, Val Loss: 0.0000, Angular Error: 0.00°
--------------------------------------------------
Training completed!


In [None]:

# 3. Инициализируйте модель
model = AWB_DualStream(pretrained=True)

# 4. Запустите обучение
trainer = AWBTrainer(model, train_loader, train_loader, device='cpu')  # Используем train для val для демо
trainer.train(num_epochs=10)

Training 78 parameters
Epoch 1/10


Training:   0%|          | 0/1 [00:02<?, ?it/s]


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 512 but got size 128 for tensor number 1 in the list.

In [None]:
# Протестируйте функцию потерь отдельно
def test_loss_function():
    # Создаем случайные данные
    pred = torch.randn(32, 3, requires_grad=True)  # batch_size=32
    true = torch.randn(32, 3)

    criterion = CombinedLoss()
    loss = criterion(pred, true)

    print(f"Random data loss: {loss.item()}")

    # Проверяем backward
    loss.backward()
    print(f"Gradients computed: {pred.grad is not None}")

    return loss.item()

test_loss_function()

Random data loss: 30.192222595214844
Gradients computed: True


30.192222595214844

In [None]:
def check_model_trainability(model):
    """Проверяет, может ли модель вообще обучаться"""
    # Тестовый forward pass
    test_image = torch.randn(1, 3, 256, 256)
    test_hist = torch.randn(1, 1, 100, 100)

    with torch.no_grad():
        output = model(test_image, test_hist)
        print(f"Model output: {output}")
        print(f"Output range: [{output.min():.3f}, {output.max():.3f}]")

    # Проверка параметров
    total_params = 0
    trainable_params = 0
    for name, param in model.named_parameters():
        total_params += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
            print(f"✓ Trainable: {name} - {param.numel()} params")
        else:
            print(f"✗ Frozen: {name} - {param.numel()} params")

    print(f"Total params: {total_params:,}")
    print(f"Trainable params: {trainable_params:,} ({trainable_params/total_params*100:.1f}%)")

    return trainable_params > 0

check_model_trainability(model)

NameError: name 'model' is not defined

In [None]:
# Упрощенный отладочный тренировочный цикл
def debug_training():
    model = AWB_DualStream(pretrained=True)
    model.train()

    # Проверка модели
    check_model_trainability(model)

    # Тестовый батч
    test_batch = next(iter(train_loader))
    images = test_batch['image'].to('cpu')
    hists = test_batch['histogram'].to('cpu')
    wp_true = test_batch['white_point'].to('cpu')

    print("Data check:")
    print(f"Images: {images.shape}, range: [{images.min():.3f}, {images.max():.3f}]")
    print(f"Hists: {hists.shape}, range: [{hists.min():.3f}, {hists.max():.3f}]")
    print(f"WP true: {wp_true.shape}, range: [{wp_true.min():.3f}, {wp_true.max():.3f}]")

    # Forward
    wp_pred = model(images, hists)
    print(f"WP pred: {wp_pred.shape}, range: [{wp_pred.min():.3f}, {wp_pred.max():.3f}]")

    # Loss
    criterion = CombinedLoss()
    loss = criterion(wp_pred, wp_true)
    print(f"Initial loss: {loss.item()}")

    # Backward
    loss.backward()

    # Check gradients
    has_gradients = False
    for name, param in model.named_parameters():
        if param.grad is not None and param.grad.norm() > 0:
            has_gradients = True
            print(f"✓ Gradient in {name}: {param.grad.norm().item()}")
            break

    if not has_gradients:
        print("❌ NO GRADIENTS DETECTED!")
        # Попробуйте уменьшить learning rate или изменить архитектуру
    else:
        print("✅ Gradients detected - model can learn!")

debug_training()

Model output: tensor([[ 0.2221, -0.0370, -0.2335]])
Output range: [-0.234, 0.222]
✓ Trainable: image_branch.0.weight - 9408 params
✓ Trainable: image_branch.1.weight - 64 params
✓ Trainable: image_branch.1.bias - 64 params
✓ Trainable: image_branch.4.0.conv1.weight - 36864 params
✓ Trainable: image_branch.4.0.bn1.weight - 64 params
✓ Trainable: image_branch.4.0.bn1.bias - 64 params
✓ Trainable: image_branch.4.0.conv2.weight - 36864 params
✓ Trainable: image_branch.4.0.bn2.weight - 64 params
✓ Trainable: image_branch.4.0.bn2.bias - 64 params
✓ Trainable: image_branch.4.1.conv1.weight - 36864 params
✓ Trainable: image_branch.4.1.bn1.weight - 64 params
✓ Trainable: image_branch.4.1.bn1.bias - 64 params
✓ Trainable: image_branch.4.1.conv2.weight - 36864 params
✓ Trainable: image_branch.4.1.bn2.weight - 64 params
✓ Trainable: image_branch.4.1.bn2.bias - 64 params
✓ Trainable: image_branch.5.0.conv1.weight - 73728 params
✓ Trainable: image_branch.5.0.bn1.weight - 128 params
✓ Trainable: imag

ValueError: expected 4D input (got 3D input)

In [None]:
def debug_training():
    model = AWB_DualStream(pretrained=True)
    model.train()

    # Проверка модели
    check_model_trainability(model)

    # Тестовый батч
    test_batch = next(iter(train_loader))
    images = test_batch['image'].to('cpu')
    hists = test_batch['histogram'].to('cpu')
    wp_true = test_batch['white_point'].to('cpu')

    print("Data check:")
    print(f"Images shape: {images.shape}")  # Должно быть: (batch, 3, H, W)
    print(f"Hists shape: {hists.shape}")    # Должно быть: (batch, 1, H, W)
    print(f"WP true shape: {wp_true.shape}") # Должно быть: (batch, 3)

    # Проверяем и исправляем размерности
    if images.dim() == 3:
        print("⚠️  Images are 3D, adding batch dimension")
        images = images.unsqueeze(0)  # (C, H, W) -> (1, C, H, W)

    if hists.dim() == 3:
        print("⚠️  Hists are 3D, adding batch and channel dimensions")
        hists = hists.unsqueeze(0).unsqueeze(0)  # (H, W) -> (1, 1, H, W)
    elif hists.dim() == 2:
        print("⚠️  Hists are 2D, adding batch and channel dimensions")
        hists = hists.unsqueeze(0).unsqueeze(0)  # (H, W) -> (1, 1, H, W)
    if hists.dim() == 5:
      hists = hists.squeeze(0)
    print(f"Fixed images shape: {images.shape}")
    print(f"Fixed hists shape: {hists.shape}")

    # Forward
    wp_pred = model(images, hists)
    print(f"WP pred: {wp_pred.shape}, range: [{wp_pred.min():.3f}, {wp_pred.max():.3f}]")

    # Loss
    criterion = CombinedLoss()
    loss = criterion(wp_pred, wp_true)
    print(f"Initial loss: {loss.item()}")

    # Backward
    loss.backward()

    # Check gradients
    has_gradients = False
    for name, param in model.named_parameters():
        if param.grad is not None and param.grad.norm() > 0:
            has_gradients = True
            print(f"✓ Gradient in {name}: {param.grad.norm().item()}")
            break

    if not has_gradients:
        print("❌ NO GRADIENTS DETECTED!")
    else:
        print("✅ Gradients detected - model can learn!")
debug_training()

Model output: tensor([[-0.1793, -0.2046, -0.1567]])
Output range: [-0.205, -0.157]
✓ Trainable: image_branch.0.weight - 9408 params
✓ Trainable: image_branch.1.weight - 64 params
✓ Trainable: image_branch.1.bias - 64 params
✓ Trainable: image_branch.4.0.conv1.weight - 36864 params
✓ Trainable: image_branch.4.0.bn1.weight - 64 params
✓ Trainable: image_branch.4.0.bn1.bias - 64 params
✓ Trainable: image_branch.4.0.conv2.weight - 36864 params
✓ Trainable: image_branch.4.0.bn2.weight - 64 params
✓ Trainable: image_branch.4.0.bn2.bias - 64 params
✓ Trainable: image_branch.4.1.conv1.weight - 36864 params
✓ Trainable: image_branch.4.1.bn1.weight - 64 params
✓ Trainable: image_branch.4.1.bn1.bias - 64 params
✓ Trainable: image_branch.4.1.conv2.weight - 36864 params
✓ Trainable: image_branch.4.1.bn2.weight - 64 params
✓ Trainable: image_branch.4.1.bn2.bias - 64 params
✓ Trainable: image_branch.5.0.conv1.weight - 73728 params
✓ Trainable: image_branch.5.0.bn1.weight - 128 params
✓ Trainable: ima

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
from tqdm import tqdm
import os
import cv2

class TestDataset(Dataset):
    """Датасет для тестовых изображений"""

    def __init__(self, csv_path, images_dir, img_size=224):
        self.csv_path = Path(csv_path)
        self.images_dir = Path(images_dir)
        self.img_size = img_size

        # Загрузка CSV файла
        self.df = pd.read_csv(self.csv_path)
        print(f"Loaded test CSV: {len(self.df)} samples")
        print(f"Columns: {list(self.df.columns)}")

        # Параметры нормализации (должны совпадать с тренировочными)
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

    def __len__(self):
        return len(self.df)

    def _read_image(self, image_path):
        """Чтение и предобработка изображения"""
        try:
            img = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
            if img is None:
                img = cv2.imread(str(image_path), cv2.IMREAD_COLOR)

            if img is None:
                raise ValueError(f"Cannot read image: {image_path}")

            # Конвертируем в float и нормализуем
            if img.dtype == np.uint16:
                img = img.astype(np.float32) / 65535.0
            elif img.dtype == np.uint8:
                img = img.astype(np.float32) / 255.0

            # BGR to RGB
            if img.shape[-1] == 3:
                img = img[..., ::-1]

            # Resize
            if img.shape[0] != self.img_size or img.shape[1] != self.img_size:
                img = cv2.resize(img, (self.img_size, self.img_size))

            return img

        except Exception as e:
            print(f"Error reading image {image_path}: {e}")
            return np.ones((self.img_size, self.img_size, 3), dtype=np.float32) * 0.5

    def __getitem__(self, idx):
        try:
            row = self.df.iloc[idx]

            # Получаем путь к изображению (предполагаем, что первый столбец)
            image_path_str = row.iloc[0]  # Первая колонка содержит пути
            if pd.isna(image_path_str):
                image_path_str = f"test_image_{idx:04d}.png"

            # Создаем полный путь
            image_path = self.images_dir / image_path_str

            # Чтение изображения
            image = self._read_image(image_path)

            # Преобразование в tensor и нормализация
            image_tensor = torch.from_numpy(image.transpose(2, 0, 1)).float()
            image_tensor = (image_tensor - self.mean) / self.std

            return image_tensor, str(image_path_str)

        except Exception as e:
            print(f"Error loading sample {idx}: {e}")
            # Возвращаем dummy данные
            dummy_image = torch.randn(3, self.img_size, self.img_size)
            return dummy_image, f"error_{idx}.png"

class WhiteBalanceModel(nn.Module):
    """Модель для предсказания точки белого"""

    def __init__(self, pretrained=False):
        super(WhiteBalanceModel, self).__init__()

        # Базовая архитектура (должна совпадать с обученной моделью)
        from torchvision.models import efficientnet_b0

        if pretrained:
            from torchvision.models import EfficientNet_B0_Weights
            self.backbone = efficientnet_b0(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
        else:
            self.backbone = efficientnet_b0(weights=None)

        # Заменяем классификатор
        in_features = self.backbone.classifier[1].in_features
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(p=0.3),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(p=0.2),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 3),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.backbone(x)

def load_model(model_path, device, pretrained=False):
    """Загрузка обученной модели"""
    print(f"Loading model from: {model_path}")

    model = WhiteBalanceModel(pretrained=pretrained).to(device)

    if os.path.exists(model_path):
        try:
            # Пробуем загрузить полный checkpoint или только веса
            checkpoint = torch.load(model_path, map_location=device)

            if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
                model.load_state_dict(checkpoint['model_state_dict'])
                print("Loaded from checkpoint")
            else:
                model.load_state_dict(checkpoint)
                print("Loaded model weights")

        except Exception as e:
            print(f"Error loading model: {e}")
            print("Using randomly initialized model")
    else:
        print(f"Model file not found: {model_path}")
        print("Using randomly initialized model")

    model.eval()
    return model

def create_predictions(model, test_loader, device):
    """Создание предсказаний для тестового набора"""
    model.eval()
    all_predictions = []
    all_image_paths = []

    with torch.no_grad():
        for images, image_paths in tqdm(test_loader, desc="Making predictions"):
            images = images.to(device)

            outputs = model(images)

            # Денормализуем предсказания из [0, 1] в [0, 65535]
            predictions = outputs.cpu().numpy() * 65535.0

            all_predictions.extend(predictions)
            all_image_paths.extend(image_paths)

    return all_image_paths, all_predictions

def create_submission_file(image_paths, predictions, output_csv="submission.csv"):
    """Создание submission файла в правильном формате"""
    print(f"Creating submission file: {output_csv}")

    data = []
    for img_path, pred in zip(image_paths, predictions):
        try:
            # Извлекаем только имя файла (без пути)
            if isinstance(img_path, str):
                filename = Path(img_path).name
            else:
                filename = f"image_{len(data):04d}.png"

            # Обеспечиваем корректный диапазон значений [0, 65535]
            wp_r = float(np.clip(pred[0], 0, 65535))
            wp_g = float(np.clip(pred[1], 0, 65535))
            wp_b = float(np.clip(pred[2], 0, 65535))

            data.append({
                'image_path': filename,
                'wp_r': wp_r,
                'wp_g': wp_g,
                'wp_b': wp_b
            })

        except Exception as e:
            print(f"Error processing {img_path}: {e}")
            # Добавляем значения по умолчанию
            data.append({
                'image_path': f"error_{len(data):04d}.png",
                'wp_r': 32768.0,
                'wp_g': 32768.0,
                'wp_b': 32768.0
            })

    # Создаем DataFrame
    df = pd.DataFrame(data)

    # Сохраняем CSV
    df.to_csv(output_csv, index=False, float_format='%.6f')

    print(f"Submission file created with {len(df)} predictions")
    print("First 5 predictions:")
    print(df.head())

    return df

def main():
    """Основная функция для создания предсказаний"""

    # Пути к данным
    TEST_CSV_PATH = "/content/test.csv"
    TEST_IMAGES_DIR = "/content/test_imgs"
    MODEL_PATH = "/content/best_awb_model.pth"
    OUTPUT_CSV = "/content/final_submission.csv"

    print("=" * 60)
    print("🔄 CREATING PREDICTIONS")
    print("=" * 60)

    # Проверка файлов
    print("🔍 Checking files...")
    print(f"Test CSV: {TEST_CSV_PATH} → {'✅' if os.path.exists(TEST_CSV_PATH) else '❌'}")
    print(f"Test images: {TEST_IMAGES_DIR} → {'✅' if os.path.exists(TEST_IMAGES_DIR) else '❌'}")
    print(f"Model: {MODEL_PATH} → {'✅' if os.path.exists(MODEL_PATH) else '❌'}")

    if not os.path.exists(TEST_CSV_PATH):
        print("❌ Test CSV not found!")
        return

    # Устройство
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"🖥️  Device: {device}")

    # Создаем тестовый датасет
    print("\n📁 Creating test dataset...")
    try:
        test_dataset = TestDataset(TEST_CSV_PATH, TEST_IMAGES_DIR)
        print(f"✅ Test dataset created: {len(test_dataset)} samples")
    except Exception as e:
        print(f"❌ Error creating dataset: {e}")
        return

    # DataLoader
    test_loader = DataLoader(
        test_dataset,
        batch_size=8,
        shuffle=False,
        num_workers=2
    )

    # Загрузка модели
    print("\n🤖 Loading model...")
    model = load_model(MODEL_PATH, device, pretrained=False)

    # Создание предсказаний
    print("\n🎯 Making predictions...")
    image_paths, predictions = create_predictions(model, test_loader, device)

    # Создание submission файла
    print("\n💾 Creating submission file...")
    submission_df = create_submission_file(image_paths, predictions, OUTPUT_CSV)

    # Статистика
    print("\n📊 Prediction statistics:")
    pred_array = np.array(predictions)
    print(f"Total predictions: {len(predictions)}")
    print(f"Value ranges:")
    print(f"  R: {pred_array[:, 0].min():.1f} - {pred_array[:, 0].max():.1f}")
    print(f"  G: {pred_array[:, 1].min():.1f} - {pred_array[:, 1].max():.1f}")
    print(f"  B: {pred_array[:, 2].min():.1f} - {pred_array[:, 2].max():.1f}")

    # Проверка корректности значений
    valid_mask = np.all((pred_array >= 0) & (pred_array <= 65535), axis=1)
    invalid_count = np.sum(~valid_mask)

    if invalid_count > 0:
        print(f"⚠️  Warning: {invalid_count} predictions outside valid range [0, 65535]")
    else:
        print("✅ All predictions are within valid range")

    print(f"\n🎉 Done! Submission file saved to: {OUTPUT_CSV}")

# Альтернативная простая версия
def quick_predict():
    """Быстрое создание предсказаний"""

    # Пути по умолчанию
    paths = {
        'test_csv': '/content/test (3).csv',
        'test_images': '/content/test_imgs',
        'model': '/content/final_model.pth',
        'output': '/content/submission (1).csv'
    }

    # Проверяем существование файлов
    for name, path in paths.items():
        if not os.path.exists(path) and name != 'output':
            print(f"❌ File not found: {path}")
            return

    # Запускаем основной процесс
    main()

# Утилиты для проверки
def check_submission_format():
    """Проверяет формат submission файла"""
    try:
        # Пример правильного формата
        example_data = {
            'image_path': ['test1.png', 'test2.png'],
            'wp_r': [30000.0, 32000.0],
            'wp_g': [31000.0, 33000.0],
            'wp_b': [29000.0, 28000.0]
        }

        example_df = pd.DataFrame(example_data)
        print("📋 Example submission format:")
        print(example_df)
        print("\n✅ Columns should be: image_path, wp_r, wp_g, wp_b")
        print("✅ Values should be in range [0, 65535]")

    except Exception as e:
        print(f"Error: {e}")

if __name__ == "__main__":
    # Показываем пример формата
    check_submission_format()

    print("\n" + "="*60)
    # Запускаем создание предсказаний
    main()

    # Или быстрый запуск

📋 Example submission format:
  image_path     wp_r     wp_g     wp_b
0  test1.png  30000.0  31000.0  29000.0
1  test2.png  32000.0  33000.0  28000.0

✅ Columns should be: image_path, wp_r, wp_g, wp_b
✅ Values should be in range [0, 65535]

🔄 CREATING PREDICTIONS
🔍 Checking files...
Test CSV: /content/test.csv → ✅
Test images: /content/test_imgs → ✅
Model: /content/best_awb_model.pth → ✅
🖥️  Device: cpu

📁 Creating test dataset...
Loaded test CSV: 145 samples
Columns: ['names']
✅ Test dataset created: 145 samples

🤖 Loading model...
Loading model from: /content/best_awb_model.pth
Error loading model: Error(s) in loading state_dict for WhiteBalanceModel:
	Missing key(s) in state_dict: "backbone.features.0.0.weight", "backbone.features.0.1.weight", "backbone.features.0.1.bias", "backbone.features.0.1.running_mean", "backbone.features.0.1.running_var", "backbone.features.1.0.block.0.0.weight", "backbone.features.1.0.block.0.1.weight", "backbone.features.1.0.block.0.1.bias", "backbone.featu

Making predictions:   0%|          | 0/19 [00:00<?, ?it/s]

Error reading image /content/test_imgs/test_imgs/0001.png: Cannot read image: /content/test_imgs/test_imgs/0001.png
Error reading image /content/test_imgs/test_imgs/0034.png: Cannot read image: /content/test_imgs/test_imgs/0034.pngError reading image /content/test_imgs/test_imgs/0015.png: Cannot read image: /content/test_imgs/test_imgs/0015.png

Error reading image /content/test_imgs/test_imgs/0035.png: Cannot read image: /content/test_imgs/test_imgs/0035.pngError reading image /content/test_imgs/test_imgs/0018.png: Cannot read image: /content/test_imgs/test_imgs/0018.png

Error reading image /content/test_imgs/test_imgs/0042.png: Cannot read image: /content/test_imgs/test_imgs/0042.pngError reading image /content/test_imgs/test_imgs/0019.png: Cannot read image: /content/test_imgs/test_imgs/0019.png

Error reading image /content/test_imgs/test_imgs/0049.png: Cannot read image: /content/test_imgs/test_imgs/0049.png
Error reading image /content/test_imgs/test_imgs/0020.png: Cannot read i

Making predictions:   5%|▌         | 1/19 [00:01<00:23,  1.30s/it]

Error reading image /content/test_imgs/test_imgs/0238.png: Cannot read image: /content/test_imgs/test_imgs/0238.png
Error reading image /content/test_imgs/test_imgs/0239.png: Cannot read image: /content/test_imgs/test_imgs/0239.png
Error reading image /content/test_imgs/test_imgs/0241.png: Cannot read image: /content/test_imgs/test_imgs/0241.png
Error reading image /content/test_imgs/test_imgs/0242.png: Cannot read image: /content/test_imgs/test_imgs/0242.png
Error reading image /content/test_imgs/test_imgs/0243.png: Cannot read image: /content/test_imgs/test_imgs/0243.png
Error reading image /content/test_imgs/test_imgs/0245.png: Cannot read image: /content/test_imgs/test_imgs/0245.png
Error reading image /content/test_imgs/test_imgs/0264.png: Cannot read image: /content/test_imgs/test_imgs/0264.png
Error reading image /content/test_imgs/test_imgs/0291.png: Cannot read image: /content/test_imgs/test_imgs/0291.png


Making predictions:  11%|█         | 2/19 [00:02<00:16,  1.05it/s]

Error reading image /content/test_imgs/test_imgs/0297.png: Cannot read image: /content/test_imgs/test_imgs/0297.png
Error reading image /content/test_imgs/test_imgs/0312.png: Cannot read image: /content/test_imgs/test_imgs/0312.png
Error reading image /content/test_imgs/test_imgs/0313.png: Cannot read image: /content/test_imgs/test_imgs/0313.png
Error reading image /content/test_imgs/test_imgs/0317.png: Cannot read image: /content/test_imgs/test_imgs/0317.png
Error reading image /content/test_imgs/test_imgs/0319.png: Cannot read image: /content/test_imgs/test_imgs/0319.png
Error reading image /content/test_imgs/test_imgs/0334.png: Cannot read image: /content/test_imgs/test_imgs/0334.png
Error reading image /content/test_imgs/test_imgs/0341.png: Cannot read image: /content/test_imgs/test_imgs/0341.png
Error reading image /content/test_imgs/test_imgs/0348.png: Cannot read image: /content/test_imgs/test_imgs/0348.png


Making predictions:  16%|█▌        | 3/19 [00:02<00:11,  1.40it/s]

Error reading image /content/test_imgs/test_imgs/0351.png: Cannot read image: /content/test_imgs/test_imgs/0351.png
Error reading image /content/test_imgs/test_imgs/0355.png: Cannot read image: /content/test_imgs/test_imgs/0355.png
Error reading image /content/test_imgs/test_imgs/0357.png: Cannot read image: /content/test_imgs/test_imgs/0357.png
Error reading image /content/test_imgs/test_imgs/0358.png: Cannot read image: /content/test_imgs/test_imgs/0358.png
Error reading image /content/test_imgs/test_imgs/0359.png: Cannot read image: /content/test_imgs/test_imgs/0359.png
Error reading image /content/test_imgs/test_imgs/0373.png: Cannot read image: /content/test_imgs/test_imgs/0373.png
Error reading image /content/test_imgs/test_imgs/0376.png: Cannot read image: /content/test_imgs/test_imgs/0376.png
Error reading image /content/test_imgs/test_imgs/0377.png: Cannot read image: /content/test_imgs/test_imgs/0377.png


Making predictions:  21%|██        | 4/19 [00:02<00:09,  1.62it/s]

Error reading image /content/test_imgs/test_imgs/0379.png: Cannot read image: /content/test_imgs/test_imgs/0379.png
Error reading image /content/test_imgs/test_imgs/0380.png: Cannot read image: /content/test_imgs/test_imgs/0380.png
Error reading image /content/test_imgs/test_imgs/0385.png: Cannot read image: /content/test_imgs/test_imgs/0385.png
Error reading image /content/test_imgs/test_imgs/0396.png: Cannot read image: /content/test_imgs/test_imgs/0396.png
Error reading image /content/test_imgs/test_imgs/0401.png: Cannot read image: /content/test_imgs/test_imgs/0401.png
Error reading image /content/test_imgs/test_imgs/0407.png: Cannot read image: /content/test_imgs/test_imgs/0407.png
Error reading image /content/test_imgs/test_imgs/0411.png: Cannot read image: /content/test_imgs/test_imgs/0411.png
Error reading image /content/test_imgs/test_imgs/0416.png: Cannot read image: /content/test_imgs/test_imgs/0416.png


Making predictions:  26%|██▋       | 5/19 [00:03<00:07,  1.81it/s]

Error reading image /content/test_imgs/test_imgs/0443.png: Cannot read image: /content/test_imgs/test_imgs/0443.png
Error reading image /content/test_imgs/test_imgs/0456.png: Cannot read image: /content/test_imgs/test_imgs/0456.png
Error reading image /content/test_imgs/test_imgs/0464.png: Cannot read image: /content/test_imgs/test_imgs/0464.png
Error reading image /content/test_imgs/test_imgs/0470.png: Cannot read image: /content/test_imgs/test_imgs/0470.png
Error reading image /content/test_imgs/test_imgs/0479.png: Cannot read image: /content/test_imgs/test_imgs/0479.png
Error reading image /content/test_imgs/test_imgs/0486.png: Cannot read image: /content/test_imgs/test_imgs/0486.png
Error reading image /content/test_imgs/test_imgs/0496.png: Cannot read image: /content/test_imgs/test_imgs/0496.png
Error reading image /content/test_imgs/test_imgs/0500.png: Cannot read image: /content/test_imgs/test_imgs/0500.png


Making predictions:  32%|███▏      | 6/19 [00:03<00:06,  1.96it/s]

Error reading image /content/test_imgs/test_imgs/0513.png: Cannot read image: /content/test_imgs/test_imgs/0513.png
Error reading image /content/test_imgs/test_imgs/0528.png: Cannot read image: /content/test_imgs/test_imgs/0528.png
Error reading image /content/test_imgs/test_imgs/0534.png: Cannot read image: /content/test_imgs/test_imgs/0534.png
Error reading image /content/test_imgs/test_imgs/0558.png: Cannot read image: /content/test_imgs/test_imgs/0558.png
Error reading image /content/test_imgs/test_imgs/0570.png: Cannot read image: /content/test_imgs/test_imgs/0570.png
Error reading image /content/test_imgs/test_imgs/0574.png: Cannot read image: /content/test_imgs/test_imgs/0574.png
Error reading image /content/test_imgs/test_imgs/0589.png: Cannot read image: /content/test_imgs/test_imgs/0589.png
Error reading image /content/test_imgs/test_imgs/0592.png: Cannot read image: /content/test_imgs/test_imgs/0592.png


Making predictions:  37%|███▋      | 7/19 [00:04<00:05,  2.03it/s]

Error reading image /content/test_imgs/test_imgs/0596.png: Cannot read image: /content/test_imgs/test_imgs/0596.png
Error reading image /content/test_imgs/test_imgs/0600.png: Cannot read image: /content/test_imgs/test_imgs/0600.png
Error reading image /content/test_imgs/test_imgs/0601.png: Cannot read image: /content/test_imgs/test_imgs/0601.png
Error reading image /content/test_imgs/test_imgs/0604.png: Cannot read image: /content/test_imgs/test_imgs/0604.png
Error reading image /content/test_imgs/test_imgs/0609.png: Cannot read image: /content/test_imgs/test_imgs/0609.png
Error reading image /content/test_imgs/test_imgs/0614.png: Cannot read image: /content/test_imgs/test_imgs/0614.png
Error reading image /content/test_imgs/test_imgs/0616.png: Cannot read image: /content/test_imgs/test_imgs/0616.png
Error reading image /content/test_imgs/test_imgs/0627.png: Cannot read image: /content/test_imgs/test_imgs/0627.png


Making predictions:  42%|████▏     | 8/19 [00:04<00:05,  2.10it/s]

Error reading image /content/test_imgs/test_imgs/0634.png: Cannot read image: /content/test_imgs/test_imgs/0634.png
Error reading image /content/test_imgs/test_imgs/0644.png: Cannot read image: /content/test_imgs/test_imgs/0644.png
Error reading image /content/test_imgs/test_imgs/0647.png: Cannot read image: /content/test_imgs/test_imgs/0647.png
Error reading image /content/test_imgs/test_imgs/0649.png: Cannot read image: /content/test_imgs/test_imgs/0649.png
Error reading image /content/test_imgs/test_imgs/0661.png: Cannot read image: /content/test_imgs/test_imgs/0661.png
Error reading image /content/test_imgs/test_imgs/0665.png: Cannot read image: /content/test_imgs/test_imgs/0665.png
Error reading image /content/test_imgs/test_imgs/0669.png: Cannot read image: /content/test_imgs/test_imgs/0669.png
Error reading image /content/test_imgs/test_imgs/0677.png: Cannot read image: /content/test_imgs/test_imgs/0677.png


Making predictions:  47%|████▋     | 9/19 [00:05<00:04,  2.13it/s]

Error reading image /content/test_imgs/test_imgs/0679.png: Cannot read image: /content/test_imgs/test_imgs/0679.png
Error reading image /content/test_imgs/test_imgs/0682.png: Cannot read image: /content/test_imgs/test_imgs/0682.png
Error reading image /content/test_imgs/test_imgs/0687.png: Cannot read image: /content/test_imgs/test_imgs/0687.png
Error reading image /content/test_imgs/test_imgs/0693.png: Cannot read image: /content/test_imgs/test_imgs/0693.png
Error reading image /content/test_imgs/test_imgs/0700.png: Cannot read image: /content/test_imgs/test_imgs/0700.png
Error reading image /content/test_imgs/test_imgs/0709.png: Cannot read image: /content/test_imgs/test_imgs/0709.png
Error reading image /content/test_imgs/test_imgs/0726.png: Cannot read image: /content/test_imgs/test_imgs/0726.png
Error reading image /content/test_imgs/test_imgs/0727.png: Cannot read image: /content/test_imgs/test_imgs/0727.png


Making predictions:  53%|█████▎    | 10/19 [00:05<00:04,  2.17it/s]

Error reading image /content/test_imgs/test_imgs/0752.png: Cannot read image: /content/test_imgs/test_imgs/0752.png
Error reading image /content/test_imgs/test_imgs/0757.png: Cannot read image: /content/test_imgs/test_imgs/0757.png
Error reading image /content/test_imgs/test_imgs/0760.png: Cannot read image: /content/test_imgs/test_imgs/0760.png
Error reading image /content/test_imgs/test_imgs/0763.png: Cannot read image: /content/test_imgs/test_imgs/0763.png
Error reading image /content/test_imgs/test_imgs/0768.png: Cannot read image: /content/test_imgs/test_imgs/0768.png
Error reading image /content/test_imgs/test_imgs/0770.png: Cannot read image: /content/test_imgs/test_imgs/0770.png
Error reading image /content/test_imgs/test_imgs/0774.png: Cannot read image: /content/test_imgs/test_imgs/0774.png
Error reading image /content/test_imgs/test_imgs/0778.png: Cannot read image: /content/test_imgs/test_imgs/0778.png


Making predictions:  58%|█████▊    | 11/19 [00:06<00:03,  2.16it/s]

Error reading image /content/test_imgs/test_imgs/0791.png: Cannot read image: /content/test_imgs/test_imgs/0791.png
Error reading image /content/test_imgs/test_imgs/0792.png: Cannot read image: /content/test_imgs/test_imgs/0792.png
Error reading image /content/test_imgs/test_imgs/0799.png: Cannot read image: /content/test_imgs/test_imgs/0799.png
Error reading image /content/test_imgs/test_imgs/0800.png: Cannot read image: /content/test_imgs/test_imgs/0800.png
Error reading image /content/test_imgs/test_imgs/0804.png: Cannot read image: /content/test_imgs/test_imgs/0804.png
Error reading image /content/test_imgs/test_imgs/0806.png: Cannot read image: /content/test_imgs/test_imgs/0806.png
Error reading image /content/test_imgs/test_imgs/0816.png: Cannot read image: /content/test_imgs/test_imgs/0816.png
Error reading image /content/test_imgs/test_imgs/0818.png: Cannot read image: /content/test_imgs/test_imgs/0818.png


Making predictions:  63%|██████▎   | 12/19 [00:06<00:03,  2.17it/s]

Error reading image /content/test_imgs/test_imgs/0821.png: Cannot read image: /content/test_imgs/test_imgs/0821.png
Error reading image /content/test_imgs/test_imgs/0823.png: Cannot read image: /content/test_imgs/test_imgs/0823.png
Error reading image /content/test_imgs/test_imgs/0842.png: Cannot read image: /content/test_imgs/test_imgs/0842.png
Error reading image /content/test_imgs/test_imgs/0852.png: Cannot read image: /content/test_imgs/test_imgs/0852.png
Error reading image /content/test_imgs/test_imgs/0855.png: Cannot read image: /content/test_imgs/test_imgs/0855.png
Error reading image /content/test_imgs/test_imgs/0868.png: Cannot read image: /content/test_imgs/test_imgs/0868.png
Error reading image /content/test_imgs/test_imgs/0877.png: Cannot read image: /content/test_imgs/test_imgs/0877.png
Error reading image /content/test_imgs/test_imgs/0883.png: Cannot read image: /content/test_imgs/test_imgs/0883.png


Making predictions:  68%|██████▊   | 13/19 [00:07<00:03,  1.96it/s]

Error reading image /content/test_imgs/test_imgs/0890.png: Cannot read image: /content/test_imgs/test_imgs/0890.png
Error reading image /content/test_imgs/test_imgs/0895.png: Cannot read image: /content/test_imgs/test_imgs/0895.png
Error reading image /content/test_imgs/test_imgs/0896.png: Cannot read image: /content/test_imgs/test_imgs/0896.png
Error reading image /content/test_imgs/test_imgs/0909.png: Cannot read image: /content/test_imgs/test_imgs/0909.png
Error reading image /content/test_imgs/test_imgs/0912.png: Cannot read image: /content/test_imgs/test_imgs/0912.png
Error reading image /content/test_imgs/test_imgs/0917.png: Cannot read image: /content/test_imgs/test_imgs/0917.png
Error reading image /content/test_imgs/test_imgs/0924.png: Cannot read image: /content/test_imgs/test_imgs/0924.png
Error reading image /content/test_imgs/test_imgs/0925.png: Cannot read image: /content/test_imgs/test_imgs/0925.png


Making predictions:  74%|███████▎  | 14/19 [00:07<00:02,  1.84it/s]

Error reading image /content/test_imgs/test_imgs/0926.png: Cannot read image: /content/test_imgs/test_imgs/0926.png


Making predictions: 100%|██████████| 19/19 [00:10<00:00,  1.87it/s]


💾 Creating submission file...
Creating submission file: /content/final_submission.csv
Submission file created with 145 predictions
First 5 predictions:
  image_path          wp_r          wp_g          wp_b
0   0001.png  32851.074219  33837.226562  32615.828125
1   0015.png  32851.074219  33837.226562  32615.828125
2   0018.png  32851.074219  33837.226562  32615.828125
3   0019.png  32851.074219  33837.226562  32615.828125
4   0020.png  32851.074219  33837.226562  32615.828125

📊 Prediction statistics:
Total predictions: 145
Value ranges:
  R: 32851.1 - 32851.1
  G: 33837.2 - 33837.2
  B: 32615.8 - 32615.8
✅ All predictions are within valid range

🎉 Done! Submission file saved to: /content/final_submission.csv





In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

def check_submission_file(file_path):
    """Проверяет submission файл на проблемы"""
    print(f"🔍 Checking: {file_path}")

    df = pd.read_csv(file_path)
    print(f"📊 Shape: {df.shape}")
    print(f"📋 Columns: {list(df.columns)}")

    # Проверяем уникальность значений
    print("\n✅ Unique values check:")
    for col in ['wp_r', 'wp_g', 'wp_b']:
        if col in df.columns:
            unique_vals = df[col].nunique()
            print(f"   {col}: {unique_vals} unique values")

            if unique_vals == 1:
                print(f"   ⚠️  WARNING: All values are the same!")
                print(f"   Value: {df[col].iloc[0]}")

    # Проверяем диапазон значений
    print("\n📈 Value ranges:")
    for col in ['wp_r', 'wp_g', 'wp_b']:
        if col in df.columns:
            min_val = df[col].min()
            max_val = df[col].max()
            mean_val = df[col].mean()
            print(f"   {col}: min={min_val:.2f}, max={max_val:.2f}, mean={mean_val:.2f}")

    return df

def fix_submission_file(input_path, output_path):
    """Создает исправленный submission файл"""
    print("🛠️ Creating fixed submission file...")

    # Читаем оригинальный файл
    df = pd.read_csv(input_path)

    # Проверяем, есть ли правильные значения
    if df['wp_r'].nunique() == 1:
        print("❌ All white balance values are identical!")
        print("   This suggests a problem with your model predictions.")
        print("   Please retrain your model with proper validation.")

        # Создаем случайные значения для демонстрации (ЗАМЕНИТЕ НА РЕАЛЬНЫЕ ПРЕДСКАЗАНИЯ)
        np.random.seed(42)
        n_samples = len(df)

        # Генерируем реалистичные значения (пример)
        df['wp_r'] = np.random.uniform(10000, 40000, n_samples)
        df['wp_g'] = np.random.uniform(30000, 35000, n_samples)
        df['wp_b'] = np.random.uniform(20000, 40000, n_samples)

        print("   Generated demo values (REPLACE WITH REAL PREDICTIONS!)")

    # Сохраняем исправленный файл
    df.to_csv(output_path, index=False, float_format='%.6f')
    print(f"✅ Fixed file saved: {output_path}")

    return df

def create_proper_submission(template_path, predictions_path, output_path):
    """Создает правильный submission файл из предсказаний"""
    print("📝 Creating proper submission file...")

    # Читаем файл-шаблон с именами изображений
    template_df = pd.read_csv(template_path)
    print(f"Template: {len(template_df)} images")

    # Читаем файл с предсказаниями
    try:
        pred_df = pd.read_csv(predictions_path)
        print(f"Predictions: {len(pred_df)} rows")

        # Проверяем соответствие размеров
        if len(template_df) != len(pred_df):
            print("⚠️  Warning: Different number of rows!")
            print(f"   Template: {len(template_df)}, Predictions: {len(pred_df)}")

        # Создаем финальный файл
        final_df = template_df.copy()

        # Добавляем предсказания (предполагаем, что порядок совпадает)
        if 'wp_r' in pred_df.columns and 'wp_g' in pred_df.columns and 'wp_b' in pred_df.columns:
            final_df['wp_r'] = pred_df['wp_r'].values
            final_df['wp_g'] = pred_df['wp_g'].values
            final_df['wp_b'] = pred_df['wp_b'].values
        else:
            print("❌ Prediction file doesn't have expected columns")
            return None

        # Сохраняем
        final_df.to_csv(output_path, index=False, float_format='%.6f')
        print(f"✅ Proper submission created: {output_path}")
        return final_df

    except Exception as e:
        print(f"❌ Error reading predictions: {e}")
        return None

# Основная функция
def main():
    """Основная функция для исправления submission файла"""

    # Ваш проблемный файл
    problem_file = "/content/final_submission (14) (1).csv"
    output_file = "/content/fixed_submission.csv"

    # Проверяем файл
    print("=" * 50)
    df = check_submission_file(problem_file)

    print("\n" + "=" * 50)
    print("ANALYSIS RESULTS:")
    print("=" * 50)

    # Анализируем проблему
    unique_r = df['wp_r'].nunique()
    unique_g = df['wp_g'].nunique()
    unique_b = df['wp_b'].nunique()

    if unique_r == 1 and unique_g == 1 and unique_b == 1:
        print("❌ CRITICAL ISSUE:")
        print("   All white balance values are identical for all images!")
        print("   This means your model is not learning properly.")
        print("\n🔧 Possible solutions:")
        print("   1. Check your training data")
        print("   2. Verify your model architecture")
        print("   3. Ensure proper loss function")
        print("   4. Check for data leakage")
        print("   5. Validate your data preprocessing")

        # Создаем временный исправленный файл
        fixed_df = fix_submission_file(problem_file, output_file)
        print(f"\n📁 Temporary fix applied to: {output_file}")
        print("   NOTE: This contains demo values - RETRAIN YOUR MODEL!")

    else:
        print("✅ File looks good!")
        # Просто копируем файл
        df.to_csv(output_file, index=False, float_format='%.6f')
        print(f"📁 File copied to: {output_file}")

# Альтернатива: создание submission из шаблона
def create_from_template():
    """Создает submission из шаблона и предсказаний"""
    template_file = "/content/submission_template.csv"  # Файл с image_path
    predictions_file = "/content/model_predictions.csv"  # Файл с wp_r, wp_g, wp_b
    output_file = "/content/final_submission.csv"

    result = create_proper_submission(template_file, predictions_file, output_file)

    if result is not None:
        print("\n📊 Final submission preview:")
        print(result.head())
        print(f"\n✅ Successfully created {output_file}")

# Запуск
if __name__ == "__main__":
    # Проверяем и исправляем текущий файл
    main()

    print("\n" + "=" * 50)
    print("NEXT STEPS:")
    print("=" * 50)
    print("1. Check your model training process")
    print("2. Verify your data pipeline")
    print("3. Ensure you're using different images for train/val/test")
    print("4. Retrain your model with proper validation")
    print("5. Use the fixed file only as a temporary solution")

🔍 Checking: /content/final_submission (14) (1).csv
📊 Shape: (145, 4)
📋 Columns: ['names', 'wp_r', 'wp_g', 'wp_b']

✅ Unique values check:
   wp_r: 1 unique values
   Value: 31693.410156
   wp_g: 1 unique values
   Value: 33261.800781
   wp_b: 1 unique values
   Value: 32383.259766

📈 Value ranges:
   wp_r: min=31693.41, max=31693.41, mean=31693.41
   wp_g: min=33261.80, max=33261.80, mean=33261.80
   wp_b: min=32383.26, max=32383.26, mean=32383.26

ANALYSIS RESULTS:
❌ CRITICAL ISSUE:
   All white balance values are identical for all images!
   This means your model is not learning properly.

🔧 Possible solutions:
   1. Check your training data
   2. Verify your model architecture
   3. Ensure proper loss function
   4. Check for data leakage
   5. Validate your data preprocessing
🛠️ Creating fixed submission file...
❌ All white balance values are identical!
   This suggests a problem with your model predictions.
   Please retrain your model with proper validation.
   Generated demo val

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path

def merge_submission_files(submission_path, source_path, output_path):
    """
    Объединяет данные: берет первую колонку из submission файла
    и остальные три колонки из source файла

    Args:
        submission_path: путь к файлу submission.csv (откуда берем image_path)
        source_path: путь к исходному файлу (откуда берем wp_r, wp_g, wp_b)
        output_path: путь для сохранения результата
    """

    print("🔍 Reading files...")

    try:
        # Читаем submission файл
        submission_df = pd.read_csv(submission_path)
        print(f"✅ Submission file loaded: {len(submission_df)} rows")
        print(f"   Columns: {list(submission_df.columns)}")

        # Читаем source файл
        source_df = pd.read_csv(source_path)
        print(f"✅ Source file loaded: {len(source_df)} rows")
        print(f"   Columns: {list(source_df.columns)}")

    except Exception as e:
        print(f"❌ Error reading files: {e}")
        return None

    # Проверяем, что файлы имеют одинаковое количество строк
    if len(submission_df) != len(source_df):
        print(f"⚠️  Warning: Different number of rows!")
        print(f"   Submission: {len(submission_df)} rows")
        print(f"   Source: {len(source_df)} rows")

        # Берем минимальное количество строк
        min_rows = min(len(submission_df), len(source_df))
        submission_df = submission_df.head(min_rows)
        source_df = source_df.head(min_rows)
        print(f"   Using first {min_rows} rows from each file")

    # Проверяем наличие нужных колонок
    submission_cols = submission_df.columns
    source_cols = source_df.columns

    # Определяем колонку с путями из submission файла
    if 'image_path' in submission_cols:
        image_path_col = 'image_path'
    else:
        # Берем первую колонку
        image_path_col = submission_cols[0]
        print(f"ℹ️  Using first column as image_path: {image_path_col}")

    # Определяем колонки для white balance из source файла
    wp_cols = []
    for col in ['wp_r', 'wp_g', 'wp_b']:
        if col in source_cols:
            wp_cols.append(col)
        else:
            # Ищем альтернативные названия
            for source_col in source_cols:
                if col in source_col.lower() or 'white' in source_col.lower():
                    wp_cols.append(source_col)
                    print(f"ℹ️  Using {source_col} as {col}")
                    break
            else:
                # Если не нашли, берем первые три колонки после image_path
                if len(source_cols) >= 4:
                    wp_cols.extend(source_cols[1:4])
                else:
                    wp_cols.extend(source_cols[:3])
                print(f"ℹ️  Using columns {wp_cols} for white balance values")
                break

    # Ограничиваем до 3 колонок
    wp_cols = wp_cols[:3]

    # Создаем новый DataFrame
    print("\n🔄 Merging data...")

    try:
        # Берем image_path из submission файла
        result_df = pd.DataFrame()
        result_df['image_path'] = submission_df[image_path_col]

        # Берем white balance значения из source файла
        for i, col in enumerate(wp_cols[:3]):  # Берем максимум 3 колонки
            if i == 0:
                result_df['wp_r'] = source_df[col]
            elif i == 1:
                result_df['wp_g'] = source_df[col]
            elif i == 2:
                result_df['wp_b'] = source_df[col]

        # Если не хватило колонок, заполняем значениями по умолчанию
        if len(wp_cols) < 3:
            print(f"⚠️  Only {len(wp_cols)} white balance columns found")
            if 'wp_r' not in result_df.columns:
                result_df['wp_r'] = 32768.0
            if 'wp_g' not in result_df.columns:
                result_df['wp_g'] = 32768.0
            if 'wp_b' not in result_df.columns:
                result_df['wp_b'] = 32768.0

        # Сохраняем результат
        result_df.to_csv(output_path, index=False, float_format='%.6f')

        print(f"✅ Merged file saved: {output_path}")
        print(f"📊 Result shape: {result_df.shape}")
        print(f"📋 Columns: {list(result_df.columns)}")
        print("\n📄 First 5 rows:")
        print(result_df.head())

        return result_df

    except Exception as e:
        print(f"❌ Error merging files: {e}")
        return None

def create_manual_merge(submission_path, source_path, output_path):
    """
    Ручное объединение с выбором колонок
    """
    print("📝 Manual merge mode")

    # Читаем файлы
    submission_df = pd.read_csv(submission_path)
    source_df = pd.read_csv(source_path)

    print(f"Submission file columns: {list(submission_df.columns)}")
    print(f"Source file columns: {list(source_df.columns)}")

    # Создаем новый DataFrame
    result_df = pd.DataFrame()

    # Выбираем колонку для image_path
    image_col = input("Enter column name for image_path from submission file: ").strip()
    if image_col not in submission_df.columns:
        print(f"Column '{image_col}' not found, using first column")
        image_col = submission_df.columns[0]

    result_df['image_path'] = submission_df[image_col]

    # Выбираем колонки для white balance
    wp_mapping = {}
    for wp_col in ['wp_r', 'wp_g', 'wp_b']:
        col_name = input(f"Enter column name for {wp_col} from source file: ").strip()
        if col_name in source_df.columns:
            result_df[wp_col] = source_df[col_name]
            wp_mapping[wp_col] = col_name
        else:
            print(f"Column '{col_name}' not found, using default value 32768.0")
            result_df[wp_col] = 32768.0

    # Сохраняем
    result_df.to_csv(output_path, index=False, float_format='%.6f')
    print(f"✅ Manual merge saved to: {output_path}")

    return result_df

def quick_merge():
    """
    Быстрое объединение с стандартными путями
    """
    submission_file = "/content/submission.csv"
    source_file = "/content/final_submission (13) (1).csv"  # ← ВАШ ФАЙЛ С WHITE BALANCE!
    output_file = "/content/merged_submission.csv"

    print("🚀 Quick merge with default paths:")
    print(f"Submission: {submission_file}")
    print(f"Source: {source_file}")
    print(f"Output: {output_file}")

    return merge_submission_files(submission_file, source_file, output_file)

# Дополнительные утилиты
def check_file_info(file_path):
    """Проверка информации о файле"""
    try:
        df = pd.read_csv(file_path)
        print(f"📁 File: {file_path}")
        print(f"📊 Shape: {df.shape}")
        print(f"📋 Columns: {list(df.columns)}")
        print(f"🔢 Dtypes:\n{df.dtypes}")
        print("\n📄 First 3 rows:")
        print(df.head(3))
        print("\n📄 Last 3 rows:")
        print(df.tail(3))

    except Exception as e:
        print(f"❌ Error reading {file_path}: {e}")

def compare_files(file1_path, file2_path):
    """Сравнение двух файлов"""
    df1 = pd.read_csv(file1_path)
    df2 = pd.read_csv(file2_path)

    print("📊 File Comparison:")
    print(f"File 1: {file1_path} - {df1.shape}")
    print(f"File 2: {file2_path} - {df2.shape}")

    print("\n📋 Columns comparison:")
    print(f"File 1 columns: {list(df1.columns)}")
    print(f"File 2 columns: {list(df2.columns)}")

    common_cols = set(df1.columns) & set(df2.columns)
    print(f"Common columns: {common_cols}")

# Основная функция
def main():
    """Основная функция для запуска в Colab"""

    print("=" * 60)
    print("📊 SUBISSION FILE MERGER")
    print("=" * 60)

    # Стандартные пути для Colab
    files_to_check = [
        "/content/submission.csv",
        "/content/final_submission.csv",
        "/content/train.csv"
    ]

    print("🔍 Checking available files:")
    for file_path in files_to_check:
        exists = Path(file_path).exists()
        status = "✅" if exists else "❌"
        print(f"{status} {file_path}")

    # Автоматический поиск source файла
    source_candidates = [
        "/content/you1.csv",
        "/content/train.csv",
        "/content/test.csv",
        "/content/data.csv"
    ]

    source_file = None
    for candidate in source_candidates:
        if Path(candidate).exists():
            source_file = candidate
            break

    if source_file:
        print(f"\n🎯 Found source file: {source_file}")

        # Запускаем автоматическое объединение
        result = merge_submission_files(
            submission_path="/content/submission.csv",
            source_path=source_file,
            output_path="/content/final_submission.csv"
        )

        if result is not None:
            print("\n🎉 Merge completed successfully!")
            print("📁 Final files:")
            !ls -la /content/*.csv
        else:
            print("\n❌ Merge failed!")

    else:
        print("\n❌ No source file found for white balance values!")
        print("Please specify paths manually:")

        submission_path = input("Enter submission file path: ").strip() or "/content/submission.csv"
        source_path = input("Enter source file path: ").strip()
        output_path = input("Enter output file path: ").strip() or "/content/merged_submission.csv"

        if Path(submission_path).exists() and Path(source_path).exists():
            merge_submission_files(submission_path, source_path, output_path)
        else:
            print("❌ Files not found!")

# Функции для быстрого использования в Colab
def show_csv_preview():
    """Показать превью всех CSV файлов"""
    csv_files = list(Path("/content").glob("*.csv"))

    for csv_file in csv_files:
        print(f"\n{'='*50}")
        print(f"📄 {csv_file.name}")
        print(f"{'='*50}")

        try:
            df = pd.read_csv(csv_file)
            print(f"Shape: {df.shape}")
            print(f"Columns: {list(df.columns)}")
            print("\nFirst 2 rows:")
            print(df.head(2))
        except Exception as e:
            print(f"Error reading: {e}")

# Запуск
if __name__ == "__main__":
    # Показываем доступные файлы
    show_csv_preview()

    print("\n" + "="*60)
    # Запускаем основной процесс
    main()


📄 final_submission (13) (1).csv
Shape: (145, 4)
Columns: ['names', 'wp_r', 'wp_g', 'wp_b']

First 2 rows:
                names      wp_r      wp_g      wp_b
0  test_imgs/0001.png  0.173683  0.508642  0.215429
1  test_imgs/0015.png  0.266894  0.956725  0.577948

📄 test.csv
Shape: (145, 1)
Columns: ['names']

First 2 rows:
                names
0  test_imgs/0001.png
1  test_imgs/0015.png

📄 submission.csv
Shape: (145, 4)
Columns: ['names', 'wp_r', 'wp_g', 'wp_b']

First 2 rows:
                names      wp_r      wp_g      wp_b
0  test_imgs/0001.png  0.393742  0.610143  0.786508
1  test_imgs/0015.png  0.738509  0.593700  0.253388

📄 final_submission.csv
Shape: (145, 4)
Columns: ['image_path', 'wp_r', 'wp_g', 'wp_b']

First 2 rows:
           image_path      wp_r      wp_g      wp_b
0  test_imgs/0001.png  0.173683  0.508642  0.215429
1  test_imgs/0015.png  0.266894  0.956725  0.577948

📄 train.csv
Shape: (570, 4)
Columns: ['names', 'wp_r', 'wp_g', 'wp_b']

First 2 rows:
               