#Libraries

In [None]:
!pip install torch torchvision torchaudio
!pip install scikit-learn
!pip install numpy

#Code adapted to run in batches and generate outputs for all classes available in the complete dataset


In this version I used the separation contained in the Training/Validation Musicnet itself. Each batch is validated with the data contained in validation_musics... This version is adapted to work here at Colab, it was a test to check the processing speed.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from datetime import datetime
import os
import numpy as np
import random
import torch.nn.functional as F
from sklearn.metrics import f1_score, precision_score, recall_score

# ============================
# Seed for reproducibility
# ============================
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Paths for training and validation
processed_path = "/content/Processed_Musics"       # Traning content path
validation_path = "/content/Validation_Musics"       # Validation content path

# Set up paths for logs and checkpoints
output_dir = os.path.join("/content", "Logs")
os.makedirs(output_dir, exist_ok=True)
log_file_name = os.path.join(output_dir, f"incremental_training_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt")

# ============================
# Dataset subset for incremental training
# ============================
class ProcessedDatasetSubset(Dataset):
    """
    Loads a subset of .pt files starting from a given index.
    All file data is loaded during initialization.
    """
    def __init__(self, data_dir, start_file=0, max_files=None):
        self.data = []
        self.labels = []
        self.file_names = []  # To store the names of the processed files
        files = sorted([f for f in os.listdir(data_dir) if f.endswith(".pt")])
        if max_files is not None:
            files = files[start_file:start_file+max_files]
        self.file_names = files  # Save filenames for logging
        for file in files:
            file_path = os.path.join(data_dir, file)
            data_loaded = torch.load(file_path, weights_only=True)
            mel_spec_segments = data_loaded["mel_spec_segments"]
            labels = data_loaded["y"]
            for segment, label in zip(mel_spec_segments, labels):
                if segment.shape[0] != 128:
                    raise ValueError(f"Invalid spectrogram found: {segment.shape}")
                self.data.append(segment.clone().detach().float())
                self.labels.append(int(label))
        unique_classes = sorted(set(self.labels))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(unique_classes)}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
        self.mapped_labels = [self.class_to_idx[label] for label in self.labels]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.mapped_labels[idx]

# ============================
# Dataset for validation
# ============================
class ProcessedDatasetValidation(Dataset):
    """
    Loads all data from the validation folder during initialization.
    """
    def __init__(self, data_dir, max_files=None):
        self.data = []
        self.labels = []
        files = sorted([f for f in os.listdir(data_dir) if f.endswith(".pt")])
        if max_files is not None:
            files = files[:max_files]
        for file in files:
            file_path = os.path.join(data_dir, file)
            data_loaded = torch.load(file_path, weights_only=True)
            mel_spec_segments = data_loaded["mel_spec_segments"]
            labels = data_loaded["y"]
            for segment, label in zip(mel_spec_segments, labels):
                if segment.shape[0] != 128:
                    raise ValueError(f"Invalid spectrogram found:: {segment.shape}")
                self.data.append(segment.clone().detach().float())
                self.labels.append(int(label))
        unique_classes = sorted(set(self.labels))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(unique_classes)}
        self.idx_to_class = {idx: cls for cls, idx in self.class_to_idx.items()}
        self.mapped_labels = [self.class_to_idx[label] for label in self.labels]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.mapped_labels[idx]

def collate_fn(batch):
    mel_specs = [item[0] for item in batch]
    labels = torch.tensor([item[1] for item in batch], dtype=torch.long)
    return mel_specs, labels

def get_total_num_classes(data_dir, max_files=None):
    unique_labels = set()
    files = sorted([f for f in os.listdir(data_dir) if f.endswith(".pt")])
    if max_files is not None:
        files = files[:max_files]
    for file in files:
        file_path = os.path.join(data_dir, file)
        data_loaded = torch.load(file_path, weights_only=True)
        # Convert each label to an integer
        labels = [int(l) for l in data_loaded["y"]]
        unique_labels.update(set(labels))
    unique_labels = sorted(unique_labels)
    return len(unique_labels), unique_labels




# ============================
# CNN Model
# ============================
class CNN(nn.Module):
    def __init__(self, input_shape, num_classes):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc2 = nn.Linear(128, num_classes)
        self.dropout = nn.Dropout(p=0.5)

    def forward(self, x_list):
        batch_outputs = []
        for x in x_list:
            x = x.unsqueeze(0).unsqueeze(0)  # from (128, L) to (1, 1, 128, L)
            x = F.relu(self.conv1(x))
            x = self.pool(x)
            x = F.relu(self.conv2(x))
            x = self.pool(x)
            x = F.relu(self.conv3(x))
            x = self.global_pool(x)  # (1, 128, 1, 1)
            x = x.view(x.size(0), -1)  # Flatten to (1, 128)
            x = self.dropout(x)
            x = self.fc2(x)
            batch_outputs.append(x)
        return torch.cat(batch_outputs, dim=0)

