In [2]:
import os
import torch
import albumentations as A
from torch.utils.data import DataLoader
from pycocotools.coco import COCO
from segmentation_models_pytorch import Unet
from segmentation_models_pytorch.losses import DiceLoss
from tqdm import tqdm
import matplotlib.pyplot as plt

# 1. Настройки
DATASET_DIR = "dataset"
ANNOTATIONS_PATH = os.path.join(DATASET_DIR, "result.json")
IMAGES_DIR = os.path.join(DATASET_DIR, "images")
BATCH_SIZE = 8
EPOCHS = 10
LEARNING_RATE = 1e-4
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 2. Кастомный датасет
class COCOSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.coco = COCO(annotation_file)
        self.image_ids = self.coco.getImgIds()
        
        # Фильтруем категории (предполагаем, что category_id 1 - дорога, 2 - сугробы)
        self.cat_ids = self.coco.getCatIds(catNms=['road', 'snow'])
        self.load_category_mapping()

    def load_category_mapping(self):
        """Создаем отображение оригинальных category_id в индексы классов"""
        self.category_map = {cat_id: idx for idx, cat_id in enumerate(self.cat_ids)}
        
    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        img_info = self.coco.loadImgs(self.image_ids[idx])[0]
        img_path = os.path.join(self.root_dir, img_info['file_name'])
        image = plt.imread(img_path)
        
        # Создаем маску
        ann_ids = self.coco.getAnnIds(imgIds=img_info['id'], catIds=self.cat_ids)
        anns = self.coco.loadAnns(ann_ids)
        
        mask = np.zeros((img_info['height'], img_info['width']), dtype=np.uint8)
        for ann in anns:
            class_id = self.category_map[ann['category_id']]
            mask = np.maximum(self.coco.annToMask(ann) * class_id, mask)
        
        # Применяем аугментации
        if self.transform:
            transformed = self.transform(image=image, mask=mask)
            image = transformed['image']
            mask = transformed['mask']
        
        # Преобразуем в тензоры
        image = torch.from_numpy(image).permute(2, 0, 1).float() / 255.0
        mask = torch.from_numpy(mask).long()
        
        return image, mask

# 3. Аугментации
transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Rotate(limit=10, p=0.3),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])

# 4. Загрузка данных
train_dataset = COCOSegmentationDataset(
    root_dir=IMAGES_DIR,
    annotation_file=ANNOTATIONS_PATH,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

# 5. Модель
model = Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    classes=2,
    activation=None  # softmax будет внутри loss
).to(DEVICE)

# 6. Функция потерь и оптимизатор
criterion = DiceLoss(mode='multiclass')
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# 7. Цикл обучения
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    for images, masks in tqdm(loader, desc="Training"):
        images = images.to(device)
        masks = masks.to(device)
        
        # Обнуляем градиенты
        optimizer.zero_grad()
        
        # Прямой проход
        outputs = model(images)
        
        # Вычисляем loss
        loss = criterion(outputs, masks)
        running_loss += loss.item()
        
        # Обратный проход
        loss.backward()
        optimizer.step()
    
    return running_loss / len(loader)

# 8. Валидация
def validate(model, loader, criterion, device):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for images, masks in tqdm(loader, desc="Validation"):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
    
    return val_loss / len(loader)

# 9. Основной цикл
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    train_loss = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    print(f"Train Loss: {train_loss:.4f}")
    
    # Сохранение чекпоинта
    torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")

# 10. Визуализация результатов
def visualize_prediction(model, dataloader, device, num_samples=3):
    model.eval()
    with torch.no_grad():
        for i, (images, masks) in enumerate(dataloader):
            if i >= num_samples:
                break
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            
            # Отображаем оригинал, маску и предсказание
            plt.figure(figsize=(15, 5))
            plt.subplot(1, 3, 1)
            plt.title("Image")
            plt.imshow(images[0].permute(1, 2, 0).cpu().numpy())
            
            plt.subplot(1, 3, 2)
            plt.title("True Mask")
            plt.imshow(masks[0].numpy())
            
            plt.subplot(1, 3, 3)
            plt.title("Predicted Mask")
            plt.imshow(preds[0])
            plt.show()

# Визуализируем предсказания на тренировочных данных
visualize_prediction(model, train_loader, DEVICE)

loading annotations into memory...
Done (t=0.03s)
creating index...
index created!
Epoch 1/10


Training:   0%|          | 0/24 [00:00<?, ?it/s]


FileNotFoundError: [Errno 2] No such file or directory: 'dataset\\images\\..\\..\\label-studio\\label-studio\\media\\upload\\2\\31ce81d4-1200x686.jpg'