<a href="https://colab.research.google.com/github/eloimoliner/unconditional-diff-STFT/blob/main/colab/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Unconditional synthesis of music using an STFT-based diffusion model

With this notebook, you can synthesize unconditional music using a diffusion model.

I provide two pretrained models trained with different instruments:
  - piano
  - strings

### Instructions for running:

* Make sure to use a GPU runtime, click:  __Runtime >> Change Runtime Type >> GPU__
* Press ▶️ on the left of each of the cells
* View the code: Double-click any of the cells
* Hide the code: Double click the right side of the cell


In [None]:
#@title #Setup environment

#@markdown Execute this cell to download the code and weights 
! git clone https://github.com/eloimoliner/unconditional-diff-STFT.git
%cd unconditional-diff-STFT

! mkdir experiments


!pip install omegaconf
! pip install dotmap


In [None]:
#@title #Imports and others

#@markdown

import soundfile as sf
import os
import logging
from tqdm import tqdm
import torch
import numpy as npp
import dataset_loader
from omegaconf import OmegaConf
from omegaconf.omegaconf import open_dict
from torch.utils.data import DataLoader
import numpy as np

from getters import get_sde
from unet_STFT import Unet2d
import scipy.signal

import yaml
from pathlib import Path
from dotmap import DotMap

args = yaml.safe_load(Path('conf/conf.yaml').read_text())
args = DotMap(args)

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

dirname = os.getcwd()



class SDESampling_context:
    """
    DDPM-like discretization of the SDE as in https://arxiv.org/abs/2107.00630
    Using context, stereo...
    """

    def __init__(self, model, sde):
        self.model = model
        self.sde = sde

    def create_schedules(self, nb_steps, stereo_split):
        t_schedule = torch.arange(0, nb_steps + 1) / nb_steps
        t_schedule = (self.sde.t_max - self.sde.t_min) * \
            t_schedule + self.sde.t_min
        split= (self.sde.t_max - self.sde.t_min) * \
           stereo_split + self.sde.t_min
        split=int(split*nb_steps)
        sigma_schedule = self.sde.sigma(t_schedule)
        m_schedule = self.sde.mean(t_schedule)

        return sigma_schedule, m_schedule, split

    def predict(
        self,
        contextL,
        contextR,
        mask,
        nb_steps,
        stereo=False,
        stereo_split=0.05
    ):

        with torch.no_grad():

            sigma, m ,stereo_split  = self.create_schedules(nb_steps, stereo_split)

            #map audio to latent space 

            #start sampling from trunc
            context=(contextL+contextR)/2
            context_noisy = m[nb_steps-1] * context + sigma[nb_steps-1] * torch.randn_like(context)
            audio=context_noisy

            for n in range(nb_steps - 1, 0, -1):
                # begins at t = 1 (n = nb_steps - 1)
                # stops at t = 2/nb_steps (n=1)
                #print(n)
                #map context to latent space

                audio = m[n-1] / m[n] * audio + (m[n] / m[n-1] * (sigma[n-1])**2 / sigma[n] - m[n-1] / m[n] * sigma[n]) * \
                    self.model(audio, sigma[n])

                if n > 0:  # everytime
                    noise = torch.randn_like(audio)
                    audio += sigma[n-1]*(1 - (sigma[n-1]*m[n] /
                                              (sigma[n]*m[n-1]))**2)**0.5 * noise
                #map to latent space
                context_noisy = m[n-1] * context + sigma[n-1] * torch.randn_like(context)

                #combine context and no context
                audio=(1-mask)*context_noisy+mask*audio
                if stereo and n==stereo_split:
                    audio_stereo=torch.clone(audio)
                    context=contextL

            # The noise level is now sigma(1/nb_steps) = sigma[0]
            # Jump step
            audio = (audio - sigma[0] * self.model(audio,
                                                   sigma[0])) / m[0]
        
            audio=(1-mask)*context+mask*audio

            if stereo:
                audio_left=audio
                audio=audio_stereo
                context=contextR
                for n in range(stereo_split - 1, 0, -1):
            
                    #print(n)
                    #map context to latent space
    
                    audio = m[n-1] / m[n] * audio + (m[n] / m[n-1] * (sigma[n-1])**2 / sigma[n] - m[n-1] / m[n] * sigma[n]) * \
                        self.model(audio, sigma[n])
    
                    if n > 0:  # everytime
                        noise = torch.randn_like(audio)
                        audio += sigma[n-1]*(1 - (sigma[n-1]*m[n] /
                                                  (sigma[n]*m[n-1]))**2)**0.5 * noise
                    #map to latent space
                    context_noisy = m[n-1] * context + sigma[n-1] * torch.randn_like(context)
    
                    #combine context and no context
                    audio=(1-mask)*context_noisy+mask*audio
    
                # The noise level is now sigma(1/nb_steps) = sigma[0]
                # Jump step
                audio = (audio - sigma[0] * self.model(audio,
                                                       sigma[0])) / m[0]
            
                audio=(1-mask)*context+mask*audio
                audio_right=audio
                return audio_left, audio_right
            else:    
                return audio


In [None]:
#@title #Generate music


#@markdown This may take a while, be patient

model = 'orchestra'  #@param ["piano", "strings", "orchestra"]

