# Install dependencies

In [None]:
%%writefile requirements.txt

requests
tqdm>=4.65.0
torch>=2.0.0
torchvision>=0.14.0
opencv-python
numpy>=1.23.0
pandas>=1.3.0
timm>=0.6.13
tensorboard>=2.13.0
albumentations>=1.3.0
scikit-learn>=1.1.0
matplotlib>=3.5.0
transformers>=4.31.0
segmentation-models-pytorch

In [None]:
!pip install -q -r requirements.txt

# Training config

In [None]:
# --- Настройки обучения ---
CONFIG = {
    'batch_size': 16,
    'num_workers': 4,
    'num_epochs': 100,
    'learning_rate': 1e-3 ,
    'weight_decay': 1e-5,
    'early_stop_patience': 10,
    'dataroot': './aitex_data/extracted',
    'log_dir': 'runs/segmentation_experiment',
    'resume': False  # Поставьте True, чтобы восстановить из чекпоинта
}

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

## Models

In [None]:
import segmentation_models_pytorch as smp
MODELS = {
    "swin-unet": lambda: smp.Unet(
        encoder_name="tu-swinv2_base_window16_256",           # Выбор transformer encoder
        encoder_weights="imagenet",      # Предобученные веса (или None)
        in_channels=3,
        classes=1,                       # для бинарной сегментации
        activation=None,                 # логиты
    ),
    "segformer_mit_b3": lambda: smp.Segformer(
        encoder_name="mit_b3",           # Выбор transformer encoder
        encoder_weights="imagenet",      # Предобученные веса (или None)
        in_channels=3,
        classes=1,                       # для бинарной сегментации
        activation=None,                 # логиты
    ),
  
    "unetplusplus_resnet50": lambda: smp.UnetPlusPlus(
        encoder_name="resnet50",
        encoder_weights="imagenet",
        in_channels=3,
        classes=1,
        activation=None,
    ),
    "deeplabv3plus_resnet101": lambda: smp.DeepLabV3Plus(
        encoder_name="resnet101",
        encoder_weights="imagenet",
        in_channels=3,
        classes=1,
        activation=None,
    ),
}

## Metrics

In [None]:
def compute_dice(preds, targets, smooth=1e-7):
    preds = preds.float()
    targets = targets.float()
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum()
    dice = (2 * intersection + smooth) / (union + smooth)
    return dice.item()

def compute_iou(preds, targets, smooth=1e-7):
    preds = preds.float()
    targets = targets.float()
    intersection = (preds * targets).sum()
    union = preds.sum() + targets.sum() - intersection
    iou = (intersection + smooth) / (union + smooth)
    return iou.item()

def compute_accuracy(preds, targets):
    return (preds == targets).float().mean().item()

def compute_precision(preds, targets, eps=1e-7):
    tp = ((preds == 1) & (targets == 1)).sum().item()
    fp = ((preds == 1) & (targets == 0)).sum().item()
    return tp / (tp + fp + eps)

def compute_recall(preds, targets, eps=1e-7):
    tp = ((preds == 1) & (targets == 1)).sum().item()
    fn = ((preds == 0) & (targets == 1)).sum().item()
    return tp / (tp + fn + eps)


## Losses

In [None]:
import torch
import torch.nn.functional as F

class FocalLoss(torch.nn.Module):
    def __init__(self, gamma=2.0, alpha=0.25, eps=1e-7):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha
        self.eps = eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits).clamp(self.eps, 1-self.eps)   # ← исправлено
        pt = torch.where(targets == 1, probs, 1-probs)
        loss = -self.alpha * (1-pt).pow(self.gamma) * pt.log()
        return loss.mean()

class DiceLoss(torch.nn.Module):
    def __init__(self, eps=1e-7):
        super().__init__()
        self.eps = eps

    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        inter = (probs * targets).sum((1,2,3))
        union = probs.sum((1,2,3)) + targets.sum((1,2,3))
        dice  = (2*inter + self.eps) / (union + self.eps)
        return 1 - dice.mean()

class FDLoss(torch.nn.Module):
    """L = w_f * Focal + (1-w_f) * Dice"""
    def __init__(self, wf=0.1, gamma=2.0, alpha=0.25):
        super().__init__()
        self.wf    = wf
        self.focal = FocalLoss(gamma=gamma, alpha=alpha)
        self.dice  = DiceLoss()

    def forward(self, logits, targets):
        return self.wf * self.focal(logits, targets) + \
               (1.0 - self.wf) * self.dice(logits, targets)

class BCEDiceLoss(torch.nn.Module):
    def __init__(self, bce_weight=0.8, dice_weight=0.2):
        super().__init__()
        self.bce = torch.nn.BCEWithLogitsLoss()
        self.dice = DiceLoss()
        self.bce_weight = bce_weight
        self.dice_weight = dice_weight

    def forward(self, logits, targets):
        return self.bce_weight * self.bce(logits, targets) + \
               self.dice_weight * self.dice(logits, targets)



# Prepare data

## download_dataset

In [None]:
import requests
from tqdm import tqdm

import zipfile
from pathlib import Path

url = "https://www.kaggle.com/api/v1/datasets/download/nexuswho/aitex-fabric-image-database"
output_dir = Path("./aitex_data")
output_dir.mkdir(exist_ok=True)

zip_path = output_dir / "aitex.zip"

# Проверка на существование архива
if zip_path.exists():
    print(f"[INFO] Архив уже существует по пути: {zip_path}")
else:
    print(f"[INFO] Скачиваем архив из {url}...")
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(zip_path, "wb") as f:
            for chunk in tqdm(response.iter_content(chunk_size=8192)):
                f.write(chunk)
        print("[INFO] Скачивание завершено.")
    else:
        raise Exception(f"Ошибка при скачивании: статус {response.status_code}")

# Распаковка архива
extract_dir = output_dir / "extracted"
if not extract_dir.exists():
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_dir)
    print(f"[INFO] Архив успешно распакован в {extract_dir}")
else:
    print(f"[INFO] Архив уже был распакован в {extract_dir}")


## remove_image_without_masks¶

In [None]:
import os

