# Imports, paths, dataset

In [None]:
from ddsp import DDSP, AudioDataset
from ddsp.utils import find_checkpoint
from ddsp.callbacks import BetaWarmupCallback
from ddsp.synths import NoiseBandSynth, SineSynth
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelCheckpoint

from IPython.display import display
from IPython.display import Audio

from torch.utils.data import DataLoader, random_split, Subset
import lightning as L
import torch
torch.set_float32_matmul_precision('medium')
torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_default_device('cuda')

import os
import shutil

# Model parameters
# model_name = 'liget'

# training_path_root = '/home/btadeusz/code/ddsp_vae/training/tarta-relena-new'
# models_path_root = '/home/btadeusz/code/ddsp_vae/models/tarta-relena-new'

# Voces solos
# model_name = 'tarta-voces-solo'
# dataset_path = '/mnt/mariadata/datasets/tarta-relena/voces_solo/processed'

# Voces fx
# model_name = 'tarta-voces-fx'
# dataset_path = '/mnt/mariadata/datasets/tarta-relena/voces_fx/processed'

# Melodic
# model_name = 'tarta-melodic'
# dataset_path = '/mnt/mariadata/datasets/tarta-relena/melodic/processed'

# Percussvie
# model_name = 'drums-1lat'
# dataset_path = '/mnt/mariadata/datasets/seven_manifolds/drums'

# Lows
# model_name = 'tarta-lows'
# dataset_path = '/mnt/mariadata/datasets/tarta-relena/lows/processed'

# Noisy
# model_name = 'tarta-noisy'
# dataset_path = '/mnt/mariadata/datasets/tarta-relena/noisy/processed'

# Water
# model_name = 'tarta-water'
# dataset_path = '/mnt/mariadata/datasets/tarta-relena/water/processed'

# Tarta Relena
# model_name = 'tarta-relena'
# dataset_path = '/mnt/mariadata/datasets/tarta-relena/temas/processed'

# ICLC

# training_path_root = '/home/btadeusz/code/ddsp_vae/training/iclc'
# models_path_root = '/home/btadeusz/code/ddsp_vae/models/iclc'

# # guitar loops
# model_name = 'iclc-guitar-loops'
# dataset_path = '/mnt/mariadata/datasets/iclc/guitar-loops/processed'

# noise-models
training_path_root = '/home/btadeusz/code/ddsp_vae/training/noise-models'
models_path_root = '/home/btadeusz/code/ddsp_vae/models/noise-models'

# ankersmit
# model_name = 'ankersmit-simple'
# dataset_path = '/mnt/mariadata/datasets/noise-artists/ankersmit/processed'

# klein
# model_name = 'klein'
# dataset_path = '/mnt/mariadata/datasets/noise-artists/klein/processed'

model_name = 'hecker'
dataset_path = '/mnt/mariadata/datasets/noise-artists/hecker/processed'

# Training dir
training_path = os.path.join(training_path_root, model_name)
synth_output_path = os.path.join(models_path_root, f'{model_name}.ts')


In [None]:
fs = 44100
batch_size = 16
n_signal = 1.5 * fs # 1.5 seconds

## Dataset loading

In [None]:
# # Load dataset
# dataset = AudioDataset(dataset_path=dataset_path, n_signal=n_signal, sampling_rate=fs)

# train_set, val_set = random_split(dataset, [0.9, 0.1], generator=torch.Generator(device='cuda'))
# train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator(device='cuda'))
# val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0, generator=torch.Generator(device='cuda'))

# Define data augmentations

from ddsp.rave_transforms import *

# augmentation_pipeline = Compose([
#     RandomPitch(n_signal=n_signal, pitch_range=[0.5, 2.0], prob=0.35),
#     RandomCompress(prob=0.05),
#     RandomGain(gain_range=(-6, 3), prob=0.05),
# ])
augmentation_pipeline = None

# Load dataset
dataset = AudioDataset(dataset_path=dataset_path, n_signal=n_signal, sampling_rate=fs, transform_fn=augmentation_pipeline)

# Create random indices for validation (20% of total)
total_len = len(dataset)
val_len = int(0.2 * total_len)
indices = torch.randperm(total_len, generator=torch.Generator(device='cuda'))

val_indices = indices[:val_len]
train_indices = torch.arange(total_len, device='cuda')  # use full dataset for training

