In [None]:
#–ò–ú–ü–û–†–¢ –ò –§–£–ù–ö–¶–ò–ò –° –ö–õ–ê–°–°–ê–ú–ò
import os
import glob
import h5py
import torch
import random
import pickle
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from IPython.display import clear_output
from IPython.display import Audio, display
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
import tqdm

# –≤—ã–±–∏—Ä–∞–µ–º —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  

#–∫–ª–∞—Å—Å –¥–∞—Ç–∞—Å–µ—Ç–∞
class AudioDataset(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path
        self.file = h5py.File(h5_path, 'r')
        
        self.spectrograms = self.file['spectrograms']
        self.labels = self.file['labels']
        self.class_names = [name.decode('utf-8') for name in self.file.attrs['class_names']]
        self.num_classes = self.file.attrs['num_classes']
        
        print(f"Loaded {len(self.spectrograms)} samples from HDF5")
    
    def __len__(self):
        return len(self.spectrograms)
    
    def __getitem__(self, idx):
        # –ó–∞–≥—Ä—É–∂–∞–µ–º –Ω–µ–ø–æ—Å—Ä–µ–¥—Å—Ç–≤–µ–Ω–Ω–æ –∏–∑ HDF5
        spectrogram = torch.tensor(self.spectrograms[idx], dtype=torch.float32)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return spectrogram, label
    
    def __del__(self):
        if hasattr(self, 'file'):
            self.file.close()
    
    def get_class_names(self):
        return self.class_names 
#—Ñ—É–Ω–∫—Ü–∏—è –ø–æ—Ç–µ—Ä—å (–¥–ª—è —Ç–µ—Å—Ç–∞)
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        
    def forward(self, inputs, targets):
        BCE_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return F_loss.mean()
#CNN + classifier    
class AudioCNN(nn.Module):
    def __init__(self, num_classes, input_channels=1):
        
        super(AudioCNN, self).__init__()
        
        # CNN
        self.conv_layers = nn.Sequential(
            
            # –±–ª–æ–∫ 1
            nn.Conv2d(1, 16, 3, padding=1), 
            nn.ReLU(),
            nn.MaxPool2d(2),    # [16, 64, 256]
            
            # –±–ª–æ–∫ 2
            nn.Conv2d(16, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),    # [32, 32, 128]
            
            # –±–ª–æ–∫ 3
            nn.Conv2d(32, 64, 3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((4,8)),    # [64, 16, 64]
        )

        # FC
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 4 * 8, 128),
            nn.ReLU(),

            nn.Linear(128, num_classes)
        )

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.conv_layers(x)
        x = self.classifier(x)
        return x
    
