In [1]:
import numpy as np
# Install necessary libraries (run this cell first)
!pip install music2latent soundfile

Collecting music2latent
  Downloading music2latent-0.1.6-py3-none-any.whl.metadata (2.9 kB)
Downloading music2latent-0.1.6-py3-none-any.whl (19 kB)
Installing collected packages: music2latent
Successfully installed music2latent-0.1.6


## Import Libraries

In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
import soundfile as sf
from tqdm import tqdm
from music2latent import EncoderDecoder
import torch.nn as nn
import torch.nn.functional as F
import scipy.signal
import IPython.display as ipd

## Set Device

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


## Download NSYNTH_GUITAR_MP3 dataset

In [6]:
!git clone https://github.com/SonyCSLParis/test-lfs.git
!bash ./test-lfs/download.sh NSYNTH_GUITAR_MP3

Cloning into 'test-lfs'...
remote: Enumerating objects: 42, done.[K
remote: Counting objects: 100% (42/42), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 42 (delta 5), reused 40 (delta 3), pack-reused 0 (from 0)[K
Unpacking objects: 100% (42/42), 5.92 KiB | 466.00 KiB/s, done.
--2024-10-21 16:31:02--  https://media.githubusercontent.com/media/SonyCSLParis/test-lfs/refs/heads/master/NSYNTH_GUITAR_MP3.zip
Resolving media.githubusercontent.com (media.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to media.githubusercontent.com (media.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 334999208 (319M) [application/zip]
Saving to: ‘NSYNTH_GUITAR_MP3.zip’


2024-10-21 16:31:24 (58.8 MB/s) - ‘NSYNTH_GUITAR_MP3.zip’ saved [334999208/334999208]

Fix archive (-F) - assume mostly intact archive
Zip entry offsets do not need adjusting
 copying:

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)




  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_004-071-025.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_036-102-075.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_005-030-050.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_006-042-025.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_037-043-127.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_synthetic_008-047-075.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_032-096-127.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_034-013-025.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_008-053-127.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_electronic_003-071-100.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-guitar-train/guitar_acoustic_008-023-100.mp3  
  inflating: NSYNTH_GUITAR_MP3/nsynth-

## Define the Dataset Class

In [7]:
# Define the dataset class
class MusicLatentDataset(Dataset):
    def __init__(self, root_dir, encoder, extensions=[".wav", ".mp3", ".flac"],
                 max_samples=-1):

        self.root_dir = root_dir
        self.encoder = encoder
        self.extensions = extensions
        self.audio_files = []

        # Walk through all subfolders to gather audio files
        for root, _, files in os.walk(root_dir):
            for file in files:
                if any(file.endswith(ext) for ext in self.extensions):
                    self.audio_files.append(os.path.join(root, file))

        if max_samples < 0:
          max_samples = len(self.audio_files)

        self.audio_files = self.audio_files[:max_samples]

        self.latent_data = [self._encode_audio(file) for file in tqdm(self.audio_files, desc="Encoding Audio")]
        # calculate mean and variance of each latent dimension
        latent_data = torch.hstack(self.latent_data)
        self.mean = latent_data.mean(dim=1)
        self.std = latent_data.std(dim=1)
        # standardize self.latent_data
        self.latent_data = [((latent.permute(1, 0) - self.mean) / self.std).permute(1, 0) for latent in self.latent_data]
        torch.save({'mean': self.mean, 'std': self.std}, 'mean_std.pth')
        # load
        mean_std = torch.load('mean_std.pth')
        self.mean = mean_std['mean']
        self.std = mean_std['std']

    def __len__(self):
        return len(self.latent_data)

    def __getitem__(self, idx):
        latent = self.latent_data[idx]
        return latent.float()

    def _encode_audio(self, filename):
        # Load the audio file using soundfile
        waveform, sample_rate = sf.read(filename)

        # Ensure it's in float32 precision
        waveform = waveform.astype('float32')

        # Resample to 44100 Hz if necessary
        if sample_rate != 44100:
            num_samples = int(len(waveform) * 44100 / sample_rate)
            waveform = scipy.signal.resample(waveform, num_samples)
            sample_rate = 44100

        # Encode using music2latent
        latent = self.encoder.encode(waveform)

        # Ensure latent is in float32
        latent = latent.float()

        # Remove batch dimension if necessary
        latent = latent.squeeze(0)

        return latent[..., :32]

    def unnormalize(self, latent):
        return latent * self.std + self.mean

## Define the Diffusion U-Net Model

In [8]:
class DiffusionUnet(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_layers=3,
        base_channels=128,
        time_embedding_size=128,
    ):
        super(DiffusionUnet, self).__init__()
        self.num_layers = num_layers

        # Store channels at each layer to match encoder and decoder
        self.channels = []

        # Downsampling path
        self.down_layers = nn.ModuleList()
        in_channels_current = in_channels
        for i in range(num_layers):
            out_channels_current = base_channels * (2 ** i)
            conv = nn.Conv1d(
                in_channels_current,
                out_channels_current,
                kernel_size=4,
                stride=2,
                padding=1,
            )
            self.down_layers.append(conv)
            self.channels.append(out_channels_current)
            in_channels_current = out_channels_current

        # Bottleneck
        self.bottleneck = nn.Conv1d(
            in_channels_current, in_channels_current, kernel_size=3, padding=1
        )

        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_embedding_size),
            nn.ReLU(),
            nn.Linear(time_embedding_size, in_channels_current),
            nn.ReLU(),
        )

        # Upsampling path
        self.up_layers = nn.ModuleList()
        for i in reversed(range(num_layers)):
            out_channels_current = in_channels_current // 2
            upsample = nn.ConvTranspose1d(
                in_channels_current,
                out_channels_current,
                kernel_size=4,
                stride=2,
                padding=1,
            )
            self.up_layers.append(upsample)
            in_channels_current = out_channels_current  # Update for next layer

        # Final output layer
        self.final_layer = nn.Conv1d(
            in_channels_current, out_channels, kernel_size=1
        )

    def forward(self, x, t):
        # Ensure input is float32
        x = x.float()
        t = t.float()

        # Downsampling path
        skip_connections = []
        h = x
        for down in self.down_layers:
            h = F.relu(down(h))
            skip_connections.append(h)

        # Bottleneck
        h = self.bottleneck(h)

        # Time embedding
        t_emb = self.time_mlp(t.unsqueeze(-1))  # [batch_size, channels]
        t_emb = t_emb.unsqueeze(-1)             # [batch_size, channels, 1]
        h = h + t_emb                           # Broadcast addition

        # Upsampling path
        for up in self.up_layers:
            skip = skip_connections.pop()
            h = h + skip  # Element-wise addition of skip connection
            h = F.relu(up(h))

        # Final output layer
        out = self.final_layer(h)

        return out

## Define the RectifiedFlows Class

In [9]:
# Define the RectifiedFlows class
class RectifiedFlows(torch.nn.Module):
    def __init__(self,
                 sigma_data=1,
                 # Expected mean and standard deviation of the training data.
                 P_mean=0.,
                 P_std=1.
                 ):
        super().__init__()
        self.sigma_data = sigma_data
        self.P_std = P_std
        self.P_mean = P_mean

    def add_noise(self, x, noise, times):
        if isinstance(times, int):
            times = float(times)
        if isinstance(times, float):
            times = torch.ones((x.shape[0],), dtype=x.dtype, device=x.device) * times
        if len(times.shape) == 1:
            shape = [times.shape[0]] + (x.ndim - 1)*[1]
            times = times.reshape(shape)
        elif len(times.shape) == 2:
            shape = [times.shape[0]]+ [1]*(x.ndim - 2) + [-1]
            times = times.reshape(shape)
        return (1. - times) * x + times * noise

    def forward(self, model, y, sigma=None, return_loss=True,
                **model_kwargs) -> torch.Tensor:
        y = y.float() * self.sigma_data  # Ensure y is float32
        times_length = y.size(-1)
        times = torch.nn.functional.sigmoid(
            torch.randn(y.shape[0], dtype=torch.float32,
                        device=y.device) * self.P_std + self.P_mean
        )
        noises = torch.randn_like(y)
        v = y - noises
        noisy_samples = self.add_noise(y, noises, times)
        fv = model(
            noisy_samples,
            times,
            **model_kwargs
        )
        mse = nn.MSELoss()
        loss = mse(v, fv)
        if return_loss:
            return {'loss': loss}, (fv + noises) / self.sigma_data

## Define the Inference Function

In [10]:
def inference(rectified_flows, net, latents_shape, num_steps):
    sigma_data = rectified_flows.sigma_data
    dtype = torch.float32  # Ensuring float32
    # Adjust noise levels based on what's supported by the network.
    step_size = 1 / num_steps
    current_sample = torch.randn(latents_shape, dtype=dtype, device=next(net.parameters()).device)
    times = torch.ones(latents_shape[0], dtype=dtype, device=current_sample.device)
    for i in tqdm(range(num_steps), desc="Sampling", leave=False):
        v = net(
            current_sample,
            times
        )
        current_sample = current_sample + step_size * v
        times = times - step_size
    return current_sample / sigma_data

## Initialize the Encoder/Decoder and Datasets

In [12]:
# Initialize the encoder/decoder
encdec = EncoderDecoder()

# Initialize the dataset and dataloader
audio_folder_train = "./NSYNTH_GUITAR_MP3/nsynth-guitar-train"
audio_folder_val = "./NSYNTH_GUITAR_MP3/nsynth-guitar-valid"

dataset = MusicLatentDataset(root_dir=audio_folder_train, encoder=encdec,
                             max_samples=-1)

dataset_val = MusicLatentDataset(root_dir=audio_folder_val, encoder=encdec,
                                 max_samples=-1)


Encoding Audio: 100%|██████████| 32690/32690 [09:04<00:00, 60.01it/s]
Encoding Audio: 100%|██████████| 2081/2081 [00:34<00:00, 61.19it/s]


In [13]:
dataloader = DataLoader(dataset, batch_size=500, shuffle=True)
dataloader_val = DataLoader(dataset_val, batch_size=500, shuffle=False)

## Initialize the Model

In [14]:
# Initialize the model
in_channels = 64  # latent dimension
out_channels = 64  # output dimension
model = DiffusionUnet(in_channels, out_channels, num_layers=5).to(device)

## (Optional) Load Pretrained Weights
### If available, you can load pretrained weights for the U-Net model.

In [None]:
model.load_state_dict(torch.load('model_diffusion.pth'))

## Initialize RectifiedFlows

In [15]:
# Initialize RectifiedFlows
rectified_flows = RectifiedFlows().to(device)

## Training and Validation loop

In [23]:
from tqdm import tqdm

epochs = 1000
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
avg_loss = -1
avg_val_loss = -1

# tqdm on range(epochs) to track epoch-level progress
progress_bar = tqdm(range(epochs), desc=f"Epoch 1/{epochs}, Loss: {avg_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")

for epoch in progress_bar:
    model.train()
    total_loss = 0
    for batch in dataloader:
        y = batch.to(device).float()  # Ensure y is float32
        # Forward pass
        loss_dict, _ = rectified_flows(model, y)
        loss = loss_dict['loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)

    # Validation loop
    model.eval()
    val_total_loss = 0
    with torch.no_grad():
        for batch in dataloader_val:
            y = batch.to(device).float()
            # Forward pass
            loss_dict, _ = rectified_flows(model, y)
            loss = loss_dict['loss']
            val_total_loss += loss.item()
    
    avg_val_loss = val_total_loss / len(dataloader_val)

    # Dynamically update the description for the current epoch
    progress_bar.set_description(f"Epoch {epoch + 1}/{epochs}, Loss: {avg_loss:.4f}, Validation Loss: {avg_val_loss:.4f}")
    
    # Save model
    torch.save(model.state_dict(), 'model_diffusion.pth')


Epoch 1000/1000, Loss: 0.7298, Validation Loss: 0.9340: 100%|██████████| 1000/1000 [23:13<00:00,  1.39s/it]


## Audio Generation

In [24]:
# Generate samples
num_samples = 5
num_steps = 100  # Number of diffusion steps
os.makedirs('generated_audio', exist_ok=True)
with torch.no_grad():
    for i in range(num_samples):
        print(f"Generating sample {i+1}/{num_samples}")
        # Get latents_shape from dataset's mean
        latents_shape = (1, *dataset.mean.shape, 32)  # (batch_size=1, channels, seq_len)
        generated_latents = inference(rectified_flows, model, latents_shape, num_steps)
        # Unnormalize the generated latents
        mean_device = dataset.mean.to(generated_latents.device)
        std_device = dataset.std.to(generated_latents.device)
        generated_latents = generated_latents.squeeze(0)  # Remove batch dimension
        generated_latents = (generated_latents.permute(1, 0) * std_device + mean_device).permute(1, 0)
        # Decode the generated latents into audio
        generated_latents_np = generated_latents.cpu().numpy()  # Shape: [channels, seq_len]
        wv_rec = encdec.decode(generated_latents_np)
        # Save the audio file
        output_filename = f'generated_audio/diffusion_sample_{i + 1}.wav'
        sf.write(output_filename,
                 wv_rec[0],
                 samplerate=44100)

Generating sample 1/5


                                                           

Generating sample 2/5


                                                           

Generating sample 3/5


                                                           

Generating sample 4/5





                                                           

Generating sample 5/5


                                                           

## Play Generated Audio


In [25]:
# Use IPython audio player to play generated audio samples.
for i in range(1, num_samples + 1):
    output_filename = f'generated_audio/diffusion_sample_{i}.wav'
    print(f"Playing {output_filename}")
    ipd.display(ipd.Audio(output_filename))

Playing generated_audio/diffusion_sample_1.wav


Playing generated_audio/diffusion_sample_2.wav


Playing generated_audio/diffusion_sample_3.wav


Playing generated_audio/diffusion_sample_4.wav


Playing generated_audio/diffusion_sample_5.wav
