In [None]:
import os
import torch
import pickle
import warnings
import torchaudio
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split

# убираем предупреждения
warnings.filterwarnings("ignore", message=".*TorchCodec.*")
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")

# класс создания общего датасета
class MoodthemeAudioDataset(Dataset):
    def __init__(self, data_dir, labels_dict, mel_spectr_func):
        
        self.data_dir = data_dir
        self.labels_dict = labels_dict
        self.mel_spectr_func = mel_spectr_func

        # множество тегов
        self.moodthemes = set()
        for tags in labels_dict.values():
            self.moodthemes.update(tags)

        # сортриуем и преобразуем теги в индексы
        self.moodthemes = sorted(list(self.moodthemes))
        self.tag_to_idx = {tag: idx for idx, tag in enumerate(self.moodthemes)}
        self.num_classes = len(self.moodthemes)

        # собираем аудиофайлы и теги
        self.audio_files = []
        self.labels = []
        
        for folder_name in os.listdir(data_dir):
            folder_path = os.path.join(data_dir, folder_name)
            if os.path.isdir(folder_path):
                
                for file_name in os.listdir(folder_path):
                    if self._is_audio_file(file_name):
                        key = os.path.join(folder_name, file_name)
                        if key in labels_dict:
                            full_path = os.path.join(folder_path, file_name)
                            self.audio_files.append(full_path)
                            
                            file_tags = labels_dict[key]
                            multi_hot = self._tags_to_multi_hot(file_tags)
                            self.labels.append(multi_hot)

    # проверка является ли файл аудиофайлом
    def _is_audio_file(self, filename):
        return filename.lower().endswith(('.wav', '.mp3', '.flac', '.m4a', '.ogg'))
    

    # преобразует список тегов в multi-hot вектор
    def _tags_to_multi_hot(self, tags):
        multi_hot = np.zeros(self.num_classes, dtype=np.float32)
        for tag in tags:
            if tag in self.tag_to_idx:
                multi_hot[self.tag_to_idx[tag]] = 1.0
        return multi_hot


    def get_class_names(self):
        return self.moodthemes
    
    def __len__(self):
        return len(self.audio_files)
    
    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        multi_hot_label = self.labels[idx]
        
        try:
            # Получаем mel-спектрограмму
            mel_spectrogram = self.mel_spectr_func(audio_path)
            
            # Добавляем dimension для канала если нужно
            if len(mel_spectrogram.shape) == 2:
                mel_spectrogram = mel_spectrogram.unsqueeze(0)  # [1, n_mels, time]
            
            # Преобразуем метку в тензор
            label_tensor = torch.tensor(multi_hot_label, dtype=torch.float32)
            
            return mel_spectrogram, label_tensor
            
        except Exception as e:
            print(f"Error processing {audio_path}: {e}")
            # Возвращаем нулевые тензоры в случае ошибки
            mel_shape = (1, 128, 1000)  # пример формы
            return torch.zeros(mel_shape), torch.zeros(self.num_classes)
        

def preprocess_audio(audio_path, target_sr = 22050) -> torch.Tensor:
    waveform, sr = torchaudio.load(audio_path, normalize=True, channels_first=True)
    #уменьшил частоту дискретизации, чтобы ещё меньше датасет весил, а то везде 44100 Гц или почти везде
    if sr != target_sr:
        waveform = torchaudio.functional.resample(waveform, sr, target_sr)
    
    mel_spectrogram = torchaudio.transforms.MelSpectrogram(
        sample_rate=target_sr,
        n_mels=128,        
        n_fft=1024,       
        hop_length=256,
        f_min=20,
        f_max=11025
    )(waveform)

    mel_spectrogram = torch.log(mel_spectrogram + 1e-6)

    return mel_spectrogram[0,:,:512]



#------------------------------------------------------------------------------

tags_dict = {}
moodthemes = []

# открываем подготовленный словрь
with open("../data/track_tags.pkl", 'rb') as file:
    tags_dict = pickle.load(file)

# список тегов
for tags in tags_dict.values():
    for tag in tags:
        if tag not in moodthemes:
            moodthemes.append(tag)

# создаём датасет
dataset = MoodthemeAudioDataset(
    data_dir = "../data/train",
    labels_dict = tags_dict,
    mel_spectr_func = preprocess_audio
)

# разбиваем датасет на train и test
dataset_size = len(dataset)
indices = list(range(dataset_size))

train_idx, test_idx = train_test_split(
    indices, 
    test_size = 0.02,
    train_size = 0.10,       
    random_state = 42        # для воспроизводимости
)

train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, test_idx)

train_loader = DataLoader(train_dataset, batch_size = 64, shuffle = True)
test_loader = DataLoader(test_dataset, batch_size = 64, shuffle = False)

