# Install dependencies

In [None]:
%%writefile requirements.txt
requests
tqdm>=4.65.0
torch>=2.0.0
# torchvision>=0.14.0
opencv-python
numpy>=1.23.0
timm>=0.6.13
# tensorboard>=2.13.0
albumentations>=1.4.0
# segmentation-models-pytorch>=0.3.3
scikit-learn>=1.1.0
matplotlib>=3.5.0
# transformers>=4.31.0


In [None]:
! pip install -U -q pip
! pip install -q -r requirements.txt

# Training config

In [None]:
# --- –ù–∞—Å—Ç—Ä–æ–π–∫–∏ –æ–±—É—á–µ–Ω–∏—è –¥–ª—è –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏ ---
CONFIG = {
    'batch_size': 64,
    'num_workers': 2,
    'num_epochs': 100,
    'learning_rate': 1e-3,
    'weight_decay': 1e-3,
    'early_stop_patience': 10,
    'dataroot': './aitex_data/extracted',
    'log_dir': 'runs/classification_experiment',   # –∏–∑–º–µ–Ω–∏–ª –¥–ª—è –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏
    'resume': False
}

import torch
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"


## Models

In [None]:
import timm

# –û–±–Ω–æ–≤–∏: –¥–ª—è –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏ –Ω—É–∂–Ω—ã num_classes (–Ω–µ classes)
NUM_CLASSES = 12 + 1  # —É–∫–∞–∂–∏ —Ä–µ–∞–ª—å–Ω–æ (–æ–±—ã—á–Ω–æ ~12 –¥–ª—è –¥–µ—Ñ–µ–∫—Ç–æ–≤ –≤ AITEX)

MODELS = {
    "beit_base_patch16_224": lambda: timm.create_model(
        "beit_base_patch16_224",
        pretrained=True,
        num_classes=NUM_CLASSES,
        in_chans=3,
        drop_rate=0.3,          # <--- Dropout –ø–æ—Å–ª–µ MLP
        drop_path_rate=0.1      # <--- DropPath –º–µ–∂–¥—É –±–ª–æ–∫–∞–º–∏
    ),
    # "convnext_base": lambda: timm.create_model(
    #     "convnext_base",
    #     pretrained=True,
    #     num_classes=NUM_CLASSES,
    #     in_chans=3,
    #     drop_path_rate=0.1      # —Ç–æ–ª—å–∫–æ DropPath
    # ),
    # "resnet": lambda: timm.create_model(
    #     "resnet50",
    #     pretrained=True,
    #     num_classes=NUM_CLASSES,
    #     in_chans=3
    #     # Dropout –Ω–µ –∏—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è
    # ),
}


In [None]:
# import torch
# import gc
# from tqdm import tqdm

# def find_max_batch_size(
#     model_fn,
#     device='cuda',
#     image_size=(224, 224),
#     max_test=512,
#     num_classes=13,
#     step=4
# ):
#     print(f"\n=== –ü–æ–∏—Å–∫ –º–∞–∫—Å–∏–º–∞–ª—å–Ω–æ–≥–æ batch_size –¥–ª—è {model_fn.__name__ if hasattr(model_fn, '__name__') else model_fn} ===")
#     batch_size = step
#     last_ok = 0
#     model = model_fn().to(device)
#     model.eval()
#     torch.cuda.empty_cache()
#     gc.collect()

#     tried = []
#     total_attempts = (max_test // step)
#     pbar = tqdm(total=total_attempts, desc='–ü–æ–¥–±–æ—Ä batch_size', ncols=100)
#     while batch_size <= max_test:
#         try:
#             dummy = torch.randn(batch_size, 3, image_size[0], image_size[1]).to(device)
#             with torch.no_grad():
#                 out = model(dummy)
#             last_ok = batch_size
#             tried.append(batch_size)
#             batch_size += step
#             del dummy, out
#             torch.cuda.empty_cache()
#             gc.collect()
#             pbar.update(1)
#         except RuntimeError as e:
#             pbar.close()
#             if 'out of memory' in str(e):
#                 print(f"\nOOM at batch_size={batch_size}. Last OK: {last_ok}")
#                 break
#             else:
#                 print(f"\nError at batch_size={batch_size}: {e}")
#                 break
#     pbar.close()
#     del model
#     torch.cuda.empty_cache()
#     gc.collect()
#     print(f"–ú–∞–∫—Å–∏–º–∞–ª—å–Ω—ã–π batch_size: {last_ok} (–¥–ª—è image_size={image_size})")
#     return last_ok

# # –ü—Ä–∏–º–µ—Ä: –¥–ª—è –≤—Å–µ—Ö –º–æ–¥–µ–ª–µ–π
# for name, model_fn in MODELS.items():
#     print(f"\n=== {name} ===")
#     max_bs = find_max_batch_size(model_fn, device='cuda', image_size=(224,224), step=8)
#     print(f"Max batch_size for {name}: {max_bs}\n")


## Metrics

In [None]:
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, top_k_accuracy_score
)
import numpy as np

def compute_classification_metrics(y_true, y_pred, y_prob, num_classes, top_k=3):
    metrics = {}
    metrics["accuracy"] = accuracy_score(y_true, y_pred)
    metrics["precision_macro"] = precision_score(y_true, y_pred, average='macro', zero_division=0)
    metrics["recall_macro"] = recall_score(y_true, y_pred, average='macro', zero_division=0)
    metrics["f1_macro"] = f1_score(y_true, y_pred, average='macro', zero_division=0)
    metrics["precision_micro"] = precision_score(y_true, y_pred, average='micro', zero_division=0)
    metrics["recall_micro"] = recall_score(y_true, y_pred, average='micro', zero_division=0)
    metrics["f1_micro"] = f1_score(y_true, y_pred, average='micro', zero_division=0)
    metrics["confusion_matrix"] = confusion_matrix(y_true, y_pred)
    if y_prob is not None:
        metrics["top1"] = top_k_accuracy_score(y_true, y_prob, k=1, labels=list(range(num_classes)))
        metrics["top3"] = top_k_accuracy_score(y_true, y_prob, k=min(3, num_classes), labels=list(range(num_classes)))
    return metrics


# Prepare data

## download_dataset

In [None]:
import requests
from tqdm import tqdm

import zipfile
from pathlib import Path

url = "https://www.kaggle.com/api/v1/datasets/download/nexuswho/aitex-fabric-image-database"
output_dir = Path("./aitex_data")
output_dir.mkdir(exist_ok=True)

zip_path = output_dir / "aitex.zip"

# –ü—Ä–æ–≤–µ—Ä–∫–∞ –Ω–∞ —Å—É—â–µ—Å—Ç–≤–æ–≤–∞–Ω–∏–µ –∞—Ä—Ö–∏–≤–∞
if zip_path.exists():
    print(f"[INFO] –ê—Ä—Ö–∏–≤ —É–∂–µ —Å—É—â–µ—Å—Ç–≤—É–µ—Ç –ø–æ –ø—É—Ç–∏: {zip_path}")
else:
    print(f"[INFO] –°–∫–∞—á–∏–≤–∞–µ–º –∞—Ä—Ö–∏–≤ –∏–∑ {url}...")
    response = requests.get(url, stream=True)
    if response.status_code == 200:
        with open(zip_path, "wb") as f:
            for chunk in tqdm(response.iter_content(chunk_size=8192)):
                f.write(chunk)
        print("[INFO] –°–∫–∞—á–∏–≤–∞–Ω–∏–µ –∑–∞–≤–µ—Ä—à–µ–Ω–æ.")
    else:
        raise Exception(f"–û—à–∏–±–∫–∞ –ø—Ä–∏ —Å–∫–∞—á–∏–≤–∞–Ω–∏–∏: —Å—Ç–∞—Ç—É—Å {response.status_code}")

# –†–∞—Å–ø–∞–∫–æ–≤–∫–∞ –∞—Ä—Ö–∏–≤–∞
extract_dir = output_dir / "extracted"
if not extract_dir.exists():
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_dir)
    print(f"[INFO] –ê—Ä—Ö–∏–≤ —É—Å–ø–µ—à–Ω–æ —Ä–∞—Å–ø–∞–∫–æ–≤–∞–Ω –≤ {extract_dir}")
else:
    print(f"[INFO] –ê—Ä—Ö–∏–≤ —É–∂–µ –±—ã–ª —Ä–∞—Å–ø–∞–∫–æ–≤–∞–Ω –≤ {extract_dir}")


## remove_image_without_masks¬∂

In [None]:
import os

def remove_images_without_masks(image_dir, mask_dir, image_suffix=".png", mask_suffix="_mask.png"):
    """
    –£–¥–∞–ª—è–µ—Ç –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è, –¥–ª—è –∫–æ—Ç–æ—Ä—ã—Ö –æ—Ç—Å—É—Ç—Å—Ç–≤—É–µ—Ç –º–∞—Å–∫–∞.
    """
    removed = 0
    for img_name in os.listdir(image_dir):
        if not img_name.endswith(image_suffix):
            continue
        base_name = os.path.splitext(img_name)[0]
        mask_name = base_name + mask_suffix
        mask_path = os.path.join(mask_dir, mask_name)
        img_path = os.path.join(image_dir, img_name)
        if not os.path.exists(mask_path):
            print(f"–£–¥–∞–ª—è–µ—Ç—Å—è {img_path} (–º–∞—Å–∫–∞ {mask_name} –Ω–µ –Ω–∞–π–¥–µ–Ω–∞)")
            os.remove(img_path)
            removed += 1
    print(f"–£–¥–∞–ª–µ–Ω–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π –±–µ–∑ –º–∞—Å–æ–∫: {removed}")

# –ü—Ä–∏–º–µ—Ä –≤—ã–∑–æ–≤–∞:
remove_images_without_masks(
    image_dir="./aitex_data/extracted/Defect_images",
    mask_dir="./aitex_data/extracted/Mask_images"
)