def remove_images_without_masks(image_dir, mask_dir, image_suffix=".png", mask_suffix="_mask.png"):
    """
    Удаляет изображения, для которых отсутствует маска.
    """
    removed = 0
    for img_name in os.listdir(image_dir):
        if not img_name.endswith(image_suffix):
            continue
        base_name = os.path.splitext(img_name)[0]
        mask_name = base_name + mask_suffix
        mask_path = os.path.join(mask_dir, mask_name)
        img_path = os.path.join(image_dir, img_name)
        if not os.path.exists(mask_path):
            print(f"Удаляется {img_path} (маска {mask_name} не найдена)")
            os.remove(img_path)
            removed += 1
    print(f"Удалено изображений без масок: {removed}")

# Пример вызова:
remove_images_without_masks(
    image_dir="./aitex_data/extracted/Defect_images",
    mask_dir="./aitex_data/extracted/Mask_images"
)


## dataset_stats

In [None]:
from pathlib import Path
import cv2
import numpy as np

SRC_MSK_DIR = Path("./aitex_data/extracted/Mask_images")

min_pixels = None
max_pixels = 0
pixels_list = []

for mask_path in SRC_MSK_DIR.glob("*.png"):
    msk = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
    if msk is None:
        continue
    num = int((msk > 0).sum())
    if num > 0:
        pixels_list.append(num)
        if min_pixels is None or num < min_pixels:
            min_pixels = num
        if num > max_pixels:
            max_pixels = num

print(f"Минимальное количество пикселей дефекта в одной маске: {min_pixels}")
print(f"Максимальное: {max_pixels}")
print(f"Медианное: {np.median(pixels_list)}")
print(f"Гистограмма по всем изображениям:")
import matplotlib.pyplot as plt
plt.hist(pixels_list, bins=30)
plt.xlabel("Дефектных пикселей на маске")
plt.ylabel("Частота")
plt.show()


# Preprocess data

## slice_to_patches

In [None]:
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm

# --- Метки дефектов ---
DEFECT_LABELS = {
    '000': 'No defect',
    '002': 'Broken end',
    '006': 'Broken yarn',
    '010': 'Broken pick',
    '016': 'Weft curling',
    '019': 'Fuzzyball',
    '022': 'Cut selvage',
    '023': 'Crease',
    '025': 'Warp ball',
    '027': 'Knots',
    '029': 'Contamination',
    '030': 'Nep',
    '036': 'Weft crack'
}

# --- Параметры нарезки ---
SRC_IMG_DIR = Path("./aitex_data/extracted/Defect_images")
SRC_MSK_DIR = Path("./aitex_data/extracted/Mask_images")
DST_IMG_DIR = Path("./aitex_patches/images")
DST_MSK_DIR = Path("./aitex_patches/masks")

PATCH_W = PATCH_H = 256
STRIDE_W = STRIDE_H = 64
# MIN_DEFECT_FRAC = 0.005       # минимальная доля дефекта
MIN_DEFECT_PIXELS = 8
KEEP_NEG = 0.05

DST_IMG_DIR.mkdir(parents=True, exist_ok=True)
DST_MSK_DIR.mkdir(parents=True, exist_ok=True)

rows = []
PATCH_AREA = PATCH_W * PATCH_H

def has_large_defect(mask, min_size=20):
    num_labels, _, stats, _ = cv2.connectedComponentsWithStats((mask > 0).astype(np.uint8))
    for i in range(1, num_labels):  # 0 — фон
        if stats[i, cv2.CC_STAT_AREA] >= min_size:
            return True
    return False

for img_path in tqdm(sorted(SRC_IMG_DIR.glob("*.png")), desc="Cropping AITEX (grid)"):
    mask_path = SRC_MSK_DIR / img_path.name.replace(".png", "_mask.png")
    img = cv2.imread(str(img_path))
    msk = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)

    if img is None or msk is None:
        print(f"❌  Пропуск {img_path.name} (файла нет)")
        continue

    msk_bin = (msk > 0).astype(np.uint8)
    orig_defect_code_str = img_path.stem.split('_')[1]
    orig_defect_code = int(orig_defect_code_str)
    orig_defect_label = DEFECT_LABELS.get(orig_defect_code_str, "Unknown")

    # ДВА ЦИКЛА: по y и по x
    for y in range(0, img.shape[0] - PATCH_H + 1, STRIDE_H):
        for x in range(0, img.shape[1] - PATCH_W + 1, STRIDE_W):
            img_crop = img[y:y+PATCH_H, x:x+PATCH_W]
            msk_crop = msk_bin[y:y+PATCH_H, x:x+PATCH_W]
            pos_pix = int(msk_crop.sum())
            defect_frac = pos_pix / PATCH_AREA

            # Комбинированный фильтр
            # is_defective = (defect_frac >= MIN_DEFECT_FRAC) or (pos_pix >= MIN_DEFECT_PIXELS)

            is_defective = pos_pix >= MIN_DEFECT_PIXELS

            
            patch_defect_code = orig_defect_code
            patch_defect_label = orig_defect_label

            # Если патч "чистый", переопределяем метку
            if not is_defective or not has_large_defect(msk_crop, min_size=20):
                if np.random.rand() > KEEP_NEG:
                    continue
                patch_defect_code = 0
                patch_defect_label = DEFECT_LABELS['000']

            suffix = f"x{x:04d}_y{y:04d}"
            fname  = f"{img_path.stem}_{suffix}.png"
            cv2.imwrite(str(DST_IMG_DIR / fname), img_crop)
            cv2.imwrite(str(DST_MSK_DIR / fname), msk_crop * 255)
            rows.append((fname, patch_defect_code, patch_defect_label))

# --- Сохраняем CSV с метками ---
label_path = Path("./aitex_patches/patch_labels.csv")
label_df = pd.DataFrame(rows, columns=["file", "defect_code", "defect_label"])
label_df.to_csv(label_path, index=False)
print("📝  Saved", len(rows), "patch labels →", label_path)
print("✅  Нарезка патчей завершена.")


## Undersampling

In [None]:
import pandas as pd

df = pd.read_csv('./aitex_patches/patch_labels.csv')
has_defect = df[df['defect_label'] != 'No defect']
no_defect  = df[df['defect_label'] == 'No defect']

print(f"Дефектных патчей: {len(has_defect)}")
print(f"Чистых патчей: {len(no_defect)}")

desired_ratio = 1.0  # например, 1:1
n_defect = len(has_defect)
n_no_defect = min(int(n_defect * desired_ratio), len(no_defect))

no_defect_sampled = no_defect.sample(n=n_no_defect, random_state=42)
df_balanced = pd.concat([has_defect, no_defect_sampled]).sample(frac=1, random_state=42)

