In [None]:
import sys
sys.path.append("..")

import random
import math
import itertools
from copy import deepcopy
from io import BytesIO
from pathlib import Path
from typing import Optional, Callable, List, Tuple, Iterable, Generator, Union, Dict

import PIL.Image
import PIL.ImageDraw
import plotly
import plotly.express as px
import plotly.graph_objects as go
plotly.io.templates.default = "plotly_dark"
import numpy as np
import pandas as pd
pd.options.plotting.backend = "plotly"
from sklearn.manifold import TSNE
from sklearn.decomposition import IncrementalPCA

from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, IterableDataset, RandomSampler
import torchvision.transforms as VT
import torchvision.transforms.functional as VF
import torchaudio.transforms as AT
import torchaudio.functional as AF
from torchvision.utils import make_grid
from IPython.display import display, Audio
import torchaudio
from torchaudio.io import StreamReader

from src.datasets import *
from src.algo import GreedyLibrary
from src.util.image import *
from src.util import to_torch_device, iter_batches
from src.patchdb import PatchDB, PatchDBIndex
from src.models.encoder import *
from src.util.audio import *
from src.util.files import *
from src.util.embedding import *
from scripts import datasets
from src.algo import AudioUnderstander 

def resize(img, scale: float, mode: VF.InterpolationMode = VF.InterpolationMode.NEAREST):
    return VF.resize(img, [max(1, int(s * scale)) for s in img.shape[-2:]], mode, antialias=False)

In [None]:
def wave_sin(t: torch.Tensor) -> torch.Tensor:
    return torch.sin(t * 2. * torch.pi)

def to_mono(t: torch.Tensor) -> torch.Tensor:
    if t.ndim == 1:
        return t
    elif t.ndim == 2:
        return t.mean(0)
    else:
        raise ValueError(f"Expected 2 dimensions, got {t.shape}")