#–ø–æ–ª–Ω–æ—Å–≤—è–∑–Ω–∞—è
class SimpleFCModel(nn.Module):
    def __init__(self, num_classes: int, input_height: int, input_width: int, hidden_size: int = 256):
        super(SimpleFCModel, self).__init__()
        
        self.input_height = input_height
        self.input_width = input_width
        self.num_classes = num_classes
        
        # –í—ã—á–∏—Å–ª—è–µ–º —Ä–∞–∑–º–µ—Ä –ø–æ—Å–ª–µ –≤—ã—Ç—è–≥–∏–≤–∞–Ω–∏—è –≤ –≤–µ–∫—Ç–æ—Ä
        self.flatten_size = input_height * input_width
        
        self.layers = nn.Sequential(
            # –í—ã—Ç—è–≥–∏–≤–∞–µ–º –≤ –ø–ª–æ—Å–∫–∏–π –≤–µ–∫—Ç–æ—Ä
            nn.Flatten(),
            
            # –ü–µ—Ä–≤—ã–π —Å–∫—Ä—ã—Ç—ã–π —Å–ª–æ–π
            nn.Linear(self.flatten_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.5),  # –¥–ª—è —Ä–µ–≥—É–ª—è—Ä–∏–∑–∞—Ü–∏–∏
            
            # –í—Ç–æ—Ä–æ–π —Å–∫—Ä—ã—Ç—ã–π —Å–ª–æ–π (–º–æ–∂–Ω–æ –¥–æ–±–∞–≤–∏—Ç—å –±–æ–ª—å—à–µ)
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(), 
            nn.Dropout(0.4),
            
            # –í—ã—Ö–æ–¥–Ω–æ–π —Å–ª–æ–π
            nn.Linear(hidden_size // 2, num_classes)
        )
    
    def forward(self, x):
        x = self.layers(x)
        return x
#–º–µ—Ç—Ä–∏–∫–∏
def calculate_all_metrics(y_true, y_pred, loss, threshold):
    
    metrics = {}
    # Loss
    metrics['loss'] = loss
    # F1 Macro - –±–∞–ª–∞–Ω—Å –º–µ–∂–¥—É –∫–ª–∞—Å—Å–∞–º–∏ (–û–°–ù–û–í–ù–ê–Ø)
    metrics['f1_macro'] = f1_score(y_true, y_pred, average='macro', zero_division=0)
    # Precision Macro - –∫–æ–Ω—Ç—Ä–æ–ª—å false positive
    metrics['precision_macro'] = precision_score(y_true, y_pred, average='macro', zero_division=0)
    # Recall Macro - –∫–æ–Ω—Ç—Ä–æ–ª—å false negative
    metrics['recall_macro'] = recall_score(y_true, y_pred, average='macro', zero_division=0)
    # Threshold - –ø–æ—Ä–æ–≥ –æ–ø—Ä–µ–¥–ª–µ–Ω–∏—è label'–∞
    metrics['threshold'] = threshold
    
    return metrics

def train_epoch(model, loader, optimizer, criterion, epoch, num_epochs, threshold):
    model.train()
    all_targets = []
    all_predictions = []
    all_loss = 0.0
    output = []
    
    # –°–æ–∑–¥–∞–µ–º –ø—Ä–æ–≥—Ä–µ—Å—Å-–±–∞—Ä
    pbar = tqdm.tqdm(loader, desc=f'Epoch {epoch}/{num_epochs}', 
                bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}')
    
    for data, targets in pbar:
        data = data.to(device)
        targets = targets.to(device)   
        optimizer.zero_grad()                          
        outputs = model(data)                          
        loss = criterion(outputs, targets)             
        all_loss += loss.item()

        loss.backward()                                
        optimizer.step()                                 
        
        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è –∏ —Ü–µ–ª–∏ –¥–ª—è –º–µ—Ç—Ä–∏–∫
        with torch.no_grad():            
            probs = torch.sigmoid(outputs)           
            predictions = (probs > threshold).float()            
            all_targets.append(targets.cpu())           
            all_predictions.append(predictions.cpu())
            
        # –û–±–Ω–æ–≤–ª—è–µ–º –ø—Ä–æ–≥—Ä–µ—Å—Å-–±–∞—Ä
        current_loss = all_loss / (pbar.n + 1)
        pbar.set_postfix({'Loss': f'{current_loss:.4f}'})

    all_targets = torch.cat(all_targets).numpy()
    all_predictions = torch.cat(all_predictions).numpy()
    all_loss /= len(loader)

    return calculate_all_metrics(all_targets, all_predictions, all_loss, threshold)

def evaluate_epoch(model, loader, criterion, threshold, auto_threshold):
    model.eval()
    all_targets = []
    all_predictions = []
    all_probs = []
    all_loss = 0.0
    
    with torch.no_grad():
        for data, targets in loader:  
            data = data.to(device)
            targets = targets.to(device) 
            outputs = model(data)                       
            loss = criterion(outputs, targets)
            all_loss += loss.item()
            
            probs = torch.sigmoid(outputs)   
            predictions = (probs > threshold).float()   
            
            all_targets.append(targets.cpu())          
            all_predictions.append(predictions.cpu())
            all_probs.append(probs.cpu())
    
    all_targets = torch.cat(all_targets).numpy()
    all_predictions = torch.cat(all_predictions).numpy()
    all_probs = torch.cat(all_probs)
    all_loss /= len(loader)

    best_threshold = threshold
    best_f1 = 0.0
    
    # —Ä–∞—Å—á—ë—Ç –ª—É—á—à–µ–≥–æ –ø–æ—Ä–æ–≥–∞
    if auto_threshold:
        for threshold in torch.arange(0.3, 0.8, 0.05):
            predictions = (all_probs > threshold).float()
            f1 = f1_score(all_targets, predictions, average='macro')
            if f1 > best_f1:
                best_f1 = f1
                best_threshold = threshold.item()

    return calculate_all_metrics(all_targets, all_predictions, all_loss, best_threshold)

def train_model(model, train_loader, val_loader, optimizer, criterion, patience=10,
                num_epochs=50, threshold=0.5, save_dir='../data/models', 
                start_epoch=1, history=None, best_f1=0.0, auto_threshold=False):
    
    os.makedirs(save_dir, exist_ok=True)

    # –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –∏—Å—Ç–æ—Ä–∏–∏, –µ—Å–ª–∏ –Ω–µ –ø–µ—Ä–µ–¥–∞–Ω–∞
    if history is None:
        history = {
            'train_loss': [], 'val_loss': [],
            'train_f1': [], 'val_f1': [],
            'train_precision': [], 'val_precision': [],
            'train_recall': [], 'val_recall': [],
            'train_threshold': [], 'val_threshold': [],
            'epochs_completed': 0,
            'total_epochs': num_epochs,
            'best_f1': 0.0
        }
    
    patience_counter = 0

    for epoch in range(start_epoch, start_epoch + num_epochs):
        # –û–±—É—á–µ–Ω–∏–µ –∏ –≤–∞–ª–∏–¥–∞—Ü–∏—è
        train_metrics = train_epoch(model, train_loader, optimizer, criterion, epoch, start_epoch + num_epochs - 1, threshold)
        val_metrics = evaluate_epoch(model, val_loader, criterion, threshold, auto_threshold)
        
        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –∏—Å—Ç–æ—Ä–∏—é
        history['train_loss'].append(train_metrics['loss'])
        history['val_loss'].append(val_metrics['loss'])
        history['train_f1'].append(train_metrics['f1_macro'])
        history['val_f1'].append(val_metrics['f1_macro'])
        history['train_precision'].append(train_metrics['precision_macro'])
        history['val_precision'].append(val_metrics['precision_macro'])
        history['train_recall'].append(train_metrics['recall_macro'])
        history['val_recall'].append(val_metrics['recall_macro'])
        history['epochs_completed'] = epoch
        history['best_f1'] = max(history.get('best_f1', 0.0), val_metrics['f1_macro'])

        # –í—ã–≤–æ–¥–∏–º –º–µ—Ç—Ä–∏–∫–∏
        print("‚îÄ" * 50)
        print(f"Epoch {epoch}/{start_epoch + num_epochs - 1}")
        print(f"Train >>> Loss: {train_metrics['loss']:.4f} | F1: {train_metrics['f1_macro']:.4f} | "
              f"Precision: {train_metrics['precision_macro']:.4f} | Recall: {train_metrics['recall_macro']:.4f} | "
              f"Threshold: {threshold:.2f}")

        print(f"Val   >>> Loss: {val_metrics['loss']:.4f} | F1: {val_metrics['f1_macro']:.4f} | "
              f"Precision: {val_metrics['precision_macro']:.4f} | Recall: {val_metrics['recall_macro']:.4f} | "
              f"Threshold: {threshold:.2f}")

        # –æ–±–Ω–æ–≤–ª—è–µ–º threshold
        threshold = val_metrics['threshold'] 

        # Early stopping –∏ —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ –ª—É—á—à–µ–π –º–æ–¥–µ–ª–∏
        if val_metrics['f1_macro'] > best_f1:
            best_f1 = val_metrics['f1_macro']
            # –°–æ—Ö—Ä–∞–Ω—è–µ–º –ø–æ–ª–Ω–æ–µ —Å–æ—Å—Ç–æ—è–Ω–∏–µ –¥–ª—è –ø—Ä–æ–¥–æ–ª–∂–µ–Ω–∏—è –æ–±—É—á–µ–Ω–∏—è
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history,
                'epoch': epoch,
                'best_f1': best_f1,
                'patience_counter': patience_counter,
                'threshold': threshold
            }
            torch.save(checkpoint, f'{save_dir}/best_model.pth')
            patience_counter = 0
            print(f"–ù–æ–≤–∞—è –ª—É—á—à–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞! F1: {best_f1:.4f}")
        else:
            patience_counter += 1
            print(f"–ù–µ—Ç —É–ª—É—á—à–µ–Ω–∏–π: ({patience_counter}/{patience})")
            
        if patience_counter >= patience:
            print(f"–†–∞–Ω–Ω—è—è –æ—Å—Ç–∞–Ω–æ–≤–∫–∞ –Ω–∞ —ç–ø–æ—Ö–µ {epoch}.")
            break

        # if train_metrics['f1_macro'] - val_metrics['f1_macro'] > 0.1:
        #     print(f"–ü–µ—Ä–µ–æ–±—É—á–µ–Ω–∏–µ –º–æ–¥–µ–ª–∏! –û—Å—Ç–∞–Ω–æ–≤–∫–∞...")
        #     break
        
        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –ø–æ—Å–ª–µ–¥–Ω—é—é –º–æ–¥–µ–ª—å –∫–∞–∂–¥—ã–µ N —ç–ø–æ—Ö
        if epoch % 5 == 0:
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history,
                'epoch': epoch,
                'best_f1': best_f1,
                'patience_counter': patience_counter,
                'threshold': threshold
            }
            torch.save(checkpoint, f'{save_dir}/checkpoint_epoch_{epoch}.pth')

        print("\n")

    # –§–∏–Ω–∞–ª—å–Ω–æ–µ —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∏–µ
    final_checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'history': history,
        'epoch': epoch,
        'best_f1': best_f1,
        'patience_counter': patience_counter,
        'threshold': threshold
    }
    torch.save(final_checkpoint, f'{save_dir}/last_model.pth')

    return history, best_f1

