In [1]:
#–ò–ú–ü–û–†–¢ –ò –§–£–ù–ö–¶–ò–ò –° –ö–õ–ê–°–°–ê–ú–ò
import os
import h5py
from matplotlib import pyplot as plt
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, accuracy_score
import torch
import torchaudio
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split
import tqdm

class SpectrogramDataset(Dataset):
    def __init__(self, h5_path):
        self.h5_path = h5_path
        self.file = h5py.File(h5_path, 'r')
        
        self.spectrograms = self.file['spectrograms']
        self.labels = self.file['labels']

        print(self.file.attrs['class_names'])

        self.class_names = [name for name in self.file.attrs['class_names']]
        self.num_classes = self.file.attrs['num_classes']
        
        print(f"Loaded {len(self.spectrograms)} samples from HDF5")
    
    def __len__(self):
        return len(self.spectrograms)
    
    def __getitem__(self, idx):
        # –ó–∞–≥—Ä—É–∂–∞–µ–º –Ω–µ–ø–æ—Å—Ä–µ–¥—Å—Ç–≤–µ–Ω–Ω–æ –∏–∑ HDF5
        spectrogram = torch.tensor(self.spectrograms[idx], dtype=torch.float32)

        label = torch.tensor(self.labels[idx], dtype=torch.uint8)

        return spectrogram, label
    
    def __del__(self):
        if hasattr(self, 'file'):
            self.file.close()
    
    def get_class_names(self):
        return self.class_names 

class TinyCNN(nn.Module):
    def __init__(self, num_classes, hidden=128):
        super(TinyCNN, self).__init__()
        
        self.conv_1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [16, 64, 256]
            nn.Dropout(0.15),
        )

        self.conv_2 = nn.Sequential(   
            nn.Conv2d(16, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [32, 32, 128]
            nn.Dropout(0.2),
        )

        self.conv_3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [64, 16, 64]
            nn.Dropout(0.2),
        )

        self.conv_4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2),  # [128, 8, 32]
            nn.Dropout(0.3),
        )

        self.avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1,4))
        )
        
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 1 * 4, hidden),
            nn.BatchNorm1d(hidden),
            nn.ReLU(),
            nn.Dropout(0.4),

            nn.Linear(hidden, num_classes)
        )
        
    def forward(self, x):
        x = self.conv_1(x)
        x = self.conv_2(x)
        x = self.conv_3(x)
        x = self.conv_4(x)
        x = self.avg_pool(x)
        x = self.fc_layers(x)
        return x

def calculate_all_metrics(outputs, labels):
    if torch.is_tensor(outputs):
        outputs = outputs.detach().cpu().numpy()
    if torch.is_tensor(labels):
        labels = labels.detach().cpu().numpy()
    
    predicted = np.argmax(outputs, axis=1)

    accuracy = accuracy_score(labels, predicted)
    precision = precision_score(labels, predicted, average='weighted', zero_division=0)
    recall = recall_score(labels, predicted, average='weighted', zero_division=0)
    f1 = f1_score(labels, predicted, average='weighted', zero_division=0)
    cm = confusion_matrix(labels, predicted)

    
    metrics = {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': cm
    }
    
    return metrics

def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_outputs = []
    all_labels = []

    for spectrograms, labels in dataloader:
        
        spectrograms, labels = spectrograms.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(spectrograms)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        all_outputs.append(outputs)
        all_labels.append(labels)
    
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    
    metrics = calculate_all_metrics(all_outputs, all_labels)
    avg_loss = running_loss / len(dataloader)
    
    return avg_loss, metrics

def evaluate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_outputs = []
    all_labels = []
    
    with torch.no_grad():
        for spectrograms, labels in dataloader:
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            
            outputs = model(spectrograms)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            all_outputs.append(outputs)
            all_labels.append(labels)
    
    all_outputs = torch.cat(all_outputs)
    all_labels = torch.cat(all_labels)
    
    metrics = calculate_all_metrics(all_outputs, all_labels)
    avg_loss = running_loss / len(dataloader)
    
    return avg_loss, metrics

