# Install dependencies

In [None]:
%%writefile requirements.txt
requests
tqdm>=4.65.0
torch>=2.0.0
# torchvision>=0.14.0
opencv-python
numpy>=1.23.0
timm>=0.6.13
# tensorboard>=2.13.0
albumentations>=1.4.0
# segmentation-models-pytorch>=0.3.3
scikit-learn>=1.1.0
matplotlib>=3.5.0
# transformers>=4.31.0


In [None]:
! pip install -U -q pip
! pip install -q -r requirements.txt

# Training config

In [None]:
# --- Настройки обучения для классификации ---
CONFIG = {
    'batch_size': 64,
    'num_workers': 2,
    'num_epochs': 100,
    'learning_rate': 1e-3,
    'weight_decay': 1e-3,
    'early_stop_patience': 10,
    'dataroot': './aitex_data/extracted',
    'log_dir': 'runs/classification_experiment',   # изменил для классификации
    'resume': False
}

import torch
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


## Models

In [None]:
import timm

# Обнови: для классификации нужны num_classes (не classes)
NUM_CLASSES = 12 + 1  # укажи реально (обычно ~12 для дефектов в AITEX)

MODELS = {
    "beit_base_patch16_224": lambda: timm.create_model(
        "beit_base_patch16_224",
        pretrained=True,
        num_classes=NUM_CLASSES,
        in_chans=3,
        drop_rate=0.3,          # <--- Dropout после MLP
        drop_path_rate=0.1      # <--- DropPath между блоками
    ),
    # "convnext_base": lambda: timm.create_model(
    #     "convnext_base",
    #     pretrained=True,
    #     num_classes=NUM_CLASSES,
    #     in_chans=3,
    #     drop_path_rate=0.1      # только DropPath
    # ),
    # "resnet": lambda: timm.create_model(
    #     "resnet50",
    #     pretrained=True,
    #     num_classes=NUM_CLASSES,
    #     in_chans=3
    #     # Dropout не используется
    # ),
}


In [None]:
# import torch
# import gc
# from tqdm import tqdm

# def find_max_batch_size(
#     model_fn,
#     device='cuda',
#     image_size=(224, 224),
#     max_test=512,
#     num_classes=13,
#     step=4
# ):
#     print(f"\n=== Поиск максимального batch_size для {model_fn.__name__ if hasattr(model_fn, '__name__') else model_fn} ===")
#     batch_size = step
#     last_ok = 0
#     model = model_fn().to(device)
#     model.eval()
#     torch.cuda.empty_cache()
#     gc.collect()

#     tried = []
#     total_attempts = (max_test // step)
#     pbar = tqdm(total=total_attempts, desc='Подбор batch_size', ncols=100)
#     while batch_size <= max_test:
#         try:
#             dummy = torch.randn(batch_size, 3, image_size[0], image_size[1]).to(device)
#             with torch.no_grad():
#                 out = model(dummy)
#             last_ok = batch_size
#             tried.append(batch_size)
#             batch_size += step
#             del dummy, out
#             torch.cuda.empty_cache()
#             gc.collect()
#             pbar.update(1)
#         except RuntimeError as e:
#             pbar.close()
#             if 'out of memory' in str(e):
#                 print(f"\nOOM at batch_size={batch_size}. Last OK: {last_ok}")
#                 break
#             else:
#                 print(f"\nError at batch_size={batch_size}: {e}")
#                 break
#     pbar.close()
#     del model
#     torch.cuda.empty_cache()
#     gc.collect()
#     print(f"Максимальный batch_size: {last_ok} (для image_size={image_size})")
#     return last_ok

# # Пример: для всех моделей
# for name, model_fn in MODELS.items():
#     print(f"\n=== {name} ===")
#     max_bs = find_max_batch_size(model_fn, device='cuda', image_size=(224,224), step=8)
#     print(f"Max batch_size for {name}: {max_bs}\n")


## Metrics

In [None]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, top_k_accuracy_score
)
import numpy as np

def compute_classification_metrics(y_true, y_pred, y_prob, num_classes, top_k=3):
    metrics = {}
    metrics["accuracy"] = accuracy_score(y_true, y_pred)
    metrics["precision_macro"] = precision_score(y_true, y_pred, average='macro', zero_division=0)
    metrics["recall_macro"] = recall_score(y_true, y_pred, average='macro', zero_division=0)
    metrics["f1_macro"] = f1_score(y_true, y_pred, average='macro', zero_division=0)
    metrics["precision_micro"] = precision_score(y_true, y_pred, average='micro', zero_division=0)
    metrics["recall_micro"] = recall_score(y_true, y_pred, average='micro', zero_division=0)
    metrics["f1_micro"] = f1_score(y_true, y_pred, average='micro', zero_division=0)
    metrics["confusion_matrix"] = confusion_matrix(y_true, y_pred)
    if y_prob is not None:
        metrics["top1"] = top_k_accuracy_score(y_true, y_prob, k=1, labels=list(range(num_classes)))
        metrics["top3"] = top_k_accuracy_score(y_true, y_prob, k=min(3, num_classes), labels=list(range(num_classes)))
    return metrics


# Prepare data

## download_dataset

In [None]:
import requests
from tqdm import tqdm

import zipfile
from pathlib import Path

url = "https://www.kaggle.com/api/v1/datasets/download/nexuswho/aitex-fabric-image-database"
output_dir = Path("./aitex_data")
output_dir.mkdir(exist_ok=True)

zip_path = output_dir / "aitex.zip"

# Проверка на существование архива
if zip_path.exists():
    print(f"[INFO] Архив уже существует по пути: {zip_path}")
else:
    print(f"[INFO] Скачиваем архив из {url}...")
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(zip_path, "wb") as f:
            for chunk in tqdm(response.iter_content(chunk_size=8192)):
                f.write(chunk)
        print("[INFO] Скачивание завершено.")
    else:
        raise Exception(f"Ошибка при скачивании: статус {response.status_code}")

# Распаковка архива
extract_dir = output_dir / "extracted"
if not extract_dir.exists():
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_dir)
    print(f"[INFO] Архив успешно распакован в {extract_dir}")
else:
    print(f"[INFO] Архив уже был распакован в {extract_dir}")


## remove_image_without_masks¶

In [None]:
import os

def remove_images_without_masks(image_dir, mask_dir, image_suffix=".png", mask_suffix="_mask.png"):
    """
    Удаляет изображения, для которых отсутствует маска.
    """
    removed = 0
    for img_name in os.listdir(image_dir):
        if not img_name.endswith(image_suffix):
            continue
        base_name = os.path.splitext(img_name)[0]
        mask_name = base_name + mask_suffix
        mask_path = os.path.join(mask_dir, mask_name)
        img_path = os.path.join(image_dir, img_name)
        if not os.path.exists(mask_path):
            print(f"Удаляется {img_path} (маска {mask_name} не найдена)")
            os.remove(img_path)
            removed += 1
    print(f"Удалено изображений без масок: {removed}")