def load_checkpoint(model, optimizer, checkpoint_path, device):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    history = checkpoint['history']
    start_epoch = checkpoint['epoch'] + 1
    best_f1 = checkpoint['best_f1']
    threshold = checkpoint.get('threshold', 0.5)
    
    print(f"–ó–∞–≥—Ä—É–∂–µ–Ω —á–µ–∫–ø–æ–∏–Ω—Ç —ç–ø–æ—Ö–∏ {checkpoint['epoch']}")
    print(f"–õ—É—á—à–∏–π F1: {best_f1:.4f}")
    print(f"–í—Å–µ–≥–æ —ç–ø–æ—Ö –≤ –∏—Å—Ç–æ—Ä–∏–∏: {len(history['train_loss'])}")
    
    return model, optimizer, history, start_epoch, best_f1, threshold

def plot_training_history(history, save_path=None):
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('–ò—Å—Ç–æ—Ä–∏—è —Ç—Ä–µ–Ω–∏—Ä–æ–≤–∫–∏', fontsize=16, fontweight='bold')
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # 1. Loss
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0, 0].set_title('Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. F1 Score
    axes[0, 1].plot(epochs, history['train_f1'], 'b-', label='Train F1', linewidth=2)
    axes[0, 1].plot(epochs, history['val_f1'], 'r-', label='Val F1', linewidth=2)
    axes[0, 1].set_title('F1 Macro Score')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('F1 Score')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Precision
    axes[1, 0].plot(epochs, history['train_precision'], 'b-', label='Train Precision', linewidth=2)
    axes[1, 0].plot(epochs, history['val_precision'], 'r-', label='Val Precision', linewidth=2)
    axes[1, 0].set_title('Precision Macro')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Precision')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # 4. Recall
    axes[1, 1].plot(epochs, history['train_recall'], 'b-', label='Train Recall', linewidth=2)
    axes[1, 1].plot(epochs, history['val_recall'], 'r-', label='Val Recall', linewidth=2)
    axes[1, 1].set_title('Recall Macro')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Recall')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"üìä Graphs saved to: {save_path}")
    
    plt.show()

def calculate_pos_weight(dataset, smoothing=1.0):
    all_targets = []
    for i in range(len(dataset)):
        _, targets = dataset[i]
        all_targets.append(targets)

    all_targets = torch.stack(all_targets)
    
    positive_counts = all_targets.sum(dim=0) + smoothing
    negative_counts = (all_targets.size(0) - all_targets.sum(dim=0)) + smoothing
    
    pos_weight = negative_counts / positive_counts
    pos_weight = torch.log(pos_weight + 1)
    pos_weight = pos_weight / max(pos_weight)
    
    # pos_weight = torch.clamp(pos_weight, max=max_weight)
    # pos_weight = torch.clamp(pos_weight, min=0.3) 

    print(f"–ö–æ—Ä—Ä–µ–∫—Ü–∏—è –≤–µ—Å–æ–≤: {pos_weight}")

    return pos_weight.to(device)



In [9]:
# –ó–ê–ì–†–£–ó–ö–ê –î–ê–¢–ê–°–ï–¢–ê
# ----------------------------------------------------------
tags_dict = {}
# –æ—Ç–∫—Ä—ã–≤–∞–µ–º –ø–æ–¥–≥–æ—Ç–æ–≤–ª–µ–Ω–Ω—ã–π —Å–ª–æ–≤—Ä—å
with open("../data/track_tags.pkl", 'rb') as file:
    tags_dict = pickle.load(file)

# –∑–∞–≥—Ä—É–∂–∞–µ–º –¥–∞—Ç–∞—Å–µ—Ç
dataset = AudioDataset("../data/preprocessed_dataset.h5")



# —Ä–∞–∑–±–∏–≤–∞–µ–º –¥–∞—Ç–∞—Å–µ—Ç –Ω–∞ train –∏ test
dataset_size = len(dataset)
indices = list(range(dataset_size))

train_idx, test_idx = train_test_split(
    indices, 
    test_size = 0.2,
    train_size = 0.8,       
    random_state = 42        # –¥–ª—è –≤–æ—Å–ø—Ä–æ–∏–∑–≤–æ–¥–∏–º–æ—Å—Ç–∏
)

train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, test_idx)