def train_model(model, train_loader, val_loader, optimizer, criterion, patience=10, num_epochs=50, device='cuda', save_path='../data/models/genres/best_model.pth'):
    
    history = {
        'train_loss': [],
        'train_accuracy': [],
        'train_f1': [],
        'val_loss': [],
        'val_accuracy': [],
        'val_f1': []
    }
    
    best_val_accuracy = 0.0
    patience_counter = 0
    
    print(f"–ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: {device}")
    print(f"–ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤ –º–æ–¥–µ–ª–∏: {sum(p.numel() for p in model.parameters())}")
    
    # –ì–ª–∞–≤–Ω—ã–π –ø—Ä–æ–≥—Ä–µ—Å—Å-–±–∞—Ä –¥–ª—è —ç–ø–æ—Ö
    epoch_pbar = tqdm.tqdm(range(num_epochs), desc='Epochs')
    
    for epoch in epoch_pbar:
        # –û–±—É—á–µ–Ω–∏–µ
        train_loss, train_metrics = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # –í–∞–ª–∏–¥–∞—Ü–∏—è
        val_loss, val_metrics = evaluate_epoch(model, val_loader, criterion, device)
        
        # –°–æ—Ö—Ä–∞–Ω—è–µ–º –∏—Å—Ç–æ—Ä–∏—é
        history['train_loss'].append(train_loss)
        history['train_accuracy'].append(train_metrics['accuracy'])
        history['train_f1'].append(train_metrics['f1_score'])
        history['val_loss'].append(val_loss)
        history['val_accuracy'].append(val_metrics['accuracy'])
        history['val_f1'].append(val_metrics['f1_score'])
        
        # –û–±–Ω–æ–≤–ª—è–µ–º –ø—Ä–æ–≥—Ä–µ—Å—Å-–±–∞—Ä —ç–ø–æ—Ö–∏
        epoch_pbar.set_postfix({
            'Train Loss': f'{train_loss:.4f}',
            'Train Acc': f'{train_metrics["accuracy"]:.4f}',
            'Val Acc': f'{val_metrics["accuracy"]:.4f}',
            'Best Val Acc': f'{best_val_accuracy:.4f}'
        })
        
        # –†–∞–Ω–Ω—è—è –æ—Å—Ç–∞–Ω–æ–≤–∫–∞
        if val_metrics['accuracy'] > best_val_accuracy:
            best_val_accuracy = val_metrics['accuracy']
            patience_counter = 0
            torch.save(model.state_dict(), save_path)
        else:
            patience_counter += 1
            
        if patience_counter >= patience:
            print(f"\n–†–∞–Ω–Ω—è—è –æ—Å—Ç–∞–Ω–æ–≤–∫–∞ –Ω–∞ —ç–ø–æ—Ö–µ {epoch+1}")
            break
    
    # –ó–∞–≥—Ä—É–∂–∞–µ–º –ª—É—á—à—É—é –º–æ–¥–µ–ª—å
    model.load_state_dict(torch.load(save_path))
    print(f"\n–õ—É—á—à–∞—è —Ç–æ—á–Ω–æ—Å—Ç—å –Ω–∞ –≤–∞–ª–∏–¥–∞—Ü–∏–∏: {best_val_accuracy:.4f}")
    
    return history, model

def load_model(model, model_path, device='cuda'):
    # –°–æ–∑–¥–∞–µ–º –º–æ–¥–µ–ª—å
    model = TinyCNN(num_classes=10)
    
    # –ó–∞–≥—Ä—É–∂–∞–µ–º –≤–µ—Å–∞
    model.load_state_dict(torch.load(model_path, map_location=device))

    print(f"–ú–æ–¥–µ–ª—å –∑–∞–≥—Ä—É–∂–µ–Ω–∞ –∏–∑ {model_path}")
    
    return model

