In [3]:
import torch
import torchaudio
import librosa
import numpy as np
import soundfile as sf
import os
import glob
import sys
from pathlib import Path
import random
from IPython.display import Audio, display


# Mount Google Drive
##drive.mount('/content/drive')

# Use the first DRIVE_DIR when using Google Drive
##DRIVE_DIR = "/content/drive/MyDrive/DeepLearning_StyleTransfer"
DRIVE_DIR = ""
##if DRIVE_DIR not in sys.path:
    ##sys.path.append(DRIVE_DIR)

##models_path = os.path.join(DRIVE_DIR, "models")
##if models_path not in sys.path:
    ##sys.path.append(models_path)


from content_encoder import ContentEncoder
from SimpleDecoder_TransformerOnly import Decoder
#from new_decoder import Decoder
from utilityFunctions import get_STFT, get_CQT, inverse_STFT, get_overlap_windows, sections2spectrogram, concat_stft_cqt
from dataloader import get_dataloader, diagnose_window_counts
from style_encoder import StyleEncoder

In [4]:
def inverse_STFT(stft_tensor, n_fft=1024, hop_length=256):
    """
    Input: torch.Tensor (2, time, freq) where 2 is [real, imaginary]

    Output: torch.Tensor (samples,) - reconstructed waveform
    """
    # Determina il dispositivo del tensore di input
    device = stft_tensor.device

    # Permuta il tensore
    stft_tensor = stft_tensor.permute(0, 2, 1)  # (2, freq, time)

    real_part = stft_tensor[0, :, :]  # (freq, frames)
    imag_part = stft_tensor[1, :, :]  # (freq, frames)
    stft_complex = torch.complex(real_part, imag_part)  # (freq, frames)

    stft_complex = stft_complex.unsqueeze(0)  # (1, freq, frames)

    # Crea la finestra e spostala sullo stesso dispositivo del tensore
    window = torch.hann_window(n_fft, device=device)

    # Inverse STFT
    waveform = torch.istft(
        stft_complex,
        n_fft=n_fft,
        hop_length=hop_length,
        window=window,
        return_complex=False
    )

    return waveform.squeeze(0)  # (samples,)


# function to generate class embeddings for style transfer
def generate_class_embeddings_from_dataloader(style_encoder, test_loader, device):
    """
    Generate class embeddings using the first batch from dataloader
    """
    style_encoder.eval()

    with torch.no_grad():
        # Get first batch
        sections, labels = next(iter(test_loader))
        sections = sections.to(device)  # (B, S, 2, T, F)
        labels = labels.to(device)      # (B,)

        print(f"📊 Generating class embeddings from batch shape: {sections.shape}")
        print(f"📋 Available labels: {labels}")

        class_embeddings = {}

        # Find piano and violin samples
        piano_idx = torch.where(labels == 0)[0]
        violin_idx = torch.where(labels == 1)[0]

        if len(piano_idx) > 0:
            piano_sections = sections[piano_idx[0]:piano_idx[0]+1]  # (1, S, 2, T, F)
            _, piano_class_emb = style_encoder(piano_sections, torch.tensor([0]).to(device))
            class_embeddings["piano"] = piano_class_emb.squeeze(0).cpu()
            print(f"✅ Piano class embedding generated: {piano_class_emb.shape}")

        if len(violin_idx) > 0:
            violin_sections = sections[violin_idx[0]:violin_idx[0]+1]  # (1, S, 2, T, F)
            _, violin_class_emb = style_encoder(violin_sections, torch.tensor([1]).to(device))
            class_embeddings["violin"] = violin_class_emb.squeeze(0).cpu()
            print(f"✅ Violin class embedding generated: {violin_class_emb.shape}")

        if len(class_embeddings) != 2:
            raise ValueError(f"Could not generate embeddings for both classes. Found: {list(class_embeddings.keys())}")

    return class_embeddings

In [5]:
# Path input/output dir
'''
TEST DIR:
/content/drive/MyDrive/test_dataset
          -> /piano
          -> /violin

OUTPUT DIR:

/content/drive/MyDrive/output
          -> /from_piano_to_violin
          -> /from_violin_to_piano
          (li crea dopo)
'''

# TEST_DIR = os.path.join(DRIVE_DIR, "test_dataset")
# OUTPUT_DIR = os.path.join(DRIVE_DIR, "output")
TEST_DIR = "dataset/test"
OUTPUT_DIR = "style_transfer_output"

# checkpoint_path = os.path.join(DRIVE_DIR, "/checkpoints/epoch100_simpleDecoder.pth")
checkpoint_path = "checkpoints\SIMPLEDECODERcheckpoint_epoch_100.pth"

# Configurations
SAMPLE_RATE = 22050
N_FFT = 1024
HOP_LENGTH = 256
WIN_LENGTH = 1024
N_BINS = 84
WINDOW_SIZE = 287
OVERLAP_PERCENTAGE = 0.3
OVERLAP_FRAMES = int(WINDOW_SIZE * OVERLAP_PERCENTAGE)
TRANSFORMER_DIM = 256
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SECTION_LENGTH = 1.0

content_encoder = ContentEncoder(
  cnn_out_dim=TRANSFORMER_DIM,
  transformer_dim=TRANSFORMER_DIM,
  num_heads=4,
  num_layers=4,
  # channels_list=[16, 32, 64, 128, 256]
  channels_list=[32, 64, 128, 256, 512, 512]  # Updated channels list
  ).to(DEVICE)