balanced_label_path = './aitex_patches/patch_labels_balanced.csv'
df_balanced.to_csv(balanced_label_path, index=False)
print(f"Balanced CSV saved: {balanced_label_path}")

# --- Новый вывод ---
summary = df_balanced['defect_label'].value_counts().reset_index()
summary.columns = ['defect_label', 'num_patches']
summary['percentage'] = (summary['num_patches'] / summary['num_patches'].sum() * 100).round(2)

print("\n=== Patch distribution (balanced) ===")
print(summary.to_string(index=False))

PATCH_LABEL_PATH = balanced_label_path

## Visualize data after processing

In [None]:
import random
import matplotlib.pyplot as plt
from pathlib import Path
import cv2
import pandas as pd

def visualize_patches_with_masks_and_labels(
    img_dir,
    mask_dir,
    csv_path,
    min_defect_pixels=MIN_DEFECT_PIXELS,
    n_pos=6,
    n_neg=6
):
    """
    Показывает патчи с дефектом и без дефекта:
    - Оригинал (RGB)
    - Маска (отдельно)
    - Наложение маски (Mask Overlay)
    В заголовке — статус (DEFECT/CLEAN) и класс дефекта.
    """
    df = pd.read_csv(csv_path)
    pos_samples, neg_samples = [], []

    for _, row in df.iterrows():
        mask_path = Path(mask_dir) / row['file']
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            continue
        defect_pixels = (mask > 0).sum()
        info = (row['file'], row['defect_label'])
        if defect_pixels >= min_defect_pixels:
            pos_samples.append(info)
        else:
            neg_samples.append(info)

    pos_samples = random.sample(pos_samples, min(n_pos, len(pos_samples)))
    neg_samples = random.sample(neg_samples, min(n_neg, len(neg_samples)))
    all_samples = [(fname, "DEFECT", label) for fname, label in pos_samples] + \
                  [(fname, "CLEAN", label) for fname, label in neg_samples]

    plt.figure(figsize=(len(all_samples) * 4, 10))
    for i, (fname, status, defect_label) in enumerate(all_samples):
        img = cv2.imread(str(Path(img_dir) / fname))
        mask = cv2.imread(str(Path(mask_dir) / fname), cv2.IMREAD_GRAYSCALE)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 1. Исходное изображение
        plt.subplot(3, len(all_samples), i + 1)
        plt.imshow(img_rgb)
        plt.title(f"{status}\n{defect_label}\n{fname}", fontsize=9)
        plt.axis('off')

        # 2. Маска (отдельно)
        plt.subplot(3, len(all_samples), len(all_samples) + i + 1)
        plt.imshow(mask, cmap='gray')
        plt.title("Mask", fontsize=9)
        plt.axis('off')

        # 3. Наложение маски
        plt.subplot(3, len(all_samples), 2 * len(all_samples) + i + 1)
        plt.imshow(img_rgb)
        plt.imshow(mask, cmap='Reds', alpha=0.5)
        plt.title("Overlay", fontsize=9)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# --- Запуск визуализации ---
visualize_patches_with_masks_and_labels(
    img_dir=DST_IMG_DIR,
    mask_dir=DST_MSK_DIR,
    csv_path=PATCH_LABEL_PATH,
    min_defect_pixels=MIN_DEFECT_PIXELS,  # ключевое отличие!
    n_pos=6,
    n_neg=2
)


# Create Dataset

## augmenations

In [None]:
# ───────── transforms.py (или любая ваша ячейка) ─────────
import albumentations as A
from albumentations.pytorch import ToTensorV2

_IMAGENET_MEAN = (0.485, 0.456, 0.406)
_IMAGENET_STD  = (0.229, 0.224, 0.225)

def get_strong_train_transform(image_size=(256, 256)):
    return A.Compose(
        [
            A.Resize(*image_size),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.Affine(scale=(0.9, 1.1),
                     translate_percent=0.1,
                     rotate=(-25, 25),
                     p=0.7),
            A.ElasticTransform(alpha=1, sigma=50, p=0.3),
            A.GridDistortion(num_steps=5, distort_limit=0.2, p=0.2),
            A.RandomBrightnessContrast(p=0.5),
            A.HueSaturationValue(p=0.3),
            A.GaussNoise(p=0.4),
            A.GaussianBlur(p=0.3),

            # ✓ корректные ключи + задаём min_* (иначе Albumentations предупреждает)
            A.CoarseDropout(
                max_holes=4, min_holes=1,       # ← важно
                max_height=16, min_height=4,
                max_width=16,  min_width=4,
                fill_value=0, p=0.3
            ),

            A.Normalize(mean=(0.485, 0.456, 0.406),
                        std=(0.229, 0.224, 0.225)),
            ToTensorV2(),
        ],
        additional_targets={"mask": "mask"},
    )

def get_weak_train_transform(image_size=(256, 256)):
    return A.Compose(
        [
            A.Resize(*image_size),
            A.HorizontalFlip(p=0.5),
            A.Normalize(mean=_IMAGENET_MEAN, std=_IMAGENET_STD),
            ToTensorV2(),
        ],
        additional_targets={"mask": "mask"},
    )


## dataset_class

In [None]:
from pathlib import Path
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset


class AITEXPatchDataset(Dataset):
    """
    Возвращает кортеж:
        (image_tensor, mask_tensor, defect_code:int, defect_label:str)
    """

    def __init__(
        self,
        img_dir,                       # ./aitex_patches/images
        msk_dir,                       # ./aitex_patches/masks
        file_list,                     # список имён файлов *.png
        label_map,                     # {fname: (code:int, label:str)}
        strong_transform,              # albumentations Compose (defect)
        weak_transform,                # albumentations Compose (clean)
        *,
        enable_copy_paste: bool = False,
        copy_paste_classwise: bool = False,
        class_copy_paste_probs: dict[int, float] | None = None,
        copy_paste_prob: float = 0.2,          # ← единая вероятность CP
        random_state: int = 42,
    ):
        self.img_dir = Path(img_dir)
        self.msk_dir = Path(msk_dir)
        self.file_list = file_list
        self.label_map = label_map
        self.strong_transform = strong_transform
        self.weak_transform = weak_transform

        self.enable_copy_paste = enable_copy_paste
        self.copy_paste_classwise = copy_paste_classwise
        self.class_copy_paste_probs = class_copy_paste_probs or {}
        self.copy_paste_prob = copy_paste_prob
        self.rng = np.random.RandomState(random_state)

        # для class-aware CP
        self.class_to_files: dict[int, list[str]] = {}
        for f in self.file_list:
            code = self.label_map[f][0]
            if code != 0:
                self.class_to_files.setdefault(code, []).append(f)

        self.defect_files = [f for f in self.file_list if self.label_map[f][0] != 0]

    # --------------------------------------------------------------------- len --
    def __len__(self) -> int:
        return len(self.file_list)

    # -------------------------------------------------------------------- getitem
    def __getitem__(self, idx: int):
        fname = self.file_list[idx]

        # ------- загрузка -------------------------------------------------------
        img = cv2.cvtColor(cv2.imread(str(self.img_dir / fname)), cv2.COLOR_BGR2RGB)
        msk = cv2.imread(str(self.msk_dir / fname), cv2.IMREAD_GRAYSCALE)
        msk = (msk > 0).astype(np.uint8)

        defect_code, defect_label = self.label_map[fname]

        # ------- probabilistic Copy-Paste для чистых ---------------------------
        if (
            self.enable_copy_paste
            and defect_code == 0
            and len(self.defect_files) > 0
            and self.rng.rand() < self.copy_paste_prob          # !!! вероятность
        ):
            # выбор «донора»
            if self.copy_paste_classwise and self.class_copy_paste_probs:
                codes, probs = zip(*self.class_copy_paste_probs.items())
                sel_code = self.rng.choice(codes, p=probs)
                pool = self.class_to_files.get(sel_code, self.defect_files)
            else:
                pool = self.defect_files

            paste_fname = self.rng.choice(pool)

            # вставка дефекта
            paste_img = cv2.cvtColor(cv2.imread(str(self.img_dir / paste_fname)), cv2.COLOR_BGR2RGB)
            paste_msk = cv2.imread(str(self.msk_dir / paste_fname), cv2.IMREAD_GRAYSCALE) > 0

            img[paste_msk] = paste_img[paste_msk]
            msk[paste_msk] = 1
            defect_code, defect_label = self.label_map[paste_fname]

        # ------- аугментации ----------------------------------------------------
        if defect_code != 0:
            aug = self.strong_transform(image=img, mask=msk)
        else:
            aug = self.weak_transform(image=img, mask=msk)

        img_t, msk_t = aug["image"], aug["mask"]
        return img_t, msk_t, defect_code, defect_label


## dataset creation helpers

In [None]:
# ───────── prepare_patch_datasets (обновлённая версия) ─────────
from pathlib import Path
from sklearn.model_selection import train_test_split
from collections import Counter
import numpy as np
import pandas as pd

def get_class_copy_paste_probs(train_files, label_map):
    cnt = Counter(label_map[f][0] for f in train_files if label_map[f][0] != 0)
    med = int(np.median(list(cnt.values()))) or 1
    deficit = {c: med - n for c, n in cnt.items() if n < med}
    total = sum(deficit.values())
    probs = {c: d / total for c, d in deficit.items()} if total else {}
    return probs, med, cnt

def prepare_patch_datasets(
    patch_root          = "./aitex_patches",
    patch_label_path    = "./aitex_patches/patch_labels_balanced.csv",
    *,
    test_size           = 0.05,
    val_size            = 0.10,
    random_state        = 42,
    image_size          = (256, 256),          # (PATCH_H, PATCH_W)
    copy_paste_prob     = 0.8,
    strong_tr           = None,
    weak_tr             = None,
):
    img_dir = Path(patch_root) / "images"
    msk_dir = Path(patch_root) / "masks"

    df = pd.read_csv(patch_label_path)
    label_map = {r.file: (int(r.defect_code), r.defect_label) for _, r in df.iterrows()}

    files = sorted(label_map)
    strat = [1 if label_map[f][0] else 0 for f in files]
    train_val, test = train_test_split(files, test_size=test_size, stratify=strat,
                                       random_state=random_state, shuffle=True)
    train, val = train_test_split(train_val, test_size=val_size/(1-test_size),
                                  random_state=random_state, shuffle=True)

    # --- аугментации ----------------------------------------------------------
    if strong_tr is None:
        strong_tr = get_strong_train_transform(image_size)
    if weak_tr is None:
        weak_tr = get_weak_train_transform(image_size)

    # --- class-aware Copy-Paste probs -----------------------------------------
    cp_probs, median, counts = get_class_copy_paste_probs(train, label_map)
    print(f"Copy-Paste up to median={median}.  Counts: {dict(counts)}")

    train_ds = AITEXPatchDataset(
        img_dir, msk_dir, train, label_map,
        strong_transform=strong_tr,
        weak_transform=weak_tr,
        enable_copy_paste=True,
        copy_paste_classwise=True,
        class_copy_paste_probs=cp_probs,
        copy_paste_prob=copy_paste_prob,
        random_state=random_state,
    )
    val_ds = AITEXPatchDataset(img_dir, msk_dir, val, label_map,
                               strong_transform=weak_tr, weak_transform=weak_tr,
                               enable_copy_paste=False)
    test_ds = AITEXPatchDataset(img_dir, msk_dir, test, label_map,
                                strong_transform=weak_tr, weak_transform=weak_tr,
                                enable_copy_paste=False)
    return train_ds, val_ds, test_ds


## dataloader creation helper

In [None]:
from torch.utils.data import DataLoader

def get_dataloaders(train_ds, val_ds, test_ds, batch_size=32, num_workers=2, pin_memory=True):
    train_loader = DataLoader(
        train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=False
    )
    val_loader = DataLoader(
        val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory
    )
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory
    )
    return train_loader, val_loader, test_loader


## Initialize dataloaders

In [None]:
train_ds, val_ds, test_ds = prepare_patch_datasets(
    patch_label_path=PATCH_LABEL_PATH,
    image_size=(PATCH_H, PATCH_W)
)

train_loader, val_loader, test_loader = get_dataloaders(
    train_ds, val_ds, test_ds,
    batch_size=CONFIG['batch_size'], num_workers=CONFIG['num_workers'], pin_memory=True
)

print(f"Train samples: {len(train_ds)}")
print(f"Validation samples: {len(val_ds)}")
print(f"Test samples: {len(test_ds)}")

import pandas as pd