def plot_training_history(history):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train Loss', linewidth=2)
    axes[0, 0].plot(history['val_loss'], label='Val Loss', linewidth=2)
    axes[0, 0].set_title('Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(history['train_accuracy'], label='Train Accuracy', linewidth=2)
    axes[0, 1].plot(history['val_accuracy'], label='Val Accuracy', linewidth=2)
    axes[0, 1].set_title('Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # F1-Score
    axes[1, 0].plot(history['train_f1'], label='Train F1', linewidth=2)
    axes[1, 0].plot(history['val_f1'], label='Val F1', linewidth=2)
    axes[1, 0].set_title('F1-Score')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1-Score')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # –°—Ä–∞–≤–Ω–µ–Ω–∏–µ –º–µ—Ç—Ä–∏–∫ –Ω–∞ –ø–æ—Å–ª–µ–¥–Ω–µ–π —ç–ø–æ—Ö–µ
    metrics_names = ['Accuracy', 'F1-Score']
    train_metrics = [history['train_accuracy'][-1], history['train_f1'][-1]]
    val_metrics = [history['val_accuracy'][-1], history['val_f1'][-1]]
    
    x = np.arange(len(metrics_names))
    width = 0.35
    
    axes[1, 1].bar(x - width/2, train_metrics, width, label='Train', alpha=0.8)
    axes[1, 1].bar(x + width/2, val_metrics, width, label='Val', alpha=0.8)
    axes[1, 1].set_title('Final Metrics Comparison')
    axes[1, 1].set_xlabel('Metrics')
    axes[1, 1].set_ylabel('Score')
    axes[1, 1].set_xticks(x)
    axes[1, 1].set_xticklabels(metrics_names)
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # –î–æ–±–∞–≤–ª—è–µ–º –∑–Ω–∞—á–µ–Ω–∏—è –Ω–∞ —Å—Ç–æ–ª–±—Ü—ã
    for i, v in enumerate(train_metrics):
        axes[1, 1].text(i - width/2, v + 0.01, f'{v:.3f}', ha='center', va='bottom')
    for i, v in enumerate(val_metrics):
        axes[1, 1].text(i + width/2, v + 0.01, f'{v:.3f}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.show()
    
    # –í—ã–≤–æ–¥–∏–º —á–∏—Å–ª–æ–≤—ã–µ —Ä–µ–∑—É–ª—å—Ç–∞—Ç—ã
    print("\n–§–ò–ù–ê–õ–¨–ù–´–ï –†–ï–ó–£–õ–¨–¢–ê–¢–´:")
    print(f"Train Loss: {history['train_loss'][-1]:.4f}")
    print(f"Val Loss: {history['val_loss'][-1]:.4f}")
    print(f"Train Accuracy: {history['train_accuracy'][-1]:.4f}")
    print(f"Val Accuracy: {history['val_accuracy'][-1]:.4f}")
    print(f"Train F1-Score: {history['train_f1'][-1]:.4f}")
    print(f"Val F1-Score: {history['val_f1'][-1]:.4f}")

In [None]:
# #–ó–ê–ì–†–£–ó–ö–ê –î–ê–¢–ê–°–ï–¢–ê
# def fix_labels(labels):
#     if isinstance(labels, np.ndarray):
#         return labels.astype(np.int64)
#     else:
#         return np.array(labels, dtype=np.int64)

dataset = SpectrogramDataset("/home/egr/projects/nmus/data/datasets/genres_dataset.h5")
# dataset.labels = fix_labels(dataset.labels)

# —Ä–∞–∑–±–∏–≤–∞–µ–º –¥–∞—Ç–∞—Å–µ—Ç –Ω–∞ train –∏ test
dataset_size = len(dataset)
indices = list(range(dataset_size))

train_idx, test_idx = train_test_split(
    indices, 
    test_size = 0.2,
    train_size = 0.8,       
    random_state = 42        # –¥–ª—è –≤–æ—Å–ø—Ä–æ–∏–∑–≤–æ–¥–∏–º–æ—Å—Ç–∏
)

train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, test_idx)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,  
    num_workers=4,   
    pin_memory=True,
    persistent_workers=True,
    prefetch_factor=2
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,  
    num_workers=2,   
    pin_memory=True,
    persistent_workers=False,
    prefetch_factor=2
)

def check_data_shapes(train_loader, test_loader):
    print("–ü—Ä–æ–≤–µ—Ä–∫–∞ —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–µ–π –¥–∞–Ω–Ω—ã—Ö:")
    
    # –ü—Ä–æ–≤–µ—Ä—è–µ–º train loader
    for batch_idx, (spectrograms, labels) in enumerate(train_loader):
        print(f"Train batch {batch_idx}:")
        print(f"  Spectrograms: {spectrograms.shape}")  # –î–æ–ª–∂–Ω–æ –±—ã—Ç—å: [batch, 1, 128, 256]
        print(f"  Labels: {labels.shape}")              # –î–æ–ª–∂–Ω–æ –±—ã—Ç—å: [batch]
        break
    
    # –ü—Ä–æ–≤–µ—Ä—è–µ–º test loader  
    for batch_idx, (spectrograms, labels) in enumerate(test_loader):
        print(f"Test batch {batch_idx}:")
        print(f"  Spectrograms: {spectrograms.shape}")
        print(f"  Labels: {labels.shape}")
        break

check_data_shapes(train_loader, test_loader)
print(dataset.class_names)

b'rock'
Loaded 1998 samples from HDF5
–ü—Ä–æ–≤–µ—Ä–∫–∞ —Ä–∞–∑–º–µ—Ä–Ω–æ—Å—Ç–µ–π –¥–∞–Ω–Ω—ã—Ö:
DATA - shape: torch.Size([32, 1, 128, 512])
       min: -4.6927, max: 3.7392
       mean: -0.0043, std: 1.0054
target exemple: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=torch.uint8)