# Create subset datasets
train_set = Subset(dataset, train_indices)
val_set = Subset(dataset, val_indices)

# Create data loaders
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, generator=torch.Generator(device='cuda'))
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0, generator=torch.Generator(device='cuda'))

# Synth training

In [None]:
# DDSP parameters
latent_size = num_params = 4
resampling_rate = 32
max_freq = 18000
perceptual_loss_weight = 0.0

# Training config
warmup_start = 1000
warmup_end = 3000
beta = 5.0
max_epochs = 10000
learning_rate = 1e-4
capacity = 64
latent_smoothing_kernel = 513
decoder_gru_layers = 2

restart = False

if restart:
  # Reinitiate the training path
  shutil.rmtree(training_path, ignore_errors=True)


os.makedirs(training_path, exist_ok=True)

## Model initialisation and training

### Load from checkpoint

In [None]:
ckpt = find_checkpoint(training_path, return_none=True)
ddsp = DDSP.load_from_checkpoint(ckpt, strict=False)

### Init new

In [None]:
n_filters = 64
n_sines = 300

# Synths
nbn = NoiseBandSynth.to_config(n_filters=n_filters, fs=fs, resampling_factor=resampling_rate)
sines = SineSynth.to_config(n_sines=n_sines, fs=fs, resampling_factor=resampling_rate)

configs = [sines, nbn]
# configs = [sines]

# Model
ddsp = DDSP(
  synth_configs=configs,
  fs=fs,
  latent_size=latent_size,
  num_params=num_params,
  learning_rate=learning_rate,
  resampling_factor=resampling_rate,
  n_melbands=128,
  perceptual_loss_weight=perceptual_loss_weight,
  capacity=capacity,
  latent_smoothing_kernel=latent_smoothing_kernel,
  decoder_gru_layers=decoder_gru_layers,
).to('cuda')

### Training

In [None]:
# Training

## Callbacks
training_callbacks = []

beta_warmup = BetaWarmupCallback(
  beta=beta,
  start_steps=warmup_start,
  end_steps=warmup_end
)
training_callbacks.append(beta_warmup)

# last_checkpoint_callback = ModelCheckpoint(
#     filename='last',
#     save_top_k=1,  # Save only one file, the most recent one
#     save_last=True  # Always save the model at the end of the epoch
#   )
# training_callbacks.append(last_checkpoint_callback)

# best checkopint callback, given the validation loss
best_checkpoint_callback = ModelCheckpoint(
  filename='best',
  monitor='val_loss',
  mode='min',
  save_top_k=1,  # Save only one file, the best one
)
training_callbacks.append(best_checkpoint_callback)

## Trainer
tb_logger = TensorBoardLogger(training_path_root, name=model_name)

# from lightning.pytorch.profilers import PyTorchProfiler
# profiler = PyTorchProfiler(with_stack=True)

# L.strict_loading
trainer = L.Trainer(
  callbacks=training_callbacks,
  max_epochs=max_epochs,
  accelerator='cuda',
  precision=16,
  log_every_n_steps=0,
  logger=tb_logger,
  # profiler=profiler
)
# trainer.lightning_module.strict_loading = False

ckpt = find_checkpoint(training_path, return_none=True, typ='best')
if ckpt is not None:
  print(f'Restoring from checkpoint {ckpt}')

# trainer.lightning_module.strict_loading =
## Start training
ddsp.train()
trainer.fit(
  model=ddsp,
  train_dataloaders=train_loader,
  val_dataloaders=val_loader,
  ckpt_path=ckpt
)

In [None]:
ddsp = ddsp.cuda()
from random import randint
test_x = dataset[randint(0, len(dataset))].unsqueeze(0)

samples = 44100 * 15
start = randint(0, len(dataset._audio) - samples)
test_x = dataset._audio[start:start+samples].unsqueeze(0)

with torch.no_grad():
  mu, scale = ddsp.encoder(test_x)
  latents, _ = ddsp.encoder.reparametrize(mu, scale)
  latents = ddsp._smooth_latents(latents)
  params = test_y = ddsp.decoder(latents)

  test_y = ddsp._synthesize(params)

  # test_y = ddsp(test_x)

print("Input")
display(Audio(test_x.squeeze(0).cpu().numpy(), rate=fs))