decoder = Decoder(
  d_model=TRANSFORMER_DIM,
  nhead=4,
  num_layers=4
  ).to(DEVICE)

style_encoder = StyleEncoder(
  cnn_out_dim=TRANSFORMER_DIM,
  transformer_dim=TRANSFORMER_DIM,
  num_heads=4,
  num_layers=4
  ).to(DEVICE)

if os.path.exists(checkpoint_path):
    print(f"📂 Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=DEVICE)

    try:
      content_encoder.load_state_dict(checkpoint['content_encoder'])
      style_encoder.load_state_dict(checkpoint['style_encoder'])
      decoder.load_state_dict(checkpoint['decoder'])
      print("✅ All models loaded successfully!")
    except Exception as e:
      print(f"⚠️ Error loading checkpoint: {e}")
      print("🔧 Using randomly initialized models...")
else:
  print(f"⚠️ Checkpoint not found: {checkpoint_path}")
  print("🔧 Using randomly initialized models...")

# Upload dataloader
piano_dir = os.path.join(TEST_DIR, "piano")
violin_dir = os.path.join(TEST_DIR, "violin")
dataloader = get_dataloader(piano_dir, violin_dir, batch_size=16, shuffle=False)


# Make output dir
p2v_dir = os.path.join(OUTPUT_DIR, "from_piano_to_violin")
v2p_dir = os.path.join(OUTPUT_DIR, "from_violin_to_piano")
os.makedirs(p2v_dir, exist_ok=True)
os.makedirs(v2p_dir, exist_ok=True)

📂 Loading checkpoint: checkpoints\SIMPLEDECODERcheckpoint_epoch_100.pth


  checkpoint = torch.load(checkpoint_path, map_location=DEVICE)


✅ All models loaded successfully!


In [6]:
# Genera le class embeddings dal test_loader
class_embeddings = generate_class_embeddings_from_dataloader(style_encoder, dataloader, DEVICE)

📊 Generating class embeddings from batch shape: torch.Size([16, 4, 2, 287, 597])
📋 Available labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')
✅ Piano class embedding generated: torch.Size([1, 256])
✅ Violin class embedding generated: torch.Size([1, 256])


  return torch._transformer_encoder_layer_fwd(


In [7]:
# Inference
content_encoder.eval()
decoder.eval()

with torch.no_grad():
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        # inputs: (B, S, 2, T, F), labels: (B,)
        B = inputs.size(0)
        inputs = inputs.to(DEVICE)  # move to GPU
        labels = labels.to(DEVICE)

        # Estrai embeddings di contenuto
        content_embeddings = content_encoder(inputs)

        # Costruisci class embeddings inversi per style transfer
        inverse_labels = 1 - labels  # 0->1, 1->0
        class_emb_list = []

        for l in inverse_labels:
            class_name = "piano" if l.item() == 0 else "violin"
            class_emb_list.append(class_embeddings[class_name])

        # Stack into tensor and move to DEVICE
        class_emb = torch.stack(class_emb_list).to(DEVICE)  # shape (B, 256)

        # Ricostruisci con stile opposto
        output_stfts = decoder(content_embeddings, class_emb, target_length=content_embeddings.size(1))

        # Loop sugli elementi del batch
        for i in range(B):
            label = labels[i].item()  # 0 = piano, 1 = violin
            source_class = "piano" if label == 0 else "violin"
            target_class = "violin" if label == 0 else "piano"
            save_dir = p2v_dir if label == 0 else v2p_dir

            stft_output = output_stfts[i]  # (S, 2, T, F)
            S = inputs[i].size(0)        # Numero di sezioni
            T = inputs[i].size(3)        # Lunghezza temporale di ciascuna sezione
            original_time = (S - 1) * (T - OVERLAP_FRAMES) + T
            full_stft = sections2spectrogram(stft_output, original_size=original_time, overlap=OVERLAP_FRAMES)

            waveform = inverse_STFT(full_stft)
            filename = f"{source_class}_to_{target_class}_sample{batch_idx}_{i}.wav"
            output_path = os.path.join(save_dir, filename)
            sf.write(output_path, waveform.cpu().numpy(), 22050)
            print(f"Salvato: {output_path}")

Salvato: style_transfer_output\from_piano_to_violin\piano_to_violin_sample0_0.wav
Salvato: style_transfer_output\from_piano_to_violin\piano_to_violin_sample0_1.wav
Salvato: style_transfer_output\from_piano_to_violin\piano_to_violin_sample0_2.wav
Salvato: style_transfer_output\from_piano_to_violin\piano_to_violin_sample0_3.wav
Salvato: style_transfer_output\from_piano_to_violin\piano_to_violin_sample0_4.wav
Salvato: style_transfer_output\from_piano_to_violin\piano_to_violin_sample0_5.wav
Salvato: style_transfer_output\from_piano_to_violin\piano_to_violin_sample0_6.wav
Salvato: style_transfer_output\from_piano_to_violin\piano_to_violin_sample0_7.wav
Salvato: style_transfer_output\from_violin_to_piano\violin_to_piano_sample0_8.wav
Salvato: style_transfer_output\from_violin_to_piano\violin_to_piano_sample0_9.wav
Salvato: style_transfer_output\from_violin_to_piano\violin_to_piano_sample0_10.wav
Salvato: style_transfer_output\from_violin_to_piano\violin_to_piano_sample0_11.wav
Salvato: style