train_loader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,  
    num_workers=8,   
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2
)

test_loader = DataLoader(
    test_dataset,
    batch_size=64,
    shuffle=False,  
    num_workers=4,   
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=2
)

for data, targets in train_loader:
    print(f"DATA - shape: {data.shape}")
    print(f"       min: {data.min().item():.4f}, max: {data.max().item():.4f}")
    print(f"       mean: {data.mean().item():.4f}, std: {data.std().item():.4f}")
    print(f"target exemple: {targets[0]}")
    break

Loaded 18486 samples from HDF5
DATA - shape: torch.Size([64, 1, 128, 512])
       min: -13.8155, max: 12.4683
       mean: 0.4846, std: 4.1769
target exemple: tensor([0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 0.])


In [None]:
#–û–ë–£–ß–ï–ù–ò–ï –ú–û–î–ï–õ–ò
def debug_model_predictions(model, loader, threshold=0.5):
    model.eval()
    
    with torch.no_grad():
        for data, targets in loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            probabilities = torch.sigmoid(outputs)
            predictions = (probabilities > threshold).float()  # ‚ö° Multi-hot –≤–µ–∫—Ç–æ—Ä!
            
            print("=== MULTI-LABEL DEBUG ===")
            print(f"Outputs range: [{outputs.min():.3f}, {outputs.max():.3f}]")
            print(f"Probabilities range: [{probabilities.min():.3f}, {probabilities.max():.3f}]")
            
            # –ü—Ä–∞–≤–∏–ª—å–Ω–∞—è –∏–Ω—Ç–µ—Ä–ø—Ä–µ—Ç–∞—Ü–∏—è
            print(f"Sample 0 predictions: {predictions[0].cpu().numpy()}")
            print(f"Sample 0 targets:     {targets[0].cpu().numpy()}")
            print(f"Sample 0 matched:     {((predictions[0] == targets[0]) & (targets[0] == 1)).sum().item()} correct positive tags")
            
            # –°—Ç–∞—Ç–∏—Å—Ç–∏–∫–∞ –ø–æ –±–∞—Ç—á—É
            batch_accuracy = ((predictions == targets).float().mean().item())
            positive_accuracy = ((predictions[targets == 1] == 1).float().mean().item())
            
            print(f"Batch accuracy: {batch_accuracy:.3f}")
            print(f"Positive tag accuracy: {positive_accuracy:.3f}")
            break
def debug_model_outputs(model, data_loader):
    model.eval()
    
    with torch.no_grad():
        for data, targets in data_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            
            print("=== MODEL DEBUG ===")
            print(f"Input data range: [{data.min():.3f}, {data.max():.3f}]")
            print(f"Input data mean: {data.mean():.3f}")
            
            # –ü—Ä–æ–≤–µ—Ä—è–µ–º –≤—ã—Ö–æ–¥—ã –∫–∞–∂–¥–æ–≥–æ —Å–ª–æ—è
            print(f"Model outputs range: [{outputs.min():.3f}, {outputs.max():.3f}]")
            print(f"Model outputs mean: {outputs.mean():.3f}")
            
            # –ü—Ä–æ–≤–µ—Ä—è–µ–º –≥—Ä–∞–¥–∏–µ–Ω—Ç—ã
            for name, param in model.named_parameters():
                if param.requires_grad and param.grad is not None:
                    print(f"Gradient {name}: {param.grad.abs().mean():.6f}")
                else:
                    print(f"Gradient {name}: None")
            
            break
