In [None]:
import torch
import torchaudio
import librosa
import numpy as np
import soundfile as sf
import os
import glob
import sys
from pathlib import Path
import random
from IPython.display import Audio, display
from google.colab import drive

# Tutti i file spostati sul drive
# Monta Google Drive
drive.mount('/content/drive')

# Aggiungi il percorso della cartella DeepLearning_StyleTransfer al sys.path
DRIVE_DIR = "/content/drive/MyDrive/DeepLearning_StyleTransfer"
if DRIVE_DIR not in sys.path:
    sys.path.append(DRIVE_DIR)

# Se i file content_encoder.py, decoder.py e utilityFunctions.py sono nella sottocartella 'models'
models_path = os.path.join(DRIVE_DIR, "models")
if models_path not in sys.path:
    sys.path.append(models_path)


### utilityFunctions da problemi se non messo direttamente nell'ambiente di colab
from content_encoder import ContentEncoder
from decoder import Decoder
from utilityFunctions import get_STFT, get_CQT, inverse_STFT, get_overlap_windows, sections2spectrogram, concat_stft_cqt

In [27]:
# Sovrascriviamo inverse_STFT perché dava problemi di compatibilità con GPU
def inverse_STFT(stft_tensor, n_fft=1024, hop_length=256):
    """
    Input: torch.Tensor (2, time, freq) where 2 is [real, imaginary]

    Output: torch.Tensor (samples,) - reconstructed waveform
    """
    # Determina il dispositivo del tensore di input
    device = stft_tensor.device

    # Permuta il tensore
    stft_tensor = stft_tensor.permute(0, 2, 1)  # (2, freq, time)

    real_part = stft_tensor[0, :, :]  # (freq, frames)
    imag_part = stft_tensor[1, :, :]  # (freq, frames)
    stft_complex = torch.complex(real_part, imag_part)  # (freq, frames)

    stft_complex = stft_complex.unsqueeze(0)  # (1, freq, frames)

    # Crea la finestra e spostala sullo stesso dispositivo del tensore
    window = torch.hann_window(n_fft, device=device)

    # Inverse STFT
    waveform = torch.istft(
        stft_complex,
        n_fft=n_fft,
        hop_length=hop_length,
        window=window,
        return_complex=False
    )

    return waveform.squeeze(0)  # (samples,)

In [None]:
#Crea class_embeddings.pth fittizio

# Parametri
d_encoder = 256
path_class_embeddings = os.path.join(DRIVE_DIR, "class_embeddings.pth")

# Crea un tensore fittizio di dimensione [2, d_encoder] con valori casuali
class_embeddings = torch.randn(2, d_encoder)

# Crea la directory di output se non esiste
os.makedirs(os.path.dirname(path_class_embeddings), exist_ok=True)

# Salva il tensore come file .pth
torch.save(class_embeddings, path_class_embeddings)

# Verifica che il file sia stato creato
if os.path.exists(path_class_embeddings):
    print(f"File {path_class_embeddings} creato con successo!")
    # Verifica il contenuto del file
    loaded_tensor = torch.load(path_class_embeddings)
    print(f"Dimensione del tensore caricato: {loaded_tensor.shape}")
else:
    print("Errore: il file non è stato creato.")

In [29]:
# Monta Google Drive
#drive.mount('/content/drive')

# Path input/output dir
'''
TEST DIR:
/content/drive/MyDrive/test_dataset
          -> /piano
          -> /violin

OUTPUT DIR:

/content/drive/MyDrive/output
          -> /from_piano_to_violin
          -> /from_violin_to_piano
          (li crea dopo)
'''
TEST_DIR = os.path.join(DRIVE_DIR, "test_dataset")
PATH_DIR = os.path.join(DRIVE_DIR, "output")
OUTPUT_DIR = os.path.join(DRIVE_DIR, "output")


SAMPLES_PER_CLASS = 5  # Numero di campioni casuali per classe

### class_embeddings pensato come file .pth con tensore [2, d_enc]
path_class_embeddings = os.path.join(DRIVE_DIR, "class_embeddings.pth")


# Configurazioni
SAMPLE_RATE = 22050
N_FFT = 1024
HOP_LENGTH = 256
WIN_LENGTH = 1024
N_BINS = 84
WINDOW_SIZE = 287
OVERLAP_PERCENTAGE = 0.3
OVERLAP_FRAMES = int(WINDOW_SIZE * OVERLAP_PERCENTAGE)
TRANSFORMER_DIM = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SECTION_LENGTH = 1.0

