In [1]:
import os
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
import random
from google.colab import drive


import soundfile as sf
import glob
import sys
from pathlib import Path
import random
from IPython.display import Audio, display
from google.colab import drive

# Tutti i file spostati sul drive
# Monta Google Drive
drive.mount('/content/drive')

# Aggiungi il percorso della cartella DeepLearning_StyleTransfer al sys.path
DRIVE_DIR = "/content/drive/MyDrive/DeepLearning_StyleTransfer"
if DRIVE_DIR not in sys.path:
    sys.path.append(DRIVE_DIR)

# Se i file content_encoder.py, decoder.py e utilityFunctions.py sono nella sottocartella 'models'
models_path = os.path.join(DRIVE_DIR, "models")
if models_path not in sys.path:
    sys.path.append(models_path)


from content_encoder import ContentEncoder
from decoder import Decoder
from style_encoder import StyleEncoder, initialize_weights
from utilityFunctions import get_STFT, get_CQT, inverse_STFT, get_overlap_windows, sections2spectrogram, concat_stft_cqt, load_audio

Mounted at /content/drive


In [2]:
# Parametri
SAMPLE_RATE = 22050
CUT_TIME_SECONDS = 10
BATCH_SIZE = 16  # Metà piano, metà violino, quindi 8+8
WINDOW_SIZE = 287
OVERLAP_FRAMES = int(WINDOW_SIZE * 0.3)
N_FFT = 1024
HOP_LENGTH = 256
N_BINS = 84
TRANSFORMER_DIM = 256  # Deve corrispondere a d_encoder del modello
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Percorsi
DATASET_PATH = os.path.join(DRIVE_DIR, 'test_dataset')
PIANO_PATH = os.path.join(DATASET_PATH, 'piano')
VIOLIN_PATH = os.path.join(DATASET_PATH, 'violin')
MODEL_WEIGHTS_PATH = os.path.join(DRIVE_DIR, 'models')
OUTPUT_PATH = os.path.join(DRIVE_DIR, 'class_embeddings.pth')


In [3]:
# Parametri (modificato MAX_SECTIONS)
MAX_SECTIONS = 4  # Numero fisso di sezioni per campione (S=4)

# Dataset personalizzato (aggiornato per usare MAX_SECTIONS=4)
class AudioDataset(Dataset):
    def __init__(self, piano_dir, violin_dir, sample_rate=22050, cut_time_seconds=10, max_sections=MAX_SECTIONS):
        self.piano_files = [os.path.join(piano_dir, f) for f in os.listdir(piano_dir) if f.endswith('.mp3')]
        self.violin_files = [os.path.join(violin_dir, f) for f in os.listdir(violin_dir) if f.endswith('.mp3')]
        self.sample_rate = sample_rate
        self.cut_time_seconds = cut_time_seconds
        self.max_sections = max_sections
        self.file_list = self.piano_files + self.violin_files
        self.labels = [0] * len(self.piano_files) + [1] * len(self.violin_files)  # 0: piano, 1: violin

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        label = self.labels[idx]
        waveform, sr = load_audio(file_path, self.sample_rate, self.cut_time_seconds)

        # Calcola STFT e CQT
        stft = get_STFT(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH)  # (2, T, F1)
        cqt = get_CQT(waveform, sample_rate=sr, n_bins=N_BINS, hop_length=HOP_LENGTH)  # (2, T, F2)

        # Concatena STFT e CQT
        spectrogram = concat_stft_cqt(stft, cqt)  # (2, T, F1+F2)

        # Crea sezioni sovrapposte
        sections = get_overlap_windows(spectrogram, window_size=WINDOW_SIZE, overlap_frames=OVERLAP_FRAMES)  # (S, 2, T, F)

        # Tronca o applica padding per avere esattamente max_sections=4
        S = sections.shape[0]
        if S > self.max_sections:
            sections = sections[:self.max_sections]  # Tronca a 4 sezioni
        elif S < self.max_sections:
            # Padding con zeri
            pad_shape = (self.max_sections - S, 2, WINDOW_SIZE, sections.shape[3])
            pad_tensor = torch.zeros(pad_shape, device=sections.device)
            sections = torch.cat([sections, pad_tensor], dim=0)  # (4, 2, T, F)

        return sections, label

# Inizializza il dataset (aggiornato per passare MAX_SECTIONS=4)
dataset = AudioDataset(PIANO_PATH, VIOLIN_PATH, SAMPLE_RATE, CUT_TIME_SECONDS, MAX_SECTIONS)
# Funzione per creare batch bilanciati
def create_balanced_batch(dataset, batch_size):
    piano_indices = [i for i, label in enumerate(dataset.labels) if label == 0]
    violin_indices = [i for i, label in enumerate(dataset.labels) if label == 1]

    # Assicurati che ci siano abbastanza campioni
    if len(piano_indices) < batch_size // 2 or len(violin_indices) < batch_size // 2:
        raise ValueError("Not enough piano or violin samples for a balanced batch")

    # Seleziona casualmente metà piano e metà violino
    selected_piano = random.sample(piano_indices, batch_size // 2)
    selected_violin = random.sample(violin_indices, batch_size // 2)
    selected_indices = selected_piano + selected_violin
    random.shuffle(selected_indices)  # Mescola per evitare ordine fisso

    return selected_indices


In [4]:
# Inizializzazione del modello
model = StyleEncoder(
    in_channels=2,
    cnn_out_dim=256,
    transformer_dim=TRANSFORMER_DIM,
    num_heads=4,
    num_layers=4,
    use_cls=True
).to(DEVICE)

# Carica i pesi pre-addestrati
checkpoint_path = os.path.join(models_path, "checkpoint_epoch_100.pth")
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
model.load_state_dict(checkpoint['style_encoder'])
model.eval()

# Inizializza il dataset
dataset = AudioDataset(PIANO_PATH, VIOLIN_PATH, SAMPLE_RATE, CUT_TIME_SECONDS)

# Crea un batch bilanciato
batch_indices = create_balanced_batch(dataset, BATCH_SIZE)
batch_data = [dataset[i] for i in batch_indices]
sections_batch, labels_batch = zip(*batch_data)

# Converti in tensori
sections_batch = torch.stack(sections_batch).to(DEVICE)  # (B, S, 2, T, F)
labels_batch = torch.tensor(labels_batch, dtype=torch.long).to(DEVICE)  # (B,)

# Passa i dati attraverso il modello
with torch.no_grad():
    style_emb, class_emb = model(sections_batch, labels_batch)  # class_emb: (2, transformer_dim)

# Salva il tensore class_emb
torch.save(class_emb, OUTPUT_PATH)
print(f"Class embeddings salvati in {OUTPUT_PATH}")

# Verifica opzionale
class_emb_np = class_emb.cpu().numpy()
print(f"Shape del tensore salvato: {class_emb_np.shape}")
print(f"Embedding piano (classe 0): {class_emb_np[0][:5]}...")  # Prime 5 dimensioni
print(f"Embedding violino (classe 1): {class_emb_np[1][:5]}...")  # Prime 5 dimensioni

Class embeddings salvati in /content/drive/MyDrive/DeepLearning_StyleTransfer/class_embeddings.pth
Shape del tensore salvato: (2, 256)
Embedding piano (classe 0): [-1.3264339   0.23438288 -0.49311617  0.6152496   0.7076507 ]...
Embedding violino (classe 1): [-1.3810045   0.33650976 -0.4032237   0.44803926  0.69075215]...
