In [3]:
import torch
import torchaudio
import librosa
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import soundfile as sf
from utilityFunctions import *
from style_encoder import *
from content_encoder import *
from discriminator import *
from losses import *

# STFT test

In [None]:
audio, sr = load_audio("./piano.wav", sample_rate=22050, cut_time_seconds=10)
display(Audio(audio.numpy(), rate=sr))

Loading audio...
Original audio shape: torch.Size([1, 1832906]), Sample rate: 44100
Processed audio shape: torch.Size([1, 220500]), Sample rate: 22050


In [3]:
stft = get_STFT(audio)
cqt = get_CQT(audio)

print("STFT shape:", stft.shape)
print("CQT shape:", cqt.shape)

sections1 = get_overlap_windows(stft)
sections2 = get_overlap_windows(cqt)

print("stft size:", sections1.shape)
print("cqt size:", sections2.shape)

STFT shape: torch.Size([2, 862, 513])
CQT shape: torch.Size([2, 862, 84])
stft size: torch.Size([4, 2, 287, 513])
cqt size: torch.Size([4, 2, 287, 84])


In [4]:
conc = concat_stft_cqt(stft, cqt)
print("Concatenated shape:", conc.shape)

sections = get_overlap_windows(conc)
print("Concatenated sections shape:", sections.shape)

Concatenated shape: torch.Size([2, 862, 597])
Concatenated sections shape: torch.Size([4, 2, 287, 597])


# Models and loss test

In [4]:
def test_style_encoder():
    
    # Parametri del test
    batch_size = 16
    num_frames = 4  # S, numero di frame temporali
    in_channels = 2  # Canali (reale e immaginario)
    T = 287  # Bin temporali
    F = 597  # Bin di frequenza
    num_classes = 2  # Numero di classi di strumenti

    # Configurazione del modello
    model = StyleEncoder(
        in_channels=in_channels,
        cnn_out_dim=256,
        transformer_dim=256,
        num_heads=4,
        num_layers=4,
        use_cls=True
    )

    # Inizializzazione dei pesi
    initialize_weights(model)

    # Passa il modello in modalità valutazione per evitare problemi con BatchNorm
    model.eval()

    # Crea dati random
    x = torch.randn(batch_size, num_frames, in_channels, T, F)  # (B, S, C, T, F)
    labels = torch.randint(0, num_classes, (batch_size,))  # (B,) etichette casuali

    # Esegui il forward pass
    try:
        style_emb, class_emb = model(x, labels)
        
        # Verifica le forme degli output
        print(f"Input shape: {x.shape}")  # Atteso: (16, 4, 2, 287, 597)
        print(f"Style embedding shape: {style_emb.shape}")  # Atteso: (16, 256)
        if class_emb is not None:
            print(f"Class embedding shape: {class_emb.shape}")  # Atteso: (2, 256)

        # Verifica che non ci siano valori NaN o infiniti
        if torch.isnan(style_emb).any() or torch.isinf(style_emb).any():
            print("Errore: style_emb contiene NaN o valori infiniti")
        else:
            print("Style embedding valido")
        if class_emb is not None:
            if torch.isnan(class_emb).any() or torch.isinf(class_emb).any():
                print("Errore: class_emb contiene NaN o valori infiniti")
            else:
                print("Class embedding valido")

        # Verifica i valori dell'output
        print(f"Style embedding mean: {style_emb.mean().item():.4f}")
        print(f"Style embedding std: {style_emb.std().item():.4f}")
        if class_emb is not None:
            print(f"Class embedding mean: {class_emb.mean().item():.4f}")
            print(f"Class embedding std: {class_emb.std().item():.4f}")

    except Exception as e:
        print(f"Errore durante il forward pass: {e}")


test_style_encoder()

Input shape: torch.Size([16, 4, 2, 287, 597])
Style embedding shape: torch.Size([16, 256])
Class embedding shape: torch.Size([2, 256])
Style embedding valido
Class embedding valido
Style embedding mean: 0.0000
Style embedding std: 1.0001
Class embedding mean: 0.0000
Class embedding std: 1.0009


In [2]:
def test_content_encoder():
    batch_size = 16
    num_frames = 4
    in_channels = 2
    T = 287
    F = 597

    model = ContentEncoder(
        in_channels=in_channels,
        cnn_out_dim=256,
        transformer_dim=256,
        num_heads=4,
        num_layers=4,
        channels_list=[16, 32, 64, 128, 256]
    )
    initialize_weights(model)
    model.eval()

    x = torch.randn(batch_size, num_frames, in_channels, T, F)
    content_emb = model(x)
    
    print(f"Input shape: {x.shape}")  # Atteso: (16, 4, 2, 287, 597)
    print(f"Content embedding shape: {content_emb.shape}")  # Atteso: (16, 4, 256)
    if torch.isnan(content_emb).any() or torch.isinf(content_emb).any():
        print("Errore: content_emb contiene NaN o valori infiniti")
    else:
        print("Content embedding valido")
    print(f"Content embedding mean: {content_emb.mean().item():.4f}")
    print(f"Content embedding std: {content_emb.std().item():.4f}")


