In [2]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Input, Dropout, BatchNormalization
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import RandomFlip, RandomRotation, RandomZoom
from sklearn.preprocessing import MultiLabelBinarizer
from PIL import Image
from sklearn.metrics import precision_recall_curve

# Конфигурация
DATASET_PATH = 'subcategory_images\dataset'
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 30  # Увеличили количество эпох
LEARNING_RATE = 0.0005  # Уменьшили learning rate
MODEL_SAVE_PATH = 'multi_label_classifier_improved.h5'

# Сопоставление подкатегорий с основными категориями
SUBCAT_TO_CAT = {
    'ak': 'Страйкбольное оружие',
    'backpack': 'Снаряжение и защита',
    'helmet': 'Снаряжение и защита',
    'HK': 'Страйкбольное оружие',
    'M serias': 'Страйкбольное оружие',
    'mashinegun': 'Страйкбольное оружие',
    'pistol': 'Страйкбольное оружие',
    'pouch': 'Аксессуары и Запчасти',
    'rifle': 'Страйкбольное оружие',
    'shutgun': 'Страйкбольное оружие',
    'vest': 'Снаряжение и защита'
}

# Получаем список всех подкатегорий и категорий
subcategories = sorted(list(SUBCAT_TO_CAT.keys()))
categories = sorted(list(set(SUBCAT_TO_CAT.values())))

# Инициализация MultiLabelBinarizer
mlb_subcat = MultiLabelBinarizer()
mlb_subcat.fit([subcategories])
mlb_cat = MultiLabelBinarizer()
mlb_cat.fit([categories])

# Аугментация данных
data_augmentation = tf.keras.Sequential([
    RandomFlip("horizontal"),
    RandomRotation(0.1),
    RandomZoom(0.1),
])

def load_and_preprocess_data(dataset_path):
    image_paths = []
    subcat_labels = []
    cat_labels = []
    
    for subcat in subcategories:
        subcat_path = os.path.join(dataset_path, subcat)
        if not os.path.exists(subcat_path):
            continue
            
        cat = SUBCAT_TO_CAT[subcat]
        for image_name in os.listdir(subcat_path):
            image_path = os.path.join(subcat_path, image_name)
            image_paths.append(image_path)
            subcat_labels.append([subcat])
            cat_labels.append([cat])
    
    subcat_labels = mlb_subcat.transform(subcat_labels)
    cat_labels = mlb_cat.transform(cat_labels)
    
    return image_paths, subcat_labels, cat_labels

# Загрузка данных
image_paths, subcat_labels, cat_labels = load_and_preprocess_data(DATASET_PATH)

# Создание tf.data.Dataset с аугментацией
def process_data(item):
    image = tf.io.read_file(item['image_path'])
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.image.resize(image, IMAGE_SIZE)
    image = data_augmentation(image)  # Применяем аугментацию
    image = tf.keras.applications.efficientnet.preprocess_input(image)
    return image, (item['subcat_label'], item['cat_label'])

dataset = tf.data.Dataset.from_tensor_slices({
    'image_path': image_paths,
    'subcat_label': subcat_labels,
    'cat_label': cat_labels
})