def debug_model_weights():
    for name, param in model.named_parameters():
        if 'weight' in name:
            print(f"{name:30} mean: {param.data.mean().item():8.4f} std: {param.data.std().item():8.4f}")
        elif 'bias' in name:
            print(f"{name:30} mean: {param.data.mean().item():8.4f}")

num_epochs = 50          # –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ —ç–ø–æ—Ö –¥–ª—è –æ–±—É—á–µ–Ω–∏—è
threshold = 0.5         # –ø–æ—Ä–æ–≥ –¥–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
patience = 5           # –æ—Å—Ç–∞–Ω–æ–≤–∫–∞ –æ–±—É—á–µ–Ω–∏—è –ø–æ—Å–ª–µ N —ç–ø–æ—Ö –±–µ–∑ —É–ª—É—á—à–µ–Ω–∏–π

start_epoch = 1         #
history = None          #
best_f1 = 0.0           #

# –ò–Ω–∏—Ü–∏–∞–ª–∏–∑–∞—Ü–∏—è –º–æ–¥–µ–ª–∏
model = AudioCNN(dataset.num_classes)
model.to(device)

print(f"–ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {device}")
print(f"–ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤ –º–æ–¥–µ–ª–∏: {sum(p.numel() for p in model.parameters())}")

criterion = FocalLoss(alpha=1, gamma=2)
#criterion = nn.MultiLabelSoftMarginLoss()
# criterion = nn.BCEWithLogitsLoss(calculate_pos_weight(dataset))
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
#optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-3)

# –ó–∞–≥—Ä—É–∑–∫–∞ –º–æ–¥–µ–ª–∏
#model, optimizer, history, start_epoch, best_f1, threshold = load_checkpoint(model, optimizer, "../data/models/best_model.pth", device)

# –ó–∞–ø—É—Å–∫ –æ–±—É—á–µ–Ω–∏—è
history, best_f1 = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    optimizer=optimizer,
    criterion=criterion,
    num_epochs=num_epochs,
    threshold = threshold,
    save_dir='../data/models',
    patience=patience,
    start_epoch=start_epoch,
    history=history,
    best_f1=best_f1,
    auto_threshold = False
)

# –≤—ã–≤–æ–¥ –≥—Ä–∞—Ñ–∏–∫–æ–≤
plot_training_history(history)

debug_model_predictions(model, train_loader, threshold)
debug_model_outputs(model, train_loader)

–ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: cuda
–ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤ –º–æ–¥–µ–ª–∏: 293179


Epoch 1/50:   0%|                    | 0/232 [00:00<?, ?it/s]

Epoch 1/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 232/232 [00:11<00:00, 19.93it/s, Loss=0.0399]


‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Epoch 1/50
Train >>> Loss: 0.0397 | F1: 0.0048 | Precision: 0.0182 | Recall: 0.0032 | Threshold: 0.50
Val   >>> Loss: 0.0364 | F1: 0.0000 | Precision: 0.0000 | Recall: 0.0000 | Threshold: 0.50
–ù–µ—Ç —É–ª—É—á—à–µ–Ω–∏–π: (1/5)




Epoch 2/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 232/232 [00:11<00:00, 20.84it/s, Loss=0.0360]


‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Epoch 2/50
Train >>> Loss: 0.0359 | F1: 0.0486 | Precision: 0.0467 | Recall: 0.0920 | Threshold: 0.30
Val   >>> Loss: 0.0363 | F1: 0.0325 | Precision: 0.0352 | Recall: 0.0768 | Threshold: 0.30
–ù–æ–≤–∞—è –ª—É—á—à–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞! F1: 0.0325




Epoch 3/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 232/232 [00:11<00:00, 20.79it/s, Loss=0.0355]


‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Epoch 3/50
Train >>> Loss: 0.0352 | F1: 0.0598 | Precision: 0.0705 | Recall: 0.1056 | Threshold: 0.30
Val   >>> Loss: 0.0350 | F1: 0.0474 | Precision: 0.0551 | Recall: 0.0796 | Threshold: 0.30
–ù–æ–≤–∞—è –ª—É—á—à–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞! F1: 0.0474




Epoch 4/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 232/232 [00:11<00:00, 20.53it/s, Loss=0.0350]


‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Epoch 4/50
Train >>> Loss: 0.0347 | F1: 0.0751 | Precision: 0.0825 | Recall: 0.1224 | Threshold: 0.30
Val   >>> Loss: 0.0346 | F1: 0.0788 | Precision: 0.0705 | Recall: 0.1608 | Threshold: 0.30
–ù–æ–≤–∞—è –ª—É—á—à–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞! F1: 0.0788




Epoch 5/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 232/232 [00:11<00:00, 20.57it/s, Loss=0.0344]


‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Epoch 5/50
Train >>> Loss: 0.0342 | F1: 0.0920 | Precision: 0.0930 | Recall: 0.1443 | Threshold: 0.30
Val   >>> Loss: 0.0343 | F1: 0.0797 | Precision: 0.0856 | Recall: 0.1444 | Threshold: 0.30
–ù–æ–≤–∞—è –ª—É—á—à–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞! F1: 0.0797