test_content_encoder()

Input shape: torch.Size([16, 4, 2, 287, 597])
Content embedding shape: torch.Size([16, 4, 256])
Content embedding valido
Content embedding mean: 0.0000
Content embedding std: 1.0000


In [5]:
# randomly initialized
def setup_test_data():
    batch_size = 16
    num_frames = 4
    in_channels = 2
    T = 287
    F = 597
    transformer_dim = 256
    num_classes = 2

    # Initialize models
    style_encoder = StyleEncoder(
        in_channels=in_channels,
        cnn_out_dim=transformer_dim,
        transformer_dim=transformer_dim,
        num_heads=4,
        num_layers=4,
        use_cls=True
    )
    content_encoder = ContentEncoder(
        in_channels=in_channels,
        cnn_out_dim=transformer_dim,
        transformer_dim=transformer_dim,
        num_heads=4,
        num_layers=4,
        channels_list=[16, 32, 64, 128, 256]
    )
    discriminator = Discriminator(input_dim=transformer_dim)

    # Initialize weights
    initialize_weights(style_encoder)
    initialize_weights(content_encoder)
    initialize_weights(discriminator)

    # Set models to evaluation mode
    style_encoder.eval()
    content_encoder.eval()
    discriminator.eval()

    # Generate random input data
    x = torch.randn(batch_size, num_frames, in_channels, T, F)
    labels = torch.randint(0, num_classes, (batch_size,))

    # Get embeddings
    style_emb, class_emb = style_encoder(x, labels)
    content_emb = content_encoder(x)

    return style_emb, class_emb, content_emb, discriminator, labels



# Test function for infoNCE_loss
def test_infoNCE_loss():
    print("\nTesting infoNCE_loss...")
    style_emb, _, _, _, labels = setup_test_data()
    
    # Verify input shapes
    assert style_emb.shape == (16, 256), f"Expected style_emb shape (16, 256), got {style_emb.shape}"
    assert labels.shape == (16,), f"Expected labels shape (16,), got {labels.shape}"
    
    # Compute loss
    loss = infoNCE_loss(style_emb, labels, temperature=0.1)
    
    # Verify output
    assert loss.dim() == 0, f"Expected scalar loss, got shape {loss.shape}"
    assert loss.item() >= 0, f"Expected non-negative loss, got {loss.item()}"
    print(f"infoNCE_loss: {loss.item():.4f}")
    print("infoNCE_loss test passed!")



# Test function for margin_loss
def test_margin_loss():
    print("\nTesting margin_loss...")
    _, class_emb, _, _, _ = setup_test_data()
    
    # Verify input shape
    assert class_emb.shape == (2, 256), f"Expected class_emb shape (2, 256), got {class_emb.shape}"
    
    # Compute loss
    loss = margin_loss(class_emb, margin=1.0)
    
    # Verify output
    assert loss.dim() == 0, f"Expected scalar loss, got shape {loss.shape}"
    assert loss.item() >= 0, f"Expected non-negative loss, got {loss.item()}"
    print(f"margin_loss: {loss.item():.4f}")
    print("margin_loss test passed!")



# Test function for adversarial_loss
def test_adversarial_loss():
    print("\nTesting adversarial_loss...")
    style_emb, class_emb, content_emb, discriminator, labels = setup_test_data()
    
    # Verify input shapes
    assert style_emb.shape == (16, 256), f"Expected style_emb shape (16, 256), got {style_emb.shape}"
    assert class_emb.shape == (2, 256), f"Expected class_emb shape (2, 256), got {class_emb.shape}"
    assert content_emb.shape == (16, 4, 256), f"Expected content_emb shape (16, 4, 256), got {content_emb.shape}"
    assert labels.shape == (16,), f"Expected labels shape (16,), got {labels.shape}"
    
    # Compute loss
    total_loss, style_loss, content_loss = adversarial_loss(
        style_emb, class_emb, content_emb, discriminator, labels,
        lambda_content=1.0, lambda_class=0.5, lambda_style=1.0
    )
    
    # Verify outputs
    assert total_loss.dim() == 0, f"Expected scalar total_loss, got shape {total_loss.shape}"
    assert style_loss.dim() == 0, f"Expected scalar style_loss, got shape {style_loss.shape}"
    assert content_loss.dim() == 0, f"Expected scalar content_loss, got shape {content_loss.shape}"
    print(f"adversarial_loss - total: {total_loss.item():.4f}, style: {style_loss.item():.4f}, content: {content_loss.item():.4f}")
    print("adversarial_loss test passed!")