import os
import random
import cv2
import matplotlib.pyplot as plt

def show_and_save_random_image_with_mask(
    image_dir, mask_dir,
    image_suffix=".png", mask_suffix="_mask.png",
    save_dir="./random_samples"
):
    """
    –í—ã–≤–æ–¥–∏—Ç –∏ —Å–æ—Ö—Ä–∞–Ω—è–µ—Ç —Å–ª—É—á–∞–π–Ω–æ–µ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ –∏ –º–∞—Å–∫—É.
    –ö–∞—Ä—Ç–∏–Ω–∫–∏ –≤—ã–≤–æ–¥—è—Ç—Å—è –ø–æ –≤–µ—Ä—Ç–∏–∫–∞–ª–∏ (2 —Å—Ç—Ä–æ–∫–∏), —Å–æ—Ö—Ä–∞–Ω—è—é—Ç—Å—è –ø–æ –æ—Ç–¥–µ–ª—å–Ω–æ—Å—Ç–∏.
    """
    os.makedirs(save_dir, exist_ok=True)
    image_files = [f for f in os.listdir(image_dir) if f.endswith(image_suffix)]
    if not image_files:
        print("–ù–µ—Ç –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–π –≤ –¥–∏—Ä–µ–∫—Ç–æ—Ä–∏–∏.")
        return

    img_name = random.choice(image_files)
    base_name = os.path.splitext(img_name)[0]
    mask_name = base_name + mask_suffix

    img_path = os.path.join(image_dir, img_name)
    mask_path = os.path.join(mask_dir, mask_name)

    if not os.path.exists(mask_path):
        print(f"–ú–∞—Å–∫–∞ –¥–ª—è {img_name} –Ω–µ –Ω–∞–π–¥–µ–Ω–∞: {mask_path}")
        return

    img = cv2.imread(img_path)
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

    # –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è –ø–æ –≤–µ—Ä—Ç–∏–∫–∞–ª–∏
    plt.figure(figsize=(6, 8))
    plt.subplot(2, 1, 1)
    plt.imshow(img_rgb)
    plt.title(f"–ò–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ: {img_name}")
    plt.axis('off')

    plt.subplot(2, 1, 2)
    plt.imshow(mask, cmap='gray')
    plt.title(f"–ú–∞—Å–∫–∞: {mask_name}")
    plt.axis('off')

    plt.tight_layout()
    plt.show()

    # –°–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ —Ñ–∞–π–ª–æ–≤ –æ—Ç–¥–µ–ª—å–Ω–æ
    img_save_path = os.path.join(save_dir, f"{base_name}_image.png")
    mask_save_path = os.path.join(save_dir, f"{base_name}_mask.png")
    cv2.imwrite(img_save_path, img)
    cv2.imwrite(mask_save_path, mask)
    print(f"–°–æ—Ö—Ä–∞–Ω–µ–Ω–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ: {img_save_path}")
    print(f"–°–æ—Ö—Ä–∞–Ω–µ–Ω–∞ –º–∞—Å–∫–∞: {mask_save_path}")

# –ü—Ä–∏–º–µ—Ä –≤—ã–∑–æ–≤–∞:
show_and_save_random_image_with_mask(
    image_dir="./aitex_data/extracted/Defect_images",
    mask_dir="./aitex_data/extracted/Mask_images"
)



## dataset_stats

In [None]:
from pathlib import Path
import cv2
import numpy as np

SRC_MSK_DIR = Path("./aitex_data/extracted/Mask_images")

min_pixels = None
max_pixels = 0
pixels_list = []

for mask_path in SRC_MSK_DIR.glob("*.png"):
    msk = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
    if msk is None:
        continue
    num = int((msk > 0).sum())
    if num > 0:
        pixels_list.append(num)
        if min_pixels is None or num < min_pixels:
            min_pixels = num
        if num > max_pixels:
            max_pixels = num

print(f"–ú–∏–Ω–∏–º–∞–ª—å–Ω–æ–µ –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ –ø–∏–∫—Å–µ–ª–µ–π –¥–µ—Ñ–µ–∫—Ç–∞ –≤ –æ–¥–Ω–æ–π –º–∞—Å–∫–µ: {min_pixels}")
print(f"–ú–∞–∫—Å–∏–º–∞–ª—å–Ω–æ–µ: {max_pixels}")
print(f"–ú–µ–¥–∏–∞–Ω–Ω–æ–µ: {np.median(pixels_list)}")
print(f"–ì–∏—Å—Ç–æ–≥—Ä–∞–º–º–∞ –ø–æ –≤—Å–µ–º –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è–º:")
import matplotlib.pyplot as plt
plt.hist(pixels_list, bins=30)
plt.xlabel("–î–µ—Ñ–µ–∫—Ç–Ω—ã—Ö –ø–∏–∫—Å–µ–ª–µ–π –Ω–∞ –º–∞—Å–∫–µ")
plt.ylabel("–ß–∞—Å—Ç–æ—Ç–∞")
plt.show()


# Preprocess data

## slice_to_patches

In [None]:
from pathlib import Path
import cv2
import numpy as np
import pandas as pd
from tqdm import tqdm

# --- –ú–µ—Ç–∫–∏ –¥–µ—Ñ–µ–∫—Ç–æ–≤ ---
DEFECT_LABELS = {
    '000': 'No defect',
    '002': 'Broken end',
    '006': 'Broken yarn',
    '010': 'Broken pick',
    '016': 'Weft curling',
    '019': 'Fuzzyball',
    '022': 'Cut selvage',
    '023': 'Crease',
    '025': 'Warp ball',
    '027': 'Knots',
    '029': 'Contamination',
    '030': 'Nep',
    '036': 'Weft crack'
}

# --- –ü–∞—Ä–∞–º–µ—Ç—Ä—ã –Ω–∞—Ä–µ–∑–∫–∏ ---
SRC_IMG_DIR = Path("./aitex_data/extracted/Defect_images")
SRC_MSK_DIR = Path("./aitex_data/extracted/Mask_images")
DST_IMG_DIR = Path("./aitex_patches/images")
DST_MSK_DIR = Path("./aitex_patches/masks")

PATCH_W = PATCH_H = 224
STRIDE_W = STRIDE_H = (256 - 224)
# MIN_DEFECT_FRAC = 0.005       # –º–∏–Ω–∏–º–∞–ª—å–Ω–∞—è –¥–æ–ª—è –¥–µ—Ñ–µ–∫—Ç–∞
MIN_DEFECT_PIXELS = 9 
KEEP_NEG = 0.05

DST_IMG_DIR.mkdir(parents=True, exist_ok=True)
DST_MSK_DIR.mkdir(parents=True, exist_ok=True)

rows = []
PATCH_AREA = PATCH_W * PATCH_H

def has_large_defect(mask, min_size=20):
    num_labels, _, stats, _ = cv2.connectedComponentsWithStats((mask > 0).astype(np.uint8))
    for i in range(1, num_labels):  # 0 ‚Äî —Ñ–æ–Ω
        if stats[i, cv2.CC_STAT_AREA] >= min_size:
            return True
    return False

for img_path in tqdm(sorted(SRC_IMG_DIR.glob("*.png")), desc="Cropping AITEX (grid)"):
    mask_path = SRC_MSK_DIR / img_path.name.replace(".png", "_mask.png")
    img = cv2.imread(str(img_path))
    msk = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)

    if img is None or msk is None:
        print(f"‚ùå  –ü—Ä–æ–ø—É—Å–∫ {img_path.name} (—Ñ–∞–π–ª–∞ –Ω–µ—Ç)")
        continue

    msk_bin = (msk > 0).astype(np.uint8)
    orig_defect_code_str = img_path.stem.split('_')[1]
    orig_defect_code = int(orig_defect_code_str)
    orig_defect_label = DEFECT_LABELS.get(orig_defect_code_str, "Unknown")

    # –î–í–ê –¶–ò–ö–õ–ê: –ø–æ y –∏ –ø–æ x
    for y in range(0, img.shape[0] - PATCH_H + 1, STRIDE_H):
        for x in range(0, img.shape[1] - PATCH_W + 1, STRIDE_W):
            img_crop = img[y:y+PATCH_H, x:x+PATCH_W]
            msk_crop = msk_bin[y:y+PATCH_H, x:x+PATCH_W]
            pos_pix = int(msk_crop.sum())
            defect_frac = pos_pix / PATCH_AREA

            # –ö–æ–º–±–∏–Ω–∏—Ä–æ–≤–∞–Ω–Ω—ã–π —Ñ–∏–ª—å—Ç—Ä
            # is_defective = (defect_frac >= MIN_DEFECT_FRAC) or (pos_pix >= MIN_DEFECT_PIXELS)

            is_defective = pos_pix >= MIN_DEFECT_PIXELS

            
            patch_defect_code = orig_defect_code
            patch_defect_label = orig_defect_label

            # –ï—Å–ª–∏ –ø–∞—Ç—á "—á–∏—Å—Ç—ã–π", –ø–µ—Ä–µ–æ–ø—Ä–µ–¥–µ–ª—è–µ–º –º–µ—Ç–∫—É
            if not is_defective or not has_large_defect(msk_crop, min_size=20):
                if np.random.rand() > KEEP_NEG:
                    continue
                patch_defect_code = 0
                patch_defect_label = DEFECT_LABELS['000']

            suffix = f"x{x:04d}_y{y:04d}"
            fname  = f"{img_path.stem}_{suffix}.png"
            cv2.imwrite(str(DST_IMG_DIR / fname), img_crop)
            cv2.imwrite(str(DST_MSK_DIR / fname), msk_crop * 255)
            rows.append((fname, patch_defect_code, patch_defect_label))

