# Writing an Oscillator in PyTorch

In [None]:
import torch
import IPython.display as ipd
import matplotlib.pyplot as plt

We create a new sinusoidal oscillator that is a PyTorch module by inheriting from the
`torch.nn.Module` base class. Our sinusoidal oscillator receives a sequence of amplitude
and angular frequency values, and optionally an initial phase.

The equation for this oscillator is as follows:

$$
y[n] = a[n] * sin(\phi[n])
$$

where $a[n]$ and $\phi[n]$ is amplitude and phase at the $n^\text{th}$ sample. However,
we receive a time-varying frequency as input to our function, so we need to convert that
to time-varying frequency. Recall the relationship between frequency and phase:

$$
\omega = \frac{d\phi}{dt}
$$

Frequency is the derivative of phase. Therefore, phase can be calculated by integrating
frequency (summing in the discrete-time case).

$$
\phi[n] = \phi_0 + \sum_{k=0}^{n}\omega[k]
$$

where $\omega[n]$ is angular frequency at the $n^\text{th}$ sample and $\phi[n]$ is computed
as the sum of all previous frequency values plus an initial phase, $\phi_0$.

In [None]:
class Sinusoid(torch.nn.Module):
    
    def forward(
            self,
            amp: torch.Tensor,                  # Amplitude (batch_size, n_steps)
            omega: torch.Tensor,                # Angular frequency (batch_size, n_steps)
            initial_phase: torch.Tensor = 0.0   # Initial phase (batch_size, 1)
        ) -> torch.Tensor:
        phase = torch.cumsum(omega, dim=1) + initial_phase
        return amp * torch.sin(phase)

In [None]:
sr = 16000
freq = torch.ones(1, sr) * 440
freq = freq * 2 * torch.pi / sr

amp = torch.linspace(1, 0, sr)

osc = Sinusoid()
y = osc(amp, freq)

In [None]:
ipd.Audio(y[0].numpy(), rate=sr)

In [None]:
learn_amp = torch.nn.Parameter(torch.rand_like(amp))
y_hat = osc(learn_amp, freq)

In [None]:
ipd.Audio(y_hat[0].detach().numpy(), rate=sr)

In [None]:
optimizer = torch.optim.Adam([learn_amp], lr=0.001)

In [None]:
log_loss = []
for i in range(1000):
    
    y_hat = osc(learn_amp, freq)
    loss = torch.mean(torch.abs(y_hat - y))
    log_loss.append(loss.item())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [None]:
ipd.Audio(y_hat[0].detach().numpy(), rate=sr)

In [None]:
plt.plot(log_loss)