## Creating a simple synth

In this example we'll create a new synthesizer using
torchsynth SynthModules. Synths in torchsynth are created using
an approach inspired by modular synthesizers that involves interconnecting individual
modules. We'll create a simple single oscillator synth with an
attack-decay-sustain-release (ADSR) envelope controlling
the amplitude. More complicated architectures can be created
using the same ideas!

In [1]:
# Just some fiddly stuff to determine if you're in colab or jupyter.
# Also, if you don't have torchsynth installed, let's install it

try:
    import torchsynth
except:
    import os
    os.system("pip3 install torchsynth")

def iscolab():  # pragma: no cover
    !pip install torchsynth
    return "google.colab" in str(get_ipython())


def isnotebook():  # pragma: no cover
    try:
        if iscolab():
            return True
        shell = get_ipython().__class__.__name__
        if shell == "ZMQInteractiveShell":
            return True  # Jupyter notebook or qtconsole
        elif shell == "TerminalInteractiveShell":
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False  # Probably standard Python interprete


# print(f"isnotebook = {isnotebook()}")

## Creating the SimpleSynth class

All synths in torchsynth derive from `AbstractSynth`, which
provides helpful functionality for managing children modules and parameters.

There are two steps involved in creating a class that derives from `AbstractSynth`:
   1. Define the modules that will be used in the `__init__` method
   2. Define how modules are connected in the `output` method

In [2]:
# %reset -f

def iscolab():  # pragma: no cover
    return "google.colab" in str(get_ipython())


def isnotebook():  # pragma: no cover
    try:
        if iscolab():
            return True
        shell = get_ipython().__class__.__name__
        if shell == "ZMQInteractiveShell":
            return True  # Jupyter notebook or qtconsole
        elif shell == "TerminalInteractiveShell":
            return False  # Terminal running IPython
        else:
            return False  # Other type (?)
    except NameError:
        return False  # Probably standard Python interprete


# print(f"isnotebook = {isnotebook()}")

if isnotebook():  # pragma: no cover
    import IPython.display as ipd
    import librosa
    import librosa.display
    import matplotlib.pyplot as plt
else:

    class IPD:
        def Audio(*args, **kwargs):
            pass

        def display(*args, **kwargs):
            pass

    ipd = IPD()
    
import random
import time

import numpy as np
# import numpy.random
import torch
from torch import tensor


def time_plot(signal, sample_rate=44100, show=True):
    if isnotebook():  # pragma: no cover
        t = np.linspace(0, len(signal) / sample_rate, len(signal), endpoint=False)
        plt.plot(t, signal)
        plt.xlabel("Time")
        plt.ylabel("Amplitude")
        if show:
            plt.show()

from typing import Optional

import torch
import IPython.display as ipd
from torchsynth.signal import Signal


from torchsynth.module import (
    ADSR,
    LFO,
    VCA,
    AudioMixer,
    ControlRateUpsample,
    ControlRateVCA,
    ModulationMixer,
    MonophonicKeyboard,
    Noise,
    SineVCO,
    SquareSawVCO,
    SynthModule,
)
from torchsynth.synth import AbstractSynth
from torchsynth.config import SynthConfig, BASE_REPRODUCIBLE_BATCH_SIZE
from torch import Tensor as T

