In [8]:
import numpy as np
from IPython.display import Audio
import torch.nn as nn
import torch
from collections import namedtuple

In [9]:
class Music:
    samplerate: "samples / s" = 4000
    beat_duration: "s" = 1/4
    bpm: "beats / m" = 120
    rythm: "beats / measure" = 4
    
    @property
    def bps(self) -> "beats / s":
        seconds_per_min: "s / m" = 60
        return self.bpm * (1/seconds_per_min)
    
    @property
    def spb(self) -> "s / beat":
        return 1/self.bps

    @property
    def beat_slice(self) -> slice:
        num_samples_in_beat = int(self.beat_duration * self.samplerate)
        return np.s_[0: num_samples_in_beat] 
        
    def new_measure(self) -> np.ndarray:
        samples_per_measure = int(self.samplerate * self.rythm * self.spb)
        return np.zeros(samples_per_measure)
    
    @property
    def samples(self) -> 'samples':
        num_samples_in_beat = int(self.beat_duration * self.samplerate)
        return num_samples_in_beat
    

music = Music()

In [12]:
# https://www.cuidevices.com/blog/understanding-audio-frequency-range-in-audio-design


class Beat:
    def __init__(self, c: int, freq: float, amp: float, music: Music):
        self.music = music
        self.c = c
        self.freq = freq
        self.amp = amp
    
    @property
    def duration(self) -> 's':
        return self.music.beat_duration
    
    def play_beat(self):
        t = np.linspace(0, self.duration, self.music.samples, dtype=np.float32)
        note = self.amp*np.sin(self.freq * t * 2 * np.pi, dtype=np.float32)
        return  note 
    

class BeatEmbed:
    
    def __init__(self, beats, music: Music):
        self.vocab_size = len(beats)
        self.music = music
        emb_table = np.zeros((self.vocab_size, self.emb_dim))
        for beat in beats:
            sound = beat.play_beat()
            emb_table[beat.c, self.music.beat_slice] = sound
        
        self.emb_table = torch.Tensor(emb_table)
    
    @property
    def emb_dim(self):
        return self.music.samples
    
    def __call__(self, c):
        return self.emb_table[c]


quiet = Beat(c=0, freq=10, amp=0, music=music) 
sub = Beat(c=1, freq=30, amp=1, music=music) 
bass = Beat(c=2, freq=60, amp=1, music=music) 
mid = Beat(c=3, freq=250, amp=.2, music=music) 
high = Beat(c=4, freq=2000, amp=.005, music=music) 

beats =  [quiet, sub, bass, mid, high]
beat_embed = BeatEmbed(beats, music=music)

In [14]:
sound = sub.play_beat()

In [19]:
Audio(sound, rate=music.samplerate)

In [16]:
class Attention(nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.W_qk = torch.randn(dim, dim)
        self.W_v = torch.eye(dim) # nn.Linear(dim, dim, bias=False)
        self.proj = nn.Linear(dim, dim)
        self.scale = self.dim**-0.5
        
    def forward(self, X):
        B, S, E = x.shape
        QK = x @ W_qk @ x.transpose(-2, -1) / self.scale
        V = self.W_v(x)
        # QK = Q @ K.transpose(-2, -1) / self.scale
        attn = torch.softmax(QK, dim=-1)
        z = QK @ V
        return self.proj(z)

In [17]:
attn = Attention(512)

In [18]:
attn

Attention(
  (proj): Linear(in_features=512, out_features=512, bias=True)
)