print("Latents")
# plot the latents with seaborn
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme(style="whitegrid")
sns.lineplot(data=latents.squeeze(0).cpu().numpy(), palette='tab10', linewidth=1.0)
plt.title('Latents')
plt.xlabel('Time')
plt.ylabel('Latent value')
plt.show()


print("Resynthesis")
display(Audio(test_y.squeeze().cpu().numpy(), rate=fs))

In [None]:
import torch

# Generate 3500 latent vectors linearly spaced from -1 to 1 for each latent dimension
num_latents = 3500
latent_dim = latent_size  # already defined as 16

# Create a (1, num_latents, latent_dim) tensor where each latent dimension is swept from -1 to 1
latents_grid = torch.linspace(1, -1, num_latents, device='cuda').unsqueeze(1).repeat(1, latent_dim)
latents_grid = latents_grid.unsqueeze(0)  # shape: (1, 3500, 16)

# Decode to parameters and synthesize audio
ddsp.eval()
with torch.no_grad():
  params = ddsp.decoder(latents_grid)
  audio_out = ddsp._synthesize(params)

# Listen to the generated audio
display(Audio(audio_out.squeeze().cpu().numpy(), rate=fs))

## Fine tuning with CLAP loss

# Latent space analysis

### Analyse Space

In [None]:
# Collect all latent vectors from the dataset using the trained encoder
latents = []
ddsp.eval()
with torch.no_grad():
  for i in range(len(dataset)):
    audio = dataset[i].unsqueeze(0).to('cuda')
    mu, scale = ddsp.encoder(audio)
    cur_latents, _ = ddsp.encoder.reparametrize(mu, scale)
    cur_latents = ddsp._smooth_latents(cur_latents)
    latents.append(cur_latents)
latents = torch.hstack(latents).squeeze(0) # [num_latents, latent_size]

# Fit PCA, get the mean and the quantiles
ddsp.analyze_latent_space(latents)
trainer.save_checkpoint(ckpt)


### Range Scaling test

In [None]:
chunk_length = 44100  * 5 # 5 second
start = randint(0, len(dataset._audio) - chunk_length)
audio_chunk = dataset._audio[start:start+chunk_length].unsqueeze(0).to('cuda')
print("Original audio")
display(Audio(audio_chunk.squeeze(0).cpu().numpy(), rate=fs))

ddsp.eval()

with torch.no_grad():
  mu, scale = ddsp.encoder(audio_chunk)
  latents, _ = ddsp.encoder.reparametrize(mu, scale)
  latents = ddsp._smooth_latents(latents)

  params_orig = ddsp.decoder(latents)
  audio_orig = ddsp._synthesize(params_orig)
  print("Original latents range", latents.min().item(), latents.max().item())
  print("Original latents audio")
  display(Audio(audio_orig.squeeze(0).cpu().numpy(), rate=fs))


  normalized = ddsp.normalize_latents(latents)
  print("Normalized latents range", normalized.min().item(), normalized.max().item())
  denormalized = ddsp.denormalize_latents(normalized)
  params_denormalized = ddsp.decoder(denormalized)
  audio_denormalized = ddsp._synthesize(params_denormalized)
  print("Denormalized latents range", denormalized.min().item(), denormalized.max().item())
  print("Denormalized latents audio")
  display(Audio(audio_denormalized.squeeze(0).cpu().numpy(), rate=fs))

### PCA test

In [None]:
# Select a random audio chunk from the dataset
chunk_length = 44100  * 5 # 5 second
start = randint(0, len(dataset._audio) - chunk_length)
audio_chunk = dataset._audio[start:start+chunk_length].unsqueeze(0).to('cuda')

# Encode to latents
ddsp.eval()
with torch.no_grad():
  mu, scale = ddsp.encoder(audio_chunk)
  latents, _ = ddsp.encoder.reparametrize(mu, scale)
  latents = ddsp._smooth_latents(latents)

  # Transform latents to params and back
  print('latents', latents.shape)
  params = ddsp.latents_to_params(latents)
  print('params', params.shape)
  latents_recon = ddsp.params_to_latents(params)
  print('latents_recon', latents_recon.shape)

  # Calculate information loss (e.g., mean squared error) between latents and reconstructed latents
  info_loss = torch.nn.functional.mse_loss(latents, latents_recon)
  print(f"Information loss (MSE) between latents and reconstructed latents: {info_loss.item():.6f}")

  # Decode audio from original latents
  audio_from_latents = ddsp._synthesize(ddsp.decoder(latents))
  # Decode audio from reconstructed latents
  audio_from_latents_recon = ddsp._synthesize(ddsp.decoder(latents_recon))

