In [2]:
import torch
import fvcore

from DDSP.models.ddsp_decoder import DDSP_Decoder
from DDSP.models.synths.hpn_synth import *
from DDSP.models.decoder.decoders import *
from DDSP.utils import build_model
import json

In [49]:
ddsp_model = "hpn"
version = "full"
config_path = f"configs/{ddsp_model}/{ddsp_model}_{version}.json"

In [50]:
with open(config_path, 'r') as file:
    config = json.load(file)

decoder_config = config["decoder"]
synth_config = config["synth"]

In [52]:
# Build a model for analyzing it in terms of embeddability
model = build_model(decoder_config=decoder_config, synth_config=synth_config, ddsp_mode=ddsp_model)
model.eval()

In [None]:
# Inference Time with a buffer
import torch.utils.benchmark as benchmark

def inference(model, x):
    with torch.no_grad():
        model(x)

x= {}
BUFFER = 512
sub_label = str(BUFFER/64)
x['audio'] = torch.randn(1, BUFFER, 1)
x['f0'] = torch.randn(1, int(BUFFER/64), 1)
x['f0_scaled'] = torch.randn(1, int(BUFFER/64), 1)
x['loudness_scaled'] = torch.randn(1, int(BUFFER/64), 1)

t0 = benchmark.Timer(
        stmt='inference(model,x)',
        globals={'x': x, 'model': model, 'inference': inference},
        num_threads=1,
        label=f"TEST with BUFFER={BUFFER}",
        sub_label=sub_label+ " ms",
        description=f'{ddsp_model}-{version}',
    )
print(t0.timeit(100))


In [None]:
param_memory = sum(p.numel() * p.element_size() for p in model.parameters())
if param_memory > 100000:
    param_memory_mb = param_memory / (1024 ** 2)  # Convert to MB
    print(f"Model Parameters Memory: {param_memory_mb:.2f} MB")
else:
    param_memory_kb = param_memory / (1024)  # Convert to MB
    print(f"Model Parameters Memory: {param_memory_kb:.2f} KB")

In [None]:
# Profile the model with 16000 samples
from torch.profiler import profile, record_function, ProfilerActivity

BUFFER = 16000

x = {}
x['audio'] = torch.randn(1, BUFFER, 1)
x['f0'] = torch.randn(1, int(BUFFER/64), 1)
x['f0_hz'] = torch.randn(1, int(BUFFER/64), 1)
x['f0_scaled'] = torch.randn(1, int(BUFFER/64), 1)
x['loudness_scaled'] = torch.randn(1, int(BUFFER/64), 1)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True, profile_memory=True,  with_flops=True, with_modules=True) as prof_large:
    with record_function("model_inference"):
        with torch.no_grad():
            model(x)

In [None]:
print(f"{ddsp_model.upper()}-{version.capitalize()}")
print(prof_large.key_averages().table(sort_by="cpu_time_total", row_limit=10))

In [None]:
# FLOPs analysis with fvcore
from fvcore.nn import FlopCountAnalysis, parameter_count_table, flop_count_str, flop_count_table

In [None]:
# Model architecture
print(model)

In [None]:
# Compute FLOPs
flops = FlopCountAnalysis(model.decoder, x)

In [None]:
# Print FLOPs
flops.by_module()

In [None]:
# Operators not supported by fvcore: GRU and ReLU/LeakyRelu
def gru_inference_flops(input_size, hidden_size, sequence_length, num_layers, batch_size):
    # FLOPs for one layer (per timestep)
    flops_per_timestep = 0
    
    # FLOPs for update gate z_t
    flops_per_timestep += 2 * (input_size * hidden_size + hidden_size * hidden_size)  # W_z * x_t and U_z * h_{t-1}
    flops_per_timestep += hidden_size  # Sigmoid activation

    # FLOPs for reset gate r_t
    flops_per_timestep += 2 * (input_size * hidden_size + hidden_size * hidden_size)  # W_r * x_t and U_r * h_{t-1}
    flops_per_timestep += hidden_size  # Sigmoid activation

    # FLOPs for candidate hidden state h_tilde
    flops_per_timestep += 2 * (input_size * hidden_size + hidden_size * hidden_size)  # W_h * x_t and U_h * h_{t-1}
    flops_per_timestep += hidden_size  # r_t * U_h * h_{t-1}
    flops_per_timestep += hidden_size  # Tanh activation

    # FLOPs for new hidden state h_t
    flops_per_timestep += 3 * hidden_size  # (1 - z_t) * h_{t-1} + z_t * h_tilde

    # Total FLOPs per timestep for one layer
    total_flops_per_timestep = flops_per_timestep

    # Total FLOPs for the entire sequence (all layers, all timesteps)
    total_flops = total_flops_per_timestep * sequence_length * num_layers * batch_size
    print(total_flops)
    return total_flops

def relu_inference_flops(input_size, sequence_length, batch_size):
    # FLOPs for ReLU activation function
    print(input_size * sequence_length * batch_size)
    return input_size * sequence_length * batch_size

In [None]:
gru_inference_flops(1024, 512, BUFFER/64, 1, 1)
relu_inference_flops(512, BUFFER/64, 1)

In [None]:
gru_inference_flops(32, 16, BUFFER/64, 1, 1)
relu_inference_flops(16, BUFFER/64, 1)

In [None]:
relu_inference_flops(128, BUFFER/64, 1)
relu_inference_flops(64, BUFFER/64, 1)