if model=="piano":
    ! wget https://github.com/eloimoliner/unconditional-diff-STFT/releases/download/weights_piano/weights_piano_uncond_synth.pt

    ! mkdir experiments/piano
    ! mv weights_piano_uncond_synth.pt experiments/piano/
    path_experiment= os.path.join(dirname, "experiments/piano")
    checkpoint="weights_piano_uncond_synth.pt"
elif model=="strings":
    ! wget https://github.com/eloimoliner/unconditional-diff-STFT/releases/download/weights_strings/weights_strings.pt

    ! mkdir experiments/strings
    ! mv weights_strings.pt experiments/strings/
    path_experiment= os.path.join(dirname, "experiments/strings")
    checkpoint="weights_strings.pt"

elif model=="orchestra":
    ! wget https://github.com/eloimoliner/unconditional-diff-STFT/releases/download/weights_orchestral/weights-200000-orchestra.pt

    ! mkdir experiments/orchestra
    ! mv weights-200000-orchestra.pt experiments/orchestra/
    path_experiment= os.path.join(dirname, "experiments/orchestra")
    checkpoint="weights-200000-orchestra.pt"

if not(os.path.exists(path_experiment)):
    os.mkdir(path_experiment)

args.model_dir=path_experiment

model=Unet2d(args).to(device)

torch.backends.cudnn.benchmark = True
sde = get_sde(args.sde_type, args.sde_kwargs)

segment_size=args.audio_len

model_dir = os.path.join(path_experiment, checkpoint) #hardcoded for now
state_dict= torch.load(model_dir, map_location=device)

if hasattr(model, 'module') and isinstance(model.module, nn.Module):
    model.module.load_state_dict(state_dict['model'])
else:
    model.load_state_dict(state_dict['model'])

sampler=SDESampling_context(model, sde)

hop_size=0.5 #@param {type:"slider", min:0, max:1, step:0.01}
overlapsize=int(args.audio_len/4) 
numchunks=15 #@param {type:"integer"}
T=51 #@param {type:"integer"}
args.inference.stereo=True  #@param {type:"boolean"}
pointer=0

for i in tqdm(range(numchunks)):
    if i==0:
        if args.inference.stereo:
            contextL=torch.zeros((1,segment_size)).to(device)
            contextR=torch.zeros((1,segment_size)).to(device)
            mask=torch.ones((1,segment_size)).to(device)
        else:
            context=torch.zeros((1,segment_size)).to(device)
            mask=torch.ones((1,segment_size)).to(device)
    else:
        if args.inference.stereo:
            mask=torch.cat((torch.zeros((1,overlapsize)),torch.ones((1,segment_size-overlapsize))),dim=1).to(device)
            contextL=torch.cat((predL[:,segment_size-overlapsize::],torch.zeros((1,segment_size-overlapsize)).to(device)),dim=1).to(device)
            contextR=torch.cat((predR[:,segment_size-overlapsize::],torch.zeros((1,segment_size-overlapsize)).to(device)),dim=1).to(device)

        else:
            mask=torch.cat((torch.zeros((1,overlapsize)),torch.ones((1,segment_size-overlapsize))),dim=1).to(device)
            context=torch.cat((pred[:,segment_size-overlapsize::],torch.zeros((1,segment_size-overlapsize)).to(device)),dim=1).to(device)

    if args.inference.stereo:
        predL, predR=sampler.predict(contextL, contextR, mask, T, stereo=True, stereo_split=0.05)
        pred_2=torch.stack((predL.squeeze(0), predR.squeeze(0)), dim=1)
    else:
        pred=sampler.predict(context, context, mask, T, stereo=False)
        pred_2=pred.squeeze(0)

    if i==0:
        bwe_data=pred_2
    else:
        bwe_data=torch.cat((bwe_data,pred_2[overlapsize::]),dim=0)

    pointer=pointer+segment_size-overlapsize

bwe_data=bwe_data.cpu().numpy()
wav_output_name="unconditional.wav"
sf.write(wav_output_name, bwe_data, args.sample_rate)


--2022-05-30 08:06:56--  https://github.com/eloimoliner/unconditional-diff-STFT/releases/download/weights_orchestral/weights-200000-orchestra.pt
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/493309012/0094c7e6-ae92-4773-8ac7-79af159e587e?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20220530%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220530T080656Z&X-Amz-Expires=300&X-Amz-Signature=416cac14672ee347c8cd6cca2007b8ca369b0f0fa8c9125857c4acaf8ad3afaf&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=493309012&response-content-disposition=attachment%3B%20filename%3Dweights-200000-orchestra.pt&response-content-type=application%2Foctet-stream [following]
--2022-05-30 08:06:56--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/

 53%|█████▎    | 8/15 [01:58<01:45, 15.08s/it]

In [None]:
#@title #Listen to the generated sound (Example: piano)
import IPython.display as ipd
#print(bwe_data.shape)
bwe_data=np.transpose(bwe_data)
ipd.Audio(bwe_data, rate=int(args.sample_rate)) # load a NumPy array

In [5]:
#@title #Listen to the generated sound
import IPython.display as ipd
#print(bwe_data.shape)
bwe_data=np.transpose(bwe_data)
ipd.Audio(bwe_data, rate=int(args.sample_rate)) # load a NumPy array

In [None]:
#@title #Download

#@markdown Execute this cell to download the generated music
from google.colab import files
files.download(wav_output_name)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>