# ============================
# Incremental training function with validation and early stopping
# ============================
def train_incremental(model, optimizer, criterion, train_loader, val_loader, device, num_epochs, logger):
    model.train()
    early_stop_threshold = 0.80  # Set F1 Score to Stop
    for epoch in range(num_epochs):
        running_loss = 0.0
        all_train_labels = []
        all_train_preds = []

        for inputs, labels in train_loader:
            inputs = [inp.to(device).float() for inp in inputs]
            labels = labels.to(device).long()
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            all_train_labels.extend(labels.cpu().numpy())
            all_train_preds.extend(predicted.cpu().numpy())

        avg_loss = running_loss / len(train_loader)
        train_f1 = f1_score(all_train_labels, all_train_preds, average='weighted', zero_division=0)
        train_precision = precision_score(all_train_labels, all_train_preds, average='weighted', zero_division=0)
        train_recall = recall_score(all_train_labels, all_train_preds, average='weighted', zero_division=0)

        # Check if early stopping criterion has been met
        if train_f1 >= early_stop_threshold:
            logger(f"EARLY STOPPING ACTIVATED! Train F1 = {train_f1:.4f} surpass {early_stop_threshold:.4f} at epoch {epoch+1}.")
            break

        # Validação (modo avaliação)
        model.eval()
        all_val_labels = []
        all_val_preds = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs = [inp.to(device).float() for inp in inputs]
                labels = labels.to(device).long()
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                all_val_labels.extend(labels.cpu().numpy())
                all_val_preds.extend(predicted.cpu().numpy())

        val_f1 = f1_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)
        val_precision = precision_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)
        val_recall = recall_score(all_val_labels, all_val_preds, average='weighted', zero_division=0)

        logger(f"Epoch {epoch+1}/{num_epochs}: Train Loss = {avg_loss:.4f}, "
               f"Train F1 = {train_f1:.4f}, P = {train_precision:.4f}, R = {train_recall:.4f} || "
               f"Val F1 = {val_f1:.4f}, P = {val_precision:.4f}, R = {val_recall:.4f}")

        model.train()  # Return to training mode after validation

    return model, optimizer


# ============================
# Incremental training main function
# ============================
def main():
    # Incremental training parameters:
    files_per_batch = 25       # Number of files (songs) per batch
    num_epochs_per_batch = 10  # Batch training epochs
    num_classes, classes_list = get_total_num_classes(processed_path)
    print(f"Total de classes no dataset: {num_classes}")
    print(f"Classes: {classes_list}")
    # List of files available in the training folder
    all_train_files = sorted([f for f in os.listdir(processed_path) if f.endswith(".pt")])
    total_train_files = len(all_train_files)

    # Load the validation dataset (you can limit it if you wish)
    validation_dataset = ProcessedDatasetValidation(validation_path, max_files=None)
    val_loader = DataLoader(validation_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

    device = "cuda" if torch.cuda.is_available() else "cpu"

    with open(log_file_name, "w") as log_file:
        def log_message(message):
            print(message)
            log_file.write(message + "\n")
            log_file.flush()

        log_message(f"Total training files in the directory: {total_train_files}")
        log_message(f"Starting incremental training with batches of {files_per_batch} files and {num_epochs_per_batch} epochs per batch.")

        start_file = 0

        # Initialize the model using the classes from the first batch
        dataset_subset = ProcessedDatasetSubset(processed_path, start_file=start_file, max_files=files_per_batch)
        # Log: display filenames from the first batch
        log_message(f"Current batch (files): {dataset_subset.file_names}")
        model = CNN(input_shape=(128, 128), num_classes=num_classes).to(device)

        optimizer = optim.Adam(model.parameters(), lr=0.0001)
        criterion = nn.CrossEntropyLoss()

        # Loop incremental: treina lote a lote
        while start_file < total_train_files:
            log_message(f"Training batch files: {start_file} until {min(start_file+files_per_batch, total_train_files)}")
            # Create dataset for the current batch
            dataset_subset = ProcessedDatasetSubset(processed_path, start_file=start_file, max_files=files_per_batch)
            # Print filenames in this batch
            log_message(f"Files at this batch: {dataset_subset.file_names}")
            train_loader = DataLoader(dataset_subset, batch_size=16, shuffle=True, collate_fn=collate_fn)

            # Train the model on this batch and evaluate using the fixed validation set
            model, optimizer = train_incremental(model, optimizer, criterion, train_loader, val_loader, device, num_epochs_per_batch, log_message)

            # Save a checkpoint after training this batch
            checkpoint_path = os.path.join(output_dir, f"checkpoint_{start_file}_{start_file+files_per_batch}.pt")
            torch.save({
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "last_file_index": start_file+files_per_batch,
            }, checkpoint_path)
            log_message(f"Checkpoint salvo em: {checkpoint_path}")

            # Move to the next batch
            start_file += files_per_batch

        log_message("Incremental training completed.")

if __name__ == "__main__":
    main()