In [3]:
num_epochs = 10           # –∫–æ–ª–∏—á–µ—Å—Ç–≤–æ —ç–ø–æ—Ö –¥–ª—è –æ–±—É—á–µ–Ω–∏—è
threshold = 0.5            # –¥–ª—è –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–π
patience = 20              # –æ—Å—Ç–∞–Ω–æ–≤–∫–∞ –æ–±—É—á–µ–Ω–∏—è –ø–æ—Å–ª–µ N —ç–ø–æ—Ö –±–µ–∑ —É–ª—É—á—à–µ–Ω–∏–π

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  

model = TinyCNN(dataset.num_classes)
#model = load_model(model, "../data/models/genres/best_model.pth", device)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
criterion = nn.CrossEntropyLoss()

history, trained_model = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    optimizer=optimizer,
    criterion=criterion,
    num_epochs=num_epochs,
    patience=patience,
    device=device
)

plot_training_history(history)

–ò—Å–ø–æ–ª—å–∑—É–µ—Ç—Å—è —É—Å—Ç—Ä–æ–π—Å—Ç–≤–æ: cuda
–ö–æ–ª–∏—á–µ—Å—Ç–≤–æ –ø–∞—Ä–∞–º–µ—Ç—Ä–æ–≤ –º–æ–¥–µ–ª–∏: 164842


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]


RuntimeError: Expected floating point type for target with class probabilities, got Byte

In [None]:
import glob
import random
import torchaudio
from IPython.display import Audio, display

def random_audio_predictor(model, preprocess_function, class_names, folder_path='../data/train/', device='cuda', threshold=0.5):
    print("üé≤ –°–õ–£–ß–ê–ô–ù–´–ô –ê–£–î–ò–û –ê–ù–ê–õ–ò–ó")
    print("=" * 50)
    
    # –ù–∞—Ö–æ–¥–∏–º –≤—Å–µ –∞—É–¥–∏–æ —Ñ–∞–π–ª—ã –≤ –ø–∞–ø–∫–µ
    audio_extensions = ['*.wav', '*.mp3', '*.flac', '*.m4a', '*.ogg']
    audio_files = []
    
    for ext in audio_extensions:
        # –ò—â–µ–º —Ñ–∞–π–ª—ã –≤ –æ—Å–Ω–æ–≤–Ω–æ–π –ø–∞–ø–∫–µ –∏ –ø–æ–¥–ø–∞–ø–∫–∞—Ö
        audio_files.extend(glob.glob(os.path.join(folder_path, '**', ext), recursive=True))
        audio_files.extend(glob.glob(os.path.join(folder_path, ext)))
    
    if not audio_files:
        print(f"‚ùå –í –ø–∞–ø–∫–µ {folder_path} –Ω–µ –Ω–∞–π–¥–µ–Ω–æ –∞—É–¥–∏–æ —Ñ–∞–π–ª–æ–≤!")
        return None
    
    print(f"üìÅ –ü–∞–ø–∫–∞: {folder_path}")
    print(f"üìä –ù–∞–π–¥–µ–Ω–æ —Ñ–∞–π–ª–æ–≤: {len(audio_files)}")
    print(f"üéöÔ∏è  –ü–æ—Ä–æ–≥: {threshold}")
    
    # –í—ã–±–∏—Ä–∞–µ–º —Å–ª—É—á–∞–π–Ω—ã–π —Ñ–∞–π–ª
    selected_file = random.choice(audio_files)
    filename = os.path.basename(selected_file)
    file_path = os.path.dirname(selected_file)
    
    print(f"\nüéØ –í—ã–±—Ä–∞–Ω —Ñ–∞–π–ª: {filename}")
    print(f"üìÇ –ü—É—Ç—å: {file_path}")
    print("‚îÄ" * 50)
    
    # –ê–≤—Ç–æ–≤–æ—Å–ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏–µ –∞—É–¥–∏–æ
    print("üîä –ê–≤—Ç–æ–≤–æ—Å–ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏–µ...")
    display(Audio(selected_file, autoplay=True))
    
    # –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ —Ç–µ–≥–æ–≤
    print("\nü§ñ –ê–Ω–∞–ª–∏–∑ —Ç–µ–≥–æ–≤...")
    active_tags = quick_predict(
        model=model,
        audio_path=selected_file,
        preprocess_function=preprocess_function,
        class_names=class_names,
        device=device,
        threshold=threshold
    )
        
    if len(active_tags) > 0:
        print(f"‚úÖ –ù–∞–π–¥–µ–Ω–æ —Ç–µ–≥–æ–≤: {len(active_tags)}")
        print("\nüìã –°–ø–∏—Å–æ–∫ —Ç–µ–≥–æ–≤:")
        for tag in active_tags:
            print(f"   ‚Ä¢ {tag['genre']}: {tag['confidence']:.3f}")
    else:
        print("‚ùå –ù–µ—Ç —Ç–µ–≥–æ–≤ –≤—ã—à–µ –ø–æ—Ä–æ–≥–∞")
    
    return {
        'file': selected_file,
        'filename': filename,
        'active_tags': active_tags,
        'threshold': threshold
    }

