In [62]:
# basic imports

import IPython
import librosa

import torch

from music2latent.hparams_inference import *

device = "mps"

from music2latent.config_loader import load_config
config = 'config.py'
hparams = load_config(config)

In [63]:
sum(hparams.layers_list)

10

In [64]:
# Load model and checkpoint

# import importlib
# import music2latent.export
# importlib.reload(music2latent.export)

from music2latent.export import ScriptedUNet

load_path = 'music2latent/models/music2latent.pt'

gen = ScriptedUNet(hparams, sigma_rescale = sigma_rescale).to(device)

checkpoint = torch.load(load_path, map_location=device)
gen.load_state_dict(checkpoint['gen_state_dict'], strict=False)

downscaling_factor = 2**hparams.freq_downsample_list.count(0)

  checkpoint = torch.load(load_path, map_location=device)


In [65]:
# audio_path = librosa.example('trumpet')
audio_path = 'audio_samples/string-003.wav'
# audio_path = 'audio_samples/string-002.wav'
# audio_path = 'audio_samples/110_drums.wav'

wv, sr = librosa.load(audio_path, sr=44100)
print(f'original waveform samples: {len(wv)}')
wv = torch.tensor(wv, device=device).unsqueeze(0)[:,:307200]
wv_chunks = [wv[:, i*12288:(i+1)*12288] for i in range(25)]
print(f'waveform samples: {wv.shape}')
print(f'number of chunks: {len(wv_chunks)}')
print(f'chunk length: {wv_chunks[0].shape}')

original waveform samples: 441000
waveform samples: torch.Size([1, 307200])
number of chunks: 25
chunk length: torch.Size([1, 12288])


In [66]:
12288*25

307200

In [67]:
import torch
from music2latent.scripted_audio import create_streaming_processors

In [68]:
# Create streaming processors for STFT and iSTFT

stft_processor, istft_processor = create_streaming_processors(
    hop_size=512, fac=4
)
stft_processor.to(device)
istft_processor.to(device)

StreamingISTFT()

## Split latent

In [69]:
## Run audio chunks to the encoder
with torch.no_grad():
        
    repr_encoder = stft_processor.process_chunk(wv)

    full_latent = gen.encoder(repr_encoder)
    full_latent = full_latent/gen.sigma_rescale

In [70]:
full_latent_chunks = [full_latent[:, :, i*3:(i+1)*3] for i in range(25)]
full_latent_chunks[0].shape

torch.Size([1, 64, 3])

In [71]:
## Run audio chunks to the decoder and the unet model
with torch.no_grad():
    wv_recons = []
    spec_recons = []
    for latent in full_latent_chunks:
        this_latent = latent*gen.sigma_rescale
        sample_length = int(this_latent.shape[-1]*downscaling_factor)
        init_noise = torch.randn((1, hparams.data_channels, hparams.hop*2, sample_length)).to(latent.device)*hparams.sigma_max
        spec_recon = gen.forward_generator(this_latent, init_noise)

        spec_recons.append(spec_recon)
        wv_recon = istft_processor.process_chunk(spec_recon)
        
        wv_recons.append(wv_recon)

In [72]:
output_chunk = torch.cat(wv_recons, dim=-1)

print('Original')
IPython.display.display(IPython.display.Audio(wv.cpu().numpy(), rate=sr))
print('Reconstructed')
IPython.display.display(IPython.display.Audio(output_chunk.cpu().numpy(), rate=sr))

Original


Reconstructed


## Split audio

In [50]:
## Run audio chunks to the encoder
with torch.no_grad():
    latent_chunks = []
    for w in wv_chunks:
        repr_encoder = stft_processor.process_chunk(w)

        latent = gen.encoder(repr_encoder)
        latent = latent/gen.sigma_rescale
        latent_chunks.append(latent)

In [51]:
latent_chunks[0].shape

torch.Size([1, 64, 3])

In [52]:
## Run audio chunks to the decoder and the unet model
with torch.no_grad():
    wv_recons = []
    spec_recons = []
    for latent in latent_chunks:
        this_latent = latent*gen.sigma_rescale
        sample_length = int(this_latent.shape[-1]*downscaling_factor)
        init_noise = torch.randn((1, hparams.data_channels, hparams.hop*2, sample_length)).to(latent.device)*hparams.sigma_max
        spec_recon = gen.forward_generator(this_latent, init_noise)

        spec_recons.append(spec_recon)
        wv_recon = istft_processor.process_chunk(spec_recon)
        
        wv_recons.append(wv_recon)

In [53]:
output_chunk = torch.cat(wv_recons, dim=-1)

print('Original')
IPython.display.display(IPython.display.Audio(wv.cpu().numpy(), rate=sr))
print('Reconstructed')
IPython.display.display(IPython.display.Audio(output_chunk.cpu().numpy(), rate=sr))

Original


Reconstructed


## Split SFTF

In [15]:
with torch.no_grad():
        
    repr_encoder = stft_processor.process_chunk(wv)

In [19]:
repr_encoder_chunks = [repr_encoder[:, :, :, i*24:(i+1)*24] for i in range(18)]
repr_encoder_chunks[0].shape

torch.Size([1, 2, 1024, 24])

In [23]:
## Run audio chunks to the decoder and the unet model
with torch.no_grad():
    wv_recons = []
    for w in wv_chunks:
        repr = stft_processor.process_chunk(w)
        wv_recon = istft_processor.process_chunk(repr)
        
        wv_recons.append(wv_recon)

In [24]:
output_chunk = torch.cat(wv_recons, dim=-1)

print('Original')
IPython.display.display(IPython.display.Audio(wv.cpu().numpy(), rate=sr))
print('Reconstructed')
IPython.display.display(IPython.display.Audio(output_chunk.cpu().numpy(), rate=sr))

Original


Reconstructed