def print_class_distribution_from_dataset(dataset, label_names=None, title="Train set"):
    """
    Выводит таблицу распределения патчей по классам для кастомного датасета.
    - dataset: экземпляр Dataset (например, AITEXPatchDataset), который возвращает (img, mask, code, label)
    - label_names: dict для отображения красивых названий классов (например, DEFECT_LABELS)
    - title: заголовок для вывода
    """
    codes = []
    labels = []
    for i in range(len(dataset)):
        try:
            # Самый частый случай (img, mask, code, label)
            _, _, code, label = dataset[i]
        except Exception:
            # На всякий случай для (img, mask, code)
            _, _, code = dataset[i]
            label = str(code)
        codes.append(code)
        labels.append(label)

    df = pd.DataFrame({'class_code': codes, 'class_label': labels})
    summary = df.value_counts(['class_code', 'class_label']).reset_index(name="num_patches")
    summary['percentage'] = (summary['num_patches'] / summary['num_patches'].sum() * 100).round(2)
    if label_names:
        summary['class_label'] = summary['class_code'].map(label_names).fillna(summary['class_label'])

    print(f"\n=== Patch distribution in {title} ===")
    print(summary.sort_values("num_patches", ascending=False).to_string(index=False))

print_class_distribution_from_dataset(train_ds, label_names=DEFECT_LABELS, title="Train set")

## visualize dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

def visualize_dataset_grid(dataset, num_samples=6):
    """
    Горизонтально — разные патчи, вертикально:
    1. Исходное изображение (RGB)
    2. Маска (отдельно)
    3. Overlay (наложение маски)
    В заголовке — класс и код дефекта.
    """
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    plt.figure(figsize=(num_samples * 4, 10))

    for col, idx in enumerate(indices):
        sample = dataset[idx]
        if len(sample) == 4:
            image, mask, code, label_str = sample
            title = f"{label_str} (code {code})"
        else:
            image, mask, code = sample
            title = f"Class {code}"

        rgb = image[:3].permute(1, 2, 0).cpu().numpy()
        rgb = (rgb * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
        mask_np = mask.cpu().numpy()

        # 1. Исходное изображение (верхний ряд)
        plt.subplot(3, num_samples, 1 + col)
        plt.imshow(rgb)
        plt.title(title, fontsize=10)
        plt.axis("off")

        # 2. Маска (средний ряд)
        plt.subplot(3, num_samples, 1 + num_samples + col)
        plt.imshow(mask_np, cmap="gray")
        plt.title("Mask", fontsize=10)
        plt.axis("off")

        # 3. Overlay (нижний ряд)
        plt.subplot(3, num_samples, 1 + 2 * num_samples + col)
        plt.imshow(rgb)
        plt.imshow(mask_np, cmap='Reds', alpha=0.5)
        plt.title("Overlay", fontsize=10)
        plt.axis("off")

    plt.tight_layout()
    plt.show()


visualize_dataset_grid(test_ds, num_samples=5)

# Training

## libraries

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import torch.nn.functional as F

#### train_epoch

In [None]:
import torch
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import torch.nn as nn

def train_epoch(model, train_loader, optimizer, loss_fn, scaler, device, epoch, model_name, log_interval=10):
    model.train()
    running_loss = 0.0

    # Сбор метрик по батчам
    dices, ious, accs, precs, recs = [], [], [], [], []

    for step, (images, masks, _, _) in enumerate(tqdm(train_loader, desc=f"Train {epoch}")):
        images = images.to(device)
        masks  = masks.unsqueeze(1).float().to(device)

        optimizer.zero_grad()
        with autocast(enabled=(device.type == 'cuda')):
            outputs = model(images)
            loss   = loss_fn(outputs, masks)

        if device.type == 'cuda':
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

        running_loss += loss.item()

        # Метрики по батчу
        with torch.no_grad():
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            dices.append(compute_dice(preds, masks))
            ious.append(compute_iou(preds, masks))
            accs.append(compute_accuracy(preds, masks))
            precs.append(compute_precision(preds, masks))
            recs.append(compute_recall(preds, masks))

    epoch_loss = running_loss / len(train_loader)
    metrics = {
        'loss': epoch_loss,
        'dice': np.mean(dices),
        'iou': np.mean(ious),
        'accuracy': np.mean(accs),
        'precision': np.mean(precs),
        'recall': np.mean(recs),
    }
    print(f"[{model_name}] Train epoch {epoch}: {metrics}")
    return metrics


#### validate_epoch

In [None]:
def validate_epoch(model, val_loader, device, epoch, model_name, threshold=0.5):
    model.eval()
    dices, ious, accs, precs, recs = [], [], [], [], []
    running_loss = 0.0
    batches = 0

    with torch.no_grad():
        for step, (images, masks, _, _) in enumerate(tqdm(val_loader, desc=f"Val  {epoch}")):
            images = images.to(device)
            masks  = masks.unsqueeze(1).float().to(device)
            if masks.sum().item() == 0:
                continue
            outputs = model(images)
            probs   = torch.sigmoid(outputs)
            preds   = (probs > threshold).float()
            # Метрики
            dices.append(compute_dice(preds, masks))
            ious.append(compute_iou(preds, masks))
            accs.append(compute_accuracy(preds, masks))
            precs.append(compute_precision(preds, masks))
            recs.append(compute_recall(preds, masks))
            running_loss += nn.BCEWithLogitsLoss()(outputs, masks).item()
            batches += 1

    metrics = {
        'loss': running_loss / batches if batches > 0 else 0.0,
        'dice': np.mean(dices) if dices else 0.0,
        'iou': np.mean(ious) if ious else 0.0,
        'accuracy': np.mean(accs) if accs else 0.0,
        'precision': np.mean(precs) if precs else 0.0,
        'recall': np.mean(recs) if recs else 0.0,
    }
    print(f"[{model_name}] Val epoch {epoch}: {metrics}")
    return metrics


### train_loop

In [None]:
def train_loop(
    model, optimizer, scheduler, train_loader, val_loader,
    loss_fn, device, scaler,
    num_epochs=50, early_stop_patience=5, model_name='model'
):
    best_dice = 0.0
    counter   = 0
    train_history, val_history = [], []

    for epoch in range(num_epochs):
        train_metrics = train_epoch(model, train_loader, optimizer, loss_fn, scaler, device, epoch, model_name)
        val_metrics   = validate_epoch(model, val_loader, device, epoch, model_name)

        scheduler.step(val_metrics['dice'])

        train_history.append(train_metrics)
        val_history.append(val_metrics)

        # Early stopping
        if val_metrics['dice'] > best_dice:
            best_dice = val_metrics['dice']
            counter = 0
            torch.save(model.state_dict(), f"best_{model_name}.pth")
        else:
            counter += 1
            if counter >= early_stop_patience:
                print(f"Early stopping at epoch {epoch}")
                break

        print(f"Epoch {epoch}: Train Dice={train_metrics['dice']:.4f}  Val Dice={val_metrics['dice']:.4f}")

    print(f"Best Val Dice: {best_dice:.4f}")
    return train_history, val_history


### Plot metrics

In [None]:
import matplotlib.pyplot as plt

def plot_metrics(train_history, val_history, model_name, save_path=None):
    epochs = np.arange(1, len(train_history) + 1)
    keys = ['loss', 'dice', 'iou', 'accuracy', 'precision', 'recall']
    plt.figure(figsize=(18, 10))
    for idx, key in enumerate(keys, 1):
        plt.subplot(2, 3, idx)
        train_vals = [m[key] for m in train_history]
        val_vals = [m[key] for m in val_history]
        plt.plot(epochs, train_vals, label=f"Train {key}")
        plt.plot(epochs, val_vals, label=f"Val {key}")
        plt.title(f"{model_name}: {key}")
        plt.xlabel("Epoch")
        plt.legend()
        plt.grid()
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.show()


## train_model

In [None]:
import segmentation_models_pytorch as smp


def train_single_model(model_name, model_fn, train_loader, val_loader, config):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # --- Model & Optim ----
    model = model_fn().to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=3
    )

    # --- Losses ---
    # loss_fn = FDLoss(wf=0.1, gamma=2.0, alpha=0.25).to(device)
    # loss_fn = 0.8 * torch.nn.BCEWithLogitsLoss()+ 0.2 * DiceLoss()
    # loss_fn = loss_fn.to(device)
    loss_fn = BCEDiceLoss(bce_weight=0.8, dice_weight=0.2).to(device)

    
    # --- AMP scaler ---
    scaler = GradScaler() if device.type == 'cuda' else None

    # --- TensorBoard ---
    writer = SummaryWriter(log_dir=os.path.join(config['log_dir'], model_name))

    # --- Resume option ---
    checkpoint_path = f"best_{model_name}.pth"
    if config.get('resume') and os.path.exists(checkpoint_path):
        print(f"Resuming {model_name} from checkpoint…")
        model.load_state_dict(torch.load(checkpoint_path, map_location=device))

    # --- Train ---
    train_hist, val_hist = train_loop(
        model, optimizer, scheduler, train_loader, val_loader,
        loss_fn, device, scaler,
        num_epochs=config['num_epochs'],
        early_stop_patience=config['early_stop_patience'],
        model_name=model_name
    )
    return train_hist, val_hist


