In [1]:
import torch
from torch.utils.data import Subset
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm
import os
from datetime import datetime

from splits.splits import split_dataset_by_groups
from model.pvtv2 import PVTv2B5ForForgerySegmentation
from tools.dataclass import *
from tools.loss import BinaryCrossEntropyLoss, DiceLoss, FocalLoss
from tools.optimizer import create_optimizer
from tools.scheduler import create_scheduler
from tools.metrics import BinarySegmentationMetrics
from tools.visualize import *

In [2]:
# --- Dataset ---
IMAGE_DIR = 'images'
MASKS_DIR = 'masks'
SPLIT_PATH="splits/grouped_indices.pt"

# --- DataLoader ---
BATCH_SIZE = 4
NUM_WORKERS = 4
SHUFFLE_TRAIN = True
SHUFFLE_VAL = False
PIN_MEMORY = True  # ускоряет передачу на GPU
DROP_LAST = True   # для стабильности batch-norm при малых батчах

# --- Configuration ---
MAX_ITERS = 320000
VAL_INTERVAL = 5000
SAVE_INTERVAL = 5000
LOG_INTERVAL = 100
VISUALIZE_EVERY = 1

In [3]:
# Обучающий датасет с аугментациями и foreground-aware кропами
train_dataset_full = ForgerySegmentationDataset(
    images_dir=IMAGE_DIR,
    masks_dir=MASKS_DIR,
    transform=get_training_augmentation(),
    fg_crop_prob=0.7,           # ← кропы с подделками
    crop_size=(512, 512),
    use_albumentations=True
)

# Валидационный датасет БЕЗ аугментаций, НО С кропами (фиксированный размер)
eval_dataset_full = ForgerySegmentationDataset(
    images_dir=IMAGE_DIR,
    masks_dir=MASKS_DIR,
    transform=get_validation_augmentation(),
    fg_crop_prob=0.0,          
    crop_size=(512, 512),     
    use_albumentations=True
)

In [4]:
# # Разбитие датасета 
# 
# # Получаем индексы один раз (на основе имён файлов — одинаковы в обоих датасетах)
# train_idx, val_idx, test_idx = split_dataset_by_groups(
#     dataset=train_dataset_full,
#     save_path=SPLIT_PATH
# )

In [5]:
# Загружаем сохранённые индексы
split_data = torch.load(SPLIT_PATH)

train_indices = split_data['train_indices']
val_indices = split_data['val_indices']
test_indices = split_data['test_indices']
seed = split_data.get('seed', 'unknown')

print(f"  Train: {len(train_indices)}")
print(f"  Val:   {len(val_indices)}")
print(f"  Test:  {len(test_indices)}")

  Train: 216602
  Val:   5972
  Test:  5925


In [6]:
# Создаём подвыборки
train_dataset = Subset(train_dataset_full, train_indices)
val_dataset = Subset(eval_dataset_full, val_indices)
test_dataset = Subset(eval_dataset_full, test_indices)

In [7]:
item = train_dataset[34534]

print(item['image'].shape)
print(item['mask'].shape)

print(item['image'].dtype)
print(item['mask'].dtype)
print(item['img_path'])
print(item['mask_path'])

ValueError: Image too small: torch.Size([3, 512, 512])

In [8]:
model = PVTv2B5ForForgerySegmentation(img_size=512)
model