dataset = dataset.map(process_data, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.shuffle(buffer_size=len(image_paths))
dataset = dataset.batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

# Разделение на train/validation
train_size = int(0.8 * len(image_paths))
train_dataset = dataset.take(train_size)
val_dataset = dataset.skip(train_size)

# Обновленная версия f1_score
def f1_score(y_true, y_pred, threshold=0.5):
    y_pred = tf.cast(y_pred > threshold, tf.float32)
    y_true = tf.cast(y_true, tf.float32)  # Добавляем приведение типа
    tp = tf.reduce_sum(y_true * y_pred)
    precision = tp / (tf.reduce_sum(y_pred) + 1e-7)
    recall = tp / (tf.reduce_sum(y_true) + 1e-7)
    return 2 * (precision * recall) / (precision + recall + 1e-7)

# Создание модели с улучшенной архитектурой
def create_model(num_subcategories, num_categories):
    base_model = EfficientNetB0(weights='imagenet', include_top=False, input_shape=(*IMAGE_SIZE, 3))
    
    # Замораживаем первые 150 слоев
    for layer in base_model.layers[:150]:
        layer.trainable = False
    
    inputs = Input(shape=(*IMAGE_SIZE, 3))
    x = base_model(inputs)
    x = GlobalAveragePooling2D()(x)
    
    # Улучшенная головная часть
    x = Dense(512, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.5)(x)
    
    subcat_output = Dense(num_subcategories, activation='sigmoid', name='subcat_output')(x)
    cat_output = Dense(num_categories, activation='sigmoid', name='cat_output')(x)
    
    model = Model(inputs=inputs, outputs=[subcat_output, cat_output])
    
    model.compile(
        optimizer=Adam(learning_rate=LEARNING_RATE),
        loss={
            'subcat_output': 'binary_crossentropy',
            'cat_output': 'binary_crossentropy'
        },
        metrics={
            'subcat_output': ['accuracy', f1_score],
            'cat_output': ['accuracy', f1_score]
        }
    )
    
    return model

model = create_model(len(subcategories), len(categories))
model.summary()

# Коллбэки
callbacks = [
    ModelCheckpoint(
        MODEL_SAVE_PATH, 
        save_best_only=True, 
        monitor='val_subcat_output_f1_score', 
        mode='max'
    ),
    EarlyStopping(
        patience=7, 
        restore_best_weights=True, 
        monitor='val_subcat_output_f1_score',
        mode='max'  # Явно указываем, что хотим максимизировать F1-score
    ),
    ReduceLROnPlateau(
        monitor='val_loss', 
        factor=0.2, 
        patience=3, 
        min_lr=1e-6,
        mode='min'  # Для val_loss используем минимизацию
    )
]

# Обучение
history = model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=EPOCHS,
    callbacks=callbacks
)

# Функция предсказания с TTA (Test-Time Augmentation)
def predict_with_tta(image_path, model, tta_steps=5, threshold=0.5):
    img = Image.open(image_path).convert('RGB')
    img = img.resize(IMAGE_SIZE)
    img_array = np.array(img)
    img_array = tf.keras.applications.efficientnet.preprocess_input(img_array)
    
    subcat_preds = []
    cat_preds = []
    
    for _ in range(tta_steps):
        augmented_img = data_augmentation(img_array)
        augmented_img = np.expand_dims(augmented_img, axis=0)
        subcat_pred, cat_pred = model.predict(augmented_img)
        subcat_preds.append(subcat_pred)
        cat_preds.append(cat_pred)
    
    subcat_pred_avg = np.mean(subcat_preds, axis=0)
    cat_pred_avg = np.mean(cat_preds, axis=0)
    
    subcat_indices = np.where(subcat_pred_avg[0] > threshold)[0]
    cat_indices = np.where(cat_pred_avg[0] > threshold)[0]
    
    predicted_subcats = mlb_subcat.classes_[subcat_indices]
    predicted_cats = mlb_cat.classes_[cat_indices]
    
    subcat_probs = {subcat: float(subcat_pred_avg[0][i]) for i, subcat in enumerate(mlb_subcat.classes_)}
    cat_probs = {cat: float(cat_pred_avg[0][i]) for i, cat in enumerate(mlb_cat.classes_)}
    
    return {
        'categories': list(predicted_cats),
        'subcategories': list(predicted_subcats),
        'category_probabilities': cat_probs,
        'subcategory_probabilities': subcat_probs
    }

# Пример использования
if __name__ == "__main__":
    model = tf.keras.models.load_model(MODEL_SAVE_PATH, custom_objects={'f1_score': f1_score})
    result = predict_with_tta('тест\дроб3.jpeg', model)
    print(result)

  DATASET_PATH = 'subcategory_images\dataset'