# Test function for disentanglement_loss (both cross-covariance and HSIC)
def test_disentanglement_loss():
    print("\nTesting disentanglement_loss...")
    style_emb, _, content_emb, _, _ = setup_test_data()
    
    # Verify input shapes
    assert style_emb.shape == (16, 256), f"Expected style_emb shape (16, 256), got {style_emb.shape}"
    assert content_emb.shape[0] == 16 and content_emb.shape[2] == 256, f"Expected content_emb batch and dim (16, *, 256), got {content_emb.shape}"
    
    # Test cross-covariance penalty
    print("Testing cross-covariance penalty..")
    loss_cov = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=False)
    assert loss_cov.dim() == 0, f"Expected scalar loss_cov, got shape {loss_cov.shape}"
    assert loss_cov.item() >= 0, f"Expected non-negative loss_cov, got {loss_cov.item()}"
    print(f"disentanglement_loss (cross-cov): {loss_cov.item():.4f}")
    
    # Test HSIC
    print("Testing HSIC...")
    loss_hsic = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=True)
    assert loss_hsic.dim() == 0, f"Expected scalar loss_hsic, got shape {loss_hsic.shape}"
    assert loss_hsic.item() >= 0, f"Expected non-negative loss_hsic, got {loss_hsic.item()}"
    print(f"disentanglement_loss (HSIC): {loss_hsic.item():.4f}")
    print("disentanglement_loss test passed!")

In [6]:
def run_all_tests():
    print("Running all loss tests...")
    test_infoNCE_loss()
    test_margin_loss()
    test_adversarial_loss()
    test_disentanglement_loss()
    print("\nAll tests completed successfully!")
    
run_all_tests()

Running all loss tests...

Testing infoNCE_loss...
infoNCE_loss: 2.7081
infoNCE_loss test passed!

Testing margin_loss...
margin_loss: 0.8602
margin_loss test passed!

Testing adversarial_loss...
adversarial_loss - total: 1.1448, style: 1.8263, content: -0.6815
adversarial_loss test passed!

Testing disentanglement_loss...
Testing cross-covariance penalty..
Norm of S: 0.5449
Norm of C: 0.0000
disentanglement_loss (cross-cov): 0.0000
Testing HSIC...
Norm of S: 0.5449
Norm of C: 0.0000
Norm of K: 12.6356
Norm of L: 16.0000
disentanglement_loss (HSIC): 0.0000
disentanglement_loss test passed!

All tests completed successfully!


In [7]:
def test_disentanglement_loss():
    print("\nTesting disentanglement_loss...")
    style_emb, _, content_emb, _, _ = setup_test_data()
    assert style_emb.shape == (16, 256), f"Expected style_emb shape (16, 256), got {style_emb.shape}"
    assert content_emb.shape[0] == 16 and content_emb.shape[2] == 256, f"Expected content_emb batch and dim (16, *, 256), got {content_emb.shape}"
    
    print("Testing cross-covariance penalty...")
    loss_cov = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=False)
    assert loss_cov.dim() == 0, f"Expected scalar loss_cov, got shape {loss_cov.shape}"
    assert loss_cov.item() >= 0, f"Expected non-negative loss_cov, got {loss_cov.item()}"
    print(f"disentanglement_loss (cross-cov): {loss_cov.item():.4f}")
    
    print("Testing HSIC...")
    loss_hsic = disentanglement_loss(style_emb, content_emb.mean(dim=1), use_hsic=True)
    assert loss_hsic.dim() == 0, f"Expected scalar loss_hsic, got shape {loss_hsic.shape}"
    assert loss_hsic.item() >= 0, f"Expected non-negative loss_hsic, got {loss_hsic.item()}"
    print(f"disentanglement_loss (HSIC): {loss_hsic.item():.4f}")
    print("disentanglement_loss test passed!")
    
test_disentanglement_loss()


Testing disentanglement_loss...
Testing cross-covariance penalty...
Norm of S: 2.1769
Norm of C: 0.0000
disentanglement_loss (cross-cov): 0.0000
Testing HSIC...
Norm of S: 2.1769
Norm of C: 0.0000
Norm of K: 14.9298
Norm of L: 16.0000
disentanglement_loss (HSIC): 0.0000
disentanglement_loss test passed!


# Batch dataloader test

In [1]:
import torch
import numpy as np
import random
import IPython.display as ipd
from utilityFunctions import inverse_STFT, plot_stft
from dataloader import get_dataloader