class SimpleSynth(AbstractSynth):
    """
    A Simple Synthesizer with a SquareSaw oscillator
    and an ADSR modulating the amplitude

    Args:
        synthconfig: Synthesizer configuration that defines the
            batch_size, buffer_size, and sample_rate among other
            variables that control synthesizer functioning
    """

    def __init__(self, synthconfig: Optional[SynthConfig] = None):

        # Make sure to call __init__ in the parent AbstractSynth
        super().__init__(synthconfig=synthconfig)

        # These are all the modules that we are going to use.
        # Pass in a list of tuples with (name, SynthModule,
        # optional params dict) after we add them we will be
        # able to access them as attributes with the same name.
        self.add_synth_modules(
            [
                ("keyboard", MonophonicKeyboard),
                ("adsr", ADSR),
                ("adsr_1", ADSR),
                ("upsample", ControlRateUpsample),
                ("vco", SineVCO),
                ("vco1", SquareSawVCO),
                ("vca", VCA),
                ("vca1", VCA),
                (
                    "mixer",
                    AudioMixer,
                    {
                        "n_input": 2,
                        "curves": [1.0, 1.0],
                        "names": ["vco", "vco1"],
                    },
                ),
                
            ]
        )
        
        self.set_hyperparameter(("keyboard", "duration", "curve"), 1.0)
        self.set_hyperparameter(("vco", "mod_depth", "curve"), 1.0)
        self.set_hyperparameter(("vco", "tuning", "curve"), 0)
        self.set_hyperparameter(("vco", "tuning", "symmetric"), True)

    def output(self) -> torch.Tensor:
        """
        This is called when we trigger the synth. We link up
        all the individual modules and pass the outputs through
        to the output of this method.
        """
        # Keyboard is parameter module, it returns parameter
        # values for the midi_f0 note value and the duration
        # that note is held for.
        midi_f0, note_on_duration = self.keyboard()

        # The amplitude envelope is generated based on note duration
        envelope = self.adsr(note_on_duration)
        envelope1 = self.adsr_1(note_on_duration)

        # The envelope that we get from ADSR is at the control rate,
        # which is by default 100x less than the sample rate. This
        # reduced control rate is used for performance reasons.
        # We need to upsample the envelope prior to use with the VCO output.

        envelope = self.upsample(envelope)
        envelope1 = self.upsample(envelope1)
        

        # half = torch.ones(((1, int(self.adsr.control_rate.item()/2))))
        # new_envelope = torch.cat([half/2, half], dim=1).as_subclass(Signal)
        # new_envelope = self.upsample(new_envelope)
        # new_envelope = new_envelope.to(self.device)
        # amps = torch.tensor([0], device=self.device, dtype=torch.float32)
        # new_envelope_test = amps.repeat_interleave(int(self.buffer_size/int(amps.shape[0]))).unsqueeze(0).as_subclass(Signal)
        # new_envelope_test = self.upsample(new_envelope_test)
        
        out = self.vco(self.midi_notes, self.midi_notes_length)
        # out1 = self.vco1(midi_f0)

        # # Apply the amplitude envelope to the oscillator output
        # out = self.vca(out, envelope)
        # out1 = self.vca(out1, envelope1)
        
        # out = self.mixer(out, out1)

        return out