Epoch 1/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.7495 - cat_output_f1_score: 0.7036 - cat_output_loss: 0.5612 - loss: 1.5045 - subcat_output_accuracy: 0.3029 - subcat_output_f1_score: 0.2168 - subcat_output_loss: 0.9433

  if self._should_save_model(epoch, batch, logs, filepath):


[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m93s[0m 2s/step - cat_output_accuracy: 0.7535 - cat_output_f1_score: 0.7062 - cat_output_loss: 0.5566 - loss: 1.4968 - subcat_output_accuracy: 0.3079 - subcat_output_f1_score: 0.2183 - subcat_output_loss: 0.9403 - learning_rate: 5.0000e-04
Epoch 2/30


  current = self.get_monitor_value(logs)
  callback.on_epoch_end(epoch, logs)


[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9696 - cat_output_f1_score: 0.9058 - cat_output_loss: 0.1908 - loss: 0.8738 - subcat_output_accuracy: 0.7343 - subcat_output_f1_score: 0.3351 - subcat_output_loss: 0.6830



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m60s[0m 2s/step - cat_output_accuracy: 0.9695 - cat_output_f1_score: 0.9059 - cat_output_loss: 0.1908 - loss: 0.8728 - subcat_output_accuracy: 0.7357 - subcat_output_f1_score: 0.3355 - subcat_output_loss: 0.6820 - learning_rate: 5.0000e-04
Epoch 3/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9871 - cat_output_f1_score: 0.9439 - cat_output_loss: 0.1164 - loss: 0.6610 - subcat_output_accuracy: 0.8289 - subcat_output_f1_score: 0.3938 - subcat_output_loss: 0.5445



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 2s/step - cat_output_accuracy: 0.9869 - cat_output_f1_score: 0.9440 - cat_output_loss: 0.1163 - loss: 0.6606 - subcat_output_accuracy: 0.8289 - subcat_output_f1_score: 0.3940 - subcat_output_loss: 0.5442 - learning_rate: 5.0000e-04
Epoch 4/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9882 - cat_output_f1_score: 0.9767 - cat_output_loss: 0.0756 - loss: 0.5471 - subcat_output_accuracy: 0.8710 - subcat_output_f1_score: 0.4574 - subcat_output_loss: 0.4715



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 2s/step - cat_output_accuracy: 0.9884 - cat_output_f1_score: 0.9766 - cat_output_loss: 0.0755 - loss: 0.5467 - subcat_output_accuracy: 0.8712 - subcat_output_f1_score: 0.4576 - subcat_output_loss: 0.4711 - learning_rate: 5.0000e-04
Epoch 5/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9926 - cat_output_f1_score: 0.9821 - cat_output_loss: 0.0452 - loss: 0.4433 - subcat_output_accuracy: 0.9095 - subcat_output_f1_score: 0.5070 - subcat_output_loss: 0.3981



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 2s/step - cat_output_accuracy: 0.9926 - cat_output_f1_score: 0.9821 - cat_output_loss: 0.0454 - loss: 0.4431 - subcat_output_accuracy: 0.9093 - subcat_output_f1_score: 0.5073 - subcat_output_loss: 0.3978 - learning_rate: 5.0000e-04
Epoch 6/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9892 - cat_output_f1_score: 0.9808 - cat_output_loss: 0.0620 - loss: 0.4149 - subcat_output_accuracy: 0.9242 - subcat_output_f1_score: 0.5498 - subcat_output_loss: 0.3530



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m61s[0m 2s/step - cat_output_accuracy: 0.9892 - cat_output_f1_score: 0.9809 - cat_output_loss: 0.0617 - loss: 0.4141 - subcat_output_accuracy: 0.9241 - subcat_output_f1_score: 0.5503 - subcat_output_loss: 0.3524 - learning_rate: 5.0000e-04
Epoch 7/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9970 - cat_output_f1_score: 0.9917 - cat_output_loss: 0.0333 - loss: 0.3213 - subcat_output_accuracy: 0.9468 - subcat_output_f1_score: 0.6153 - subcat_output_loss: 0.2880



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 2s/step - cat_output_accuracy: 0.9970 - cat_output_f1_score: 0.9917 - cat_output_loss: 0.0332 - loss: 0.3209 - subcat_output_accuracy: 0.9465 - subcat_output_f1_score: 0.6156 - subcat_output_loss: 0.2877 - learning_rate: 5.0000e-04
Epoch 8/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9978 - cat_output_f1_score: 0.9960 - cat_output_loss: 0.0198 - loss: 0.2641 - subcat_output_accuracy: 0.9566 - subcat_output_f1_score: 0.6867 - subcat_output_loss: 0.2443



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 2s/step - cat_output_accuracy: 0.9978 - cat_output_f1_score: 0.9960 - cat_output_loss: 0.0199 - loss: 0.2640 - subcat_output_accuracy: 0.9562 - subcat_output_f1_score: 0.6870 - subcat_output_loss: 0.2441 - learning_rate: 5.0000e-04
Epoch 9/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9940 - cat_output_f1_score: 0.9841 - cat_output_loss: 0.0515 - loss: 0.2662 - subcat_output_accuracy: 0.9388 - subcat_output_f1_score: 0.7161 - subcat_output_loss: 0.2147



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 2s/step - cat_output_accuracy: 0.9940 - cat_output_f1_score: 0.9842 - cat_output_loss: 0.0514 - loss: 0.2661 - subcat_output_accuracy: 0.9390 - subcat_output_f1_score: 0.7163 - subcat_output_loss: 0.2147 - learning_rate: 5.0000e-04
Epoch 10/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9986 - cat_output_f1_score: 0.9872 - cat_output_loss: 0.0340 - loss: 0.2167 - subcat_output_accuracy: 0.9437 - subcat_output_f1_score: 0.7580 - subcat_output_loss: 0.1826



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m62s[0m 2s/step - cat_output_accuracy: 0.9986 - cat_output_f1_score: 0.9874 - cat_output_loss: 0.0338 - loss: 0.2161 - subcat_output_accuracy: 0.9439 - subcat_output_f1_score: 0.7584 - subcat_output_loss: 0.1824 - learning_rate: 5.0000e-04
Epoch 11/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9997 - cat_output_f1_score: 0.9954 - cat_output_loss: 0.0243 - loss: 0.1815 - subcat_output_accuracy: 0.9767 - subcat_output_f1_score: 0.8057 - subcat_output_loss: 0.1573



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 2s/step - cat_output_accuracy: 0.9996 - cat_output_f1_score: 0.9953 - cat_output_loss: 0.0246 - loss: 0.1816 - subcat_output_accuracy: 0.9768 - subcat_output_f1_score: 0.8063 - subcat_output_loss: 0.1570 - learning_rate: 5.0000e-04
Epoch 12/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9989 - cat_output_f1_score: 0.9859 - cat_output_loss: 0.0307 - loss: 0.1686 - subcat_output_accuracy: 0.9631 - subcat_output_f1_score: 0.8362 - subcat_output_loss: 0.1379



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m64s[0m 2s/step - cat_output_accuracy: 0.9989 - cat_output_f1_score: 0.9860 - cat_output_loss: 0.0305 - loss: 0.1682 - subcat_output_accuracy: 0.9631 - subcat_output_f1_score: 0.8364 - subcat_output_loss: 0.1377 - learning_rate: 5.0000e-04
Epoch 13/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9927 - cat_output_f1_score: 0.9817 - cat_output_loss: 0.0497 - loss: 0.1635 - subcat_output_accuracy: 0.9582 - subcat_output_f1_score: 0.8851 - subcat_output_loss: 0.1138



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 2s/step - cat_output_accuracy: 0.9928 - cat_output_f1_score: 0.9820 - cat_output_loss: 0.0490 - loss: 0.1624 - subcat_output_accuracy: 0.9583 - subcat_output_f1_score: 0.8852 - subcat_output_loss: 0.1135 - learning_rate: 5.0000e-04
Epoch 14/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 1.0000 - cat_output_f1_score: 0.9999 - cat_output_loss: 0.0108 - loss: 0.1073 - subcat_output_accuracy: 0.9700 - subcat_output_f1_score: 0.8985 - subcat_output_loss: 0.0965



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m72s[0m 2s/step - cat_output_accuracy: 1.0000 - cat_output_f1_score: 0.9999 - cat_output_loss: 0.0108 - loss: 0.1073 - subcat_output_accuracy: 0.9699 - subcat_output_f1_score: 0.8985 - subcat_output_loss: 0.0965 - learning_rate: 5.0000e-04
Epoch 15/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9970 - cat_output_f1_score: 0.9946 - cat_output_loss: 0.0216 - loss: 0.1052 - subcat_output_accuracy: 0.9793 - subcat_output_f1_score: 0.9183 - subcat_output_loss: 0.0836



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 2s/step - cat_output_accuracy: 0.9970 - cat_output_f1_score: 0.9945 - cat_output_loss: 0.0219 - loss: 0.1055 - subcat_output_accuracy: 0.9791 - subcat_output_f1_score: 0.9183 - subcat_output_loss: 0.0836 - learning_rate: 5.0000e-04
Epoch 16/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9949 - cat_output_f1_score: 0.9940 - cat_output_loss: 0.0252 - loss: 0.1049 - subcat_output_accuracy: 0.9761 - subcat_output_f1_score: 0.8979 - subcat_output_loss: 0.0797



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 2s/step - cat_output_accuracy: 0.9950 - cat_output_f1_score: 0.9941 - cat_output_loss: 0.0249 - loss: 0.1044 - subcat_output_accuracy: 0.9760 - subcat_output_f1_score: 0.8982 - subcat_output_loss: 0.0795 - learning_rate: 5.0000e-04
Epoch 17/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9994 - cat_output_f1_score: 0.9937 - cat_output_loss: 0.0179 - loss: 0.0884 - subcat_output_accuracy: 0.9818 - subcat_output_f1_score: 0.9219 - subcat_output_loss: 0.0705



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9993 - cat_output_f1_score: 0.9937 - cat_output_loss: 0.0179 - loss: 0.0884 - subcat_output_accuracy: 0.9817 - subcat_output_f1_score: 0.9218 - subcat_output_loss: 0.0705 - learning_rate: 5.0000e-04
Epoch 18/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9982 - cat_output_f1_score: 0.9972 - cat_output_loss: 0.0150 - loss: 0.0744 - subcat_output_accuracy: 0.9795 - subcat_output_f1_score: 0.9457 - subcat_output_loss: 0.0594



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9981 - cat_output_f1_score: 0.9971 - cat_output_loss: 0.0151 - loss: 0.0745 - subcat_output_accuracy: 0.9794 - subcat_output_f1_score: 0.9456 - subcat_output_loss: 0.0593 - learning_rate: 5.0000e-04
Epoch 19/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9932 - cat_output_f1_score: 0.9922 - cat_output_loss: 0.0216 - loss: 0.0727 - subcat_output_accuracy: 0.9731 - subcat_output_f1_score: 0.9414 - subcat_output_loss: 0.0511



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9931 - cat_output_f1_score: 0.9922 - cat_output_loss: 0.0217 - loss: 0.0728 - subcat_output_accuracy: 0.9731 - subcat_output_f1_score: 0.9416 - subcat_output_loss: 0.0511 - learning_rate: 5.0000e-04
Epoch 20/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 1.0000 - cat_output_f1_score: 0.9995 - cat_output_loss: 0.0063 - loss: 0.0443 - subcat_output_accuracy: 0.9919 - subcat_output_f1_score: 0.9735 - subcat_output_loss: 0.0380



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 1.0000 - cat_output_f1_score: 0.9994 - cat_output_loss: 0.0064 - loss: 0.0444 - subcat_output_accuracy: 0.9918 - subcat_output_f1_score: 0.9734 - subcat_output_loss: 0.0381 - learning_rate: 5.0000e-04
Epoch 21/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9944 - cat_output_f1_score: 0.9936 - cat_output_loss: 0.0136 - loss: 0.0517 - subcat_output_accuracy: 0.9916 - subcat_output_f1_score: 0.9675 - subcat_output_loss: 0.0381



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m55s[0m 2s/step - cat_output_accuracy: 0.9944 - cat_output_f1_score: 0.9936 - cat_output_loss: 0.0138 - loss: 0.0518 - subcat_output_accuracy: 0.9915 - subcat_output_f1_score: 0.9675 - subcat_output_loss: 0.0380 - learning_rate: 5.0000e-04
Epoch 22/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9971 - cat_output_f1_score: 0.9945 - cat_output_loss: 0.0126 - loss: 0.0468 - subcat_output_accuracy: 0.9902 - subcat_output_f1_score: 0.9718 - subcat_output_loss: 0.0342



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9972 - cat_output_f1_score: 0.9945 - cat_output_loss: 0.0127 - loss: 0.0469 - subcat_output_accuracy: 0.9900 - subcat_output_f1_score: 0.9717 - subcat_output_loss: 0.0342 - learning_rate: 5.0000e-04
Epoch 23/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9991 - cat_output_f1_score: 0.9979 - cat_output_loss: 0.0097 - loss: 0.0430 - subcat_output_accuracy: 0.9837 - subcat_output_f1_score: 0.9700 - subcat_output_loss: 0.0333



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9991 - cat_output_f1_score: 0.9979 - cat_output_loss: 0.0099 - loss: 0.0431 - subcat_output_accuracy: 0.9839 - subcat_output_f1_score: 0.9702 - subcat_output_loss: 0.0332 - learning_rate: 5.0000e-04
Epoch 24/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9988 - cat_output_f1_score: 0.9965 - cat_output_loss: 0.0139 - loss: 0.0463 - subcat_output_accuracy: 0.9822 - subcat_output_f1_score: 0.9697 - subcat_output_loss: 0.0324



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9988 - cat_output_f1_score: 0.9965 - cat_output_loss: 0.0138 - loss: 0.0462 - subcat_output_accuracy: 0.9823 - subcat_output_f1_score: 0.9697 - subcat_output_loss: 0.0324 - learning_rate: 5.0000e-04
Epoch 25/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9992 - cat_output_f1_score: 0.9982 - cat_output_loss: 0.0092 - loss: 0.0376 - subcat_output_accuracy: 0.9903 - subcat_output_f1_score: 0.9730 - subcat_output_loss: 0.0285



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9991 - cat_output_f1_score: 0.9982 - cat_output_loss: 0.0092 - loss: 0.0376 - subcat_output_accuracy: 0.9902 - subcat_output_f1_score: 0.9733 - subcat_output_loss: 0.0284 - learning_rate: 5.0000e-04
Epoch 26/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9996 - cat_output_f1_score: 0.9988 - cat_output_loss: 0.0060 - loss: 0.0259 - subcat_output_accuracy: 0.9948 - subcat_output_f1_score: 0.9890 - subcat_output_loss: 0.0199



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9996 - cat_output_f1_score: 0.9988 - cat_output_loss: 0.0061 - loss: 0.0260 - subcat_output_accuracy: 0.9947 - subcat_output_f1_score: 0.9889 - subcat_output_loss: 0.0200 - learning_rate: 5.0000e-04
Epoch 27/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9968 - cat_output_f1_score: 0.9963 - cat_output_loss: 0.0115 - loss: 0.0349 - subcat_output_accuracy: 0.9920 - subcat_output_f1_score: 0.9756 - subcat_output_loss: 0.0234



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9969 - cat_output_f1_score: 0.9963 - cat_output_loss: 0.0114 - loss: 0.0348 - subcat_output_accuracy: 0.9921 - subcat_output_f1_score: 0.9758 - subcat_output_loss: 0.0234 - learning_rate: 5.0000e-04
Epoch 28/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9987 - cat_output_f1_score: 0.9986 - cat_output_loss: 0.0042 - loss: 0.0218 - subcat_output_accuracy: 0.9951 - subcat_output_f1_score: 0.9890 - subcat_output_loss: 0.0176



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9987 - cat_output_f1_score: 0.9986 - cat_output_loss: 0.0042 - loss: 0.0218 - subcat_output_accuracy: 0.9951 - subcat_output_f1_score: 0.9889 - subcat_output_loss: 0.0176 - learning_rate: 5.0000e-04
Epoch 29/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 0.9968 - cat_output_f1_score: 0.9978 - cat_output_loss: 0.0079 - loss: 0.0270 - subcat_output_accuracy: 0.9944 - subcat_output_f1_score: 0.9849 - subcat_output_loss: 0.0190



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 0.9969 - cat_output_f1_score: 0.9978 - cat_output_loss: 0.0078 - loss: 0.0268 - subcat_output_accuracy: 0.9944 - subcat_output_f1_score: 0.9849 - subcat_output_loss: 0.0190 - learning_rate: 5.0000e-04
Epoch 30/30
[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 2s/step - cat_output_accuracy: 1.0000 - cat_output_f1_score: 0.9958 - cat_output_loss: 0.0112 - loss: 0.0278 - subcat_output_accuracy: 0.9963 - subcat_output_f1_score: 0.9861 - subcat_output_loss: 0.0166



[1m29/29[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m54s[0m 2s/step - cat_output_accuracy: 1.0000 - cat_output_f1_score: 0.9959 - cat_output_loss: 0.0111 - loss: 0.0277 - subcat_output_accuracy: 0.9962 - subcat_output_f1_score: 0.9859 - subcat_output_loss: 0.0166 - learning_rate: 5.0000e-04




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2s/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 71ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 70ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 71ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 70ms/step
{'categories': ['Страйкбольное оружие'], 'subcategories': ['rifle'], 'category_probabilities': {'Аксессуары и Запчасти': 0.0036617859732359648, 'Снаряжение и защита': 0.0009852994699031115, 'Страйкбольное оружие': 0.9983514547348022}, 'subcategory_probabilities': {'HK': 0.006292761769145727, 'M serias': 0.006835001055151224, 'ak': 0.00799906812608242, 'backpack': 0.011783991940319538, 'helmet': 0.0028981147333979607, 'mashinegun': 0.00692513445392251, 'pistol': 0.0034643891267478466, 'pouch': 0.0049142478965222836, 'rifle': 0.9984733462333679, 'shutgun': 0.030544739216566086, 'vest': 0.0042833611369132996}}


In [7]:
result = predict_with_tta('тест\мка ар.webp', model)
print(result)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 73ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 68ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 73ms/step
{'categories': ['Страйкбольное оружие'], 'subcategories': ['M serias'], 'category_probabilities': {'Аксессуары и Запчасти': 0.0018220817437395453, 'Снаряжение и защита': 0.0002366189582971856, 'Страйкбольное оружие': 0.999935507774353}, 'subcategory_probabilities': {'HK': 0.05416689068078995, 'M serias': 0.5270873308181763, 'ak': 0.0765266865491867, 'backpack': 0.011207496747374535, 'helmet': 0.002302464796230197, 'mashinegun': 0.0039348965510725975, 'pistol': 0.0019907939713448286, 'pouch': 0.00211693299934268, 'rifle': 0.023626241832971573, 'shutgun': 0.010622268542647362, 'vest': 0.0013050779234617949}}