## run training 

In [None]:
import gc
import torch

def train_all_models(models_dict, train_loader, val_loader, config):
    for name, model_fn in models_dict.items():
        print(f"==== Training model: {name} ====")

        # --- Очистить память перед обучением новой модели ---
        torch.cuda.empty_cache()
        gc.collect()

        train_hist, val_hist = train_single_model(name, model_fn, train_loader, val_loader, config)    
        plot_metrics(train_hist, val_hist, model_name=name, save_path=f"metrics_{name}.png")
        
        # --- Очистить память после обучения модели (особенно если большие модели!) ---
        torch.cuda.empty_cache()
        gc.collect()

    return


train_all_models(MODELS, train_loader, val_loader, CONFIG)

## Run training

# Test

### tta_predict

In [None]:

# --- Функция TTA предсказания ---
def tta_predict(model, images):
    """
    Применяет простую горизонтальную TTA: оригинал + отражение
    """
    model.eval()
    with torch.no_grad():
        orig = model(images)
        flip = model(torch.flip(images, dims=[3]))  # flip horizontally
        flip = torch.flip(flip, dims=[3])
        return (orig + flip) / 2


### visualize_pred

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_predictions_tta_with_label(
    model,
    dataloader,
    device,
    num_samples=10,
    imagenet_norm=True,
    model_title=None  # Новый параметр — имя/название модели
):
    """
    Горизонтально: разные патчи (num_samples), вертикально:
    1. RGB
    2. Ground Truth
    3. Prediction (TTA)
    В заголовке каждого столбца — код и строка класса, если есть.
    model_title — общий заголовок для всей фигуры (имя модели)
    """
    model = model.to(device)
    model.eval()
    shown = 0

    images_list = []
    true_masks_list = []
    pred_masks_list = []
    titles = []

    with torch.no_grad():
        for batch in dataloader:
            images, masks, *meta = batch
            images = images.to(device)
            masks  = masks.unsqueeze(1).float().to(device)

            outputs = tta_predict(model, images)
            probs   = torch.sigmoid(outputs)
            preds   = (probs > 0.5).float()

            batch_size = images.size(0)
            label_codes = meta[0] if len(meta) > 0 else [None] * batch_size
            label_strs  = meta[1] if len(meta) > 1 else [None] * batch_size

            for i in range(batch_size):
                if shown >= num_samples:
                    break

                img = images[i].cpu().permute(1, 2, 0).numpy()
                if imagenet_norm:
                    img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0,1)
                true_mask = masks[i].cpu().squeeze().numpy()
                pred_mask = preds[i].cpu().squeeze().numpy()

                # Заголовок с кодом и строкой класса
                label_code = label_codes[i] if label_codes is not None else None
                label_str = label_strs[i] if label_strs is not None else None

                if label_str is not None and label_code is not None:
                    title = f"Code: {label_code}, {label_str}"
                elif label_code is not None:
                    title = f"Code: {label_code}"
                elif label_str is not None:
                    title = f"{label_str}"
                else:
                    title = "Image"

                images_list.append(img)
                true_masks_list.append(true_mask)
                pred_masks_list.append(pred_mask)
                titles.append(title)

                shown += 1
            if shown >= num_samples:
                break

    # Визуализация: по горизонтали — примеры, по вертикали — Image, GT, Prediction
    fig = plt.figure(figsize=(num_samples * 3, 10))
    if model_title is not None:
        plt.suptitle(model_title, fontsize=18, y=1.02)
    for col in range(shown):
        # 1. Image
        plt.subplot(3, num_samples, col + 1)
        plt.imshow(images_list[col])
        plt.title(titles[col], fontsize=9)
        plt.axis("off")
        # 2. GT
        plt.subplot(3, num_samples, num_samples + col + 1)
        plt.imshow(true_masks_list[col], cmap="gray")
        if col == 0:
            plt.ylabel("GT Mask", fontsize=12)
        plt.axis("off")
        # 3. Prediction
        plt.subplot(3, num_samples, 2 * num_samples + col + 1)
        plt.imshow(pred_masks_list[col], cmap="gray")
        if col == 0:
            plt.ylabel("Prediction", fontsize=12)
        plt.axis("off")

    plt.tight_layout()
    plt.show()


