In [ ]:
# Install necessary libraries (run this cell first)
!pip install music2latent soundfile

In [ ]:
## Import Libraries

import torch
from torch.utils.data import Dataset, DataLoader
import os
import soundfile as sf
from tqdm import tqdm
from music2latent import EncoderDecoder
import torch.nn as nn
import torch.nn.functional as F
import scipy.signal
import IPython.display as ipd

In [ ]:
## Download NSYNTH_GUITAR_MP3 dataset

In [ ]:
!git clone https://github.com/SonyCSLParis/test-lfs.git
!bash ./test-lfs/download.sh NSYNTH_GUITAR_MP3

In [ ]:
## Define the Dataset Class
# Define the dataset class
class MusicLatentDataset(Dataset):
    def __init__(self, root_dir, encoder, extensions=[".wav", ".mp3", ".flac"]):
        self.root_dir = root_dir
        self.encoder = encoder
        self.extensions = extensions
        self.audio_files = []
        for root, _, files in os.walk(root_dir):
            for file in files:
                if any(file.endswith(ext) for ext in self.extensions):
                    self.audio_files.append(os.path.join(root, file))
        self.latent_data = [self._encode_audio(file) for file in tqdm(self.audio_files, desc="Encoding Audio")]
        # calculate mean and variance of each latent dimension
        latent_data = torch.hstack(self.latent_data)
        self.mean = latent_data.mean(dim=1)
        self.std = latent_data.std(dim=1)
        # standardize self.latent_data
        self.latent_data = [((latent.permute(1, 0) - self.mean) / self.std).permute(1, 0) for latent in self.latent_data]
        torch.save({'mean': self.mean, 'std': self.std}, 'mean_std.pth')
        # load
        mean_std = torch.load('mean_std.pth')
        self.mean = mean_std['mean']
        self.std = mean_std['std']

    def __len__(self):
        return len(self.latent_data)

    def __getitem__(self, idx):
        latent = self.latent_data[idx]
        return latent.float()

    def _encode_audio(self, filename):
        # Load the audio file using soundfile
        waveform, sample_rate = sf.read(filename)

        # Ensure it's in float32 precision
        waveform = waveform.astype('float32')

        # Resample to 44100 Hz if necessary
        if sample_rate != 44100:
            num_samples = int(len(waveform) * 44100 / sample_rate)
            waveform = scipy.signal.resample(waveform, num_samples)
            sample_rate = 44100

        # Encode using music2latent
        latent = self.encoder.encode(waveform)

        # Ensure latent is in float32
        latent = latent.float()
        
        # Remove batch dimension if necessary
        latent = latent.squeeze(0)

        return latent[..., :32]

    def unnormalize(self, latent):
        return latent * self.std + self.mean

In [ ]:
## Define the Diffusion U-Net Model

class DiffusionUnet(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        num_layers=3,
        base_channels=128,
        time_embedding_size=128,
    ):
        super(DiffusionUnet, self).__init__()
        self.num_layers = num_layers

        # Store channels at each layer to match encoder and decoder
        self.channels = []

        # Downsampling path
        self.down_layers = nn.ModuleList()
        in_channels_current = in_channels
        for i in range(num_layers):
            out_channels_current = base_channels * (2 ** i)
            conv = nn.Conv1d(
                in_channels_current,
                out_channels_current,
                kernel_size=4,
                stride=2,
                padding=1,
            )
            self.down_layers.append(conv)
            self.channels.append(out_channels_current)
            in_channels_current = out_channels_current

        # Bottleneck
        self.bottleneck = nn.Conv1d(
            in_channels_current, in_channels_current, kernel_size=3, padding=1
        )

        # Time embedding
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_embedding_size),
            nn.ReLU(),
            nn.Linear(time_embedding_size, in_channels_current),
            nn.ReLU(),
        )

        # Upsampling path
        self.up_layers = nn.ModuleList()
        for i in reversed(range(num_layers)):
            out_channels_current = in_channels_current // 2
            upsample = nn.ConvTranspose1d(
                in_channels_current,
                out_channels_current,
                kernel_size=4,
                stride=2,
                padding=1,
            )
            self.up_layers.append(upsample)
            in_channels_current = out_channels_current  # Update for next layer

        # Final output layer
        self.final_layer = nn.Conv1d(
            in_channels_current, out_channels, kernel_size=1
        )

    def forward(self, x, t):
        # Ensure input is float32
        x = x.float()
        t = t.float()

        # Downsampling path
        skip_connections = []
        h = x
        for down in self.down_layers:
            h = F.relu(down(h))
            skip_connections.append(h)

        # Bottleneck
        h = self.bottleneck(h)

        # Time embedding
        t_emb = self.time_mlp(t.unsqueeze(-1))  # [batch_size, channels]
        t_emb = t_emb.unsqueeze(-1)             # [batch_size, channels, 1]
        h = h + t_emb                           # Broadcast addition

        # Upsampling path
        for up in self.up_layers:
            skip = skip_connections.pop()
            h = h + skip  # Element-wise addition of skip connection
            h = F.relu(up(h))

        # Final output layer
        out = self.final_layer(h)

        return out

In [ ]:
## Define the RectifiedFlows Class