def create_spectrogram(audio_path, sample_rate=22050, n_fft=2048, hop_length=512, n_mels=128):
    
    # –ó–∞–≥—Ä—É–∂–∞–µ–º –∞—É–¥–∏–æ—Ñ–∞–π–ª
    waveform, original_sample_rate = torchaudio.load(audio_path)

    # –ü—Ä–µ–æ–±—Ä–∞–∑—É–µ–º –∫ –º–æ–Ω–æ –µ—Å–ª–∏ –Ω—É–∂–Ω–æ
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # –†–µ—Å–µ–º–ø–ª–∏—Ä—É–µ–º –µ—Å–ª–∏ –Ω–µ–æ–±—Ö–æ–¥–∏–º–æ
    if original_sample_rate != sample_rate:
        resampler = torchaudio.transforms.Resample(original_sample_rate, sample_rate)
        waveform = resampler(waveform)
    
    # –°–æ–∑–¥–∞–µ–º mel-—Å–ø–µ–∫—Ç—Ä–æ–≥—Ä–∞–º–º—É
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels
    )

    # –ü—Ä–∏–º–µ–Ω—è–µ–º –ø—Ä–µ–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏–µ
    spectrogram = mel_spectrogram(waveform)
    
    # –ü—Ä–∏–º–µ–Ω—è–µ–º –ª–æ–≥–∞—Ä–∏—Ñ–º–∏—á–µ—Å–∫–æ–µ –ø—Ä–µ–æ–±—Ä–∞–∑–æ–≤–∞–Ω–∏–µ –¥–ª—è –ª—É—á—à–µ–≥–æ –≤–æ—Å–ø—Ä–∏—è—Ç–∏—è
    spectrogram = torchaudio.functional.amplitude_to_DB(
        spectrogram, 
        multiplier=10, 
        amin=1e-10, 
        db_multiplier=0, 
        top_db=80.0
    )
    
    # –ù–æ—Ä–º–∞–ª–∏–∑—É–µ–º —Å–ø–µ–∫—Ç—Ä–æ–≥—Ä–∞–º–º—É
    spectrogram = (spectrogram - spectrogram.mean()) / (spectrogram.std() + 1e-8)

    return spectrogram[:,:,0:256]