Epoch 6/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 232/232 [00:11<00:00, 20.55it/s, Loss=0.0340]


‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Epoch 6/50
Train >>> Loss: 0.0337 | F1: 0.1082 | Precision: 0.1077 | Recall: 0.1676 | Threshold: 0.30
Val   >>> Loss: 0.0343 | F1: 0.0882 | Precision: 0.0871 | Recall: 0.1328 | Threshold: 0.30
–ù–æ–≤–∞—è –ª—É—á—à–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞! F1: 0.0882




Epoch 7/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 232/232 [00:11<00:00, 20.59it/s, Loss=0.0336]


‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Epoch 7/50
Train >>> Loss: 0.0335 | F1: 0.1200 | Precision: 0.1221 | Recall: 0.1881 | Threshold: 0.30
Val   >>> Loss: 0.0343 | F1: 0.0853 | Precision: 0.1058 | Recall: 0.1377 | Threshold: 0.30
–ù–µ—Ç —É–ª—É—á—à–µ–Ω–∏–π: (1/5)




Epoch 8/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 232/232 [00:11<00:00, 20.77it/s, Loss=0.0333]


‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Epoch 8/50
Train >>> Loss: 0.0330 | F1: 0.1298 | Precision: 0.1214 | Recall: 0.2018 | Threshold: 0.30
Val   >>> Loss: 0.0338 | F1: 0.1046 | Precision: 0.1278 | Recall: 0.1779 | Threshold: 0.30
–ù–æ–≤–∞—è –ª—É—á—à–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞! F1: 0.1046




Epoch 9/50: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 232/232 [00:11<00:00, 20.58it/s, Loss=0.0330]


‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
Epoch 9/50
Train >>> Loss: 0.0327 | F1: 0.1444 | Precision: 0.1338 | Recall: 0.2223 | Threshold: 0.30
Val   >>> Loss: 0.0337 | F1: 0.1136 | Precision: 0.1240 | Recall: 0.1971 | Threshold: 0.30
–ù–æ–≤–∞—è –ª—É—á—à–∞—è –º–æ–¥–µ–ª—å —Å–æ—Ö—Ä–∞–Ω–µ–Ω–∞! F1: 0.1136




Epoch 10/50:  52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç         | 121/232 [00:05<00:05, 21.04it/s, Loss=0.0327]

In [None]:
import torchaudio 

def random_audio_predictor(model, preprocess_function, class_names, folder_path='../data/train/', device='cuda', threshold=0.5):
    """
    –°–ª—É—á–∞–π–Ω—ã–π –≤—ã–±–æ—Ä –∞—É–¥–∏–æ —Ñ–∞–π–ª–∞ —Å –∞–≤—Ç–æ–≤–æ—Å–ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏–µ–º –∏ –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ–º —Ç–µ–≥–æ–≤
    """
    
    print("üé≤ –°–õ–£–ß–ê–ô–ù–´–ô –ê–£–î–ò–û –ê–ù–ê–õ–ò–ó")
    print("=" * 50)
    
    # –ù–∞—Ö–æ–¥–∏–º –≤—Å–µ –∞—É–¥–∏–æ —Ñ–∞–π–ª—ã –≤ –ø–∞–ø–∫–µ
    audio_extensions = ['*.wav', '*.mp3', '*.flac', '*.m4a', '*.ogg']
    audio_files = []
    
    for ext in audio_extensions:
        # –ò—â–µ–º —Ñ–∞–π–ª—ã –≤ –æ—Å–Ω–æ–≤–Ω–æ–π –ø–∞–ø–∫–µ –∏ –ø–æ–¥–ø–∞–ø–∫–∞—Ö
        audio_files.extend(glob.glob(os.path.join(folder_path, '**', ext), recursive=True))
        audio_files.extend(glob.glob(os.path.join(folder_path, ext)))
    
    if not audio_files:
        print(f"‚ùå –í –ø–∞–ø–∫–µ {folder_path} –Ω–µ –Ω–∞–π–¥–µ–Ω–æ –∞—É–¥–∏–æ —Ñ–∞–π–ª–æ–≤!")
        return None
    
    # print(f"üìÅ –ü–∞–ø–∫–∞: {folder_path}")
    # print(f"üìä –ù–∞–π–¥–µ–Ω–æ —Ñ–∞–π–ª–æ–≤: {len(audio_files)}")
    # print(f"üéöÔ∏è  –ü–æ—Ä–æ–≥: {threshold}")
    
    # –í—ã–±–∏—Ä–∞–µ–º —Å–ª—É—á–∞–π–Ω—ã–π —Ñ–∞–π–ª
    selected_file = random.choice(audio_files)
    filename = os.path.basename(selected_file)
    file_path = os.path.dirname(selected_file)
    
    # print(f"\nüéØ –í—ã–±—Ä–∞–Ω —Ñ–∞–π–ª: {filename}")
    # print(f"üìÇ –ü—É—Ç—å: {file_path}")
    # print("‚îÄ" * 50)
    

    # –ê–≤—Ç–æ–≤–æ—Å–ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏–µ –∞—É–¥–∏–æ
    print("üîä –ê–≤—Ç–æ–≤–æ—Å–ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏–µ...")
    display(Audio(selected_file, autoplay=True))
        
    # –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ —Ç–µ–≥–æ–≤
    print("\nü§ñ –ê–Ω–∞–ª–∏–∑ —Ç–µ–≥–æ–≤...")
    probs, active_tags = quick_predict(
        model=model,
        audio_path=selected_file,
        preprocess_function=preprocess_function,
        class_names=class_names,
        device=device,
        threshold=threshold
    )
    
    if probs is not None:
        # –í—ã–≤–æ–¥–∏–º —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ã
        print(f"\nüè∑Ô∏è  –ü–†–ï–î–°–ö–ê–ó–ê–ù–ù–´–ï –¢–ï–ì–ò:")
        print("‚îÄ" * 30)
        
        if active_tags:
            print(f"‚úÖ –ù–∞–π–¥–µ–Ω–æ —Ç–µ–≥–æ–≤: {len(active_tags)}")
            print("\nüìã –°–ø–∏—Å–æ–∫ —Ç–µ–≥–æ–≤:")
            for tag in active_tags:
                prob = probs[class_names.index(tag)]
                print(f"   ‚Ä¢ {tag}: {prob:.3f}")
        else:
            print("‚ùå –ù–µ—Ç —Ç–µ–≥–æ–≤ –≤—ã—à–µ –ø–æ—Ä–æ–≥–∞")
        
        # # –¢–æ–ø-10 –≤—Å–µ—Ö –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
        # top_indices = np.argsort(probs)[::-1][:10]
        # print(f"\nüîù –¢–æ–ø-10 –≤—Å–µ—Ö –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π:")
        # for i, idx in enumerate(top_indices, 1):
        #     prob = probs[idx]
        #     tag = class_names[idx]
        #     marker = "‚úÖ" if prob > threshold else "  "
        #     print(f"   {i:2d}. {marker} {tag:<25} {prob:.3f}")
        
        new_dict = {}
        for key, vals in tags_dict.items():
            new_dict[key[3:]] = vals

        print(f"–ü—Ä–∞–≤–∏–ª—å–Ω—ã–µ —Ç–µ–≥–∏: {new_dict.get(filename, "none")}")

        # –°—Ç–∞—Ç–∏—Å—Ç–∏–∫–∞
        print(f"\nüìä –°–¢–ê–¢–ò–°–¢–ò–ö–ê:")
        print(f"   –í—Å–µ–≥–æ –≤–æ–∑–º–æ–∂–Ω—ã—Ö —Ç–µ–≥–æ–≤: {len(class_names)}")
        print(f"   –ê–∫—Ç–∏–≤–Ω—ã—Ö —Ç–µ–≥–æ–≤: {len(active_tags)}")
        print(f"   –ú–∞–∫—Å–∏–º–∞–ª—å–Ω–∞—è –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å: {max(probs):.3f}")
        print(f"   –°—Ä–µ–¥–Ω—è—è –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å: {np.mean(probs):.3f}")
            
        return {
            'file': selected_file,
            'filename': filename,
            'probs': probs,
            'active_tags': active_tags,
            'threshold': threshold
        }           