# Пример вызова:
remove_images_without_masks(
    image_dir="./aitex_data/extracted/Defect_images",
    mask_dir="./aitex_data/extracted/Mask_images"
)


import os
import random
import cv2
import matplotlib.pyplot as plt

def show_and_save_random_image_with_mask(
    image_dir, mask_dir,
    image_suffix=".png", mask_suffix="_mask.png",
    save_dir="./random_samples"
):
    """
    Выводит и сохраняет случайное изображение и маску.
    Картинки выводятся по вертикали (2 строки), сохраняются по отдельности.
    """
    os.makedirs(save_dir, exist_ok=True)
    image_files = [f for f in os.listdir(image_dir) if f.endswith(image_suffix)]
    if not image_files:
        print("Нет изображений в директории.")
        return

    img_name = random.choice(image_files)
    base_name = os.path.splitext(img_name)[0]
    mask_name = base_name + mask_suffix

    img_path = os.path.join(image_dir, img_name)
    mask_path = os.path.join(mask_dir, mask_name)

    if not os.path.exists(mask_path):
        print(f"Маска для {img_name} не найдена: {mask_path}")
        return

    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    # Визуализация по вертикали
    plt.figure(figsize=(6, 8))
    plt.subplot(2, 1, 1)
    plt.imshow(img_rgb)
    plt.title(f"Изображение: {img_name}")
    plt.axis('off')

    plt.subplot(2, 1, 2)
    plt.imshow(mask, cmap='gray')
    plt.title(f"Маска: {mask_name}")
    plt.axis('off')

    plt.tight_layout()
    plt.show()

    # Сохранение файлов отдельно
    img_save_path = os.path.join(save_dir, f"{base_name}_image.png")
    mask_save_path = os.path.join(save_dir, f"{base_name}_mask.png")
    cv2.imwrite(img_save_path, img)
    cv2.imwrite(mask_save_path, mask)
    print(f"Сохранено изображение: {img_save_path}")
    print(f"Сохранена маска: {mask_save_path}")

# Пример вызова:
show_and_save_random_image_with_mask(
    image_dir="./aitex_data/extracted/Defect_images",
    mask_dir="./aitex_data/extracted/Mask_images"
)



## dataset_stats

In [None]:
from pathlib import Path
import cv2
import numpy as np

SRC_MSK_DIR = Path("./aitex_data/extracted/Mask_images")

min_pixels = None
max_pixels = 0
pixels_list = []

for mask_path in SRC_MSK_DIR.glob("*.png"):
    msk = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
    if msk is None:
        continue
    num = int((msk > 0).sum())
    if num > 0:
        pixels_list.append(num)
        if min_pixels is None or num < min_pixels:
            min_pixels = num
        if num > max_pixels:
            max_pixels = num

print(f"Минимальное количество пикселей дефекта в одной маске: {min_pixels}")
print(f"Максимальное: {max_pixels}")
print(f"Медианное: {np.median(pixels_list)}")
print(f"Гистограмма по всем изображениям:")
import matplotlib.pyplot as plt
plt.hist(pixels_list, bins=30)
plt.xlabel("Дефектных пикселей на маске")
plt.ylabel("Частота")
plt.show()


# Preprocess data

## slice_to_patches

In [None]:
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm

# --- Метки дефектов ---
DEFECT_LABELS = {
    '000': 'No defect',
    '002': 'Broken end',
    '006': 'Broken yarn',
    '010': 'Broken pick',
    '016': 'Weft curling',
    '019': 'Fuzzyball',
    '022': 'Cut selvage',
    '023': 'Crease',
    '025': 'Warp ball',
    '027': 'Knots',
    '029': 'Contamination',
    '030': 'Nep',
    '036': 'Weft crack'
}

# --- Параметры нарезки ---
SRC_IMG_DIR = Path("./aitex_data/extracted/Defect_images")
SRC_MSK_DIR = Path("./aitex_data/extracted/Mask_images")
DST_IMG_DIR = Path("./aitex_patches/images")
DST_MSK_DIR = Path("./aitex_patches/masks")

PATCH_W = PATCH_H = 224
STRIDE_W = STRIDE_H = (256 - 224)
# MIN_DEFECT_FRAC = 0.005       # минимальная доля дефекта
MIN_DEFECT_PIXELS = 9 
KEEP_NEG = 0.05

DST_IMG_DIR.mkdir(parents=True, exist_ok=True)
DST_MSK_DIR.mkdir(parents=True, exist_ok=True)

rows = []
PATCH_AREA = PATCH_W * PATCH_H

def has_large_defect(mask, min_size=20):
    num_labels, _, stats, _ = cv2.connectedComponentsWithStats((mask > 0).astype(np.uint8))
    for i in range(1, num_labels):  # 0 — фон
        if stats[i, cv2.CC_STAT_AREA] >= min_size:
            return True
    return False

for img_path in tqdm(sorted(SRC_IMG_DIR.glob("*.png")), desc="Cropping AITEX (grid)"):
    mask_path = SRC_MSK_DIR / img_path.name.replace(".png", "_mask.png")
    img = cv2.imread(str(img_path))
    msk = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)

    if img is None or msk is None:
        print(f"❌  Пропуск {img_path.name} (файла нет)")
        continue

    msk_bin = (msk > 0).astype(np.uint8)
    orig_defect_code_str = img_path.stem.split('_')[1]
    orig_defect_code = int(orig_defect_code_str)
    orig_defect_label = DEFECT_LABELS.get(orig_defect_code_str, "Unknown")

    # ДВА ЦИКЛА: по y и по x
    for y in range(0, img.shape[0] - PATCH_H + 1, STRIDE_H):
        for x in range(0, img.shape[1] - PATCH_W + 1, STRIDE_W):
            img_crop = img[y:y+PATCH_H, x:x+PATCH_W]
            msk_crop = msk_bin[y:y+PATCH_H, x:x+PATCH_W]
            pos_pix = int(msk_crop.sum())
            defect_frac = pos_pix / PATCH_AREA

            # Комбинированный фильтр
            # is_defective = (defect_frac >= MIN_DEFECT_FRAC) or (pos_pix >= MIN_DEFECT_PIXELS)

            is_defective = pos_pix >= MIN_DEFECT_PIXELS

            
            patch_defect_code = orig_defect_code
            patch_defect_label = orig_defect_label

            # Если патч "чистый", переопределяем метку
            if not is_defective or not has_large_defect(msk_crop, min_size=20):
                if np.random.rand() > KEEP_NEG:
                    continue
                patch_defect_code = 0
                patch_defect_label = DEFECT_LABELS['000']

            suffix = f"x{x:04d}_y{y:04d}"
            fname  = f"{img_path.stem}_{suffix}.png"
            cv2.imwrite(str(DST_IMG_DIR / fname), img_crop)
            cv2.imwrite(str(DST_MSK_DIR / fname), msk_crop * 255)
            rows.append((fname, patch_defect_code, patch_defect_label))