# --- –°–æ—Ö—Ä–∞–Ω—è–µ–º CSV —Å –º–µ—Ç–∫–∞–º–∏ ---
label_path = Path("./aitex_patches/patch_labels.csv")
label_df = pd.DataFrame(rows, columns=["file", "defect_code", "defect_label"])
label_df.to_csv(label_path, index=False)
print("üìù  Saved", len(rows), "patch labels ‚Üí", label_path)
print("‚úÖ  –ù–∞—Ä–µ–∑–∫–∞ –ø–∞—Ç—á–µ–π –∑–∞–≤–µ—Ä—à–µ–Ω–∞.")


## Undersampling

In [None]:
import pandas as pd

df = pd.read_csv('./aitex_patches/patch_labels.csv')
has_defect = df[df['defect_label'] != 'No defect']
no_defect  = df[df['defect_label'] == 'No defect']

print(f"–î–µ—Ñ–µ–∫—Ç–Ω—ã—Ö –ø–∞—Ç—á–µ–π: {len(has_defect)}")
print(f"–ß–∏—Å—Ç—ã—Ö –ø–∞—Ç—á–µ–π: {len(no_defect)}")

desired_ratio = 1.0  # –Ω–∞–ø—Ä–∏–º–µ—Ä, 1:1
n_defect = len(has_defect)
n_no_defect = min(int(n_defect * desired_ratio), len(no_defect))

no_defect_sampled = no_defect.sample(n=n_no_defect, random_state=42)
df_balanced = pd.concat([has_defect, no_defect_sampled]).sample(frac=1, random_state=42)

balanced_label_path = './aitex_patches/patch_labels_balanced.csv'
df_balanced.to_csv(balanced_label_path, index=False)
print(f"Balanced CSV saved: {balanced_label_path}")

# --- –ù–æ–≤—ã–π –≤—ã–≤–æ–¥ ---
summary = df_balanced['defect_label'].value_counts().reset_index()
summary.columns = ['defect_label', 'num_patches']
summary['percentage'] = (summary['num_patches'] / summary['num_patches'].sum() * 100).round(2)

print("\n=== Patch distribution (balanced) ===")
print(summary.to_string(index=False))

PATCH_LABEL_PATH = balanced_label_path

## Visualize data after processing

In [None]:
import random
import matplotlib.pyplot as plt
from pathlib import Path
import cv2
import pandas as pd

def visualize_patches_with_masks_and_labels(
    img_dir,
    mask_dir,
    csv_path,
    min_defect_pixels=MIN_DEFECT_PIXELS,
    n_pos=6,
    n_neg=6
):
    """
    –ü–æ–∫–∞–∑—ã–≤–∞–µ—Ç –ø–∞—Ç—á–∏ —Å –¥–µ—Ñ–µ–∫—Ç–æ–º –∏ –±–µ–∑ –¥–µ—Ñ–µ–∫—Ç–∞:
    - –û—Ä–∏–≥–∏–Ω–∞–ª (RGB)
    - –ú–∞—Å–∫–∞ (–æ—Ç–¥–µ–ª—å–Ω–æ)
    - –ù–∞–ª–æ–∂–µ–Ω–∏–µ –º–∞—Å–∫–∏ (Mask Overlay)
    –í –∑–∞–≥–æ–ª–æ–≤–∫–µ ‚Äî —Å—Ç–∞—Ç—É—Å (DEFECT/CLEAN) –∏ –∫–ª–∞—Å—Å –¥–µ—Ñ–µ–∫—Ç–∞.
    """
    df = pd.read_csv(csv_path)
    pos_samples, neg_samples = [], []

    for _, row in df.iterrows():
        mask_path = Path(mask_dir) / row['file']
        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            continue
        defect_pixels = (mask > 0).sum()
        info = (row['file'], row['defect_label'])
        if defect_pixels >= min_defect_pixels:
            pos_samples.append(info)
        else:
            neg_samples.append(info)

    pos_samples = random.sample(pos_samples, min(n_pos, len(pos_samples)))
    neg_samples = random.sample(neg_samples, min(n_neg, len(neg_samples)))
    all_samples = [(fname, "DEFECT", label) for fname, label in pos_samples] + \
                  [(fname, "CLEAN", label) for fname, label in neg_samples]

    plt.figure(figsize=(len(all_samples) * 4, 10))
    for i, (fname, status, defect_label) in enumerate(all_samples):
        img = cv2.imread(str(Path(img_dir) / fname))
        mask = cv2.imread(str(Path(mask_dir) / fname), cv2.IMREAD_GRAYSCALE)
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        # 1. –ò—Å—Ö–æ–¥–Ω–æ–µ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ
        plt.subplot(3, len(all_samples), i + 1)
        plt.imshow(img_rgb)
        plt.title(f"{status}\n{defect_label}\n{fname}", fontsize=9)
        plt.axis('off')

        # 2. –ú–∞—Å–∫–∞ (–æ—Ç–¥–µ–ª—å–Ω–æ)
        plt.subplot(3, len(all_samples), len(all_samples) + i + 1)
        plt.imshow(mask, cmap='gray')
        plt.title("Mask", fontsize=9)
        plt.axis('off')

        # 3. –ù–∞–ª–æ–∂–µ–Ω–∏–µ –º–∞—Å–∫–∏
        plt.subplot(3, len(all_samples), 2 * len(all_samples) + i + 1)
        plt.imshow(img_rgb)
        plt.imshow(mask, cmap='Reds', alpha=0.5)
        plt.title("Overlay", fontsize=9)
        plt.axis('off')

    plt.tight_layout()
    plt.show()

# --- –ó–∞–ø—É—Å–∫ –≤–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏–∏ ---
visualize_patches_with_masks_and_labels(
    img_dir=DST_IMG_DIR,
    mask_dir=DST_MSK_DIR,
    csv_path=PATCH_LABEL_PATH,
    min_defect_pixels=MIN_DEFECT_PIXELS,  # –∫–ª—é—á–µ–≤–æ–µ –æ—Ç–ª–∏—á–∏–µ!
    n_pos=3,
    n_neg=1
)


In [None]:
import pandas as pd

# –ü—É—Ç–∏ –∫ –¥–∞–Ω–Ω—ã–º
PATCH_LABEL_PATH = './aitex_patches/patch_labels_balanced.csv'  # –∏–ª–∏ —Ç–≤–æ–π –∏—Ç–æ–≥–æ–≤—ã–π –ø—É—Ç—å

# –ß—Ç–µ–Ω–∏–µ –º–µ—Ç–æ–∫
df = pd.read_csv(PATCH_LABEL_PATH)
print(df.head())


# Create Dataset

## augmenations

In [None]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

def get_strong_classification_augmentations(image_size=(256, 256)):
    """
    –°–∏–ª—å–Ω—ã–µ –∞—É–≥–º–µ–Ω—Ç–∞—Ü–∏–∏ –¥–ª—è train.
    –£—Å—Ç—Ä–∞–Ω—è–µ–º –ø—Ä–µ–¥—É–ø—Ä–µ–∂–¥–µ–Ω–∏—è, –æ—Å—Ç–∞–≤–ª—è—è —Ç–æ–ª—å–∫–æ –ø–æ–¥–¥–µ—Ä–∂–∏–≤–∞–µ–º—ã–µ –ø–∞—Ä–∞–º–µ—Ç—Ä—ã.
    """
    ops = [
        A.Resize(int(image_size[0]*1.1), int(image_size[1]*1.1)),
        A.RandomCrop(*image_size, p=1.0),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
        A.Affine(
            scale=(0.85, 1.15),
            translate_percent=0.15,
            rotate=(-30, 30),
            shear=(-12, 12),
            p=0.7
        ),
        A.ElasticTransform(p=0.25),  # —É–±—Ä–∞–Ω—ã –Ω–µ–ø–æ–¥–¥–µ—Ä–∂–∏–≤–∞–µ–º—ã–µ alpha_affine
        A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.2),
        A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
        A.HueSaturationValue(hue_shift_limit=15, sat_shift_limit=20, val_shift_limit=15, p=0.4),
        A.RandomGamma(gamma_limit=(80, 120), p=0.3),
        A.GaussNoise(p=0.4),          # var_limit –ø–æ —É–º–æ–ª—á–∞–Ω–∏—é
        A.GaussianBlur(blur_limit=(3, 9), p=0.3),
        A.CoarseDropout(p=0.4),        # default –ø–∞—Ä–∞–º–µ—Ç—Ä—ã
    ]

    # –î–æ–±–∞–≤–ª—è–µ–º GridMask, –µ—Å–ª–∏ –¥–æ—Å—Ç—É–ø–µ–Ω (albumentations>=1.2.0)
    if hasattr(A, 'GridMask'):
        ops.append(A.GridMask(num_grid=(3, 7), rotate=15, p=0.3))

    # –ù–æ—Ä–º–∞–ª–∏–∑–∞—Ü–∏—è –∏ –ø–µ—Ä–µ–≤–æ–¥ –≤ —Ç–µ–Ω–∑–æ—Ä
    ops.extend([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])

    return A.Compose(ops)


def get_val_classification_augmentations(image_size=(256, 256)):
    """
    –õ—ë–≥–∫–∏–µ –∞—É–≥–º–µ–Ω—Ç–∞—Ü–∏–∏ –¥–ª—è –≤–∞–ª–∏–¥–∞—Ü–∏–∏/—Ç–µ—Å—Ç–∞.
    """
    return A.Compose([
        A.Resize(*image_size),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2()
    ])


## dataset_class

In [None]:
import pandas as pd
import numpy as np
from pathlib import Path
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

# --- —Ç–≤–æ–π –∫–∞—Å—Ç–æ–º–Ω—ã–π –∫–ª–∞—Å—Å PatchDataset —Å copy-paste ---
import random
import cv2
from torch.utils.data import Dataset