def preprocess_audio(audio_path, target_sr = 22050, noise = False, start=2048, size=512) -> torch.Tensor:
    waveform, sr = torchaudio.load(audio_path, normalize=True, channels_first=True)
    #—É–º–µ–Ω—å—à–∏–ª —á–∞—Å—Ç–æ—Ç—É –¥–∏—Å–∫—Ä–µ—Ç–∏–∑–∞—Ü–∏–∏, —á—Ç–æ–±—ã –µ—â—ë –º–µ–Ω—å—à–µ –¥–∞—Ç–∞—Å–µ—Ç –≤–µ—Å–∏–ª, –∞ —Ç–æ –≤–µ–∑–¥–µ 44100 –ì—Ü –∏–ª–∏ –ø–æ—á—Ç–∏ –≤–µ–∑–¥–µ
    if sr != target_sr:
        waveform = torchaudio.functional.resample(waveform, sr, target_sr)
    
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=22050,
        n_mels=128,        
        n_fft=2048,       
        hop_length=512,
        f_min=20,
        f_max=11025
    )(waveform)

    mel_spectrogram = torch.log(mel_spectrogram + 1e-6)

    # –¥–æ–±–∞–≤–∏–º –Ω–µ–º–Ω–æ–≥–æ —à—É–º–∞ (—ç—Ç–æ —Å–∫–æ—Ä–µ–µ –≤—Å–µ–≥–æ —Å–∏–ª—å–Ω–æ –∑–∞–º–µ–¥–ª–∏—Ç –ø–æ–¥–≥–æ—Ç–æ–≤–∫—É –¥–∞–Ω–Ω—ã—Ö)
    if noise:
        for x in mel_spectrogram:
            x += random.uniform(-1.0, 1.0)

    return mel_spectrogram[0,:start,:start+size]