In [None]:
import torch
import gc

model_names = list(MODELS.keys())

for model_name in model_names:
    checkpoint_path = f'best_{model_name}.pth'
    print(f"\n--- Model: {model_name} ---")
    try:
        test_model = MODELS[model_name]().to(device)
        test_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        test_model.eval()
        
        visualize_predictions_tta_with_label(
            test_model,
            dataloader=test_loader,
            device=device,
            num_samples=5,
            model_title=f"Model: {model_name}"
        )
    except Exception as e:
        print(f"❌ Could not evaluate {model_name}: {e}")
    finally:
        # === Освобождение памяти ===
        del test_model
        torch.cuda.empty_cache()
        gc.collect()


## Sanity checks

In [None]:
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# 1. Найдём патч с максимальной площадью дефекта
max_defect_idx = None
max_defect_mean = 0
for i in range(len(train_ds)):
    img, mask, code, label = train_ds[i]
    m = mask.float().mean().item()
    if m > max_defect_mean:
        max_defect_mean = m
        max_defect_idx = i

print(f"Max defect mean: {max_defect_mean:.4f} at index {max_defect_idx}")

# 2. Делаем датасет с 8 копиями одного патча
img, mask, *_ = train_ds[max_defect_idx]
N = 8
single_ds = [(img, mask)] * N
single_loader = DataLoader(single_ds, batch_size=N, shuffle=True)

# 3. Прогоним sanity check для всех моделей
for model_name, model_fn in MODELS.items():
    print(f"\n==== Sanity check: {model_name} ====")
    # Переинициализируем модель
    model = model_fn().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = FDLoss(wf=0.1, gamma=2.0, alpha=0.25)  # или твой основной loss

    # Быстрое обучение на одном патче
    for epoch in range(20):
        model.train()
        for img_batch, mask_batch in single_loader:
            img_batch = img_batch.to(device)
            mask_batch = mask_batch.unsqueeze(1).float().to(device)
            optimizer.zero_grad()
            out = model(img_batch)
            loss = loss_fn(out, mask_batch)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

        model.eval()
        with torch.no_grad():
            pred = torch.sigmoid(model(img_batch)).cpu().numpy()
            mask_np = mask_batch.cpu().numpy()
            print("GT mean:", mask_np.mean(), "Pred mean:", pred.mean())

    # Визуализация первого патча
    plt.figure(figsize=(12, 4))
    plt.suptitle(f"Sanity Check: {model_name}")
    plt.subplot(1, 3, 1); plt.imshow(img[:3].permute(1,2,0).cpu().numpy()); plt.title("Image"); plt.axis("off")
    plt.subplot(1, 3, 2); plt.imshow(mask.squeeze().cpu().numpy(), cmap='gray'); plt.title("Mask"); plt.axis("off")
    plt.subplot(1, 3, 3); plt.imshow(pred[0].squeeze(), cmap='gray'); plt.title("Prediction"); plt.axis("off")
    plt.show()


In [None]:
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

# 1. Найти чистый патч в train_ds (код == 0 и маска полностью пустая)
clean_idx = None
for i in range(len(train_ds)):
    img, mask, code, label = train_ds[i]
    if code == 0 and mask.float().mean().item() == 0.0:
        clean_idx = i
        break

if clean_idx is None:
    raise ValueError("Чистый патч не найден!")
print(f"Чистый патч найден: index {clean_idx}")

# 2. Делаем датасет с 8 копиями одного патча (для BatchNorm sanity)
img, mask, *_ = train_ds[clean_idx]
N = 8
single_clean_ds = [(img, mask)] * N
single_clean_loader = DataLoader(single_clean_ds, batch_size=N, shuffle=True)

# 3. Прогоним sanity check для всех моделей
for model_name, model_fn in MODELS.items():
    print(f"\n==== Sanity check on CLEAN: {model_name} ====")
    # Переинициализируем модель
    model = model_fn().to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
    loss_fn = FDLoss(wf=0.1, gamma=2.0, alpha=0.25)  # или твой основной loss

    # Быстрое обучение на одном патче
    for epoch in range(10):  # обычно 10 эпох достаточно для чистого
        model.train()
        for img_batch, mask_batch in single_clean_loader:
            img_batch = img_batch.to(device)
            mask_batch = mask_batch.unsqueeze(1).float().to(device)
            optimizer.zero_grad()
            out = model(img_batch)
            loss = loss_fn(out, mask_batch)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch}: Loss = {loss.item():.4f}")

        model.eval()
        with torch.no_grad():
            pred = torch.sigmoid(model(img_batch)).cpu().numpy()
            mask_np = mask_batch.cpu().numpy()
            print("GT mean:", mask_np.mean(), "Pred mean:", pred.mean())

    # 4. Визуализация первого патча (первый элемент батча)
    plt.figure(figsize=(12, 4))
    plt.suptitle(f"Sanity Check (CLEAN): {model_name}")
    plt.subplot(1, 3, 1)
    plt.imshow(img[:3].permute(1, 2, 0).cpu().numpy())
    plt.title("Image")
    plt.axis("off")
    plt.subplot(1, 3, 2)
    plt.imshow(mask.squeeze().cpu().numpy(), cmap='gray')
    plt.title("Mask")
    plt.axis("off")
    plt.subplot(1, 3, 3)
    plt.imshow(pred[0].squeeze(), cmap='gray')
    plt.title("Prediction")
    plt.axis("off")
    plt.show()


# Inference

In [None]:
import numpy as np
import torch
import cv2
from pathlib import Path
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt

def infer_full_image_with_models(
    image,
    models_dict,
    preprocess,
    patch_h=256,
    patch_w=256,
    stride_h=64,
    stride_w=64,
    device='cuda',
    threshold=0.5,
    model_names=None,
    config=None
):
    H, W, C = image.shape
    if model_names is None:
        model_names = list(models_dict.keys())
    batch_size = config["batch_size"] if config and "batch_size" in config else 8

    masks_dict = {name: np.zeros((H, W), dtype=np.float32) for name in model_names}
    patch_coords = []
    patches = []
    for y in range(0, H - patch_h + 1, stride_h):
        for x in range(0, W - patch_w + 1, stride_w):
            patch = image[y:y+patch_h, x:x+patch_w]
            patches.append(patch)
            patch_coords.append((y, x))
    patch_tensors = [preprocess(p) for p in patches]
    patch_batch = torch.stack(patch_tensors)  # [N, C, H, W]

    for model_name in model_names:
        torch.cuda.empty_cache()
        model = models_dict[model_name]().to(device)
        model.eval()
        try:
            model.load_state_dict(torch.load(f'best_{model_name}.pth', map_location=device))
        except Exception as e:
            print(f"[{model_name}] checkpoint not loaded: {e}")

        preds = []
        with torch.no_grad():
            for i in range(0, len(patch_batch), batch_size):
                batch = patch_batch[i:i+batch_size].to(device)
                out = model(batch)
                prob = torch.sigmoid(out)
                preds.append(prob.cpu().squeeze().numpy())
                del batch, out, prob
                torch.cuda.empty_cache()
        preds = np.concatenate(preds, axis=0)

        # Собираем маску
        mask = np.zeros((H, W), dtype=np.float32)
        count = np.zeros((H, W), dtype=np.float32)
        for idx, (y, x) in enumerate(patch_coords):
            patch_pred = preds[idx]
            if patch_pred.ndim == 3:
                patch_pred = patch_pred[0]
            mask[y:y+patch_h, x:x+patch_w] += patch_pred
            count[y:y+patch_h, x:x+patch_w] += 1
        mask /= np.maximum(count, 1)
        masks_dict[model_name] = (mask > threshold).astype(np.uint8)
        del model
        torch.cuda.empty_cache()

    del patch_batch, patch_tensors, patches
    torch.cuda.empty_cache()
    return masks_dict


In [None]:
def show_result_column(img, mask, model_name='Model'):
    plt.figure(figsize=(6, 12))
    plt.subplot(3, 1, 1)
    plt.imshow(img)
    plt.title(f"{model_name}: Original")
    plt.axis('off')

    plt.subplot(3, 1, 2)
    plt.imshow(mask, cmap='gray')
    plt.title("Mask")
    plt.axis('off')

    plt.subplot(3, 1, 3)
    plt.imshow(img)
    plt.imshow(mask, alpha=0.4, cmap='Reds')
    plt.title("Overlay")
    plt.axis('off')
    plt.tight_layout()
    plt.show()
def show_multi_model_results(img, masks_dict):
    n_models = len(masks_dict)
    plt.figure(figsize=(4 * n_models, 12))
    for col, (model_name, mask) in enumerate(masks_dict.items()):
        # 1. Original
        plt.subplot(3, n_models, 1 + col)
        plt.imshow(img)
        plt.title(f"{model_name}\nOriginal", fontsize=10)
        plt.axis("off")
        # 2. Mask
        plt.subplot(3, n_models, 1 + n_models + col)
        plt.imshow(mask, cmap="gray")
        plt.title("Mask", fontsize=10)
        plt.axis("off")
        # 3. Overlay
        plt.subplot(3, n_models, 1 + 2*n_models + col)
        plt.imshow(img)
        plt.imshow(mask, cmap='Reds', alpha=0.4)
        plt.title("Overlay", fontsize=10)
        plt.axis("off")
    plt.tight_layout()
    plt.show()


In [None]:
import matplotlib.pyplot as plt
import os

def save_and_show_image(img, fname, title=None):
    plt.figure(figsize=(min(16, img.shape[1]//32), 6))  # Пропорционально длине изображения
    plt.imshow(img, cmap=None if img.ndim == 3 else 'gray')
    if title: plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.savefig(fname, bbox_inches='tight', pad_inches=0.1)
    plt.show()

def visualize_and_save_results(img, mask, model_name, save_dir="./inference_results", idx=0):
    os.makedirs(save_dir, exist_ok=True)

    # 1. Original
    fname_orig = os.path.join(save_dir, f"{model_name}_orig_{idx}.png")
    save_and_show_image(img, fname_orig, title=f"{model_name} Original")
    print(f"Saved: {fname_orig}")

    # 2. Mask
    fname_mask = os.path.join(save_dir, f"{model_name}_mask_{idx}.png")
    save_and_show_image(mask, fname_mask, title=f"{model_name} Mask")
    print(f"Saved: {fname_mask}")

    # 3. Overlay
    plt.figure(figsize=(min(16, img.shape[1]//32), 6))
    plt.imshow(img)
    plt.imshow(mask, alpha=0.4, cmap='Reds')
    plt.title(f"{model_name} Overlay")
    plt.axis('off')
    plt.tight_layout()
    fname_overlay = os.path.join(save_dir, f"{model_name}_overlay_{idx}.png")
    plt.savefig(fname_overlay, bbox_inches='tight', pad_inches=0.1)
    plt.show()
    print(f"Saved: {fname_overlay}")


In [None]:

preprocess_albu = A.Compose([
    A.Resize(PATCH_H, PATCH_W),  # Автоматически под твой размер
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])
def preprocess_patch(patch):
    return preprocess_albu(image=patch)['image']

# --- Выбор изображения ---
IMG_DIR = Path('./aitex_data/extracted/Defect_images')
img_files = sorted(list(IMG_DIR.glob('*.png')))
img_path = img_files[0]

img = cv2.imread(str(img_path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# --- Инференс ---
result_masks = infer_full_image_with_models(
    img,
    models_dict=MODELS,
    preprocess=preprocess_patch,
    patch_h=PATCH_H,
    patch_w=PATCH_W,
    stride_h=STRIDE_H,
    stride_w=STRIDE_W,
    device=device,
    config=CONFIG
)

for i, (model_name, mask) in enumerate(result_masks.items()):
    visualize_and_save_results(img, mask, model_name, save_dir="./inference_results", idx=i)