# тут будут классы моделей
class SimpleFCModel(nn.Module):
    def __init__(self, num_classes: int, input_height: int, input_width: int, hidden_size: int):
        super(SimpleFCModel, self).__init__()
        
        self.input_height = input_height
        self.input_width = input_width
        self.num_classes = num_classes
        
        # Вычисляем размер после вытягивания в вектор
        self.flatten_size = input_height * input_width
        
        self.layers = nn.Sequential(
            # Вытягиваем в плоский вектор
            nn.Flatten(),
            
            # Первый скрытый слой
            nn.Linear(self.flatten_size, hidden_size),
            nn.ReLU(),
            nn.Dropout(0.25),  # для регуляризации
            
            # Второй скрытый слой (можно добавить больше)
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(), 
            nn.Dropout(0.25),
            
            # Выходной слой
            nn.Linear(hidden_size // 2, num_classes)
        )
    
    def forward(self, x):
        return self.layers(x)

In [29]:
from tqdm import tqdm
from sklearn.metrics import f1_score

def train_epoch(model, loader, optimizer, criterion, device):
    """
    Простая функция обучения для мультитеговой классификации аудио
    """
    model.train()
    running_loss = 0.0
    all_targets = []
    all_predictions = []
    
    pbar = tqdm(loader, desc='Training')
    
    for data, targets in pbar:
        
        data = data.to(device)                          # 
        targets = targets.to(device)                    # Перемещаем данные на устройство
        
        optimizer.zero_grad()                           # Обнуляем градиенты
        
        outputs = model(data)                           # Прямой проход
        
        loss = criterion(outputs, targets)              # Вычисляем loss
        
        loss.backward()                                 # 
        optimizer.step()                                # Обратный проход
        
        with torch.no_grad():                           # 
            probs = torch.sigmoid(outputs)              #
            predictions = (probs > 0.5).float()         # Сохраняем предсказания и цели для метрик
                                                        #
            all_targets.append(targets.cpu())           #
            all_predictions.append(predictions.cpu())   #
        
        running_loss += loss.item()                     # Обновляем статистику
        
        # Обновляем прогресс-бар
        current_loss = running_loss / len(loader)       
        pbar.set_postfix({'Loss': f'{current_loss:.4f}'})
    
    # Вычисляем метрики
    all_targets = torch.cat(all_targets)
    all_predictions = torch.cat(all_predictions)
    
    # Accuracy
    accuracy = (all_predictions == all_targets).float().mean().item()
    
    # F1-score (macro)
    f1 = f1_score(all_targets.numpy(), all_predictions.numpy(), average='macro', zero_division=0)
    
    epoch_loss = running_loss / len(loader)
    
    return {
        'loss': epoch_loss,
        'accuracy': accuracy,
        'f1': f1
    }


def evaluate_epoch(model, loader, criterion, device):
    """
    Функция оценки для мультитеговой классификации аудио
    """
    model.eval()
    running_loss = 0.0
    all_targets = []
    all_predictions = []
    all_probs = []
    
    pbar = tqdm(loader, desc='Evaluation')
    
    with torch.no_grad():
        for data, targets in pbar:
            data = data.to(device)                      #
            targets = targets.to(device)                # Перемещаем данные на устройство
            
            outputs = model(data)                       # Прямой проход
            
            loss = criterion(outputs, targets)          # Вычисляем loss
            
            probs = torch.sigmoid(outputs)              # 
            predictions = (probs > 0.5).float()         # Получаем вероятности и предсказания

            all_targets.append(targets.cpu())           # 
            all_predictions.append(predictions.cpu())   #
            all_probs.append(probs.cpu())               # Сохраняем для метрик
            
            running_loss += loss.item()                 # Обновляем статистику
            
            # Обновляем прогресс-бар
            current_loss = running_loss / len(loader)
            pbar.set_postfix({'Loss': f'{current_loss:.4f}'})
    
    # Объединяем все батчи
    all_targets = torch.cat(all_targets)
    all_predictions = torch.cat(all_predictions)
    all_probs = torch.cat(all_probs)
    
    # Вычисляем метрики
    accuracy = (all_predictions == all_targets).float().mean().item()
    f1 = f1_score(all_targets.numpy(), all_predictions.numpy(), average='macro', zero_division=0)
    epoch_loss = running_loss / len(loader)
    
    return {
        'loss': epoch_loss,
        'accuracy': accuracy,
        'f1': f1,
        'predictions': all_predictions,
        'targets': all_targets,
        'probabilities': all_probs
    }

def train_model(model, train_loader, val_loader, optimizer, criterion, device, num_epochs, save_dir='../data/models'):
    # Создаем директорию для сохранения
    os.makedirs(save_dir, exist_ok=True)
    
    # Инициализируем лучшие метрики
    best_f1 = 0.0
    best_epoch = 0
    
    # История обучения
    history = {
        'train_loss': [], 'train_accuracy': [], 'train_f1': [],
        'val_loss': [], 'val_accuracy': [], 'val_f1': []
    }
    
    for epoch in range(1, num_epochs + 1):
        print(f'\nEpoch {epoch}/{num_epochs}')
        print('-' * 50)
        
        # Обучение
        train_metrics = train_epoch(model, train_loader, optimizer, criterion, device)
        
        # Валидация
        val_metrics = evaluate_epoch(model, val_loader, criterion, device)
        
        # Сохраняем историю
        history['train_loss'].append(train_metrics['loss'])
        history['train_accuracy'].append(train_metrics['accuracy'])
        history['train_f1'].append(train_metrics['f1'])
        history['val_loss'].append(val_metrics['loss'])
        history['val_accuracy'].append(val_metrics['accuracy'])
        history['val_f1'].append(val_metrics['f1'])
        
        # Выводим метрики
        print(f'Train - Loss: {train_metrics["loss"]:.4f}, Acc: {train_metrics["accuracy"]:.4f}, F1: {train_metrics["f1"]:.4f}')
        print(f'Val   - Loss: {val_metrics["loss"]:.4f}, Acc: {val_metrics["accuracy"]:.4f}, F1: {val_metrics["f1"]:.4f}')
        
        # Сохраняем лучшую модель
        if val_metrics['f1'] > best_f1:
            best_f1 = val_metrics['f1']
            best_epoch = epoch
            
            # Сохраняем модель
            best_model_path = os.path.join(save_dir, 'FCmodel_best.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_f1': best_f1,
                'val_loss': val_metrics['loss'],
                'val_accuracy': val_metrics['accuracy']
            }, best_model_path)
            print(f'✓ New best model saved! F1: {best_f1:.4f}')
        
        # Сохраняем бэкап каждые 5 эпох
        if epoch % 5 == 0:
            backup_path = os.path.join(save_dir, f'FCmodel_{epoch}.pth')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_metrics': train_metrics,
                'val_metrics': val_metrics,
                'history': history,
                'best_f1': best_f1,  
                'current_val_f1': val_metrics['f1'],
            }, backup_path)
            print(f'✓ Backup saved: {backup_path}')
    
    print(f'\nTraining completed!')
    print(f'Best model: epoch {best_epoch}, F1: {best_f1:.4f}')
    
    return history, best_f1