In [None]:
class SynthBase:
    
    multi_channel: bool = False
    
    def __init__(self, rate: int = 44_100):
        self.rate = rate

    def process(self, t: torch.Tensor) -> torch.Tensor:
        return wave_sin(t * 440.)

    def render(self, seconds: float = 1.):
        t = torch.linspace(0, seconds, int(self.rate * seconds))
        return self(t)
    
    def __call__(self, t: torch.Tensor) -> torch.Tensor:
        if t.ndim == 1:
            return self.process(t)
        elif t.ndim == 2:
            if self.multi_channel:
                return to_mono(self.process(t))
            else:
                return torch.concat([
                    self.process(ti).unsqueeze(0)
                    for ti in t
                ]).mean(0)
    
    def play(self, seconds: float = 1.):
        audio = self.render(seconds)
        audio_shape = (128, int(128 * seconds))
        win_length = 512
        spec = AF.spectrogram(
            waveform=audio,
            pad=0,
            window=torch.hamming_window(win_length),
            n_fft=1024,
            hop_length=win_length // 2,
            win_length=win_length,
            power=1.,
            normalized=True,
        )
        spec = VF.resize(VF.vflip(spec[:spec.shape[0] // 2]).unsqueeze(0), audio_shape)
        display(VF.to_pil_image(make_grid([
            plot_audio(audio, audio_shape, tensor=True),
            spec.expand(3, -1, -1),
        ])))
        display(Audio(audio, rate=self.rate, autoplay=True))
        #display(VF.to_pil_image(spec))
        #return spec
            
SynthBase().play()

In [None]:
class DecayEnvelope(SynthBase):
    multi_channel = True
    
    def __init__(
            self, 
            length: float = 1.,
            power: float = 2.,
            rate: int = 44_100,
    ):
        super().__init__(rate=rate)
        self.length = length
        self.power = power
        
    def process(self, t: torch.Tensor) -> torch.Tensor:
        env = (1. - t / self.length).clamp(0, 1)
        return env.pow(self.power)
        
DecayEnvelope().play(2)    

In [None]:
class AttackEnvelope(SynthBase):
    multi_channel = True
    
    def __init__(
            self, 
            length: float = 1.,
            power: float = 2.,
            rate: int = 44_100,
    ):
        super().__init__(rate=rate)
        self.length = length
        self.power = power
        
    def process(self, t: torch.Tensor) -> torch.Tensor:
        env = (t / self.length).clamp(0, 1)
        return env.pow(self.power)
    
AttackEnvelope().play(2)    

In [None]:
class AttackDecayEnvelope(SynthBase):
    multi_channel = True
    
    def __init__(
            self,
            attack_length: float = .1,
            decay_length: float = 1.,
            attack_power: float = .5,
            decay_power: float = 2.,
            rate: int = 44_100,
    ):
        super().__init__(rate=rate)
        self.attack_length = attack_length
        self.decay_length = decay_length
        self.attack_power = attack_power
        self.decay_power = decay_power
        
    def process(self, t: torch.Tensor) -> torch.Tensor:
        return (
            (t / self.attack_length).clamp(0, 1).pow(self.attack_power) 
            - (1. - (1. - (t - self.attack_length) / self.decay_length).clamp(0, 1).pow(self.decay_power))
        )
        
AttackDecayEnvelope().play(2)#process(torch.linspace(0, 1, 10).unsqueeze(0).expand(3, -1))    

In [None]:
class SynthFreqRatio(SynthBase):
    
    def __init__(
            self, 
            freq: float = 440.,
            freq_ratios: Iterable[float] = (1, 2, 3, 4, 5, 6),
            amplitudes: Optional[Iterable[float]] = None,
            wave: Callable = wave_sin,
            rate: int = 44_100,
    ):
        super().__init__(rate=rate)
        self.freq = freq
        self.freq_ratios = torch.Tensor(list(freq_ratios))
        if amplitudes is not None:
            self.amplitudes = torch.Tensor(list(amplitudes))
        else:
            self.amplitudes = torch.ones_like(self.freq_ratios)
        self.wave = wave
        
    def process(self, t: torch.Tensor) -> torch.Tensor:
        t_channels = t.unsqueeze(0).repeat(self.freq_ratios.shape[0], 1)
        freqs = self.freq_ratios * self.freq
        t_channels = t_channels * freqs.unsqueeze(1)
        
        oscs = self.wave(t_channels) * AttackDecayEnvelope(.01, 2)(t)
        #oscs = AF.overdrive(oscs)
        return oscs.mean(0)

SynthMulti(freq_ratios=(1, 2, 3, 5, 8, 13, 21)).play(2)    

In [None]:
class SynthKali(SynthBase):
    
    def __init__(
            self, 
            freq: float = 440.,
            wave: Callable = torch.sin,
            rate: int = 44_100,
    ):
        super().__init__(rate=rate)
        self.freq = freq
        self.param = torch.Tensor([.5, .6, .7])
        self.wave = wave
        
    def process(self, t: torch.Tensor) -> torch.Tensor:
        t3 = t.unsqueeze(0).repeat(3, 1) 
        
        f = t3 / 10. + .1
        #+ 0.00001
        for i in range(21):
            f = f.abs() / (f * f).sum(dim=0, keepdim=True)
            if i < 9:
                f -= self.param.unsqueeze(1)
        
        return self.wave(t3 * self.freq + f.cumsum(1) * .3)#.mean(0)

#SynthKali().play(2)    
SynthKali().process(torch.linspace(0, 1, 5))

In [None]:
class StrangeSynth(SynthBase):
    
    def __init__(
            self, 
            n_voices: int = 10,
            rate: int = 44_100,
    ):
        super().__init__(rate=rate)
        self.n_voices = n_voices 
        
    def process(self, t: torch.Tensor) -> torch.Tensor:
        freqs = t.unsqueeze(0).repeat(self.n_voices, 1)
        freqs *= 2. - torch.linspace(0, 1, self.n_voices).unsqueeze(1).pow(1/12.)
        phase = (1. - t).clamp_min(0).pow(2)
        freqs += phase.expand(self.n_voices, -1)
        oscs = torch.sin(freqs * 2 * torch.pi * 440.)
        #oscs = AF.overdrive(oscs)
        return oscs.mean(0)

StrangeSynth().play(2)    

In [None]:
SynthFreqRatio(wave=SynthKali(), freq_ratios=(.1, .2, .3, .4)).play(2)