def play_batch(batch_size=4, plot_stft_flag=False):
    """
    Carica un batch randomico dal dataloader, stampa le label e lo strumento associato,
    riproduce l'audio ricostruito dalla STFT e opzionalmente mostra la STFT.
    """
    # Crea il dataloader (usa shuffle per randomizzare i batch)
    dataloader = get_dataloader(
        piano_dir="dataset/train/piano",
        violin_dir="dataset/train/violin",
        batch_size=batch_size,
        shuffle=True,
        stats_path="stats_stft_cqt.npz"
    )
    print(f"\nDataloader creato con {len(dataloader)} batch di dimensione {batch_size}")
          
    # Scegli un batch randomico
    num_batches = len(dataloader)
    batch_idx = random.randint(0, num_batches - 1)
    for i, (x, labels) in enumerate(dataloader):
        if i == batch_idx:
            break

    print(f"\n🎲 Batch randomico scelto: {batch_idx+1}/{num_batches}")
    print(f"Shape batch: {x.shape}, Shape labels: {labels.shape}")
    print(f"Labels batch: {labels.tolist()}")

    # Associa label a strumento
    for idx, label in enumerate(labels):
        strumento = "Piano" if label.item() == 0 else "Violin"
        print(f"Sample {idx}: Label={label.item()} ({strumento})")

    # Per ogni audio nel batch: ricostruisci e riproduci
    for idx in range(x.shape[0]):
        label = labels[idx].item()
        strumento = "Piano" if label == 0 else "Violin"
        # Prendi la prima finestra e il primo canale (shape: (2, T, F))
        stft_tensor = x[idx, 0]  # (2, T, F_stft+CQT)
        # Prendi solo la parte STFT (usa costante STFT_F)
        stft_part = stft_tensor[:, :, :513]  # (2, T, 513)
        # La funzione inverse_STFT si aspetta (2, T, F), quindi non serve permutare
        # Ricostruisci audio
        try:
            audio = inverse_STFT(stft_part)
            audio = audio / (torch.max(torch.abs(audio)) + 1e-8)
            print(f"\n🔊 Sample {idx} - {strumento}:")
            display(ipd.Audio(audio.cpu().numpy(), rate=22050))
        except Exception as e:
            print(f"Errore nella ricostruzione audio: {e}")

        # Plot STFT se richiesto
        if plot_stft_flag:
            print(f"STFT di Sample {idx} - {strumento}:")
            plot_stft(stft_part)

# Esempio di utilizzo:
play_batch(batch_size=16, plot_stft_flag=False)


Dataloader creato con 38 batch di dimensione 16

🎲 Batch randomico scelto: 38/38
Shape batch: torch.Size([16, 4, 2, 287, 597]), Shape labels: torch.Size([16])
Labels batch: [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1]
Sample 0: Label=0 (Piano)
Sample 1: Label=0 (Piano)
Sample 2: Label=0 (Piano)
Sample 3: Label=0 (Piano)
Sample 4: Label=0 (Piano)
Sample 5: Label=0 (Piano)
Sample 6: Label=0 (Piano)
Sample 7: Label=0 (Piano)
Sample 8: Label=1 (Violin)
Sample 9: Label=1 (Violin)
Sample 10: Label=1 (Violin)
Sample 11: Label=1 (Violin)
Sample 12: Label=1 (Violin)
Sample 13: Label=1 (Violin)
Sample 14: Label=1 (Violin)
Sample 15: Label=1 (Violin)

🔊 Sample 0 - Piano:



🔊 Sample 1 - Piano:



🔊 Sample 2 - Piano:



🔊 Sample 3 - Piano:



🔊 Sample 4 - Piano:



🔊 Sample 5 - Piano:



🔊 Sample 6 - Piano:



🔊 Sample 7 - Piano:



🔊 Sample 8 - Violin:



🔊 Sample 9 - Violin:



🔊 Sample 10 - Violin:



🔊 Sample 11 - Violin:



🔊 Sample 12 - Violin:



🔊 Sample 13 - Violin:



🔊 Sample 14 - Violin:



🔊 Sample 15 - Violin:


In [1]:
import torch
import numpy as np
import random
import IPython.display as ipd
from NEW_utilityFunctions import inverse_STFT, plot_stft
from NEW_dataloader import get_dataloader

loader = get_dataloader(batch_size=8, piano_dir="dataset/train/piano", violin_dir="dataset/train/violin", shuffle=False)
x, labels = next(iter(loader))

print(f"Shape batch: {x.shape}, Shape labels: {labels.shape}")
print(f"Labels batch: {labels.tolist()}")

Shape batch: torch.Size([8, 4, 2, 287, 597]), Shape labels: torch.Size([8])
Labels batch: [0, 0, 0, 0, 1, 1, 1, 1]