# Listen to original audio
print("Original audio chunk")
display(Audio(audio_chunk.squeeze(0).cpu().numpy(), rate=fs))

# Listen to audio generated from original latents
print("Audio from original latents")
display(Audio(audio_from_latents.squeeze().cpu().numpy(), rate=fs))

# Listen to audio generated from reconstructed latents
print("Audio from latents after params transform")
display(Audio(audio_from_latents_recon.squeeze().cpu().numpy(), rate=fs))

In [None]:
# # Freeze parameters
# ddsp = ddsp.cuda()
# ddsp.train()
# # ddsp.encoder.eval()
# for param in ddsp.encoder.parameters():
#   param.requires_grad = False

# ddsp._recons_loss = CLAPLoss()

# max_epochs = 100
# trainer = L.Trainer(
#   callbacks=training_callbacks,
#   max_epochs=max_epochs,
#   accelerator='cuda',
#   precision=16,
#   log_every_n_steps=0,
#   logger=tb_logger,
#   # profiler=profiler
# )

# trainer.fit(
#   model=ddsp,
#   train_dataloaders=train_loader,
#   val_dataloaders=val_loader,
# )

# Model export

In [None]:
# Export the model
!python -m cli.export --model_directory {training_path} --output_path {synth_output_path} --type best

print(f'Model saved at {synth_output_path}')

# Prior training

## Params & configuration

In [None]:
prior_batch_size = 128

sequence_length = 256
embedding_dim = 32
nhead = 8
num_layers = 4
quantization_channels = 16
lr = 1e-4

prior_epochs = 10000
dataset_stride_factor = 0.05

reinitiate = True

In [None]:
from ddsp.prior import Prior, PriorDataset

prior_model_name = f'{model_name}-prior'

# Training dir
prior_training_path = os.path.join(training_path_root, prior_model_name)

# Reinitiate the training path
if reinitiate:
  shutil.rmtree(prior_training_path, ignore_errors=True)
  os.makedirs(prior_training_path, exist_ok=True)

## Prior dataset
The Prior dataset is the same as the audio dataset encoded by the synth encoder into the latent space and arranged into sequences, ready for the sequence prediction training.

In [None]:
prior_dataset = PriorDataset(
  audio_dataset_path=dataset_path,
  encoding_model_path=synth_output_path,
  sequence_length=sequence_length+1,
  sampling_rate=fs,
  device='cuda',
  stride_factor=dataset_stride_factor
)
normalization_dict = prior_dataset.normalization_dict

generator = torch.Generator(device='cuda')
prior_train_set, prior_val_set = random_split(prior_dataset, [0.8, 0.2], generator=generator)
prior_train_loader = DataLoader(prior_train_set, batch_size=prior_batch_size, shuffle=True, generator=generator)
prior_val_loader = DataLoader(prior_val_set, batch_size=prior_batch_size, shuffle=False, generator=generator)

latent_size = prior_dataset[0].shape[-1]

## Model initialisation and training

### Load from checkpoint

In [None]:
prior_checkpoint = find_checkpoint(prior_training_path)
prior = Prior.load_from_checkpoint(prior_checkpoint)

### Or initialize for training

In [None]:
prior = Prior(
  latent_size=latent_size,
  embedding_dim=embedding_dim,
  quantization_channels=quantization_channels,
  max_len=sequence_length,
  lr=lr,
  nhead=nhead,
  num_layers=num_layers,
  normalization_dict=normalization_dict
)

prior_logger = TensorBoardLogger(prior_training_path, name=prior_model_name)

prior_callbacks = []
prior_checkpoint_callback = ModelCheckpoint(
  dirpath=prior_training_path,
  filename='best',
  monitor='val_loss',
  mode='min'
)
prior_callbacks.append(prior_checkpoint_callback)

from lightning.pytorch.callbacks import EarlyStopping

early_stopping = EarlyStopping(monitor='val_acc', patience=1000, mode='max', stopping_threshold=0.99)
prior_callbacks.append(early_stopping)

prior_trainer = L.Trainer(
  callbacks=prior_callbacks,
  accelerator='cuda',
  log_every_n_steps=4,
  logger=prior_logger,
  max_epochs=prior_epochs
)