class PatchDataset(Dataset):
    """
    –ö–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–æ–Ω–Ω—ã–π –¥–∞—Ç–∞—Å–µ—Ç –ø–∞—Ç—á–µ–π —Å class-aware Copy-Paste.
    –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç: (img_tensor, class_idx)
    """
    def __init__(
        self,
        df: pd.DataFrame,
        img_dir: str,
        transform,
        copy_paste_prob: float = 0.3,
        class_copy_paste: bool = True,
        random_state: int = 42
    ):
        self.df = df.reset_index(drop=True)
        self.img_dir = Path(img_dir)
        self.transform = transform
        self.copy_paste_prob = copy_paste_prob
        self.class_copy_paste = class_copy_paste
        self.rng = np.random.RandomState(random_state)

        # mapping label ‚Üî idx
        labels = sorted(self.df['defect_label'].unique())
        self.label2idx = {lbl: i for i, lbl in enumerate(labels)}
        self.idx2label = {i: lbl for lbl, i in self.label2idx.items()}

        # pool –¥–µ—Ñ–µ–∫—Ç–Ω—ã—Ö —Ñ–∞–π–ª–æ–≤ –∏ –ø–æ–¥—Å—á–µ—Ç –¥–µ—Ñ–∏—Ü–∏—Ç–∞ –¥–æ –º–µ–¥–∏–∞–Ω—ã
        codes = self.df['defect_code'].values
        files = self.df['file'].values
        defect_mask = codes != 0
        defect_codes = codes[defect_mask]
        defect_files = files[defect_mask]

        cnt = pd.Series(defect_codes).value_counts().to_dict()
        median = int(np.median(list(cnt.values()))) or 1
        deficit = {c: max(0, median - n) for c, n in cnt.items()}
        total_def = sum(deficit.values())
        self.class_probs = {c: d/total_def for c, d in deficit.items()} if total_def>0 else {}
        self.defect_pool = list(zip(defect_files, defect_codes))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int):
        row = self.df.iloc[idx]
        fname = row['file']
        code = int(row['defect_code'])
        img = cv2.cvtColor(cv2.imread(str(self.img_dir / fname)), cv2.COLOR_BGR2RGB)

        # Copy-Paste –¥–ª—è —á–∏—Å—Ç—ã—Ö –ø–∞—Ç—á–µ–π
        if code == 0 and self.defect_pool and self.rng.rand() < self.copy_paste_prob:
            # –≤—ã–±–æ—Ä –∫–ª–∞—Å—Å–∞-–¥–æ–Ω–æ—Ä
            if self.class_copy_paste and self.class_probs:
                codes, probs = zip(*self.class_probs.items())
                sel_code = int(self.rng.choice(codes, p=probs))
                candidates = [f for f, c in self.defect_pool if c == sel_code]
            else:
                candidates = [f for f, _ in self.defect_pool]
            donor = self.rng.choice(candidates)
            donor_img = cv2.cvtColor(cv2.imread(str(self.img_dir / donor)), cv2.COLOR_BGR2RGB)
            # –ø—Ä–æ—Å—Ç–æ–π –º–∞—Å–∫–∏—Ä–æ–≤–∞–Ω–Ω—ã–π –≤—Å—Ç–∞–≤: –ø–æ –ø–∏–∫—Å–µ–ª—å–Ω–æ–º—É –æ—Ç–ª–∏—á–∏—é
            mask = (donor_img != img).any(axis=2).astype(np.uint8)
            for ch in range(3):
                img[:, :, ch] = donor_img[:, :, ch] * mask + img[:, :, ch] * (1-mask)
            code = int([c for f, c in self.defect_pool if f == donor][0])

        # –ø—Ä–∏–º–µ–Ω–µ–Ω–∏–µ –∞—É–≥–º–µ–Ω—Ç–∞—Ü–∏–π –∏ –Ω–æ—Ä–º–∞–ª–∏–∑–∞—Ü–∏—è
        img = self.transform(image=img)['image']
        label = self.label2idx[row['defect_label']]
        return img, label



## dataset creation helpers

In [None]:
def upsample_train_to_median(train_df, label_col='defect_code', random_state=42):
    counts = train_df[label_col].value_counts()
    median = int(counts.median())
    upsampled = []
    for label, cnt in counts.items():
        df_label = train_df[train_df[label_col] == label]
        if cnt < median:
            n_more = median - cnt
            sampled = df_label.sample(n=n_more, replace=True, random_state=random_state)
            upsampled.append(sampled)
    if upsampled:
        train_df = pd.concat([train_df] + upsampled).sample(frac=1, random_state=random_state).reset_index(drop=True)
    return train_df


In [None]:

# --- prepare_datasets –¥–ª—è –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏ –ø–∞—Ç—á–µ–π ---
def prepare_datasets(
    patch_label_path: str,
    img_dir: str,
    test_size: float = 0.05,
    val_size: float = 0.1,
    batch_size: int = 16,
    num_workers: int = 4,
    random_state: int = 42,
    image_size: tuple = (224, 224),
    train_aug_fn=None,
    val_aug_fn=None,
    copy_paste_prob: float = 0.3
) -> tuple:
    """
    –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç train_ds, val_ds, test_ds (PatchDataset) —Å copy-paste –¥–ª—è train.
    """
    df = pd.read_csv(patch_label_path)

    # stratified split
    train_files, test_files = train_test_split(
        df['file'], test_size=test_size,
        stratify=df['defect_code'], random_state=random_state
    )
    train_files, val_files = train_test_split(
        train_files, test_size=val_size,
        stratify=df[df['file'].isin(train_files)]['defect_code'],
        random_state=random_state
    )
    train_df = df[df['file'].isin(train_files)].reset_index(drop=True)
    val_df = df[df['file'].isin(val_files)].reset_index(drop=True)
    test_df = df[df['file'].isin(test_files)].reset_index(drop=True)

    # upsamples –¥–æ –º–µ–¥–∏–∞–Ω—ã
    # train_df = upsample_train_to_median(train_df, label_col='defect_code', random_state=random_state)
    print("Train class distribution")
    print(train_df['defect_code'].value_counts())

    # transforms
    train_transform = train_aug_fn or get_strong_classification_augmentations(image_size)
    val_transform = val_aug_fn or get_val_classification_augmentations(image_size)

    # datasets
    train_ds = PatchDataset(
        train_df, img_dir,
        transform=train_transform,
        copy_paste_prob=copy_paste_prob,
        class_copy_paste=True,
        random_state=random_state
    )
    val_ds = PatchDataset(
        val_df, img_dir,
        transform=val_transform,
        copy_paste_prob=0.0
    )
    test_ds = PatchDataset(
        test_df, img_dir,
        transform=val_transform,
        copy_paste_prob=0.0
    )

    return train_ds, val_ds, test_ds


## dataloader creation helper

In [None]:

def get_classification_dataloaders(
    train_ds, val_ds, test_ds,
    batch_size: int = 16,
    num_workers: int = 4
) -> tuple:
    """
    –û–±–æ—Ä–∞—á–∏–≤–∞–µ—Ç –¥–∞—Ç–∞—Å–µ—Ç—ã –≤ DataLoader'—ã —Å –ø—Ä–∞–≤–∏–ª—å–Ω—ã–º–∏ –ø–∞—Ä–∞–º–µ—Ç—Ä–∞–º–∏.
    """
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    return train_loader, val_loader, test_loader



## Initialize dataloaders

In [None]:

# === –ü—Ä–∏–º–µ—Ä –∏—Å–ø–æ–ª—å–∑–æ–≤–∞–Ω–∏—è ===
PATCH_LABEL_PATH = './aitex_patches/patch_labels_balanced.csv'
IMG_DIR = './aitex_patches/images'
MSK_DIR = './aitex_patches/masks'

train_ds, val_ds, test_ds = prepare_datasets(
    patch_label_path=PATCH_LABEL_PATH,
    img_dir=IMG_DIR,
    # msk_dir=MSK_DIR,
    test_size=0.05,
    val_size=0.1,
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers'],
    random_state=42,
    image_size=(PATCH_H, PATCH_W),
    train_aug_fn=get_strong_classification_augmentations((PATCH_H, PATCH_W)),
    val_aug_fn=get_val_classification_augmentations((PATCH_H, PATCH_W)),
    copy_paste_prob=0.8
)

print("label2idx:", train_ds.label2idx)
print("idx2label:", train_ds.idx2label)
print("–í—Å–µ–≥–æ –∫–ª–∞—Å—Å–æ–≤:", len(train_ds.label2idx))

train_loader, val_loader, test_loader = get_classification_dataloaders(
    train_ds, val_ds, test_ds,
    batch_size=CONFIG['batch_size'],
    num_workers=CONFIG['num_workers']
)


In [None]:

import numpy as np
import torch
from collections import Counter
import torch.nn as nn

def get_weighted_loss(train_ds, smoothing=0.1, device='cuda'):
    # –ü–æ–¥—Å—á–∏—Ç—ã–≤–∞–µ–º –º–µ—Ç–∫–∏ –∏–∑ train_ds
    train_labels = [label for _, label in train_ds]

    # –ü–æ–¥—Å—á—ë—Ç –∫–æ–ª–∏—á–µ—Å—Ç–≤–∞ –∫–∞–∂–¥–æ–≥–æ –∫–ª–∞—Å—Å–∞
    class_counts = Counter(train_labels)
    num_classes = len(train_ds.label2idx)
    total_samples = len(train_labels)

    # –ü–æ–¥—Å—á–∏—Ç—ã–≤–∞–µ–º –≤–µ—Å–∞ (–æ–±—Ä–∞—Ç–Ω–∞—è —á–∞—Å—Ç–æ—Ç–∞)
    class_weights = np.zeros(num_classes)
    for cls_idx in range(num_classes):
        cls_count = class_counts.get(cls_idx, 0)
        class_weights[cls_idx] = total_samples / (num_classes * cls_count)

    # –ù–æ—Ä–º–∞–ª–∏–∑–∞—Ü–∏—è –≤–µ—Å–æ–≤
    class_weights = class_weights / class_weights.sum() * num_classes
    class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

    print(f"Class weights: {class_weights}")

    # –°–æ–∑–¥–∞—ë–º –≤–∑–≤–µ—à–µ–Ω–Ω—É—é CrossEntropyLoss —Å label smoothing
    loss_fn = nn.CrossEntropyLoss(weight=class_weights_tensor, label_smoothing=smoothing)
    return loss_fn