# --- Сохраняем CSV с метками ---
label_path = Path("./aitex_patches/patch_labels.csv")
label_df = pd.DataFrame(rows, columns=["file", "defect_code", "defect_label"])
label_df.to_csv(label_path, index=False)
print("📝  Saved", len(rows), "patch labels →", label_path)
print("✅  Нарезка патчей завершена.")


## Undersampling

In [None]:
import pandas as pd

df = pd.read_csv('./aitex_patches/patch_labels.csv')
has_defect = df[df['defect_label'] != 'No defect']
no_defect  = df[df['defect_label'] == 'No defect']

print(f"Дефектных патчей: {len(has_defect)}")
print(f"Чистых патчей: {len(no_defect)}")

desired_ratio = 1.0  # например, 1:1
n_defect = len(has_defect)
n_no_defect = min(int(n_defect * desired_ratio), len(no_defect))

no_defect_sampled = no_defect.sample(n=n_no_defect, random_state=42)
df_balanced = pd.concat([has_defect, no_defect_sampled]).sample(frac=1, random_state=42)

balanced_label_path = './aitex_patches/patch_labels_balanced.csv'
df_balanced.to_csv(balanced_label_path, index=False)
print(f"Balanced CSV saved: {balanced_label_path}")

# --- Новый вывод ---
summary = df_balanced['defect_label'].value_counts().reset_index()
summary.columns = ['defect_label', 'num_patches']
summary['percentage'] = (summary['num_patches'] / summary['num_patches'].sum() * 100).round(2)

print("\n=== Patch distribution (balanced) ===")
print(summary.to_string(index=False))

PATCH_LABEL_PATH = balanced_label_path

## Visualize data after processing

In [None]:
import random
import matplotlib.pyplot as plt
from pathlib import Path
import cv2
import pandas as pd

def visualize_patches_with_masks_and_labels(
    img_dir,
    mask_dir,
    csv_path,
    min_defect_pixels=MIN_DEFECT_PIXELS,
    n_pos=6,
    n_neg=6
):
    """
    Показывает патчи с дефектом и без дефекта:
    - Оригинал (RGB)
    - Маска (отдельно)
    - Наложение маски (Mask Overlay)
    В заголовке — статус (DEFECT/CLEAN) и класс дефекта.
    """
    df = pd.read_csv(csv_path)
    pos_samples, neg_samples = [], []

    for _, row in df.iterrows():
        mask_path = Path(mask_dir) / row['file']
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            continue
        defect_pixels = (mask > 0).sum()
        info = (row['file'], row['defect_label'])
        if defect_pixels >= min_defect_pixels:
            pos_samples.append(info)
        else:
            neg_samples.append(info)

    pos_samples = random.sample(pos_samples, min(n_pos, len(pos_samples)))
    neg_samples = random.sample(neg_samples, min(n_neg, len(neg_samples)))
    all_samples = [(fname, "DEFECT", label) for fname, label in pos_samples] + \
                  [(fname, "CLEAN", label) for fname, label in neg_samples]

    plt.figure(figsize=(len(all_samples) * 4, 10))
    for i, (fname, status, defect_label) in enumerate(all_samples):
        img = cv2.imread(str(Path(img_dir) / fname))
        mask = cv2.imread(str(Path(mask_dir) / fname), cv2.IMREAD_GRAYSCALE)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 1. Исходное изображение
        plt.subplot(3, len(all_samples), i + 1)
        plt.imshow(img_rgb)
        plt.title(f"{status}\n{defect_label}\n{fname}", fontsize=9)
        plt.axis('off')

        # 2. Маска (отдельно)
        plt.subplot(3, len(all_samples), len(all_samples) + i + 1)
        plt.imshow(mask, cmap='gray')
        plt.title("Mask", fontsize=9)
        plt.axis('off')

        # 3. Наложение маски
        plt.subplot(3, len(all_samples), 2 * len(all_samples) + i + 1)
        plt.imshow(img_rgb)
        plt.imshow(mask, cmap='Reds', alpha=0.5)
        plt.title("Overlay", fontsize=9)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# --- Запуск визуализации ---
visualize_patches_with_masks_and_labels(
    img_dir=DST_IMG_DIR,
    mask_dir=DST_MSK_DIR,
    csv_path=PATCH_LABEL_PATH,
    min_defect_pixels=MIN_DEFECT_PIXELS,  # ключевое отличие!
    n_pos=3,
    n_neg=1
)


In [None]:
import pandas as pd

# Пути к данным
PATCH_LABEL_PATH = './aitex_patches/patch_labels_balanced.csv'  # или твой итоговый путь

# Чтение меток
df = pd.read_csv(PATCH_LABEL_PATH)
print(df.head())


# Create Dataset

## augmenations

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