ckpt_path = find_checkpoint(prior_training_path, return_none=True)
if ckpt_path is not None:
  print(f'Resuming training from checkpoint {ckpt_path}')


# Start training
prior_trainer.fit(
  model=prior,
  train_dataloaders=prior_train_loader,
  val_dataloaders=prior_val_loader,
  ckpt_path=ckpt_path
)

In [None]:
prior_val_loss = prior_trainer.callback_metrics['val_acc']
print(f'Final validation loss: {prior_val_loss}')

# Unconditional generation

In [None]:
from torch import jit
synth = jit.load(synth_output_path)

In [None]:
import matplotlib.pyplot as plt

# Generate sequence
prior = prior.cuda()
prior.eval()

num_steps = 10000
prime_len = sequence_length // 4

orig_sequence = prior_dataset[randint(0, len(prior_dataset))].cuda()
prime = orig_sequence[:prime_len, :]
# prime = torch.randn_like(prime)
print(prime.shape)
sequence = prior.generate(prime, num_steps, 1.0)
latents = sequence.unsqueeze(0)

# Decode to audio
with torch.no_grad():
  # latents, _ = synth.pretrained.encoder.reparametrize(mu, logvar)
  audio = synth.decode(latents.permute(0, 2, 1).to('cpu'))

# display latents
plt.plot(latents.squeeze(0).cpu().numpy()[:,0], label='generated', marker='o')

plt.axvline(x=prime_len-1, color='r', linestyle='--')

plt.plot(orig_sequence.cpu().numpy()[:,0], linestyle=':', label='original', marker='x')
plt.legend()
plt.grid()
# zoom around the prime
margin = 64
# plt.xlim(prime_len-margin, prime_len+margin)
plt.show()

# display audio
audio = audio.cpu().numpy().squeeze()
audio = audio / audio.max()
audio_widget = Audio(audio, rate=fs)
display(audio_widget)

## Export prior model

In [None]:
prior_model_path = os.path.join(models_path_root, f'{prior_model_name}.ts')
!python -m cli.export --model_directory {training_path} --prior_directory {prior_training_path} --output_path {prior_model_path} --type best

print(f'Model saved at {prior_model_path}')

## Others

In [None]:
# prior.eval()
orig_sequence = prior_dataset[randint(0, len(prior_dataset))].cuda()
slen = 256

x = orig_sequence[:-1]
y = prior._quantizer(prior.normalize(orig_sequence[1:]))

# loss = prior.training_step(x.unsqueeze(0), _)
# print(loss)

y_hat = prior(x.unsqueeze(0)).argmax(dim=-1).squeeze(0)

x = prior._quantizer(prior.normalize(x))


idx = 1

print(f'acc: {(y_hat == y).float().sum() / y_hat.numel()}')
print(f'y[{idx}]: {y[idx]}')
print(f'y_hat[{idx}]: {y_hat[idx]}')

print(f'x[{idx+1}]: {x[idx+1]}')

In [None]:
x = prior_dataset[randint(0, len(prior_dataset))].cuda().unsqueeze(0)
with torch.no_grad():
  x = next(iter(prior_train_loader))[:, ...]
  print(x.shape)
  print(prior._step(x))

In [None]:
import torch

# Number of steps to generate
num_steps = 10

prior = prior.train().to('cuda')
prior.eval_mode = True

# Initialize the sequence with the given random codes
# sequence = torch.zeros(1, sequence_length, latent_size, device='cuda')

sequence = prior_dataset[101].unsqueeze(0).to('cuda')
norm = prior.normalize(sequence)

# Generate latent codes autoregressively
for i in range(num_steps):
  # Predict the next latent code
  with torch.no_grad():
    logits = prior(sequence[:, -sequence_length:, :])
    next_code = prior.sample(logits=logits, temperature=0.1)

  # Append the predicted code to the sequence and shift the sequence to the right
  sequence = torch.cat((sequence, next_code[:, -1:, :]), dim=1)

mu, logvar = sequence.chunk(2, dim=-1)

In [None]:
with torch.no_grad():
  display(Audio(dataset[100].cpu().numpy(), rate=fs))
  display(Audio(synth.pretrained(dataset[2].unsqueeze(0).cpu()).squeeze().numpy(), rate=fs))