import pandas as pd

def print_class_distribution_from_dataset(dataset, label_names=None, title="Train set"):
    """
    –í—ã–≤–æ–¥–∏—Ç —Ç–∞–±–ª–∏—Ü—É —Ä–∞—Å–ø—Ä–µ–¥–µ–ª–µ–Ω–∏—è –ø–∞—Ç—á–µ–π –ø–æ –∫–ª–∞—Å—Å–∞–º –¥–ª—è –∫–∞—Å—Ç–æ–º–Ω–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞ (–∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏—è –∏–ª–∏ —Å–µ–≥–º–µ–Ω—Ç–∞—Ü–∏—è).
    - dataset: —ç–∫–∑–µ–º–ø–ª—è—Ä Dataset (–Ω–∞–ø—Ä–∏–º–µ—Ä, PatchClassificationDataset –∏–ª–∏ AITEXPatchDataset)
    - label_names: dict –¥–ª—è –∫—Ä–∞—Å–∏–≤—ã—Ö –Ω–∞–∑–≤–∞–Ω–∏–π –∫–ª–∞—Å—Å–æ–≤ {idx: name} –∏–ª–∏ {code: name}
    - title: –∑–∞–≥–æ–ª–æ–≤–æ–∫ –¥–ª—è –≤—ã–≤–æ–¥–∞
    """
    labels = []
    for i in range(len(dataset)):
        try:
            # –ö–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏—è (img, label) –∏–ª–∏ (img, mask, label)
            if hasattr(dataset, "idx2label"):
                # –ò–Ω–¥–µ–∫—Å ‚Üí —Å—Ç—Ä–æ–∫–∞ –∫–ª–∞—Å—Å–∞
                _, label = dataset[i][:2]
                class_label = dataset.idx2label[label]
            else:
                # –ù–∞–ø—Ä–∏–º–µ—Ä, (img, mask, code, label)
                *_, label = dataset[i]
                class_label = label
        except Exception:
            # –õ—é–±–æ–π fallback
            class_label = "unknown"
        labels.append(class_label)

    df = pd.DataFrame({'class_label': labels})
    summary = df['class_label'].value_counts().reset_index()
    summary.columns = ['class_label', 'num_patches']
    summary['percentage'] = (summary['num_patches'] / summary['num_patches'].sum() * 100).round(2)
    if label_names:
        summary['class_label'] = summary['class_label'].map(label_names).fillna(summary['class_label'])
    print(f"\n=== Patch distribution in {title} ===")
    print(summary.to_string(index=False))


print_class_distribution_from_dataset(train_ds, label_names=DEFECT_LABELS, title="Train set")

## visualize dataset

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

def visualize_classification_dataset_grid(dataset, num_samples=6, label_names=None):
    """
    –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–æ–Ω–Ω–æ–≥–æ –¥–∞—Ç–∞—Å–µ—Ç–∞:
    - –ì–æ—Ä–∏–∑–æ–Ω—Ç–∞–ª—å–Ω–æ ‚Äî —Ä–∞–∑–Ω—ã–µ –ø–∞—Ç—á–∏/–∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è (num_samples)
    - –¢–æ–ª—å–∫–æ 1 —Ä—è–¥: –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ
    - –í –∑–∞–≥–æ–ª–æ–≤–∫–µ: –∫–æ–¥ –∏ (–ø—Ä–∏ –Ω–∞–ª–∏—á–∏–∏) –∫—Ä–∞—Å–∏–≤—ã–π –ª–µ–π–±–ª –∫–ª–∞—Å—Å–∞
    """
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    plt.figure(figsize=(num_samples * 4, 4))

    for col, idx in enumerate(indices):
        sample = dataset[idx]
        # –û–∂–∏–¥–∞–µ—Ç—Å—è (img, class_code, class_label) –∏–ª–∏ (img, class_code)
        if len(sample) == 3:
            image, code, label = sample
        elif len(sample) == 2:
            image, code = sample
            label = str(code)
        else:
            image = sample[0]
            code = None
            label = "unknown"

        # –ö—Ä–∞—Å–∏–≤–æ–µ –∏–º—è –∫–ª–∞—Å—Å–∞
        if label_names is not None:
            class_name = label_names.get(str(code), str(label))
        else:
            class_name = str(label)

        title = f"Code: {code}\nLabel: {class_name}"

        # –î–µ–Ω–æ—Ä–º–∞–ª–∏–∑–∞—Ü–∏—è –µ—Å–ª–∏ —Ç–µ–Ω–∑–æ—Ä
        if isinstance(image, torch.Tensor):
            rgb = image[:3].permute(1, 2, 0).cpu().numpy()
            rgb = (rgb * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
        else:
            rgb = image

        plt.subplot(1, num_samples, col + 1)
        plt.imshow(rgb)
        plt.title(title, fontsize=11)
        plt.axis("off")

    plt.tight_layout()
    plt.show()

visualize_classification_dataset_grid(train_ds, num_samples=8, label_names=DEFECT_LABELS)

# Training

## libraries

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
from timm.data.mixup import Mixup

# --- –ù–∞—Å—Ç—Ä–æ–π–∫–∞ MixUp + CutMix ---
mixup_fn = Mixup(
    mixup_alpha=0.4,        # "—Å–º–µ—à–∏–≤–∞–Ω–∏–µ" ‚Äî —Ç–∏–ø–æ–≤–æ–µ –∑–Ω–∞—á–µ–Ω–∏–µ
    cutmix_alpha=1.0,       # CutMix —Ç–æ–∂–µ –≤–∫–ª—é—á—ë–Ω (–æ–±—ã—á–Ω–æ 0.5-1.0, –º–æ–∂–Ω–æ –ø–æ–∏–≥—Ä–∞—Ç—å—Å—è)
    label_smoothing=0.1,    # –æ–±—è–∑–∞—Ç–µ–ª—å–Ω–æ, –µ—Å–ª–∏ —É —Ç–µ–±—è –∏ —Ç–∞–∫ —Å—Ç–æ–∏—Ç ‚Äî –º–æ–∂–Ω–æ —á—É—Ç—å —Å–Ω–∏–∑–∏—Ç—å
    num_classes=NUM_CLASSES
)


#### train_epoch

In [None]:
import numpy as np
import torch
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from timm.utils import ModelEmaV2


def train_epoch(
    model, train_loader, optimizer, loss_fn, scaler,
    device, epoch, model_name, num_classes,
    mixup_fn=None, ema=None
):
    """
    –û–¥–Ω–∞ —ç–ø–æ—Ö–∞ –æ–±—É—á–µ–Ω–∏—è —Å –æ–ø—Ü–∏–æ–Ω–∞–ª—å–Ω—ã–º MixUp/CutMix –∏ –æ–±–Ω–æ–≤–ª–µ–Ω–∏–µ–º EMA.
    –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç –º–µ—Ç—Ä–∏–∫–∏ –¥–ª—è train.
    """
    model.train()
    running_loss = 0.0
    all_preds, all_probs, all_targets = [], [], []

    for step, batch in enumerate(tqdm(train_loader, desc=f"Train {epoch}")):
        images, labels = batch[:2]
        images = images.to(device)
        labels = labels.to(device)

        # MIXUP / CUTMIX
        if mixup_fn:
            images, labels = mixup_fn(images, labels)

        optimizer.zero_grad()
        with autocast(enabled=(device.type=='cuda')):
            outputs = model(images)
            loss = loss_fn(outputs, labels)

        # backward + step
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        # EMA update
        if ema is not None:
            ema.update(model)

        running_loss += loss.item()

        # accumulate predictions
        probs = torch.softmax(outputs.detach(), dim=1).cpu().numpy()
        preds = np.argmax(probs, axis=1)
        all_preds.append(preds)
        all_probs.append(probs)
        if labels.dtype == torch.float32:
            all_targets.append(labels.argmax(dim=1).cpu().numpy())
        else:
            all_targets.append(labels.detach().cpu().numpy())

    y_true = np.concatenate(all_targets)
    y_pred = np.concatenate(all_preds)
    y_prob = np.concatenate(all_probs)

    metrics = compute_classification_metrics(y_true, y_pred, y_prob, num_classes)
    metrics['loss'] = running_loss / len(train_loader)
    print(f"[{model_name}] Train epoch {epoch}: {metrics}")
    return metrics


#### validate_epoch

In [None]:


def validate_epoch(
    model, val_loader, device, epoch, model_name, num_classes
):
    """
    –û–¥–Ω–∞ —ç–ø–æ—Ö–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏ –Ω–∞ –ø–µ—Ä–µ–¥–∞–Ω–Ω–æ–π –º–æ–¥–µ–ª–∏ (–æ–±—ã—á–Ω–æ EMA).
    """
    model.eval()
    running_loss = 0.0
    all_preds, all_probs, all_targets = [], [], []
    loss_fn = torch.nn.CrossEntropyLoss()

    with torch.no_grad():
        for step, batch in enumerate(tqdm(val_loader, desc=f"Val  {epoch}")):
            images, labels = batch[:2]
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            running_loss += loss.item()

            probs = torch.softmax(outputs.cpu(), dim=1).numpy()
            preds = np.argmax(probs, axis=1)
            all_preds.append(preds)
            all_probs.append(probs)
            all_targets.append(labels.cpu().numpy())

    y_true = np.concatenate(all_targets)
    y_pred = np.concatenate(all_preds)
    y_prob = np.concatenate(all_probs)

    metrics = compute_classification_metrics(y_true, y_pred, y_prob, num_classes)
    metrics['loss'] = running_loss / len(val_loader)
    print(f"[{model_name}] Val epoch {epoch}: {metrics}")
    return metrics



### train_loop

In [None]:

def train_loop(
    model, optimizer, scheduler,
    train_loader, val_loader,
    loss_fn, device, scaler,
    num_classes,
    num_epochs=50, early_stop_patience=5,
    model_name='model', mixup_fn=None,
    ema_decay: float = 0.9999
):
    """
    –û—Å–Ω–æ–≤–Ω–∞—è –ø–µ—Ç–ª—è –æ–±—É—á–µ–Ω–∏—è —Å EMA, scheduler –∏ early-stopping –ø–æ macro-F1.
    –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç –∏—Å—Ç–æ—Ä–∏–∏ train –∏ val.
    """
    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è EMA
    ema = ModelEmaV2(model, decay=ema_decay, device=device)
    best_f1 = 0.0
    counter = 0
    train_history, val_history = [], []

    for epoch in range(num_epochs):
        # train
        train_metrics = train_epoch(
            model, train_loader, optimizer, loss_fn, scaler,
            device, epoch, model_name, num_classes,
            mixup_fn=mixup_fn, ema=ema
        )
        # val –Ω–∞ EMA-–º–æ–¥–µ–ª–∏
        ema_model = ema.module
        val_metrics = validate_epoch(
            ema_model, val_loader, device, epoch, model_name, num_classes
        )

        scheduler.step(val_metrics['f1_macro'])
        train_history.append(train_metrics)
        val_history.append(val_metrics)

        # early stopping + —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ –ª—É—á—à–∏—Ö
        if val_metrics['f1_macro'] > best_f1:
            best_f1 = val_metrics['f1_macro']
            counter = 0
            torch.save(model.state_dict(), f"best_{model_name}.pth")
            ema_state = ema.state_dict()
            torch.save(ema_state, f"ema_{model_name}.pth")
        else:
            counter += 1
            if counter >= early_stop_patience:
                print(f"Early stopping at epoch {epoch}")
                break
        print(f"Epoch {epoch}: Train F1(macro)={train_metrics['f1_macro']:.4f}  Val F1(macro)={val_metrics['f1_macro']:.4f}")

    print(f"Best Val F1(macro): {best_f1:.4f}")
    return train_history, val_history


### Plot metrics

In [None]:
import math
import matplotlib.pyplot as plt

def plot_metrics(train_history, val_history, model_name, save_path=None):
    keys = ['loss', 'accuracy', 'precision_macro', 'recall_macro', 'f1_macro', 'top1', 'top3']
    n_keys = len(keys)
    n_cols = 3  # –ú–æ–∂–Ω–æ —Å–¥–µ–ª–∞—Ç—å 4 ‚Äî –±—É–¥–µ—Ç –±–æ–ª–µ–µ —Ä–∞—Å—Ç—è–Ω—É—Ç–æ
    n_rows = math.ceil(n_keys / n_cols)
    plt.figure(figsize=(n_cols * 6, n_rows * 4))
    for idx, key in enumerate(keys, 1):
        plt.subplot(n_rows, n_cols, idx)
        train_vals = [m.get(key, 0) for m in train_history]
        val_vals = [m.get(key, 0) for m in val_history]
        plt.plot(train_vals, label=f"Train {key}")
        plt.plot(val_vals, label=f"Val {key}")
        plt.title(f"{model_name}: {key}")
        plt.xlabel("Epoch")
        plt.legend()
        plt.grid()
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path)
    plt.show()