def get_strong_classification_augmentations(image_size=(256, 256)):
    """
    Сильные аугментации для train.
    Устраняем предупреждения, оставляя только поддерживаемые параметры.
    """
    ops = [
        A.Resize(int(image_size[0]*1.1), int(image_size[1]*1.1)),
        A.RandomCrop(*image_size, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Affine(
            scale=(0.85, 1.15),
            translate_percent=0.15,
            rotate=(-30, 30),
            shear=(-12, 12),
            p=0.7
        ),
        A.ElasticTransform(p=0.25),  # убраны неподдерживаемые alpha_affine
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.2),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
        A.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=20, val_shift_limit=15, p=0.4),
        A.RandomGamma(gamma_limit=(80, 120), p=0.3),
        A.GaussNoise(p=0.4),          # var_limit по умолчанию
        A.GaussianBlur(blur_limit=(3, 9), p=0.3),
        A.CoarseDropout(p=0.4),        # default параметры
    ]

    # Добавляем GridMask, если доступен (albumentations>=1.2.0)
    if hasattr(A, 'GridMask'):
        ops.append(A.GridMask(num_grid=(3, 7), rotate=15, p=0.3))

    # Нормализация и перевод в тензор
    ops.extend([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

    return A.Compose(ops)


def get_val_classification_augmentations(image_size=(256, 256)):
    """
    Лёгкие аугментации для валидации/теста.
    """
    return A.Compose([
        A.Resize(*image_size),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])


## dataset_class

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

# --- твой кастомный класс PatchDataset с copy-paste ---
import random
import cv2
from torch.utils.data import Dataset

class PatchDataset(Dataset):
    """
    Классификационный датасет патчей с class-aware Copy-Paste.
    Возвращает: (img_tensor, class_idx)
    """
    def __init__(
        self,
        df: pd.DataFrame,
        img_dir: str,
        transform,
        copy_paste_prob: float = 0.3,
        class_copy_paste: bool = True,
        random_state: int = 42
    ):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.transform = transform
        self.copy_paste_prob = copy_paste_prob
        self.class_copy_paste = class_copy_paste
        self.rng = np.random.RandomState(random_state)

        # mapping label ↔ idx
        labels = sorted(self.df['defect_label'].unique())
        self.label2idx = {lbl: i for i, lbl in enumerate(labels)}
        self.idx2label = {i: lbl for lbl, i in self.label2idx.items()}

        # pool дефектных файлов и подсчет дефицита до медианы
        codes = self.df['defect_code'].values
        files = self.df['file'].values
        defect_mask = codes != 0
        defect_codes = codes[defect_mask]
        defect_files = files[defect_mask]

        cnt = pd.Series(defect_codes).value_counts().to_dict()
        median = int(np.median(list(cnt.values()))) or 1
        deficit = {c: max(0, median - n) for c, n in cnt.items()}
        total_def = sum(deficit.values())
        self.class_probs = {c: d/total_def for c, d in deficit.items()} if total_def>0 else {}
        self.defect_pool = list(zip(defect_files, defect_codes))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        fname = row['file']
        code = int(row['defect_code'])
        img = cv2.cvtColor(cv2.imread(str(self.img_dir / fname)), cv2.COLOR_BGR2RGB)

        # Copy-Paste для чистых патчей
        if code == 0 and self.defect_pool and self.rng.rand() < self.copy_paste_prob:
            # выбор класса-донор
            if self.class_copy_paste and self.class_probs:
                codes, probs = zip(*self.class_probs.items())
                sel_code = int(self.rng.choice(codes, p=probs))
                candidates = [f for f, c in self.defect_pool if c == sel_code]
            else:
                candidates = [f for f, _ in self.defect_pool]
            donor = self.rng.choice(candidates)
            donor_img = cv2.cvtColor(cv2.imread(str(self.img_dir / donor)), cv2.COLOR_BGR2RGB)
            # простой маскированный встав: по пиксельному отличию
            mask = (donor_img != img).any(axis=2).astype(np.uint8)
            for ch in range(3):
                img[:, :, ch] = donor_img[:, :, ch] * mask + img[:, :, ch] * (1-mask)
            code = int([c for f, c in self.defect_pool if f == donor][0])

        # применение аугментаций и нормализация
        img = self.transform(image=img)['image']
        label = self.label2idx[row['defect_label']]
        return img, label



## dataset creation helpers

In [None]:
def upsample_train_to_median(train_df, label_col='defect_code', random_state=42):
    counts = train_df[label_col].value_counts()
    median = int(counts.median())
    upsampled = []
    for label, cnt in counts.items():
        df_label = train_df[train_df[label_col] == label]
        if cnt < median:
            n_more = median - cnt
            sampled = df_label.sample(n=n_more, replace=True, random_state=random_state)
            upsampled.append(sampled)
    if upsampled:
        train_df = pd.concat([train_df] + upsampled).sample(frac=1, random_state=random_state).reset_index(drop=True)
    return train_df


In [None]:

# --- prepare_datasets для классификации патчей ---
def prepare_datasets(
    patch_label_path: str,
    img_dir: str,
    test_size: float = 0.05,
    val_size: float = 0.1,
    batch_size: int = 16,
    num_workers: int = 4,
    random_state: int = 42,
    image_size: tuple = (224, 224),
    train_aug_fn=None,
    val_aug_fn=None,
    copy_paste_prob: float = 0.3
) -> tuple:
    """
    Возвращает train_ds, val_ds, test_ds (PatchDataset) с copy-paste для train.
    """
    df = pd.read_csv(patch_label_path)

    # stratified split
    train_files, test_files = train_test_split(
        df['file'], test_size=test_size,
        stratify=df['defect_code'], random_state=random_state
    )
    train_files, val_files = train_test_split(
        train_files, test_size=val_size,
        stratify=df[df['file'].isin(train_files)]['defect_code'],
        random_state=random_state
    )
    train_df = df[df['file'].isin(train_files)].reset_index(drop=True)
    val_df = df[df['file'].isin(val_files)].reset_index(drop=True)
    test_df = df[df['file'].isin(test_files)].reset_index(drop=True)

    # upsamples до медианы
    # train_df = upsample_train_to_median(train_df, label_col='defect_code', random_state=random_state)
    print("Train class distribution")
    print(train_df['defect_code'].value_counts())

    # transforms
    train_transform = train_aug_fn or get_strong_classification_augmentations(image_size)
    val_transform = val_aug_fn or get_val_classification_augmentations(image_size)

    # datasets
    train_ds = PatchDataset(
        train_df, img_dir,
        transform=train_transform,
        copy_paste_prob=copy_paste_prob,
        class_copy_paste=True,
        random_state=random_state
    )
    val_ds = PatchDataset(
        val_df, img_dir,
        transform=val_transform,
        copy_paste_prob=0.0
    )
    test_ds = PatchDataset(
        test_df, img_dir,
        transform=val_transform,
        copy_paste_prob=0.0
    )

    return train_ds, val_ds, test_ds


## dataloader creation helper

In [None]:

def get_classification_dataloaders(
    train_ds, val_ds, test_ds,
    batch_size: int = 16,
    num_workers: int = 4
) -> tuple:
    """
    Оборачивает датасеты в DataLoader'ы с правильными параметрами.
    """
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    return train_loader, val_loader, test_loader



## Initialize dataloaders

In [None]:

# === Пример использования ===
PATCH_LABEL_PATH = './aitex_patches/patch_labels_balanced.csv'
IMG_DIR = './aitex_patches/images'
MSK_DIR = './aitex_patches/masks'

train_ds, val_ds, test_ds = prepare_datasets(
    patch_label_path=PATCH_LABEL_PATH,
    img_dir=IMG_DIR,
    # msk_dir=MSK_DIR,
    test_size=0.05,
    val_size=0.1,
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    random_state=42,
    image_size=(PATCH_H, PATCH_W),
    train_aug_fn=get_strong_classification_augmentations((PATCH_H, PATCH_W)),
    val_aug_fn=get_val_classification_augmentations((PATCH_H, PATCH_W)),
    copy_paste_prob=0.8
)

print("label2idx:", train_ds.label2idx)
print("idx2label:", train_ds.idx2label)
print("Всего классов:", len(train_ds.label2idx))

train_loader, val_loader, test_loader = get_classification_dataloaders(
    train_ds, val_ds, test_ds,
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers']
)


In [None]:

import numpy as np
import torch
from collections import Counter
import torch.nn as nn

def get_weighted_loss(train_ds, smoothing=0.1, device='cuda'):
    # Подсчитываем метки из train_ds
    train_labels = [label for _, label in train_ds]

    # Подсчёт количества каждого класса
    class_counts = Counter(train_labels)
    num_classes = len(train_ds.label2idx)
    total_samples = len(train_labels)

    # Подсчитываем веса (обратная частота)
    class_weights = np.zeros(num_classes)
    for cls_idx in range(num_classes):
        cls_count = class_counts.get(cls_idx, 0)
        class_weights[cls_idx] = total_samples / (num_classes * cls_count)

    # Нормализация весов
    class_weights = class_weights / class_weights.sum() * num_classes
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

    print(f"Class weights: {class_weights}")

    # Создаём взвешенную CrossEntropyLoss с label smoothing
    loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=smoothing)
    return loss_fn

import pandas as pd

def print_class_distribution_from_dataset(dataset, label_names=None, title="Train set"):
    """
    Выводит таблицу распределения патчей по классам для кастомного датасета (классификация или сегментация).
    - dataset: экземпляр Dataset (например, PatchClassificationDataset или AITEXPatchDataset)
    - label_names: dict для красивых названий классов {idx: name} или {code: name}
    - title: заголовок для вывода
    """
    labels = []
    for i in range(len(dataset)):
        try:
            # Классификация (img, label) или (img, mask, label)
            if hasattr(dataset, "idx2label"):
                # Индекс → строка класса
                _, label = dataset[i][:2]
                class_label = dataset.idx2label[label]
            else:
                # Например, (img, mask, code, label)
                *_, label = dataset[i]
                class_label = label
        except Exception:
            # Любой fallback
            class_label = "unknown"
        labels.append(class_label)

    df = pd.DataFrame({'class_label': labels})
    summary = df['class_label'].value_counts().reset_index()
    summary.columns = ['class_label', 'num_patches']
    summary['percentage'] = (summary['num_patches'] / summary['num_patches'].sum() * 100).round(2)
    if label_names:
        summary['class_label'] = summary['class_label'].map(label_names).fillna(summary['class_label'])
    print(f"\n=== Patch distribution in {title} ===")
    print(summary.to_string(index=False))


print_class_distribution_from_dataset(train_ds, label_names=DEFECT_LABELS, title="Train set")

## visualize dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

def visualize_classification_dataset_grid(dataset, num_samples=6, label_names=None):
    """
    Визуализация классификационного датасета:
    - Горизонтально — разные патчи/изображения (num_samples)
    - Только 1 ряд: изображение
    - В заголовке: код и (при наличии) красивый лейбл класса
    """
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    plt.figure(figsize=(num_samples * 4, 4))

    for col, idx in enumerate(indices):
        sample = dataset[idx]
        # Ожидается (img, class_code, class_label) или (img, class_code)
        if len(sample) == 3:
            image, code, label = sample
        elif len(sample) == 2:
            image, code = sample
            label = str(code)
        else:
            image = sample[0]
            code = None
            label = "unknown"

        # Красивое имя класса
        if label_names is not None:
            class_name = label_names.get(str(code), str(label))
        else:
            class_name = str(label)

        title = f"Code: {code}\nLabel: {class_name}"

        # Денормализация если тензор
        if isinstance(image, torch.Tensor):
            rgb = image[:3].permute(1, 2, 0).cpu().numpy()
            rgb = (rgb * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
        else:
            rgb = image

        plt.subplot(1, num_samples, col + 1)
        plt.imshow(rgb)
        plt.title(title, fontsize=11)
        plt.axis("off")

    plt.tight_layout()
    plt.show()

visualize_classification_dataset_grid(train_ds, num_samples=8, label_names=DEFECT_LABELS)

# Training

## libraries

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
from timm.data.mixup import Mixup

# --- Настройка MixUp + CutMix ---
mixup_fn = Mixup(
    mixup_alpha=0.4,        # "смешивание" — типовое значение
    cutmix_alpha=1.0,       # CutMix тоже включён (обычно 0.5-1.0, можно поиграться)
    label_smoothing=0.1,    # обязательно, если у тебя и так стоит — можно чуть снизить
    num_classes=NUM_CLASSES
)


#### train_epoch

In [None]:
import numpy as np
import torch
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from timm.utils import ModelEmaV2


def train_epoch(
    model, train_loader, optimizer, loss_fn, scaler,
    device, epoch, model_name, num_classes,
    mixup_fn=None, ema=None
):
    """
    Одна эпоха обучения с опциональным MixUp/CutMix и обновлением EMA.
    Возвращает метрики для train.
    """
    model.train()
    running_loss = 0.0
    all_preds, all_probs, all_targets = [], [], []

    for step, batch in enumerate(tqdm(train_loader, desc=f"Train {epoch}")):
        images, labels = batch[:2]
        images = images.to(device)
        labels = labels.to(device)

        # MIXUP / CUTMIX
        if mixup_fn:
            images, labels = mixup_fn(images, labels)

        optimizer.zero_grad()
        with autocast(enabled=(device.type=='cuda')):
            outputs = model(images)
            loss = loss_fn(outputs, labels)

        # backward + step
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        # EMA update
        if ema is not None:
            ema.update(model)

        running_loss += loss.item()

        # accumulate predictions
        probs = torch.softmax(outputs.detach(), dim=1).cpu().numpy()
        preds = np.argmax(probs, axis=1)
        all_preds.append(preds)
        all_probs.append(probs)
        if labels.dtype == torch.float32:
            all_targets.append(labels.argmax(dim=1).cpu().numpy())
        else:
            all_targets.append(labels.detach().cpu().numpy())

    y_true = np.concatenate(all_targets)
    y_pred = np.concatenate(all_preds)
    y_prob = np.concatenate(all_probs)

    metrics = compute_classification_metrics(y_true, y_pred, y_prob, num_classes)
    metrics['loss'] = running_loss / len(train_loader)
    print(f"[{model_name}] Train epoch {epoch}: {metrics}")
    return metrics


#### validate_epoch

In [None]:


def validate_epoch(
    model, val_loader, device, epoch, model_name, num_classes
):
    """
    Одна эпоха валидации на переданной модели (обычно EMA).
    """
    model.eval()
    running_loss = 0.0
    all_preds, all_probs, all_targets = [], [], []
    loss_fn = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for step, batch in enumerate(tqdm(val_loader, desc=f"Val  {epoch}")):
            images, labels = batch[:2]
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            running_loss += loss.item()

            probs = torch.softmax(outputs.cpu(), dim=1).numpy()
            preds = np.argmax(probs, axis=1)
            all_preds.append(preds)
            all_probs.append(probs)
            all_targets.append(labels.cpu().numpy())

    y_true = np.concatenate(all_targets)
    y_pred = np.concatenate(all_preds)
    y_prob = np.concatenate(all_probs)

    metrics = compute_classification_metrics(y_true, y_pred, y_prob, num_classes)
    metrics['loss'] = running_loss / len(val_loader)
    print(f"[{model_name}] Val epoch {epoch}: {metrics}")
    return metrics



### train_loop

In [None]:

def train_loop(
    model, optimizer, scheduler,
    train_loader, val_loader,
    loss_fn, device, scaler,
    num_classes,
    num_epochs=50, early_stop_patience=5,
    model_name='model', mixup_fn=None,
    ema_decay: float = 0.9999
):
    """
    Основная петля обучения с EMA, scheduler и early-stopping по macro-F1.
    Возвращает истории train и val.
    """
    # Инициализация EMA
    ema = ModelEmaV2(model, decay=ema_decay, device=device)
    best_f1 = 0.0
    counter = 0
    train_history, val_history = [], []

    for epoch in range(num_epochs):
        # train
        train_metrics = train_epoch(
            model, train_loader, optimizer, loss_fn, scaler,
            device, epoch, model_name, num_classes,
            mixup_fn=mixup_fn, ema=ema
        )
        # val на EMA-модели
        ema_model = ema.module
        val_metrics = validate_epoch(
            ema_model, val_loader, device, epoch, model_name, num_classes
        )

        scheduler.step(val_metrics['f1_macro'])
        train_history.append(train_metrics)
        val_history.append(val_metrics)

        # early stopping + сохранение лучших
        if val_metrics['f1_macro'] > best_f1:
            best_f1 = val_metrics['f1_macro']
            counter = 0
            torch.save(model.state_dict(), f"best_{model_name}.pth")
            ema_state = ema.state_dict()
            torch.save(ema_state, f"ema_{model_name}.pth")
        else:
            counter += 1
            if counter >= early_stop_patience:
                print(f"Early stopping at epoch {epoch}")
                break
        print(f"Epoch {epoch}: Train F1(macro)={train_metrics['f1_macro']:.4f}  Val F1(macro)={val_metrics['f1_macro']:.4f}")

    print(f"Best Val F1(macro): {best_f1:.4f}")
    return train_history, val_history


### Plot metrics

In [None]:
import math
import matplotlib.pyplot as plt

def plot_metrics(train_history, val_history, model_name, save_path=None):
    keys = ['loss', 'accuracy', 'precision_macro', 'recall_macro', 'f1_macro', 'top1', 'top3']
    n_keys = len(keys)
    n_cols = 3  # Можно сделать 4 — будет более растянуто
    n_rows = math.ceil(n_keys / n_cols)
    plt.figure(figsize=(n_cols * 6, n_rows * 4))
    for idx, key in enumerate(keys, 1):
        plt.subplot(n_rows, n_cols, idx)
        train_vals = [m.get(key, 0) for m in train_history]
        val_vals = [m.get(key, 0) for m in val_history]
        plt.plot(train_vals, label=f"Train {key}")
        plt.plot(val_vals, label=f"Val {key}")
        plt.title(f"{model_name}: {key}")
        plt.xlabel("Epoch")
        plt.legend()
        plt.grid()
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.show()


In [None]:


def train_single_model(
    model_name, model_fn,
    train_loader, val_loader,
    config, num_classes,
    mixup_fn=None
):
    """
    Тренирует модель с EMA, OneCycleLR и weighted loss + smoothing.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model_fn().to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=config['learning_rate'],
        total_steps=len(train_loader)*config['num_epochs'],
        pct_start=0.1,
        anneal_strategy='cos',
        div_factor=25.0
    )
    # Взвешенный loss + smoothing
    loss_fn = get_weighted_loss(train_ds, smoothing=config.get('smoothing',0.1), device=device)
    scaler = GradScaler(enabled=(device.type=='cuda'))

    return train_loop(
        model, optimizer, scheduler,
        train_loader, val_loader,
        loss_fn, device, scaler,
        num_classes,
        num_epochs=config['num_epochs'],
        early_stop_patience=config['early_stop_patience'],
        model_name=model_name,
        mixup_fn=mixup_fn,
        ema_decay=config.get('ema_decay',0.9999)
    )


## train_model

## run training 

In [None]:
import gc

def train_all_models(models_dict, train_loader, val_loader, config, num_classes, mixup_fn=None):
    for name, model_fn in models_dict.items():
        print(f"==== Training model: {name} ====")
        torch.cuda.empty_cache(); gc.collect()
        train_hist, val_hist = train_single_model(
            name, model_fn, train_loader, val_loader, config, num_classes, mixup_fn=mixup_fn
        )
        plot_metrics(train_hist, val_hist, model_name=name, save_path=f"metrics_{name}.png")
        torch.cuda.empty_cache(); gc.collect()

train_all_models(MODELS, train_loader, val_loader, CONFIG, NUM_CLASSES, mixup_fn=mixup_fn)

## Run training

# Test

### tta_predict

In [None]:
import torch

def tta_predict_classification(model, images):
    """
    Применяет TTA для классификации: оригинал + горизонтальное отражение.
    Возвращает усреднённые logits.
    """
    model.eval()
    with torch.no_grad():
        orig = model(images)
        flip_imgs = torch.flip(images, dims=[3])
        flip_preds = model(flip_imgs)
        # НЕ надо обратно отражать flip_preds для классификации!
        # Просто усредняем вероятности/логиты
        return (orig + flip_preds) / 2


### visualize_pred

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_classification_predictions_tta(
    model,
    dataloader,
    device,
    label_names=None,
    num_samples=10,
    imagenet_norm=True,
    model_title=None
):
    """
    Визуализирует несколько случайных примеров из даталоадера:
    1. Исходное изображение
    2. Истинная метка (GT)
    3. Предсказанная метка (TTA)
    """
    model = model.to(device)
    model.eval()
    shown = 0
    images_list = []
    gt_labels_list = []
    pred_labels_list = []
    titles = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.cpu().numpy()

            # Получаем предсказания c TTA
            logits = tta_predict_classification(model, images)
            probs = torch.softmax(logits, dim=1).cpu().numpy()
            preds = np.argmax(probs, axis=1)

            batch_size = images.size(0)
            for i in range(batch_size):
                if shown >= num_samples:
                    break

                img = images[i].cpu().permute(1, 2, 0).numpy()
                if imagenet_norm:
                    img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
                gt_idx = labels[i]
                pred_idx = preds[i]
                gt_name = label_names.get(gt_idx, str(gt_idx)) if label_names else str(gt_idx)
                pred_name = label_names.get(pred_idx, str(pred_idx)) if label_names else str(pred_idx)

                title = f"GT: {gt_name}\nPred: {pred_name}"
                images_list.append(img)
                titles.append(title)

                shown += 1
            if shown >= num_samples:
                break

    # Визуализация
    plt.figure(figsize=(num_samples * 3, 4))
    if model_title is not None:
        plt.suptitle(model_title, fontsize=16, y=1.08)
    for col in range(shown):
        plt.subplot(1, num_samples, col + 1)
        plt.imshow(images_list[col])
        plt.title(titles[col], fontsize=10)
        plt.axis("off")
    plt.tight_layout()
    plt.show()


In [None]:
import gc

idx2label = train_ds.idx2label  # Или свой словарь {int: str}

for model_name in MODELS:
    checkpoint_path = f'best_{model_name}.pth'
    print(f"\n--- Model: {model_name} ---")
    try:
        test_model = MODELS[model_name]().to(device)
        test_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        test_model.eval()
        
        visualize_classification_predictions_tta(
            test_model,
            dataloader=test_loader,
            device=device,
            label_names=idx2label,
            num_samples=5,
            model_title=f"Model: {model_name}"
        )
    except Exception as e:
        print(f"❌ Could not evaluate {model_name}: {e}")
    finally:
        del test_model
        torch.cuda.empty_cache()
        gc.collect()


## Sanity checks

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

def classification_sanity_check(
    train_ds, MODELS, device, target_label=None, num_epochs=20, N=8, idx2label=None
):
    """
    Sanity check: быстрое переобучение на одном патче (классификация).
    Если target_label не задан — берёт первый попавшийся патч.
    Если задан (например, "No defect" или "Broken end"), берёт первый патч этого класса.
    """
    # 1. Найти патч с нужным классом
    idx = None
    for i in range(len(train_ds)):
        img, label = train_ds[i]
        if (target_label is None) or (idx2label and idx2label[label] == target_label):
            idx = i
            break
    if idx is None:
        raise RuntimeError("Не найден подходящий патч для sanity check!")
    img, label = train_ds[idx]

    print(f"Sanity check: патч класса '{idx2label[label] if idx2label else label}' (индекс {label})")

    # 2. Делаем датасет из N копий этого патча
    single_ds = [(img, label)] * N
    single_loader = DataLoader(single_ds, batch_size=N, shuffle=True)

    # 3. Проверяем все модели
    for model_name, model_fn in MODELS.items():
        print(f"\n==== Sanity check: {model_name} ====")
        model = model_fn().to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
        loss_fn = torch.nn.CrossEntropyLoss()
        losses, accs = [], []

        for epoch in range(num_epochs):
            model.train()
            for img_batch, label_batch in single_loader:
                img_batch = img_batch.to(device)
                label_batch = label_batch.to(device)
                optimizer.zero_grad()
                output = model(img_batch)
                loss = loss_fn(output, label_batch)
                loss.backward()
                optimizer.step()

            # Проверка: точность и динамика
            model.eval()
            with torch.no_grad():
                output = model(img_batch)
                probs = torch.softmax(output, dim=1)
                pred_class = probs.argmax(dim=1).cpu().numpy()
                gt_class = label_batch.cpu().numpy()
                acc = (pred_class == gt_class).mean()
            losses.append(loss.item())
            accs.append(acc)
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Acc = {acc:.3f}, GT: {idx2label[gt_class[0]] if idx2label else gt_class[0]}, Pred: {idx2label[pred_class[0]] if idx2label else pred_class[0]}")

        # Визуализация loss/accuracy
        fig, axs = plt.subplots(1, 2, figsize=(10,4))
        axs[0].plot(losses, label="Loss")
        axs[0].set_title(f"{model_name} Loss")
        axs[1].plot(accs, label="Acc")
        axs[1].set_title(f"{model_name} Accuracy")
        for ax in axs: ax.grid(); ax.set_xlabel('Epoch')
        plt.suptitle(f"Sanity Check: {model_name} — класс '{idx2label[label] if idx2label else label}'")
        plt.show()

        # Визуализация самого патча
        img_np = img[:3].permute(1,2,0).cpu().numpy() if isinstance(img, torch.Tensor) else img
        img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0,1)
        plt.figure(figsize=(3,3))
        plt.imshow(img_np)
        plt.title(f"Class: {idx2label[label] if idx2label else label}")
        plt.axis("off")
        plt.show()

# === Пример запуска ===
# Для любого класса (первый попавшийся):
classification_sanity_check(
    train_ds, MODELS, device,
    target_label=None,     # или например "No defect"
    num_epochs=20, N=8,
    idx2label=train_ds.idx2label
)

# Для "чистого" патча (если нужно):
classification_sanity_check(
    train_ds, MODELS, device,
    target_label="No defect",
    num_epochs=20, N=8,
    idx2label=train_ds.idx2label
)


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

def classification_sanity_check(train_ds, MODELS, device, target_label=None, num_epochs=20, N=8):
    """
    Sanity check: модель должна зафититься на одном и том же патче (классификация).
    target_label: если задан, ищет первый патч с этим label.
                  Например, для дефектного: target_label != 'No defect'
                  Для чистого: target_label == 'No defect'
    """
    # 1. Найти патч по критерию
    idx = None
    for i in range(len(train_ds)):
        img, label = train_ds[i]
        if target_label is None or train_ds.idx2label[label] == target_label:
            idx = i
            break
    if idx is None:
        raise RuntimeError("Не найден подходящий патч для sanity check!")
    img, label = train_ds[idx]

    # 2. Делаем датасет из N копий одного патча
    single_ds = [(img, label)] * N
    single_loader = DataLoader(single_ds, batch_size=N, shuffle=True)

    for model_name, model_fn in MODELS.items():
        print(f"\n==== Sanity check: {model_name} ====")
        model = model_fn().to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
        loss_fn = torch.nn.CrossEntropyLoss()

        for epoch in range(num_epochs):
            model.train()
            for img_batch, label_batch in single_loader:
                img_batch = img_batch.to(device)
                label_batch = label_batch.to(device)
                optimizer.zero_grad()
                output = model(img_batch)
                loss = loss_fn(output, label_batch)
                loss.backward()
                optimizer.step()

            # Проверка
            model.eval()
            with torch.no_grad():
                output = model(img_batch)
                probs = torch.softmax(output, dim=1)
                pred_class = probs.argmax(dim=1).cpu().numpy()
                gt_class = label_batch.cpu().numpy()
                acc = (pred_class == gt_class).mean()
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Acc = {acc:.3f}, GT: {train_ds.idx2label[gt_class[0]]}, Pred: {train_ds.idx2label[pred_class[0]]}")

        # Визуализация патча
        img_np = img[:3].permute(1,2,0).cpu().numpy() if isinstance(img, torch.Tensor) else img
        img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0,1)
        plt.figure(figsize=(4, 4))
        plt.imshow(img_np)
        plt.title(f"Sanity Check: {model_name}\nClass: {train_ds.idx2label[label]}")
        plt.axis("off")
        plt.show()


In [None]:
classification_sanity_check(train_ds, MODELS, device, target_label='Broken end', num_epochs=20, N=8)
# (или target_label != 'No defect', если хочется подобрать именно дефект)


In [None]:
classification_sanity_check(train_ds, MODELS, device, target_label='No defect', num_epochs=20, N=8)


# Inference

In [None]:
import numpy as np
import torch
from pathlib import Path

def infer_full_image_classification_with_models(
    image,
    models_dict,
    preprocess,
    patch_h=224,
    patch_w=224,
    stride_h=64,
    stride_w=64,
    device='cuda',
    model_names=None,
    config=None
):
    """
    Для каждого патча исходного изображения получить предсказания классов от всех моделей.
    Возвращает dict: {model_name: 2D array of class indices}, а также coords всех патчей.
    """
    H, W, C = image.shape
    if model_names is None:
        model_names = list(models_dict.keys())
    batch_size = config["batch_size"] if config and "batch_size" in config else 8

    # Для каждого патча храним координаты и patch
    patch_coords = []
    patches = []
    for y in range(0, H - patch_h + 1, stride_h):
        for x in range(0, W - patch_w + 1, stride_w):
            patch = image[y:y+patch_h, x:x+patch_w]
            patches.append(patch)
            patch_coords.append((y, x))
    patch_tensors = [preprocess(p) for p in patches]
    patch_batch = torch.stack(patch_tensors)  # [N, C, H, W]

    # Словарь: для каждой модели — список предсказанных классов для каждого патча
    results_dict = {}
    for model_name in model_names:
        torch.cuda.empty_cache()
        model = models_dict[model_name]().to(device)
        model.eval()
        try:
            model.load_state_dict(torch.load(f'best_{model_name}.pth', map_location=device))
        except Exception as e:
            print(f"[{model_name}] checkpoint not loaded: {e}")

        preds = []
        with torch.no_grad():
            for i in range(0, len(patch_batch), batch_size):
                batch = patch_batch[i:i+batch_size].to(device)
                logits = model(batch)
                probs = torch.softmax(logits, dim=1)
                pred_class = probs.argmax(dim=1).cpu().numpy()
                preds.append(pred_class)
                del batch, logits, probs
                torch.cuda.empty_cache()
        preds = np.concatenate(preds, axis=0)  # [num_patches]

        # Собираем предсказанную "карту" классов (кластеризуем патчи обратно в 2D-матрицу)
        map_H = (H - patch_h) // stride_h + 1
        map_W = (W - patch_w) // stride_w + 1
        class_map = preds.reshape(map_H, map_W)
        results_dict[model_name] = class_map
        del model
        torch.cuda.empty_cache()

    del patch_batch, patch_tensors, patches
    torch.cuda.empty_cache()
    return results_dict, patch_coords, (map_H, map_W)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

def show_classification_map(
    image, class_map, patch_h, patch_w, stride_h, stride_w, idx2label=None, model_name="Model"
):
    """
    Визуализация исходного изображения и карты предсказанных классов (псевдоцвет).
    """
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"{model_name}: Original")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    # "разворачиваем" карту в исходные координаты для наглядности
    H, W = image.shape[:2]
    map_H, map_W = class_map.shape
    vis_map = np.zeros((H, W), dtype=np.int32)
    for i in range(map_H):
        for j in range(map_W):
            y = i * stride_h
            x = j * stride_w
            vis_map[y:y+patch_h, x:x+patch_w] = class_map[i, j]
    cmap = plt.get_cmap('tab20' if class_map.max() < 20 else 'nipy_spectral')
    im = plt.imshow(vis_map, cmap=cmap, vmin=0, vmax=class_map.max())
    plt.title(f"{model_name}: Predicted Classes Map")
    plt.axis('off')
    cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
    if idx2label:
        labels = [idx2label[i] for i in range(class_map.max() + 1)]
        cbar.set_ticks(range(len(labels)))
        cbar.set_ticklabels(labels)
    plt.tight_layout()
    plt.show()

def save_classification_map_image(
    image, class_map, patch_h, patch_w, stride_h, stride_w, idx2label=None,
    model_name="Model", save_dir="./inference_results", img_id=0
):
    """
    Сохраняет визуализацию карты классов, наложенной на исходное изображение.
    """
    os.makedirs(save_dir, exist_ok=True)
    H, W = image.shape[:2]
    map_H, map_W = class_map.shape
    vis_map = np.zeros((H, W), dtype=np.int32)
    for i in range(map_H):
        for j in range(map_W):
            y = i * stride_h
            x = j * stride_w
            vis_map[y:y+patch_h, x:x+patch_w] = class_map[i, j]
    cmap = plt.get_cmap('tab20' if class_map.max() < 20 else 'nipy_spectral')

    # Сохраняем как RGB (цветную карту классов)
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"{model_name}: Original")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    im = plt.imshow(vis_map, cmap=cmap, vmin=0, vmax=class_map.max())
    plt.title(f"{model_name}: Predicted Classes Map")
    plt.axis('off')
    cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
    if idx2label:
        labels = [idx2label[i] for i in range(class_map.max() + 1)]
        cbar.set_ticks(range(len(labels)))
        cbar.set_ticklabels(labels)
    plt.tight_layout()
    fname = os.path.join(save_dir, f"{model_name}_classmap_{img_id}.png")
    plt.savefig(fname, bbox_inches='tight', pad_inches=0.1)
    plt.close()
    print(f"Saved: {fname}")


In [None]:
# --- Препроцессинг патчей ---
preprocess_albu = A.Compose([
    A.Resize(PATCH_H, PATCH_W),  # 224, 224
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])
def preprocess_patch(patch):
    return preprocess_albu(image=patch)['image']

# --- Выбор изображения ---
IMG_DIR = Path('./aitex_data/extracted/Defect_images')
img_files = sorted(list(IMG_DIR.glob('*.png')))
img_path = img_files[0]

img = cv2.imread(str(img_path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# --- Инференс ---
result_classmaps, patch_coords, (map_H, map_W) = infer_full_image_classification_with_models(
    img,
    models_dict=MODELS,
    preprocess=preprocess_patch,
    patch_h=PATCH_H,
    patch_w=PATCH_W,
    stride_h=STRIDE_H,
    stride_w=STRIDE_W,
    device=device,
    config=CONFIG
)

# --- Визуализация ---
for model_name, class_map in result_classmaps.items():
    show_classification_map(
        img, class_map,
        patch_h=PATCH_H, patch_w=PATCH_W,
        stride_h=STRIDE_H, stride_w=STRIDE_W,
        idx2label=train_ds.idx2label,  # или свой словарь {int: str}
        model_name=model_name
    )

save_dir = "./inference_results"
for idx, (model_name, class_map) in enumerate(result_classmaps.items()):
    save_classification_map_image(
        img, class_map,
        patch_h=PATCH_H, patch_w=PATCH_W,
        stride_h=STRIDE_H, stride_w=STRIDE_W,
        idx2label=train_ds.idx2label,  # или свой словарь
        model_name=model_name,
        save_dir=save_dir,
        img_id=idx
    )

In [None]:

for model_name, class_map in result_classmaps.items():
    uniques, counts = np.unique(class_map, return_counts=True)
    print(f"Model: {model_name}")
    for idx, cnt in zip(uniques, counts):
        print(f"  {train_ds.idx2label[idx]}: {cnt} патчей")