class Voice(AbstractSynth):
    def __init__(
        self,
        synthconfig: Optional[SynthConfig] = None,
        nebula: Optional[str] = "default",
        *args,
        **kwargs,
    ):
        AbstractSynth.__init__(self, synthconfig=synthconfig, *args, **kwargs)

        # Register all modules as children
        self.add_synth_modules(
            [
                ("keyboard", MonophonicKeyboard),
                ("adsr_1", ADSR),
                ("adsr_2", ADSR),
                ("lfo_1", LFO),
                ("lfo_2", LFO),
                ("lfo_1_amp_adsr", ADSR),
                ("lfo_2_amp_adsr", ADSR),
                ("lfo_1_rate_adsr", ADSR),
                ("lfo_2_rate_adsr", ADSR),
                ("control_vca", ControlRateVCA),
                ("control_upsample", ControlRateUpsample),
                (
                    "mod_matrix",
                    ModulationMixer,
                    {
                        "n_input": 4,
                        "n_output": 5,
                        "input_names": ["adsr_1", "adsr_2", "lfo_1", "lfo_2"],
                        "output_names": [
                            "vco_1_pitch",
                            "vco_1_amp",
                            "vco_2_pitch",
                            "vco_2_amp",
                            "noise_amp",
                        ],
                    },
                ),
                ("vco_1", SineVCO),
                ("vco_2", SquareSawVCO),
                ("noise", Noise, {"seed": 13}),
                ("vca", VCA),
                (
                    "mixer",
                    AudioMixer,
                    {
                        "n_input": 3,
                        "curves": [1.0, 1.0, 0.025],
                        "names": ["vco_1", "vco_2", "noise"],
                    },
                ),
            ]
        )

        # Load the nebula
        self.load_hyperparameters(nebula)

    def output(self) -> T:
        # The convention for triggering a note event is that it has
        # the same note_on_duration for both ADSRs.
        midi_f0, note_on_duration = self.keyboard()

        # ADSRs for modulating LFOs
        lfo_1_rate = self.lfo_1_rate_adsr(note_on_duration)
        lfo_2_rate = self.lfo_2_rate_adsr(note_on_duration)
        lfo_1_amp = self.lfo_1_amp_adsr(note_on_duration)
        lfo_2_amp = self.lfo_2_amp_adsr(note_on_duration)

        # Compute LFOs with envelopes
        lfo_1 = self.control_vca(self.lfo_1(lfo_1_rate), lfo_1_amp)
        lfo_2 = self.control_vca(self.lfo_2(lfo_2_rate), lfo_2_amp)

        # ADSRs for Oscillators and noise
        adsr_1 = self.adsr_1(note_on_duration)
        adsr_2 = self.adsr_2(note_on_duration)

        # Mix all modulation signals
        (vco_1_pitch, vco_1_amp, vco_2_pitch, vco_2_amp, noise_amp) = self.mod_matrix(
            adsr_1, adsr_2, lfo_1, lfo_2
        )

        # Create signal and with modulations and mix together
        vco_1_out = self.vca(
            self.vco_1(midi_f0, self.control_upsample(vco_1_pitch)),
            self.control_upsample(vco_1_amp),
        )
        time_plot(self.control_upsample(vco_1_pitch).clone().detach().cpu().T, self.sample_rate.item())
        # print(self.control_upsample(vco_1_pitch).clone().detach().cpu().T.shape)

        vco_2_out = self.vca(
            self.vco_2(midi_f0, self.control_upsample(vco_2_pitch)),
            self.control_upsample(vco_2_amp),
        )

        noise_out = self.vca(self.noise(), self.control_upsample(noise_amp))

        return self.mixer(vco_1_out, vco_2_out, noise_out)

device = "cuda"
notes = torch.rand(256, 60, device=device) * 20 + 60
n = [torch.ones(60, device=device) * 0.5 for _ in range(256)]
notes_length = torch.stack(n, dim=0)

notes = torch.tensor([[60, 62, 64, 65, 67, 69, 71, 72],
                      [74, 76, 78, 79, 81, 83, 85, 86]], device=device)
notes_length = torch.tensor([[1, 1, 1, 1, 1, 1, 1, 1],
                             [1, 1, 1, 1, 1, 1, 1, 1]], device=device)

synthconfig = SynthConfig(batch_size=notes.shape[0], reproducible=False, buffer_size_seconds=3, sample_rate=44100)
customsynth = Voice(synthconfig)
simplesynth = SimpleSynth(synthconfig)

if not torch.cuda.is_available():
    customsynth = customsynth.to("cpu")
    simplesynth = simplesynth.to("cpu")
else:
    customsynth = customsynth.to(device)    
    simplesynth = simplesynth.to(device)

start = time.time()
audio, parameters, is_train = simplesynth(0, notes, notes_length)
torch.cuda.synchronize()
print("Synthesis taken ", time.time() - start)

print(f"Created {audio.shape[0]} synthesizer sounds ", f"that are each {audio.shape[1]} samples long")
ipd.Audio(audio[0].cpu().numpy()*0.1, rate=int(simplesynth.sample_rate.item()), normalize=False)
# for i in range(synthconfig.batch_size):
#     ipd.display(ipd.Audio(audio[i].cpu().numpy(), rate=int(simplesynth.sample_rate.item())))

Synthesis taken  0.005130290985107422
Created 2 synthesizer sounds  that are each 352800 samples long