In [None]:


def train_single_model(
    model_name, model_fn,
    train_loader, val_loader,
    config, num_classes,
    mixup_fn=None
):
    """
    –¢—Ä–µ–Ω–∏—Ä—É–µ—Ç –º–æ–¥–µ–ª—å —Å EMA, OneCycleLR –∏ weighted loss + smoothing.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model_fn().to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=config['learning_rate'],
        total_steps=len(train_loader)*config['num_epochs'],
        pct_start=0.1,
        anneal_strategy='cos',
        div_factor=25.0
    )
    # –í–∑–≤–µ—à–µ–Ω–Ω—ã–π loss + smoothing
    loss_fn = get_weighted_loss(train_ds, smoothing=config.get('smoothing',0.1), device=device)
    scaler = GradScaler(enabled=(device.type=='cuda'))

    return train_loop(
        model, optimizer, scheduler,
        train_loader, val_loader,
        loss_fn, device, scaler,
        num_classes,
        num_epochs=config['num_epochs'],
        early_stop_patience=config['early_stop_patience'],
        model_name=model_name,
        mixup_fn=mixup_fn,
        ema_decay=config.get('ema_decay',0.9999)
    )


## train_model

## run training 

In [None]:
import gc

def train_all_models(models_dict, train_loader, val_loader, config, num_classes, mixup_fn=None):
    for name, model_fn in models_dict.items():
        print(f"==== Training model: {name} ====")
        torch.cuda.empty_cache(); gc.collect()
        train_hist, val_hist = train_single_model(
            name, model_fn, train_loader, val_loader, config, num_classes, mixup_fn=mixup_fn
        )
        plot_metrics(train_hist, val_hist, model_name=name, save_path=f"metrics_{name}.png")
        torch.cuda.empty_cache(); gc.collect()

train_all_models(MODELS, train_loader, val_loader, CONFIG, NUM_CLASSES, mixup_fn=mixup_fn)

## Run training

# Test

### tta_predict

In [None]:
import torch

def tta_predict_classification(model, images):
    """
    –ü—Ä–∏–º–µ–Ω—è–µ—Ç TTA –¥–ª—è –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏: –æ—Ä–∏–≥–∏–Ω–∞–ª + –≥–æ—Ä–∏–∑–æ–Ω—Ç–∞–ª—å–Ω–æ–µ –æ—Ç—Ä–∞–∂–µ–Ω–∏–µ.
    –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç —É—Å—Ä–µ–¥–Ω—ë–Ω–Ω—ã–µ logits.
    """
    model.eval()
    with torch.no_grad():
        orig = model(images)
        flip_imgs = torch.flip(images, dims=[3])
        flip_preds = model(flip_imgs)
        # –ù–ï –Ω–∞–¥–æ –æ–±—Ä–∞—Ç–Ω–æ –æ—Ç—Ä–∞–∂–∞—Ç—å flip_preds –¥–ª—è –∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏–∏!
        # –ü—Ä–æ—Å—Ç–æ —É—Å—Ä–µ–¥–Ω—è–µ–º –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç–∏/–ª–æ–≥–∏—Ç—ã
        return (orig + flip_preds) / 2


### visualize_pred

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_classification_predictions_tta(
    model,
    dataloader,
    device,
    label_names=None,
    num_samples=10,
    imagenet_norm=True,
    model_title=None
):
    """
    –í–∏–∑—É–∞–ª–∏–∑–∏—Ä—É–µ—Ç –Ω–µ—Å–∫–æ–ª—å–∫–æ —Å–ª—É—á–∞–π–Ω—ã—Ö –ø—Ä–∏–º–µ—Ä–æ–≤ –∏–∑ –¥–∞—Ç–∞–ª–æ–∞–¥–µ—Ä–∞:
    1. –ò—Å—Ö–æ–¥–Ω–æ–µ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ
    2. –ò—Å—Ç–∏–Ω–Ω–∞—è –º–µ—Ç–∫–∞ (GT)
    3. –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω–∞—è –º–µ—Ç–∫–∞ (TTA)
    """
    model = model.to(device)
    model.eval()
    shown = 0
    images_list = []
    gt_labels_list = []
    pred_labels_list = []
    titles = []

    with torch.no_grad():
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.cpu().numpy()

            # –ü–æ–ª—É—á–∞–µ–º –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è c TTA
            logits = tta_predict_classification(model, images)
            probs = torch.softmax(logits, dim=1).cpu().numpy()
            preds = np.argmax(probs, axis=1)

            batch_size = images.size(0)
            for i in range(batch_size):
                if shown >= num_samples:
                    break

                img = images[i].cpu().permute(1, 2, 0).numpy()
                if imagenet_norm:
                    img = (img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0, 1)
                gt_idx = labels[i]
                pred_idx = preds[i]
                gt_name = label_names.get(gt_idx, str(gt_idx)) if label_names else str(gt_idx)
                pred_name = label_names.get(pred_idx, str(pred_idx)) if label_names else str(pred_idx)

                title = f"GT: {gt_name}\nPred: {pred_name}"
                images_list.append(img)
                titles.append(title)

                shown += 1
            if shown >= num_samples:
                break

    # –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è
    plt.figure(figsize=(num_samples * 3, 4))
    if model_title is not None:
        plt.suptitle(model_title, fontsize=16, y=1.08)
    for col in range(shown):
        plt.subplot(1, num_samples, col + 1)
        plt.imshow(images_list[col])
        plt.title(titles[col], fontsize=10)
        plt.axis("off")
    plt.tight_layout()
    plt.show()


In [None]:
import gc

idx2label = train_ds.idx2label  # –ò–ª–∏ —Å–≤–æ–π —Å–ª–æ–≤–∞—Ä—å {int: str}

for model_name in MODELS:
    checkpoint_path = f'best_{model_name}.pth'
    print(f"\n--- Model: {model_name} ---")
    try:
        test_model = MODELS[model_name]().to(device)
        test_model.load_state_dict(torch.load(checkpoint_path, map_location=device))
        test_model.eval()
        
        visualize_classification_predictions_tta(
            test_model,
            dataloader=test_loader,
            device=device,
            label_names=idx2label,
            num_samples=5,
            model_title=f"Model: {model_name}"
        )
    except Exception as e:
        print(f"‚ùå Could not evaluate {model_name}: {e}")
    finally:
        del test_model
        torch.cuda.empty_cache()
        gc.collect()


## Sanity checks

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

def classification_sanity_check(
    train_ds, MODELS, device, target_label=None, num_epochs=20, N=8, idx2label=None
):
    """
    Sanity check: –±—ã—Å—Ç—Ä–æ–µ –ø–µ—Ä–µ–æ–±—É—á–µ–Ω–∏–µ –Ω–∞ –æ–¥–Ω–æ–º –ø–∞—Ç—á–µ (–∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏—è).
    –ï—Å–ª–∏ target_label –Ω–µ –∑–∞–¥–∞–Ω ‚Äî –±–µ—Ä—ë—Ç –ø–µ—Ä–≤—ã–π –ø–æ–ø–∞–≤—à–∏–π—Å—è –ø–∞—Ç—á.
    –ï—Å–ª–∏ –∑–∞–¥–∞–Ω (–Ω–∞–ø—Ä–∏–º–µ—Ä, "No defect" –∏–ª–∏ "Broken end"), –±–µ—Ä—ë—Ç –ø–µ—Ä–≤—ã–π –ø–∞—Ç—á —ç—Ç–æ–≥–æ –∫–ª–∞—Å—Å–∞.
    """
    # 1. –ù–∞–π—Ç–∏ –ø–∞—Ç—á —Å –Ω—É–∂–Ω—ã–º –∫–ª–∞—Å—Å–æ–º
    idx = None
    for i in range(len(train_ds)):
        img, label = train_ds[i]
        if (target_label is None) or (idx2label and idx2label[label] == target_label):
            idx = i
            break
    if idx is None:
        raise RuntimeError("–ù–µ –Ω–∞–π–¥–µ–Ω –ø–æ–¥—Ö–æ–¥—è—â–∏–π –ø–∞—Ç—á –¥–ª—è sanity check!")
    img, label = train_ds[idx]

    print(f"Sanity check: –ø–∞—Ç—á –∫–ª–∞—Å—Å–∞ '{idx2label[label] if idx2label else label}' (–∏–Ω–¥–µ–∫—Å {label})")

    # 2. –î–µ–ª–∞–µ–º –¥–∞—Ç–∞—Å–µ—Ç –∏–∑ N –∫–æ–ø–∏–π —ç—Ç–æ–≥–æ –ø–∞—Ç—á–∞
    single_ds = [(img, label)] * N
    single_loader = DataLoader(single_ds, batch_size=N, shuffle=True)

    # 3. –ü—Ä–æ–≤–µ—Ä—è–µ–º –≤—Å–µ –º–æ–¥–µ–ª–∏
    for model_name, model_fn in MODELS.items():
        print(f"\n==== Sanity check: {model_name} ====")
        model = model_fn().to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
        loss_fn = torch.nn.CrossEntropyLoss()
        losses, accs = [], []

        for epoch in range(num_epochs):
            model.train()
            for img_batch, label_batch in single_loader:
                img_batch = img_batch.to(device)
                label_batch = label_batch.to(device)
                optimizer.zero_grad()
                output = model(img_batch)
                loss = loss_fn(output, label_batch)
                loss.backward()
                optimizer.step()

            # –ü—Ä–æ–≤–µ—Ä–∫–∞: —Ç–æ—á–Ω–æ—Å—Ç—å –∏ –¥–∏–Ω–∞–º–∏–∫–∞
            model.eval()
            with torch.no_grad():
                output = model(img_batch)
                probs = torch.softmax(output, dim=1)
                pred_class = probs.argmax(dim=1).cpu().numpy()
                gt_class = label_batch.cpu().numpy()
                acc = (pred_class == gt_class).mean()
            losses.append(loss.item())
            accs.append(acc)
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Acc = {acc:.3f}, GT: {idx2label[gt_class[0]] if idx2label else gt_class[0]}, Pred: {idx2label[pred_class[0]] if idx2label else pred_class[0]}")

        # –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è loss/accuracy
        fig, axs = plt.subplots(1, 2, figsize=(10,4))
        axs[0].plot(losses, label="Loss")
        axs[0].set_title(f"{model_name} Loss")
        axs[1].plot(accs, label="Acc")
        axs[1].set_title(f"{model_name} Accuracy")
        for ax in axs: ax.grid(); ax.set_xlabel('Epoch')
        plt.suptitle(f"Sanity Check: {model_name} ‚Äî –∫–ª–∞—Å—Å '{idx2label[label] if idx2label else label}'")
        plt.show()

        # –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è —Å–∞–º–æ–≥–æ –ø–∞—Ç—á–∞
        img_np = img[:3].permute(1,2,0).cpu().numpy() if isinstance(img, torch.Tensor) else img
        img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0,1)
        plt.figure(figsize=(3,3))
        plt.imshow(img_np)
        plt.title(f"Class: {idx2label[label] if idx2label else label}")
        plt.axis("off")
        plt.show()

# === –ü—Ä–∏–º–µ—Ä –∑–∞–ø—É—Å–∫–∞ ===
# –î–ª—è –ª—é–±–æ–≥–æ –∫–ª–∞—Å—Å–∞ (–ø–µ—Ä–≤—ã–π –ø–æ–ø–∞–≤—à–∏–π—Å—è):
classification_sanity_check(
    train_ds, MODELS, device,
    target_label=None,     # –∏–ª–∏ –Ω–∞–ø—Ä–∏–º–µ—Ä "No defect"
    num_epochs=20, N=8,
    idx2label=train_ds.idx2label
)

# –î–ª—è "—á–∏—Å—Ç–æ–≥–æ" –ø–∞—Ç—á–∞ (–µ—Å–ª–∏ –Ω—É–∂–Ω–æ):
classification_sanity_check(
    train_ds, MODELS, device,
    target_label="No defect",
    num_epochs=20, N=8,
    idx2label=train_ds.idx2label
)


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader

def classification_sanity_check(train_ds, MODELS, device, target_label=None, num_epochs=20, N=8):
    """
    Sanity check: –º–æ–¥–µ–ª—å –¥–æ–ª–∂–Ω–∞ –∑–∞—Ñ–∏—Ç–∏—Ç—å—Å—è –Ω–∞ –æ–¥–Ω–æ–º –∏ —Ç–æ–º –∂–µ –ø–∞—Ç—á–µ (–∫–ª–∞—Å—Å–∏—Ñ–∏–∫–∞—Ü–∏—è).
    target_label: –µ—Å–ª–∏ –∑–∞–¥–∞–Ω, –∏—â–µ—Ç –ø–µ—Ä–≤—ã–π –ø–∞—Ç—á —Å —ç—Ç–∏–º label.
                  –ù–∞–ø—Ä–∏–º–µ—Ä, –¥–ª—è –¥–µ—Ñ–µ–∫—Ç–Ω–æ–≥–æ: target_label != 'No defect'
                  –î–ª—è —á–∏—Å—Ç–æ–≥–æ: target_label == 'No defect'
    """
    # 1. –ù–∞–π—Ç–∏ –ø–∞—Ç—á –ø–æ –∫—Ä–∏—Ç–µ—Ä–∏—é
    idx = None
    for i in range(len(train_ds)):
        img, label = train_ds[i]
        if target_label is None or train_ds.idx2label[label] == target_label:
            idx = i
            break
    if idx is None:
        raise RuntimeError("–ù–µ –Ω–∞–π–¥–µ–Ω –ø–æ–¥—Ö–æ–¥—è—â–∏–π –ø–∞—Ç—á –¥–ª—è sanity check!")
    img, label = train_ds[idx]

    # 2. –î–µ–ª–∞–µ–º –¥–∞—Ç–∞—Å–µ—Ç –∏–∑ N –∫–æ–ø–∏–π –æ–¥–Ω–æ–≥–æ –ø–∞—Ç—á–∞
    single_ds = [(img, label)] * N
    single_loader = DataLoader(single_ds, batch_size=N, shuffle=True)

    for model_name, model_fn in MODELS.items():
        print(f"\n==== Sanity check: {model_name} ====")
        model = model_fn().to(device)
        optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
        loss_fn = torch.nn.CrossEntropyLoss()

        for epoch in range(num_epochs):
            model.train()
            for img_batch, label_batch in single_loader:
                img_batch = img_batch.to(device)
                label_batch = label_batch.to(device)
                optimizer.zero_grad()
                output = model(img_batch)
                loss = loss_fn(output, label_batch)
                loss.backward()
                optimizer.step()

            # –ü—Ä–æ–≤–µ—Ä–∫–∞
            model.eval()
            with torch.no_grad():
                output = model(img_batch)
                probs = torch.softmax(output, dim=1)
                pred_class = probs.argmax(dim=1).cpu().numpy()
                gt_class = label_batch.cpu().numpy()
                acc = (pred_class == gt_class).mean()
            print(f"Epoch {epoch}: Loss = {loss.item():.4f}, Acc = {acc:.3f}, GT: {train_ds.idx2label[gt_class[0]]}, Pred: {train_ds.idx2label[pred_class[0]]}")

        # –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è –ø–∞—Ç—á–∞
        img_np = img[:3].permute(1,2,0).cpu().numpy() if isinstance(img, torch.Tensor) else img
        img_np = (img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])).clip(0,1)
        plt.figure(figsize=(4, 4))
        plt.imshow(img_np)
        plt.title(f"Sanity Check: {model_name}\nClass: {train_ds.idx2label[label]}")
        plt.axis("off")
        plt.show()


In [None]:
classification_sanity_check(train_ds, MODELS, device, target_label='Broken end', num_epochs=20, N=8)
# (–∏–ª–∏ target_label != 'No defect', –µ—Å–ª–∏ —Ö–æ—á–µ—Ç—Å—è –ø–æ–¥–æ–±—Ä–∞—Ç—å –∏–º–µ–Ω–Ω–æ –¥–µ—Ñ–µ–∫—Ç)


In [None]:
classification_sanity_check(train_ds, MODELS, device, target_label='No defect', num_epochs=20, N=8)


# Inference

In [None]:
import numpy as np
import torch
from pathlib import Path

def infer_full_image_classification_with_models(
    image,
    models_dict,
    preprocess,
    patch_h=224,
    patch_w=224,
    stride_h=64,
    stride_w=64,
    device='cuda',
    model_names=None,
    config=None
):
    """
    –î–ª—è –∫–∞–∂–¥–æ–≥–æ –ø–∞—Ç—á–∞ –∏—Å—Ö–æ–¥–Ω–æ–≥–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è –ø–æ–ª—É—á–∏—Ç—å –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è –∫–ª–∞—Å—Å–æ–≤ –æ—Ç –≤—Å–µ—Ö –º–æ–¥–µ–ª–µ–π.
    –í–æ–∑–≤—Ä–∞—â–∞–µ—Ç dict: {model_name: 2D array of class indices}, –∞ —Ç–∞–∫–∂–µ coords –≤—Å–µ—Ö –ø–∞—Ç—á–µ–π.
    """
    H, W, C = image.shape
    if model_names is None:
        model_names = list(models_dict.keys())
    batch_size = config["batch_size"] if config and "batch_size" in config else 8

    # –î–ª—è –∫–∞–∂–¥–æ–≥–æ –ø–∞—Ç—á–∞ —Ö—Ä–∞–Ω–∏–º –∫–æ–æ—Ä–¥–∏–Ω–∞—Ç—ã –∏ patch
    patch_coords = []
    patches = []
    for y in range(0, H - patch_h + 1, stride_h):
        for x in range(0, W - patch_w + 1, stride_w):
            patch = image[y:y+patch_h, x:x+patch_w]
            patches.append(patch)
            patch_coords.append((y, x))
    patch_tensors = [preprocess(p) for p in patches]
    patch_batch = torch.stack(patch_tensors)  # [N, C, H, W]

    # –°–ª–æ–≤–∞—Ä—å: –¥–ª—è –∫–∞–∂–¥–æ–π –º–æ–¥–µ–ª–∏ ‚Äî —Å–ø–∏—Å–æ–∫ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω—ã—Ö –∫–ª–∞—Å—Å–æ–≤ –¥–ª—è –∫–∞–∂–¥–æ–≥–æ –ø–∞—Ç—á–∞
    results_dict = {}
    for model_name in model_names:
        torch.cuda.empty_cache()
        model = models_dict[model_name]().to(device)
        model.eval()
        try:
            model.load_state_dict(torch.load(f'best_{model_name}.pth', map_location=device))
        except Exception as e:
            print(f"[{model_name}] checkpoint not loaded: {e}")

        preds = []
        with torch.no_grad():
            for i in range(0, len(patch_batch), batch_size):
                batch = patch_batch[i:i+batch_size].to(device)
                logits = model(batch)
                probs = torch.softmax(logits, dim=1)
                pred_class = probs.argmax(dim=1).cpu().numpy()
                preds.append(pred_class)
                del batch, logits, probs
                torch.cuda.empty_cache()
        preds = np.concatenate(preds, axis=0)  # [num_patches]

        # –°–æ–±–∏—Ä–∞–µ–º –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω—É—é "–∫–∞—Ä—Ç—É" –∫–ª–∞—Å—Å–æ–≤ (–∫–ª–∞—Å—Ç–µ—Ä–∏–∑—É–µ–º –ø–∞—Ç—á–∏ –æ–±—Ä–∞—Ç–Ω–æ –≤ 2D-–º–∞—Ç—Ä–∏—Ü—É)
        map_H = (H - patch_h) // stride_h + 1
        map_W = (W - patch_w) // stride_w + 1
        class_map = preds.reshape(map_H, map_W)
        results_dict[model_name] = class_map
        del model
        torch.cuda.empty_cache()

    del patch_batch, patch_tensors, patches
    torch.cuda.empty_cache()
    return results_dict, patch_coords, (map_H, map_W)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

def show_classification_map(
    image, class_map, patch_h, patch_w, stride_h, stride_w, idx2label=None, model_name="Model"
):
    """
    –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è –∏—Å—Ö–æ–¥–Ω–æ–≥–æ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è –∏ –∫–∞—Ä—Ç—ã –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω—ã—Ö –∫–ª–∞—Å—Å–æ–≤ (–ø—Å–µ–≤–¥–æ—Ü–≤–µ—Ç).
    """
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"{model_name}: Original")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    # "—Ä–∞–∑–≤–æ—Ä–∞—á–∏–≤–∞–µ–º" –∫–∞—Ä—Ç—É –≤ –∏—Å—Ö–æ–¥–Ω—ã–µ –∫–æ–æ—Ä–¥–∏–Ω–∞—Ç—ã –¥–ª—è –Ω–∞–≥–ª—è–¥–Ω–æ—Å—Ç–∏
    H, W = image.shape[:2]
    map_H, map_W = class_map.shape
    vis_map = np.zeros((H, W), dtype=np.int32)
    for i in range(map_H):
        for j in range(map_W):
            y = i * stride_h
            x = j * stride_w
            vis_map[y:y+patch_h, x:x+patch_w] = class_map[i, j]
    cmap = plt.get_cmap('tab20' if class_map.max() < 20 else 'nipy_spectral')
    im = plt.imshow(vis_map, cmap=cmap, vmin=0, vmax=class_map.max())
    plt.title(f"{model_name}: Predicted Classes Map")
    plt.axis('off')
    cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
    if idx2label:
        labels = [idx2label[i] for i in range(class_map.max() + 1)]
        cbar.set_ticks(range(len(labels)))
        cbar.set_ticklabels(labels)
    plt.tight_layout()
    plt.show()

def save_classification_map_image(
    image, class_map, patch_h, patch_w, stride_h, stride_w, idx2label=None,
    model_name="Model", save_dir="./inference_results", img_id=0
):
    """
    –°–æ—Ö—Ä–∞–Ω—è–µ—Ç –≤–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—é –∫–∞—Ä—Ç—ã –∫–ª–∞—Å—Å–æ–≤, –Ω–∞–ª–æ–∂–µ–Ω–Ω–æ–π –Ω–∞ –∏—Å—Ö–æ–¥–Ω–æ–µ –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏–µ.
    """
    os.makedirs(save_dir, exist_ok=True)
    H, W = image.shape[:2]
    map_H, map_W = class_map.shape
    vis_map = np.zeros((H, W), dtype=np.int32)
    for i in range(map_H):
        for j in range(map_W):
            y = i * stride_h
            x = j * stride_w
            vis_map[y:y+patch_h, x:x+patch_w] = class_map[i, j]
    cmap = plt.get_cmap('tab20' if class_map.max() < 20 else 'nipy_spectral')

    # –°–æ—Ö—Ä–∞–Ω—è–µ–º –∫–∞–∫ RGB (—Ü–≤–µ—Ç–Ω—É—é –∫–∞—Ä—Ç—É –∫–ª–∞—Å—Å–æ–≤)
    plt.figure(figsize=(18, 6))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"{model_name}: Original")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    im = plt.imshow(vis_map, cmap=cmap, vmin=0, vmax=class_map.max())
    plt.title(f"{model_name}: Predicted Classes Map")
    plt.axis('off')
    cbar = plt.colorbar(im, fraction=0.046, pad=0.04)
    if idx2label:
        labels = [idx2label[i] for i in range(class_map.max() + 1)]
        cbar.set_ticks(range(len(labels)))
        cbar.set_ticklabels(labels)
    plt.tight_layout()
    fname = os.path.join(save_dir, f"{model_name}_classmap_{img_id}.png")
    plt.savefig(fname, bbox_inches='tight', pad_inches=0.1)
    plt.close()
    print(f"Saved: {fname}")


In [None]:
# --- –ü—Ä–µ–ø—Ä–æ—Ü–µ—Å—Å–∏–Ω–≥ –ø–∞—Ç—á–µ–π ---
preprocess_albu = A.Compose([
    A.Resize(PATCH_H, PATCH_W),  # 224, 224
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])
def preprocess_patch(patch):
    return preprocess_albu(image=patch)['image']

# --- –í—ã–±–æ—Ä –∏–∑–æ–±—Ä–∞–∂–µ–Ω–∏—è ---
IMG_DIR = Path('./aitex_data/extracted/Defect_images')
img_files = sorted(list(IMG_DIR.glob('*.png')))
img_path = img_files[0]

img = cv2.imread(str(img_path))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# --- –ò–Ω—Ñ–µ—Ä–µ–Ω—Å ---
result_classmaps, patch_coords, (map_H, map_W) = infer_full_image_classification_with_models(
    img,
    models_dict=MODELS,
    preprocess=preprocess_patch,
    patch_h=PATCH_H,
    patch_w=PATCH_W,
    stride_h=STRIDE_H,
    stride_w=STRIDE_W,
    device=device,
    config=CONFIG
)

# --- –í–∏–∑—É–∞–ª–∏–∑–∞—Ü–∏—è ---
for model_name, class_map in result_classmaps.items():
    show_classification_map(
        img, class_map,
        patch_h=PATCH_H, patch_w=PATCH_W,
        stride_h=STRIDE_H, stride_w=STRIDE_W,
        idx2label=train_ds.idx2label,  # –∏–ª–∏ —Å–≤–æ–π —Å–ª–æ–≤–∞—Ä—å {int: str}
        model_name=model_name
    )

save_dir = "./inference_results"
for idx, (model_name, class_map) in enumerate(result_classmaps.items()):
    save_classification_map_image(
        img, class_map,
        patch_h=PATCH_H, patch_w=PATCH_W,
        stride_h=STRIDE_H, stride_w=STRIDE_W,
        idx2label=train_ds.idx2label,  # –∏–ª–∏ —Å–≤–æ–π —Å–ª–æ–≤–∞—Ä—å
        model_name=model_name,
        save_dir=save_dir,
        img_id=idx
    )

In [None]:

for model_name, class_map in result_classmaps.items():
    uniques, counts = np.unique(class_map, return_counts=True)
    print(f"Model: {model_name}")
    for idx, cnt in zip(uniques, counts):
        print(f"  {train_ds.idx2label[idx]}: {cnt} –ø–∞—Ç—á–µ–π")