# Функция для загрузки модели
def load_model(model, checkpoint_path, optimizer=None):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    
    print(f'Model loaded from {checkpoint_path}')
    print(f'Epoch: {checkpoint["epoch"]}, Val F1: {checkpoint.get("val_f1", "N/A")}')
    
    return checkpoint

def plot_history(history, figsize=(15, 10)):
    """
    Строит графики истории обучения для формата из train_model
    
    Args:
        history: словарь из train_model с ключами:
            'train_loss', 'train_accuracy', 'train_f1'
            'val_loss', 'val_accuracy', 'val_f1'
        figsize: размер фигуры
    """
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Создаем subplots
    fig, axes = plt.subplots(2, 2, figsize=figsize)
    fig.suptitle('Training History', fontsize=16, fontweight='bold')
    
    # 1. График Loss
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Val Loss', linewidth=2)
    axes[0, 0].set_title('Loss')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. График Accuracy
    axes[0, 1].plot(epochs, history['train_accuracy'], 'b-', label='Train Accuracy', linewidth=2)
    axes[0, 1].plot(epochs, history['val_accuracy'], 'r-', label='Val Accuracy', linewidth=2)
    axes[0, 1].set_title('Accuracy')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].set_ylim(0, 1)
    
    # 3. График F1-score
    axes[1, 0].plot(epochs, history['train_f1'], 'b-', label='Train F1', linewidth=2)
    axes[1, 0].plot(epochs, history['val_f1'], 'r-', label='Val F1', linewidth=2)
    axes[1, 0].set_title('F1-Score (Macro)')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1-Score')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].set_ylim(0, 1)
    
    # 4. Сводный график всех метрик (валидация) - ИСПРАВЛЕННАЯ ЧАСТЬ
    axes[1, 1].plot(epochs, history['val_accuracy'], 'g-', label='Val Accuracy', linewidth=2)
    axes[1, 1].plot(epochs, history['val_f1'], 'orange', label='Val F1', linewidth=2)
    
    # ИСПРАВЛЕНИЕ: преобразуем в numpy array перед делением
    val_loss = np.array(history['val_loss'])
    if max(val_loss) > 0:  # Проверяем чтобы не делить на 0
        val_loss_norm = val_loss / max(val_loss)
    else:
        val_loss_norm = val_loss
    
    axes[1, 1].plot(epochs, val_loss_norm, 'r-', label='Val Loss (norm)', linewidth=2)
    axes[1, 1].set_title('Validation Metrics Summary')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Value')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Выводим лучшие значения
    best_val_acc = max(history['val_accuracy'])
    best_val_f1 = max(history['val_f1'])
    best_epoch_acc = history['val_accuracy'].index(best_val_acc) + 1
    best_epoch_f1 = history['val_f1'].index(best_val_f1) + 1
    
    print("Training Results Summary:")
    print(f"Best Validation Accuracy: {best_val_acc:.4f} (epoch {best_epoch_acc})")
    print(f"Best Validation F1-Score: {best_val_f1:.4f} (epoch {best_epoch_f1})")
    print(f"Final Validation Loss: {history['val_loss'][-1]:.4f}")
    print(f"Final Train Accuracy: {history['train_accuracy'][-1]:.4f}")
    print(f"Final Train F1: {history['train_f1'][-1]:.4f}")

