In [5]:
import math

import hydra
from hydra import compose, initialize

from main.inference import ContextRegularizedGenerator
from audio_diffusion_pytorch import KarrasSchedule
import torch
import torchaudio

device = torch.device("cuda")

In [6]:
with initialize(config_path="../exp"):
    cfg = compose(config_name="base_slakh_1")

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  with initialize(config_path="../exp"):


In [8]:
def load_model(path):
    if path.endswith('ckpt'):
        state_dict = torch.load(path)['state_dict']
    else:
        state_dict = torch.load(path)
    model = hydra.utils.instantiate(cfg.model)
    model.load_state_dict(state_dict)
    model.to(device)
    return model

In [15]:
sampling_rate = 22050

# @markdown Generation length in seconds (will be rounded to be a power of 2 of sample_rate*length)
length = 10
length_samples = 2**math.ceil(math.log2(length * sampling_rate))

# @markdown Number of samples to generate
num_samples = 1

# @markdown Number of diffusion steps (higher tends to be better but takes longer to generate)
num_steps = 100

smin = 1e-4
smax = 1.0
rho = 7.0

sigma_schedule=KarrasSchedule(sigma_min=smin, sigma_max=smax, rho=rho)

model_piano = load_model("../logs/ckpts/piano-vital-sun-29_epoch880.ckpt")
model_drums = load_model("../logs/ckpts/drums-lunar-blaze-24_epoch933.pt")
model_mix = load_model("../logs/ckpts/frosty-waterfall-235-epoch=328.ckpt")

cr_generator = ContextRegularizedGenerator(stem_to_model={'drums': model_drums,
                                                       'piano': model_drums,
                                                       'mixture': model_drums},
                                           sigma_schedule=sigma_schedule)

In [23]:
from audio_diffusion_pytorch import KarrasSampler

noise = torch.randn(1, 1, 2**18)


In [None]:
sample_piano = model_piano.model.sample(
    noise.cuda(),
    num_steps=40
)

# sample_drums = model_drums.model.sample(
#     noise.cuda(),
#     num_steps=20
# )

# sample_mix = model_mix.model.sample(
#     noise.cuda(),
#    num_steps=20
# )

In [16]:
output = cr_generator.generate(1, cfg.length, num_steps=20)

In [17]:
torchaudio.save(f"drums.wav", output['drums'].reshape(1,-1).cpu(), 22050)
torchaudio.save(f"piano.wav", output['piano'].reshape(1,-1).cpu(), 22050)
torchaudio.save(f"mixture.wav", output['mixture'].reshape(1, -1).cpu(), 22050)