PVTv2B5ForForgerySegmentation(
  (backbone): pvt_v2_b5(
    (patch_embed1): OverlapPatchEmbed(
      (proj): Conv2d(3, 64, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed2): OverlapPatchEmbed(
      (proj): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed3): OverlapPatchEmbed(
      (proj): Conv2d(128, 320, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
    )
    (patch_embed4): OverlapPatchEmbed(
      (proj): Conv2d(320, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (block1): ModuleList(
      (0): Block(
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (q): Linear(in_fea

In [9]:
# Загружаем чекпоинт
checkpoint_path = "model/pvt_v2_b5.pth"
state_dict = torch.load(checkpoint_path, map_location="cpu")

# Удаляем классификационную голову (не нужна для сегментации)
keys_to_remove = [k for k in state_dict.keys() if k.startswith('head')]
for k in keys_to_remove:
    del state_dict[k]

# Загружаем в backbone
missing_keys, unexpected_keys = model.backbone.load_state_dict(state_dict, strict=False)

# Проверяем, что всё ок
if len(unexpected_keys) == 0 and all('head' not in k for k in missing_keys):
    print("Предобученные веса PVTv2-B5 успешно загружены!")
    if missing_keys:
        print(f"Не загружены ключи (ожидаемо для head): {missing_keys}")
else:
    print("Ошибка при загрузке весов:")
    print("Unexpected keys:", unexpected_keys)
    print("Missing keys:", missing_keys)

Предобученные веса PVTv2-B5 успешно загружены!


In [10]:
bce_loss_fn = BinaryCrossEntropyLoss(loss_weight=1.0, avg_non_ignore=True)
dice_loss_fn = DiceLoss(loss_weight=1.0, use_sigmoid=True)
focal_loss_fn = FocalLoss(loss_weight=1.0)  # опционально

optimizer = create_optimizer(
    model,
    lr=6e-5,
    weight_decay=0.01,
    head_lr_mult=10.0
)

# Планировщик
scheduler = create_scheduler(
    optimizer,
    warmup_iters=1500,
    total_iters=320000,
    min_lr=0.0,
    power=1.0
)

Параметры разбиты на группы:
   Backbone: 746 параметров
   Norm:     322 параметров
   Head:     30 параметров


In [17]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=SHUFFLE_TRAIN,
    num_workers=0,
    pin_memory=PIN_MEMORY,
    drop_last=DROP_LAST
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=SHUFFLE_VAL,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
    drop_last=False  # на валидации лучше сохранять все примеры
)

In [18]:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = f"runs/forgery_pvtv2_b5_{timestamp}"
checkpoint_dir = f"{run_dir}/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
writer = SummaryWriter(log_dir=run_dir)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
train_iter = iter(train_loader)
metrics_val = BinarySegmentationMetrics(threshold=0.5)

best_iou = 0.0

In [19]:
for iter_idx in tqdm(range(1, MAX_ITERS + 1), desc="Training"):

    # --- Обучение ---
    model.train()
    optimizer.zero_grad()

    try:
        batch = next(train_iter)
    except StopIteration:
        train_iter = iter(train_loader)
        batch = next(train_iter)

    # МАСКА: float32 + канал → [B, 1, H, W]
    images = batch['image'].to(device, non_blocking=True)
    masks = batch['mask'].to(device, non_blocking=True).float().unsqueeze(1)

    pred = model(images)
    loss = 0.5 * bce_loss_fn(pred, masks) + 1.0 * dice_loss_fn(pred, masks) + 0.8 * focal_loss_fn(pred, masks)

    loss.backward()
    optimizer.step()
    scheduler.step()

    # --- Логирование ---
    if iter_idx % LOG_INTERVAL == 0:
        writer.add_scalar('Train/Loss', loss.item(), iter_idx)
        writer.add_scalar('Train/LR', optimizer.param_groups[0]['lr'], iter_idx)

    # --- Валидация ---
    if iter_idx % VAL_INTERVAL == 0:
        val_metrics = validate_epoch(
            model, val_loader, metrics_val, device,
            writer=writer,
            epoch=iter_idx // VAL_INTERVAL,
            visualize_every=VISUALIZE_EVERY
        )

        for name, value in val_metrics.items():
            writer.add_scalar(f'Val/{name}', value, iter_idx)

        print(f"\n[Iter {iter_idx}] Val — " + 
              " | ".join(f"{k}: {v:.4f}" for k, v in val_metrics.items()))

        # --- Сохранение ЛУЧШЕЙ модели ---
        current_iou = val_metrics['IoU_forgery']
        if current_iou > best_iou:
            best_iou = current_iou
            best_path = f"{checkpoint_dir}/best_model_iou_{current_iou:.4f}_iter_{iter_idx}.pth"
            torch.save(model.state_dict(), best_path)
            print(f"Новая лучшая модель сохранена: {os.path.basename(best_path)}")

    # --- Сохранение ПОСЛЕДНЕЙ модели (каждые 10k) ---
    if iter_idx % SAVE_INTERVAL == 0:
        last_path = f"{checkpoint_dir}/last_model_iter_{iter_idx}.pth"
        torch.save(model.state_dict(), last_path)

# --- Финальное сохранение последней модели ---
final_path = f"{checkpoint_dir}/last_model_final.pth"
torch.save(model.state_dict(), final_path)
writer.close()

Training:   0%|          | 0/320000 [00:00<?, ?it/s]


AcceleratorError: CUDA error: an illegal memory access was encountered
Search for `cudaErrorIllegalAddress' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
