<a href="https://colab.research.google.com/github/eloimoliner/CQTdiff/blob/main/notebook/demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Solving Audio Inverse Problems with a Diffusion Model

This notebook is a demo of the gramophone noise synthesis method proposed in:

> E. Moliner and V. Välimäki,, "Solving audio inverse problems with a diffusion model", submitted to IEEE International Conference on Acoustics, Speech, and Signal Processing (ICASSP) 2023
Rhodes, Greece, May, 2023

Listen to our [audio samples](http://research.spa.aalto.fi/publications/papers/icassp23-cqt-diff/)

### Instructions for running:

* Make sure to use a GPU runtime, click:  __Runtime >> Change Runtime Type >> GPU__
* Press ▶️ on the left of each of the cells
* View the code: Double-click any of the cells
* Hide the code: Double click the right side of the cell


In [None]:
!git clone https://github.com/eloimoliner/bwe_historical_recordings.git

Cloning into 'bwe_historical_recordings'...
remote: Enumerating objects: 167, done.[K
remote: Counting objects: 100% (167/167), done.[K
remote: Compressing objects: 100% (108/108), done.[K
remote: Total 167 (delta 94), reused 109 (delta 51), pack-reused 0[K
Receiving objects: 100% (167/167), 76.75 KiB | 3.07 MiB/s, done.
Resolving deltas: 100% (94/94), done.


In [None]:
%cd bwe_historical_recordings

/content/bwe_historical_recordings/denoising-historical-recordings/bwe_historical_recordings


In [None]:
!bash prepare_data.sh

--2022-04-13 07:56:45--  https://github.com/eloimoliner/bwe_historical_recordings/releases/download/v0.0-alpha/audio_examples.zip
Resolving github.com (github.com)... 140.82.114.4
Connecting to github.com (github.com)|140.82.114.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/448304570/9414369f-d90e-4e18-9379-a7c5aab87836?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20220413%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20220413T075645Z&X-Amz-Expires=300&X-Amz-Signature=de3e5ddfb996e09221a35341afac07c4c12e6cf76a27e19a5012d94f8283d6da&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=448304570&response-content-disposition=attachment%3B%20filename%3Daudio_examples.zip&response-content-type=application%2Foctet-stream [following]
--2022-04-13 07:56:45--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/448304570/9414369f-d90e-

In [None]:
! pip install hydra



In [None]:
import os
import hydra
import logging
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
print("CUDA??",torch.cuda.is_available())
import soundfile as sf
import datetime
import numpy as np
import scipy
from tqdm import tqdm

import utils.utils as utils 
import utils.lowpass_utils as lowpass_utils 
import  utils.dataset_loader as dataset_loader
import  utils.stft_loss as stft_loss
import models.discriminators as discriminators
import models.unet2d_generator as unet2d_generator
import models.audiounet as audiounet
import models.seanet as seanet
import models.denoiser as denoiser

import yaml
from pathlib import Path





device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = yaml.safe_load(Path('conf/conf.yaml').read_text())
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__
args=dotdict(args)
unet_args=dotdict(args.unet_generator)
args_denoiser=dotdict(args.denoiser)

gener_model = unet2d_generator.Unet2d(unet_args=unet_args).to(device)

#dirname = os.path.dirname(__file__)

#print("something went wrong while loading the checkpoint")

checkpoint_filepath_denoiser=os.path.join('/content/bwe_historical_recordings','experiments_denoiser/pretrained_model/checkpoint_denoiser')
unet_model = denoiser.MultiStage_denoise(unet_args=args_denoiser)
unet_model.load_state_dict(torch.load(checkpoint_filepath_denoiser, map_location=device))
unet_model.to(device)



def apply_denoiser_model(segment):
    segment_TF=utils.do_stft(segment,win_size=args.stft["win_size"], hop_size=args.stft["hop_size"], device=device)
    #segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
    with torch.no_grad():
        pred = unet_model(segment_TF)
    if args_denoiser.num_stages>1:
        pred=pred[0]

    pred_time=utils.do_istft(pred, args.stft["win_size"], args.stft["hop_size"],device)
    #pred_time=pred_time[0]
    #pred_time=pred_time[0].detach().cpu().numpy()
    return pred_time

def apply_bwe_model(x): 
    x_init=x

    #if args.bwe.add_noise.add_noise:
    n=args.bwe["add_noise"]["power"]*torch.randn(x.shape)
    print("adding noise")
    x=x+n.to(device) #not tested, need to tune the noise power
    
    
    
    xF =utils.do_stft(x,win_size=args.stft["win_size"], hop_size=args.stft["hop_size"], device=device)

    with torch.no_grad():
        y_gF = gener_model(xF)
    
    y_g=utils.do_istft(y_gF, args.stft["win_size"], args.stft["hop_size"], device)
    y_g=y_g[:,0:x.shape[-1]]
    y_g=y_g.unsqueeze(1)


    pred_time=y_g.squeeze(1)
    pred_time=pred_time[0].detach().cpu().numpy()
    return pred_time







CUDA?? True


In [None]:
def process_audio(audio, use_denoiser=True, use_bwe=True):
    try:
        data, samplerate = sf.read(audio)
    except:
        print("reading relative path")
        data, samplerate = sf.read(audio)

    #Stereo to mono
    if len(data.shape)>1:
        data=np.mean(data,axis=1)

    if samplerate!=22050: 
        print("Resampling")

        data=scipy.signal.resample(data, int((22050  / samplerate )*len(data))+1)  


    segment_size=22050*5  #5s segment

    length_data=len(data)
    overlapsize=1024 #samples (46 ms)
    window=np.hanning(2*overlapsize)
    window_right=window[overlapsize::]
    window_left=window[0:overlapsize]
    audio_finished=False
    pointer=0
    denoised_data=np.zeros(shape=(len(data),))
    bwe_data=np.zeros(shape=(len(data),))
    numchunks=int(np.ceil(length_data/segment_size))

      
    for i in tqdm(range(numchunks)):
        if pointer+segment_size<length_data:
            segment=data[pointer:pointer+segment_size]
            #dostft
            segment = torch.from_numpy(segment)
            segment=segment.type(torch.FloatTensor)
            segment=segment.to(device)
            segment=torch.unsqueeze(segment,0)

            if use_denoiser:
                denoised_time=apply_denoiser_model(segment)
                segment=denoised_time
                denoised_time=denoised_time[0].detach().cpu().numpy()
                #just concatenating with a little bit of OLA
                if pointer==0:
                    denoised_time=np.concatenate((denoised_time[0:int(segment_size-overlapsize)], np.multiply(denoised_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
                else:
                    denoised_time=np.concatenate((np.multiply(denoised_time[0:int(overlapsize)], window_left), denoised_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(denoised_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
                denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+denoised_time

            if use_bwe:
                pred_time =apply_bwe_model(segment)
                
                if pointer==0:
                    pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
                else:
                    pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
                    
                bwe_data[pointer:pointer+segment_size]=bwe_data[pointer:pointer+segment_size]+pred_time

            pointer=pointer+segment_size-overlapsize
        else: 
            segment=data[pointer::]

            lensegment=len(segment)
            segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0)

            audio_finished=True
            #dostft
            segment = torch.from_numpy(segment)
            segment=segment.type(torch.FloatTensor)
            segment=segment.to(device)
            segment=torch.unsqueeze(segment,0)
            if use_denoiser:
                denoised_time=apply_denoiser_model(segment)
                segment=denoised_time
                denoised_time=denoised_time[0].detach().cpu().numpy()
                if pointer!=0:
                    denoised_time=np.concatenate((np.multiply(denoised_time[0:int(overlapsize)], window_left), denoised_time[int(overlapsize):int(segment_size)]),axis=0)
                denoised_data[pointer::]=denoised_data[pointer::]+denoised_time[0:lensegment]

            if use_bwe:
                pred_time =apply_bwe_model(segment)
                
                if pointer!=0:
                    pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0)
                
                bwe_data[pointer::]=bwe_data[pointer::]+pred_time[0:lensegment]
    return denoised_data, bwe_data

In [None]:
#@title #Upload file to denoise
#@markdown not implemented yet, sorry :(
##@markdown Execute this cell to upload a single audio recording you would like to denoise (accepted extensions: .wav, .flac, .mp3)
from google.colab import files
uploaded=files.upload()

In [None]:
#Please select your preferences

use_denoiser=True #@param {type:"boolean"} 
use_bwe=True #@param {type:"boolean"} 

mode="orchestra" #@param ["piano", "strings", "orchestra"]


if mode=="orchestra":
    checkpoint_filepath = os.path.join('/content/bwe_historical_recordings','experiments_bwe/orchestra/checkpoint_orchestra')

    gener_model.load_state_dict(torch.load(checkpoint_filepath, map_location=device))
elif mode=="piano":
    checkpoint_filepath = os.path.join('/content/bwe_historical_recordings','experiments_bwe/piano/checkpoint_piano')
    gener_model.load_state_dict(torch.load(checkpoint_filepath, map_location=device))

elif mode=="strings":11111111111111111111111111111111111
    checkpoint_filepath = os.path.join('/content/bwe_historical_recordings','experiments_bwe/strings/checkpoint_strings')
    gener_model.load_state_dict(torch.load(checkpoint_filepath, map_location=device))


In [None]:
#@title #Enhance

#@markdown Execute this cell to denoise the uploaded file. Modify it to ad the path to you audio file
#add here your audio file
fn="audio_examples/1st_Movement-Allegro_mod_-_PHILADELPHIA_SYMPHONY_ORCHESTRA_noisy_input.wav"
print('Processing uploaded file "{name}"'.format(
    name=fn))
denoise_data, bwe_data=process_audio(fn, use_bwe=True, use_denoiser=True )
basename=os.path.splitext(fn)[0]
wav_output_name=basename+"_denoised"+".wav"
sf.write(wav_output_name, denoise_data, 22050)
wav_output_name=basename+"_bwe"+".wav"
sf.write(wav_output_name, bwe_data, 22050)

Denoising uploaded file "audio_examples/1st_Movement-Allegro_mod_-_PHILADELPHIA_SYMPHONY_ORCHESTRA_noisy_input.wav"
Resampling


  0%|          | 0/40 [00:00<?, ?it/s]

adding noise


  2%|▎         | 1/40 [00:02<01:26,  2.23s/it]

adding noise


  5%|▌         | 2/40 [00:04<01:23,  2.20s/it]

adding noise


  8%|▊         | 3/40 [00:06<01:20,  2.18s/it]

adding noise


 10%|█         | 4/40 [00:08<01:18,  2.18s/it]

adding noise


 12%|█▎        | 5/40 [00:10<01:16,  2.19s/it]

adding noise


 15%|█▌        | 6/40 [00:13<01:14,  2.18s/it]

adding noise


 18%|█▊        | 7/40 [00:15<01:11,  2.18s/it]

adding noise


 20%|██        | 8/40 [00:17<01:09,  2.18s/it]

adding noise


 22%|██▎       | 9/40 [00:19<01:07,  2.18s/it]

adding noise


 25%|██▌       | 10/40 [00:21<01:05,  2.18s/it]

adding noise


 28%|██▊       | 11/40 [00:23<01:03,  2.18s/it]

adding noise


 30%|███       | 12/40 [00:26<01:00,  2.17s/it]

adding noise


 32%|███▎      | 13/40 [00:28<00:58,  2.17s/it]

adding noise


 35%|███▌      | 14/40 [00:30<00:56,  2.17s/it]

adding noise


 38%|███▊      | 15/40 [00:32<00:54,  2.17s/it]

adding noise


 40%|████      | 16/40 [00:34<00:51,  2.17s/it]

adding noise


 42%|████▎     | 17/40 [00:36<00:49,  2.17s/it]

adding noise


 45%|████▌     | 18/40 [00:39<00:47,  2.17s/it]

adding noise


 48%|████▊     | 19/40 [00:41<00:45,  2.16s/it]

adding noise


 50%|█████     | 20/40 [00:43<00:43,  2.16s/it]

adding noise


 52%|█████▎    | 21/40 [00:45<00:41,  2.16s/it]

adding noise


 55%|█████▌    | 22/40 [00:47<00:38,  2.16s/it]

adding noise


 57%|█████▊    | 23/40 [00:49<00:36,  2.16s/it]

adding noise


 60%|██████    | 24/40 [00:52<00:34,  2.16s/it]

adding noise


 62%|██████▎   | 25/40 [00:54<00:32,  2.16s/it]

adding noise


 65%|██████▌   | 26/40 [00:56<00:30,  2.16s/it]

adding noise


 68%|██████▊   | 27/40 [00:58<00:28,  2.16s/it]

adding noise


 70%|███████   | 28/40 [01:00<00:25,  2.16s/it]

adding noise


 72%|███████▎  | 29/40 [01:02<00:23,  2.16s/it]

adding noise


 75%|███████▌  | 30/40 [01:05<00:21,  2.16s/it]

adding noise


 78%|███████▊  | 31/40 [01:07<00:19,  2.16s/it]

adding noise


 80%|████████  | 32/40 [01:09<00:17,  2.16s/it]

adding noise


 82%|████████▎ | 33/40 [01:11<00:15,  2.16s/it]

adding noise


 85%|████████▌ | 34/40 [01:13<00:12,  2.16s/it]

adding noise


 88%|████████▊ | 35/40 [01:15<00:10,  2.15s/it]

adding noise


 90%|█████████ | 36/40 [01:17<00:08,  2.16s/it]

adding noise


 92%|█████████▎| 37/40 [01:20<00:06,  2.15s/it]

adding noise


 95%|█████████▌| 38/40 [01:22<00:04,  2.16s/it]

adding noise


 98%|█████████▊| 39/40 [01:24<00:02,  2.16s/it]

adding noise


100%|██████████| 40/40 [01:26<00:00,  2.17s/it]


In [None]:
#@title #Download

#@markdown Execute this cell to download the enhanced recording
files.download(wav_output_name)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>