# Define the RectifiedFlows class
class RectifiedFlows(torch.nn.Module):
    def __init__(self,
                 sigma_data=1,
                 # Expected mean and standard deviation of the training data.
                 P_mean=0.,
                 P_std=1.
                 ):
        super().__init__()
        self.sigma_data = sigma_data
        self.P_std = P_std
        self.P_mean = P_mean

    def add_noise(self, x, noise, times):
        if isinstance(times, int):
            times = float(times)
        if isinstance(times, float):
            times = torch.ones((x.shape[0],), dtype=x.dtype, device=x.device) * times
        if len(times.shape) == 1:
            shape = [times.shape[0]] + (x.ndim - 1)*[1]
            times = times.reshape(shape)
        elif len(times.shape) == 2:
            shape = [times.shape[0]]+ [1]*(x.ndim - 2) + [-1]
            times = times.reshape(shape)
        return (1. - times) * x + times * noise

    def forward(self, model, y, sigma=None, return_loss=True,
                **model_kwargs) -> torch.Tensor:
        y = y.float() * self.sigma_data  # Ensure y is float32
        times_length = y.size(-1)
        times = torch.nn.functional.sigmoid(
            torch.randn(y.shape[0], dtype=torch.float32,
                        device=y.device) * self.P_std + self.P_mean
        )
        noises = torch.randn_like(y)
        v = y - noises
        noisy_samples = self.add_noise(y, noises, times)
        fv = model(
            noisy_samples,
            times,
            **model_kwargs
        )
        mse = nn.MSELoss()
        loss = mse(v, fv)
        if return_loss:
            return {'loss': loss}, (fv + noises) / self.sigma_data

In [ ]:
## Define the Inference Function
def inference(rectified_flows, net, latents_shape, num_steps):
    sigma_data = rectified_flows.sigma_data
    dtype = torch.float32  # Ensuring float32
    # Adjust noise levels based on what's supported by the network.
    step_size = 1 / num_steps
    current_sample = torch.randn(latents_shape, dtype=dtype, device=next(net.parameters()).device)
    times = torch.ones(latents_shape[0], dtype=dtype, device=current_sample.device)
    for i in tqdm(range(num_steps), desc="Sampling", leave=False):
        v = net(
            current_sample,
            times
        )
        current_sample = current_sample + step_size * v
        times = times - step_size
    return current_sample / sigma_data

In [ ]:
## Initialize the Encoder/Decoder and Datasets
# Initialize the encoder/decoder
encdec = EncoderDecoder()

# Initialize the dataset and dataloader
audio_folder_train = "./NSYNTH_GUITAR_MP3/nsynth-guitar-train"
audio_folder_val = "./NSYNTH_GUITAR_MP3/nsynth-guitar-valid"

dataset = MusicLatentDataset(root_dir=audio_folder_train, encoder=encdec)
dataloader = DataLoader(dataset, batch_size=500, shuffle=True)

dataset_val = MusicLatentDataset(root_dir=audio_folder_val, encoder=encdec)
dataloader_val = DataLoader(dataset_val, batch_size=500, shuffle=False)

In [ ]:
## Initialize the Model
# Initialize the model
in_channels = 64  # latent dimension
out_channels = 64  # output dimension
model = DiffusionUnet(in_channels, out_channels, num_layers=5).cuda()

In [ ]:
## Optionally Load a Pre-Trained Model
model.load_state_dict(torch.load('model_diffusion.pth'))

In [ ]:
# Initialize RectifiedFlows
rectified_flows = RectifiedFlows().cuda()

In [ ]:
## Training and Validation loop

epochs = 5000
lr = 1e-4
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in tqdm(dataloader, desc=f"Epoch {epoch + 1}/{epochs}"):
        y = batch.cuda().float()  # Ensure y is float32
        # Forward pass
        loss_dict, _ = rectified_flows(model, y)
        loss = loss_dict['loss']
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch {epoch + 1}, Loss: {avg_loss:.4f}')
    # Save model
    torch.save(model.state_dict(), 'model_diffusion.pth')

    # Validation loop
    model.eval()
    val_total_loss = 0
    with torch.no_grad():
        for batch in tqdm(dataloader_val, desc=f"Validation Epoch {epoch + 1}/{epochs}", leave=False):
            y = batch.cuda().float()
            # Forward pass
            loss_dict, _ = rectified_flows(model, y)
            loss = loss_dict['loss']
            val_total_loss += loss.item()
    avg_val_loss = val_total_loss / len(dataloader_val)
    print(f'Epoch {epoch + 1}, Validation Loss: {avg_val_loss:.4f}')

In [ ]:
## Test the Model
# Generate samples
num_samples = 5
num_steps = 100  # Number of diffusion steps
os.makedirs('generated_audio', exist_ok=True)
with torch.no_grad():
    for i in range(num_samples):
        # Get latents_shape from dataset's mean
        latents_shape = (1, *dataset.mean.shape, 32)  # (batch_size=1, channels, seq_len)
        generated_latents = inference(rectified_flows, model, latents_shape, num_steps)
        # Unnormalize the generated latents
        mean_device = dataset.mean.to(generated_latents.device)
        std_device = dataset.std.to(generated_latents.device)
        generated_latents = generated_latents.squeeze(0)  # Remove batch dimension
        generated_latents = (generated_latents.permute(1, 0) * std_device + mean_device).permute(1, 0)
        # Decode the generated latents into audio
        generated_latents_np = generated_latents.cpu().numpy()  # Shape: [channels, seq_len]
        wv_rec = encdec.decode(generated_latents_np)
        # Save the audio file
        output_filename = f'generated_audio/diffusion_sample_{i + 1}.wav'
        sf.write(output_filename,
                 wv_rec[0],
                 samplerate=44100)

In [ ]:
# Use IPython audio player to play generated audio samples.
for i in range(1, num_samples + 1):
    output_filename = f'generated_audio/sample_{i}.wav'
    print(f"Playing {output_filename}")
    ipd.display(ipd.Audio(output_filename))