In [30]:
# Funzione di style transfer
def style_transfer(waveform, sr, content_encoder, decoder, class_embeddings, target_class_id):
    stft = get_STFT(waveform, n_fft=N_FFT, hop_length=HOP_LENGTH).to(DEVICE)
    cqt = get_CQT(waveform, sample_rate=SAMPLE_RATE, n_bins=N_BINS, hop_length=HOP_LENGTH).to(DEVICE)
    input_spectrogram = concat_stft_cqt(stft, cqt)
    sections = get_overlap_windows(input_spectrogram, window_size=WINDOW_SIZE, overlap_frames=OVERLAP_FRAMES)
    sections = sections.unsqueeze(0)

    content_encoder.eval()
    with torch.no_grad():
        content_emb = content_encoder(sections)

    class_emb = class_embeddings[target_class_id].unsqueeze(0)

    decoder.eval()
    with torch.no_grad():
        output_stft = decoder(content_emb, class_emb, target_length=content_emb.size(1))

    output_stft = output_stft.squeeze(0)
    original_time = stft.size(1)
    full_spectrogram = sections2spectrogram(output_stft, original_size=original_time, overlap=OVERLAP_FRAMES)

    output_audio = inverse_STFT(full_spectrogram, n_fft=N_FFT, hop_length=HOP_LENGTH)

    return output_audio.cpu(), sr

# Funzione per processare il dataset di test
def process_test_set(test_dir, output_dir, samples_per_class=5):
    """
    Processa un numero specificato di file casuali da ogni cartella di test.

    Args:
        test_dir: str - directory del dataset di test con sottocartelle 'piano' e 'violin'
        output_dir: str - directory per salvare gli audio generati
        samples_per_class: int - numero di campioni casuali per classe da processare
    """
    # Crea directory di output (una per piano->violino + una violino -> piano)
    piano_to_violin_dir = os.path.join(output_dir, "from_piano_to_violin")
    violin_to_piano_dir = os.path.join(output_dir, "from_violin_to_piano")
    Path(piano_to_violin_dir).mkdir(parents=True, exist_ok=True)
    Path(violin_to_piano_dir).mkdir(parents=True, exist_ok=True)

    # Carica i modelli
    content_encoder = ContentEncoder().to(DEVICE)
    decoder = Decoder().to(DEVICE)

    # Carica i pesi dei modelli
    checkpoint_path = os.path.join(models_path, "checkpoint_epoch_100.pth")
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)
    content_encoder.load_state_dict(checkpoint['content_encoder'])
    decoder.load_state_dict(checkpoint['decoder'])

    class_embeddings = torch.load(path_class_embeddings).to(DEVICE)

    # Directory delle classi
    piano_dir = os.path.join(test_dir, "piano")
    violin_dir = os.path.join(test_dir, "violin")

    # Ottieni i file audio e seleziona campioni casuali
    piano_files = glob.glob(os.path.join(piano_dir, "*.mp3"))
    violin_files = glob.glob(os.path.join(violin_dir, "*.mp3"))

    if len(piano_files) < samples_per_class or len(violin_files) < samples_per_class:
        raise ValueError(f"Non abbastanza file: piano ({len(piano_files)}), violino ({len(violin_files)})")

    piano_files = random.sample(piano_files, samples_per_class)
    violin_files = random.sample(violin_files, samples_per_class)

    # Processa i file
    print("Processamento file piano → violino:")
    for audio_path in piano_files:
        output_audio, sr = process_file(audio_path, content_encoder, decoder, class_embeddings,
                                      source_class="piano", target_class_id=1, target_class="violin",
                                      output_dir=piano_to_violin_dir)

    print("\nProcessamento file violino → piano:")
    for audio_path in violin_files:
        output_audio, sr = process_file(audio_path, content_encoder, decoder, class_embeddings,
                                      source_class="violin", target_class_id=0, target_class="piano",
                                      output_dir=violin_to_piano_dir)

def process_file(audio_path, content_encoder, decoder, class_embeddings, source_class, target_class_id, target_class, output_dir):
    waveform, sr = torchaudio.load(audio_path)
    if waveform.shape[0] == 2:
        waveform = torch.mean(waveform, dim=0, keepdim=True)

    print(f"\nFile: {os.path.basename(audio_path)} ({source_class} → {target_class})")
    print("Audio originale:")
    display(Audio(waveform.numpy(), rate=sr))

    output_audio, sr = style_transfer(waveform, sr, content_encoder, decoder, class_embeddings, target_class_id)

    print(f"Audio con stile trasferito ({target_class}):")
    display(Audio(output_audio.numpy(), rate=sr))

    output_filename = f"{source_class}_to_{target_class}_{os.path.basename(audio_path)}"
    output_path = os.path.join(output_dir, output_filename)
    sf.write(output_path, output_audio.numpy(), sr)
    print(f"Salvato: {output_path}")

    return output_audio, sr

In [None]:
process_test_set(TEST_DIR, OUTPUT_DIR, samples_per_class=SAMPLES_PER_CLASS)