def quick_predict(model, audio_path, preprocess_function, class_names, device='cuda', threshold=0.5):
    model.to(device)
    model.eval()
    
    input_tensor = preprocess_function(audio_path).float().unsqueeze(0).to(device)
    print(input_tensor.shape)

    # –ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏–µ
    with torch.no_grad():
        outputs = model(input_tensor)
        print(f"outputs: {outputs}")
        probabilities = torch.softmax(outputs, dim=1)
        print(f"probabilities: {probabilities}")
        probs = probabilities.cpu().numpy()[0]
        print(f"probs: {probs}")
    
        # –ü–æ–ª—É—á–∞–µ–º —Ç–æ–ø-3 –ø—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–∏—è
        top3_indices = np.argsort(probs)[-3:][::-1]
        top3_predictions = []
        indexes = torch.argmax(outputs, dim=1).cpu().numpy()
        print(f"–ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω—ã–π –∂–∞–Ω—Ä: {dataset.idx_to_genre[indexes[0]]}")
        print(f"–ü–æ—Å–ª–µ argmax: {indexes}")
        for idx in top3_indices:
            genre_name = dataset.idx_to_genre[idx]
            confidence = probs[idx]
            top3_predictions.append({
                'genre': genre_name,
                'confidence': float(confidence),
                'index': idx
            })
    
    return top3_predictions

best_model = TinyCNN(dataset.num_genres)
load_model(best_model, '../data/models/genres/best_model.pth')
best_model.to(device)
best_model.eval()

# –ü—Ä–æ—Å—Ç–æ–π –≤—ã–∑–æ–≤ - –æ–¥–∏–Ω —Å–ª—É—á–∞–π–Ω—ã–π —Ñ–∞–π–ª
random_audio_predictor(
    model=best_model,
    preprocess_function=create_spectrogram,
    class_names=dataset.get_genre_names(),
    folder_path='../data/genres', #'/home/egr/–ú—É–∑—ã–∫–∞',  # –í–∞—à–∞ –ø–∞–ø–∫–∞ —Å –∞—É–¥–∏–æ
    device=device,
    threshold=0.5  # –ú–æ–∂–Ω–æ –Ω–∞—Å—Ç—Ä–æ–∏—Ç—å –ø–æ—Ä–æ–≥
)


–ú–æ–¥–µ–ª—å –∑–∞–≥—Ä—É–∂–µ–Ω–∞ –∏–∑ ../data/models/genres/best_model.pth
üé≤ –°–õ–£–ß–ê–ô–ù–´–ô –ê–£–î–ò–û –ê–ù–ê–õ–ò–ó
üìÅ –ü–∞–ø–∫–∞: ../data/genres
üìä –ù–∞–π–¥–µ–Ω–æ —Ñ–∞–π–ª–æ–≤: 2000
üéöÔ∏è  –ü–æ—Ä–æ–≥: 0.5

üéØ –í—ã–±—Ä–∞–Ω —Ñ–∞–π–ª: country.00027.wav
üìÇ –ü—É—Ç—å: ../data/genres
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
üîä –ê–≤—Ç–æ–≤–æ—Å–ø—Ä–æ–∏–∑–≤–µ–¥–µ–Ω–∏–µ...



ü§ñ –ê–Ω–∞–ª–∏–∑ —Ç–µ–≥–æ–≤...
torch.Size([1, 1, 128, 256])
outputs: tensor([[-0.0129, -0.0243,  0.0774, -0.1104,  0.0126, -0.0162, -0.0210, -0.0169,
          0.1516,  0.0214]], device='cuda:0')
probabilities: tensor([[0.0979, 0.0968, 0.1072, 0.0888, 0.1004, 0.0976, 0.0971, 0.0975, 0.1154,
         0.1013]], device='cuda:0')
probs: [0.09790107 0.09679391 0.10715699 0.08880325 0.10042746 0.09758142
 0.0971074  0.09750592 0.11540172 0.1013209 ]
–ü—Ä–µ–¥—Å–∫–∞–∑–∞–Ω–Ω—ã–π –∂–∞–Ω—Ä: reggae
–ü–æ—Å–ª–µ argmax: [8]
‚úÖ –ù–∞–π–¥–µ–Ω–æ —Ç–µ–≥–æ–≤: 3

üìã –°–ø–∏—Å–æ–∫ —Ç–µ–≥–æ–≤:
   ‚Ä¢ reggae: 0.115
   ‚Ä¢ country: 0.107
   ‚Ä¢ rock: 0.101


{'file': '../data/genres/country.00027.wav',
 'filename': 'country.00027.wav',
 'active_tags': [{'genre': 'reggae',
   'confidence': 0.11540171504020691,
   'index': np.int64(8)},
  {'genre': 'country',
   'confidence': 0.10715699195861816,
   'index': np.int64(2)},
  {'genre': 'rock', 'confidence': 0.10132090002298355, 'index': np.int64(9)}],
 'threshold': 0.5}