In [None]:
# Инициализация модели
model = SimpleFCModel(
    num_classes=len(dataset.get_class_names()),
    input_height=128,
    input_width=512,
    hidden_size=256
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
criterion = nn.BCEWithLogitsLoss()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)

# Запуск обучения
history, best_f1 = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=test_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    num_epochs=5,
    save_dir='../data/models'
)

# вывод графиков
plot_history(history)

In [34]:
def quick_predict(model, audio_path, preprocess_function, class_names, device='cuda', treshold = 0.5):
    """Упрощенная версия для быстрой проверки"""
    # Убедимся, что модель на правильном устройстве
    model.to(device)
    model.eval()
    
    try:
        # Препроцессинг
        mel_spec = preprocess_function(audio_path)
        
        # Добавляем batch dimension и перемещаем на устройство
        if len(mel_spec.shape) == 3:  # [channels, height, width]
            input_tensor = mel_spec.unsqueeze(0)  # [1, channels, height, width]
        else:
            input_tensor = mel_spec.unsqueeze(0).unsqueeze(0)  # [1, 1, height, width]
        
        input_tensor = input_tensor.to(device)
        
        # Предсказание
        with torch.no_grad():
            outputs = model(input_tensor)
            probs = torch.sigmoid(outputs).cpu().numpy()[0]
        
        # Результаты
        print(f"\n Prediction for: {os.path.basename(audio_path)}")
        
        active_tags = []
        for i, (class_name, prob) in enumerate(zip(class_names, probs)):
            if prob > treshold:
                print(f"  -- {class_name}: {prob:.3f}")
                active_tags.append(class_name)
        
        if not active_tags:
            print(f"  ❌ No tags predicted above 0.5 threshold")
            
        return probs, active_tags
        
    except Exception as e:
        print(f"❌ Error during prediction: {e}")
        return None, None

best_model = SimpleFCModel(
    input_height=128,        
    input_width=512,          
    hidden_size=256,  
    num_classes=len(dataset.get_class_names()),             
)

load_model(best_model, '../data/models/FCmodel_best.pth')
best_model.to(device)
best_model.eval()

test_audio_path = "../data/train/00/7400.mp3"

quick_predict(
    model=model,
    audio_path=test_audio_path,
    preprocess_function=preprocess_audio,
    class_names=dataset.get_class_names(),
    device=device,
    treshold = 0.5
)

Model loaded from ../data/models/FCmodel_best.pth
Epoch: 1, Val F1: 0.022597383639533532

 Prediction for: 7400.mp3
  -- dark: 0.505
  -- film: 0.611


(array([0.16224001, 0.18871887, 0.1908976 , 0.11464704, 0.0152881 ,
        0.12977386, 0.11931247, 0.20430069, 0.49834076, 0.00795187,
        0.04476983, 0.1851877 , 0.50474167, 0.07586961, 0.09089307,
        0.05371616, 0.1069052 , 0.1244147 , 0.08459464, 0.13955484,
        0.06145494, 0.02307401, 0.61050445, 0.12722939, 0.08689252,
        0.0457852 , 0.35765156, 0.37032825, 0.17118411, 0.0810883 ,
        0.1789267 , 0.3510151 , 0.27122477, 0.01916159, 0.21072173,
        0.42612377, 0.31619397, 0.21093695, 0.04952242, 0.09588888,
        0.08292124, 0.24146736, 0.14612049, 0.13461299, 0.16951759,
        0.1774612 , 0.10752928, 0.3087906 , 0.03573575, 0.29870623,
        0.08306908, 0.1216257 , 0.33548093, 0.0572835 , 0.02912924,
        0.24575889, 0.17778079, 0.4892593 , 0.16161305], dtype=float32),
 ['dark', 'film'])