def quick_predict(model, audio_path, preprocess_function, class_names, device='cuda', threshold=0.5):
    """
    –§—É–Ω–∫—Ü–∏—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è —Ç–µ–≥–æ–≤ –¥–ª—è –∞—É–¥–∏–æ —Ñ–∞–π–ª–∞
    """
    model.to(device)
    model.eval()
    
    try:
        # –ü—Ä–µ–ø—Ä–æ—Ü–µ—Å—Å–∏–Ω–≥
        mel_spec = preprocess_function(audio_path)
        
        # –î–æ–±–∞–≤–ª—è–µ–º batch dimension –∏ –ø–µ—Ä–µ–º–µ—â–∞–µ–º –Ω–∞ —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ
        if len(mel_spec.shape) == 3:  # [channels, height, width]
            input_tensor = mel_spec.unsqueeze(0)  # [1, channels, height, width]
        else:
            input_tensor = mel_spec.unsqueeze(0).unsqueeze(0)  # [1, 1, height, width]
        
        input_tensor = input_tensor.to(device)
        
        # –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ
        with torch.no_grad():
            outputs = model(input_tensor)
            probs = torch.sigmoid(outputs).cpu().numpy()[0]
        
        # –°–æ–±–∏—Ä–∞–µ–º –∞–∫—Ç–∏–≤–Ω—ã–µ —Ç–µ–≥–∏
        active_tags = []
        for i, (class_name, prob) in enumerate(zip(class_names, probs)):
            if prob > threshold:
                active_tags.append(class_name)
        
        return probs, active_tags
        
    except Exception as e:
        print(f"‚ùå –û—à–∏–±–∫–∞ –≤–æ –≤—Ä–µ–º—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è: {e}")
        return None, None

best_model = AudioCNN(dataset.num_classes)
best_model,_,_,_,_,threshold = load_checkpoint(best_model, optimizer, "../data/models/best_model.pth", device)
best_model.to(device)
best_model.eval()

# –ü—Ä–æ—Å—Ç–æ–π –≤—ã–∑–æ–≤ - –æ–¥–∏–Ω —Å–ª—É—á–∞–π–Ω—ã–π —Ñ–∞–π–ª
random_audio_predictor(
    model=best_model,
    preprocess_function=preprocess_audio,
    class_names=dataset.get_class_names(),
    folder_path='../data/train', #'/home/egr/–ú—É–∑—ã–∫–∞',  # –í–∞—à–∞ –ø–∞–ø–∫–∞ —Å –∞—É–¥–∏–æ
    device=device,
    threshold=threshold  # –ú–æ–∂–Ω–æ –Ω–∞—Å—Ç—Ä–æ–∏—Ç—å –ø–æ—Ä–æ–≥
)


–ó–∞–≥—Ä—É–∂–µ–Ω —á–µ–∫–ø–æ–∏–Ω—Ç —ç–ø–æ—Ö–∏ 8
–õ—É—á—à–∏–π F1: 0.1407
–í—Å–µ–≥–æ —ç–ø–æ—Ö –≤ –∏—Å—Ç–æ—Ä–∏–∏: 8
üé≤ –°–õ–£–ß–ê–ô–ù–´–ô –ê–£–î–ò–û –ê–ù–ê–õ–ò–ó
üîä –ê–≤—Ç–æ–≤–æ—Å–ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏–µ...



ü§ñ –ê–Ω–∞–ª–∏–∑ —Ç–µ–≥–æ–≤...

üè∑Ô∏è  –ü–†–ï–î–°–ö–ê–ó–ê–ù–ù–´–ï –¢–ï–ì–ò:
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
‚úÖ –ù–∞–π–¥–µ–Ω–æ —Ç–µ–≥–æ–≤: 3

üìã –°–ø–∏—Å–æ–∫ —Ç–µ–≥–æ–≤:
   ‚Ä¢ background: 0.350
   ‚Ä¢ meditative: 0.315
   ‚Ä¢ soundscape: 0.370
–ü—Ä–∞–≤–∏–ª—å–Ω—ã–µ —Ç–µ–≥–∏: {'relaxing'}

üìä –°–¢–ê–¢–ò–°–¢–ò–ö–ê:
   –í—Å–µ–≥–æ –≤–æ–∑–º–æ–∂–Ω—ã—Ö —Ç–µ–≥–æ–≤: 59
   –ê–∫—Ç–∏–≤–Ω—ã—Ö —Ç–µ–≥–æ–≤: 3
   –ú–∞–∫—Å–∏–º–∞–ª—å–Ω–∞—è –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å: 0.370
   –°—Ä–µ–¥–Ω—è—è –≤–µ—Ä–æ—è—Ç–Ω–æ—Å—Ç—å: 0.102


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


{'file': '../data/train/52/1090652.mp3',
 'filename': '1090652.mp3',
 'probs': array([0.0155134 , 0.01942021, 0.0403605 , 0.21273407, 0.3497307 ,
        0.02856989, 0.1258596 , 0.06242988, 0.05462703, 0.01877237,
        0.01873024, 0.06615639, 0.23969632, 0.06060187, 0.10993326,
        0.06949602, 0.05631348, 0.20014621, 0.2210205 , 0.0777566 ,
        0.05719805, 0.0051653 , 0.2246038 , 0.02143176, 0.05988639,
        0.02145566, 0.00101096, 0.16998795, 0.02039509, 0.07859646,
        0.13241966, 0.03267135, 0.1531653 , 0.09587714, 0.31468758,
        0.1493248 , 0.05737305, 0.07158424, 0.07246327, 0.06443791,
        0.17662704, 0.04542641, 0.07062485, 0.01835235, 0.23864506,
        0.01099946, 0.08969235, 0.26335177, 0.02200166, 0.18783759,
        0.20475818, 0.36976668, 0.2067063 , 0.00176064, 0.02009471,
        0.04609552, 0.06068613, 0.01841536, 0.08554162], dtype=float32),
 'active_tags': ['background', 'meditative', 'soundscape'],
 'threshold': 0.30000001192092896}