In [13]:
from signal_processors import *
import torch.nn as nn
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def mlp(in_size, hidden_size, n_layers):
    channels = [in_size] + [hidden_size] * n_layers
    net = []
    for i in range(n_layers):
        net.append(nn.Linear(channels[i], channels[i + 1]))
        net.append(nn.LayerNorm(channels[i + 1]))
        net.append(nn.LeakyReLU())
    return nn.Sequential(*net)

def gru(n_input, hidden_size):
    return nn.GRU(n_input, hidden_size, batch_first=True)

class textsynth_DDSP(nn.Module):
    def __init__(self, hidden_size, N_filter_bank, deepness, compression, frame_size, sampling_rate):
        super().__init__()

        self.N_filter_bank = N_filter_bank
        self.seed = seed_maker(frame_size, sampling_rate, N_filter_bank)
        self.frame_size = frame_size
        self.param_per_env = int(frame_size / (2*N_filter_bank*compression))
        
        self.f_encoder = mlp(1, hidden_size, deepness)
        self.l_encoder = mlp(N_filter_bank, hidden_size, deepness)
        self.z_encoder = gru(2 * hidden_size, hidden_size)
    
        self.a_decoder_1 = mlp(3 * hidden_size, hidden_size, deepness)
        self.a_decoder_2 = nn.Linear(hidden_size, 16 * self.param_per_env)
        self.p_decoder_1 = mlp(3 * hidden_size, hidden_size, deepness)
        self.p_decoder_2 = nn.Linear(hidden_size, 16 * self.param_per_env)

    def encoder(self, spectral_centroid, loudness):
        f = self.f_encoder(spectral_centroid)
        # print("f shape: ",f.shape)
        l = self.l_encoder(loudness)
        # print("l shape: ",l.shape)
        z, _ = self.z_encoder(torch.cat([f,l], dim=-1).unsqueeze(0))
        # print("z_1 shape: ",z.shape)
        z = z.squeeze(0)
        # print("z_2 shape: ",z.shape)
        return torch.cat([f,l,z], dim=-1)

    def decoder(self, latent_vector):
        a = self.a_decoder_1(latent_vector)
        a = self.a_decoder_2(a)
        a = torch.sigmoid(a)
        p = self.p_decoder_1(latent_vector)
        p = self.p_decoder_2(p)
        p = 2*torch.pi*torch.sigmoid(p)
        real_param = a * torch.cos(p)
        imag_param = a * torch.sin(p)
        return real_param, imag_param

    def forward(self, spectral_centroid, loudness):
        # print("sp.centroid shape: ",spectral_centroid.shape)
        # print("loudness.shape:    ",loudness.shape)
        #encoder
        latent_vector = self.encoder(spectral_centroid, loudness)
        # print("latent_vector.shape: ",latent_vector.shape)
        #decoder
        real_param, imag_param = self.decoder(latent_vector)
        # print("real_param.shape: ",real_param.shape)
        # print("imag_param.shape: ",imag_param.shape)

        signal = textsynth_env_batches(real_param, imag_param, self.seed, self.N_filter_bank, self.frame_size)
        return signal, self.seed

# Initialize model and move it to the appropriate device
hidden_size = 128  # Example hidden size
N_filter_bank = 16  # Example filter bank size
frame_size = 2**15  # Example frame size
sampling_rate = 44100  # Example sampling rate
compression = 8  # Placeholder for compression

model = textsynth_DDSP(hidden_size=128, N_filter_bank=16, deepness=2, compression=8, frame_size=2**15, sampling_rate=44100).to(device)

In [14]:
# # Generate dummy input data
# batch_size = 7 # Example batch size
# spectral_stats = torch.randn(batch_size, 1).to(device)  # Dummy spectral stats with shape (batch_size, 2)
# loudness_stats = torch.randn(batch_size, N_filter_bank).to(device)  # Dummy loudness stats with shape (batch_size, 2 * N_filter_bank)

# # Forward pass
# output_signal, seed = model(spectral_stats, loudness_stats)

# for i in range(7):
#     plotter(output_signal[i,:], 44100)

In [23]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import librosa
import torchaudio

def feature_extractor(signal, sample_rate):
    size = signal.shape[0]
    sp_centroid = torchaudio.functional.spectral_centroid(signal, sample_rate, 0, torch.hamming_window(size), size, size, size) 

    low_lim = 20  # Low limit of filter
    high_lim = sample_rate / 2  # Centre freq. of highest filter

     # Initialize filter bank
    erb_bank = fb.EqualRectangularBandwidth(size, sample_rate, N_filter_bank, low_lim, high_lim)
    
    # Generate subbands for noise
    erb_bank.generate_subbands(signal)
    
    # Extract subbands
    erb_subbands_signal = erb_bank.subbands[:, 1:-1]

    loudness = torch.norm(erb_subbands_signal, dim=0)
    return [sp_centroid[0], loudness]

class SoundDataset(Dataset):
    def __init__(self, audio_path, frame_size, hop_size, sampling_rate):
        self.audio_path = audio_path
        self.frame_size = frame_size
        self.hop_size   = hop_size
        self.sampling_rate = sampling_rate
        self.audio, _ = librosa.load(audio_path, sr=sampling_rate)
        self.content = []

    def compute_dataset(self):
        audio_tensor = torch.tensor(self.audio)
        size = audio_tensor.shape[0]
        dataset_size = (size - self.frame_size) // self.hop_size
        for i in range(dataset_size):
            segment = audio_tensor[i * self.hop_size: i * self.hop_size + self.frame_size]
            features = feature_extractor(segment, self.sampling_rate)
            self.content.append([features, segment])
        print(dataset_size)

dataset = SoundDataset(audio_path='noises/fire_long.wav', frame_size=2**15, hop_size=2**10, sampling_rate=44100)
dataset.compute_dataset()
actual_dataset = dataset.content

dataloader = DataLoader(actual_dataset, batch_size=32, shuffle=True)

3453


In [24]:
import torch

def multiscale_fft(signal, scales, overlap):
    stfts = []
    for s in scales:
        S = torch.stft(
            signal,
            s,
            int(s * (1 - overlap)),
            s,
            torch.hann_window(s).to(signal),
            True,
            normalized=True,
            return_complex=True,
        ).abs()
        stfts.append(S)
    return stfts

def multispectrogram_loss(original_signal, reconstructed_signal, scales, overlap):
    ori_stft = multiscale_fft(original_signal, scales, overlap)
    rec_stft = multiscale_fft(reconstructed_signal, scales, overlap)

    loss = 0
    for s_x, s_y in zip(ori_stft, rec_stft):
        lin_loss = (s_x - s_y).abs().mean()
        log_loss = (torch.log(s_x + 1e-8) - torch.log(s_y + 1e-8)).abs().mean()
        loss += lin_loss + log_loss

    return loss

In [25]:
import torch
import torch.optim as optim
from tqdm import tqdm

# Initialize the model, optimizer, and loss function
model = textsynth_DDSP(hidden_size=128, N_filter_bank=16, deepness=2, compression=8, frame_size=2**15, sampling_rate=44100).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-2)

# Hyperparameters for multiscale FFT
scales = [2048, 1024, 512, 256]  # Example scales
overlap = 0.5  # Example overlap


In [26]:
# Training loop
num_epochs = 100  # Define the number of epochs
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Unpack batch data
        features, segments = batch
        spectral_centroid = features[0].unsqueeze(1).to(device)
        loudness = features[1].to(device)
        segments = segments.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        reconstructed_signal, _ = model(spectral_centroid, loudness)

        # Compute loss
        loss = multispectrogram_loss(segments, reconstructed_signal, scales, overlap)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate the loss
        running_loss += loss.item()

    # Print average loss for the epoch
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

print("Training complete.")

Epoch 1/100: 100%|██████████| 108/108 [00:40<00:00,  2.70it/s]


Epoch [1/100], Loss: 5.4969


Epoch 2/100: 100%|██████████| 108/108 [00:38<00:00,  2.80it/s]


Epoch [2/100], Loss: 5.3219


Epoch 3/100: 100%|██████████| 108/108 [00:38<00:00,  2.77it/s]


Epoch [3/100], Loss: 5.3253


Epoch 4/100: 100%|██████████| 108/108 [00:41<00:00,  2.58it/s]


Epoch [4/100], Loss: 5.2686


Epoch 5/100: 100%|██████████| 108/108 [00:39<00:00,  2.72it/s]


Epoch [5/100], Loss: 5.3364


Epoch 6/100: 100%|██████████| 108/108 [00:38<00:00,  2.77it/s]


Epoch [6/100], Loss: 5.3082


Epoch 7/100: 100%|██████████| 108/108 [00:38<00:00,  2.80it/s]


Epoch [7/100], Loss: 5.3074


Epoch 8/100: 100%|██████████| 108/108 [00:38<00:00,  2.79it/s]


Epoch [8/100], Loss: 5.2878


Epoch 9/100: 100%|██████████| 108/108 [00:38<00:00,  2.79it/s]


Epoch [9/100], Loss: 5.2550


Epoch 10/100: 100%|██████████| 108/108 [00:38<00:00,  2.78it/s]


Epoch [10/100], Loss: 5.2495


Epoch 11/100: 100%|██████████| 108/108 [00:38<00:00,  2.80it/s]


Epoch [11/100], Loss: 5.2599


Epoch 12/100: 100%|██████████| 108/108 [00:38<00:00,  2.80it/s]


Epoch [12/100], Loss: 5.2466


Epoch 13/100: 100%|██████████| 108/108 [00:38<00:00,  2.80it/s]


Epoch [13/100], Loss: 5.2294


Epoch 14/100: 100%|██████████| 108/108 [00:40<00:00,  2.69it/s]


Epoch [14/100], Loss: 5.2171


Epoch 15/100: 100%|██████████| 108/108 [14:01<00:00,  7.79s/it]  


Epoch [15/100], Loss: 5.2337


Epoch 16/100: 100%|██████████| 108/108 [00:40<00:00,  2.67it/s]


Epoch [16/100], Loss: 5.2687


Epoch 17/100: 100%|██████████| 108/108 [00:40<00:00,  2.67it/s]


Epoch [17/100], Loss: 5.2429


Epoch 18/100: 100%|██████████| 108/108 [00:40<00:00,  2.67it/s]


Epoch [18/100], Loss: 5.2195


Epoch 19/100: 100%|██████████| 108/108 [00:40<00:00,  2.65it/s]


Epoch [19/100], Loss: 5.2494


Epoch 20/100: 100%|██████████| 108/108 [01:26<00:00,  1.25it/s]


Epoch [20/100], Loss: 5.2412


Epoch 21/100: 100%|██████████| 108/108 [01:01<00:00,  1.77it/s]


Epoch [21/100], Loss: 5.2586


Epoch 22/100: 100%|██████████| 108/108 [01:01<00:00,  1.77it/s]


Epoch [22/100], Loss: 5.2111


Epoch 23/100: 100%|██████████| 108/108 [00:57<00:00,  1.88it/s]


Epoch [23/100], Loss: 5.2346


Epoch 24/100: 100%|██████████| 108/108 [03:32<00:00,  1.97s/it]


Epoch [24/100], Loss: 5.2414


Epoch 25/100: 100%|██████████| 108/108 [01:00<00:00,  1.78it/s]


Epoch [25/100], Loss: 5.1905


Epoch 26/100: 100%|██████████| 108/108 [01:00<00:00,  1.77it/s]


Epoch [26/100], Loss: 5.1737


Epoch 27/100: 100%|██████████| 108/108 [01:00<00:00,  1.79it/s]


Epoch [27/100], Loss: 5.2156


Epoch 28/100: 100%|██████████| 108/108 [01:00<00:00,  1.79it/s]


Epoch [28/100], Loss: 5.2334


Epoch 29/100: 100%|██████████| 108/108 [01:00<00:00,  1.78it/s]


Epoch [29/100], Loss: 5.1982


Epoch 30/100: 100%|██████████| 108/108 [01:00<00:00,  1.78it/s]


Epoch [30/100], Loss: 5.2005


Epoch 31/100: 100%|██████████| 108/108 [05:08<00:00,  2.85s/it] 


Epoch [31/100], Loss: 5.2131


Epoch 32/100: 100%|██████████| 108/108 [01:00<00:00,  1.79it/s]


Epoch [32/100], Loss: 5.1867


Epoch 33/100: 100%|██████████| 108/108 [01:00<00:00,  1.78it/s]


Epoch [33/100], Loss: 5.1896


Epoch 34/100: 100%|██████████| 108/108 [14:53:39<00:00, 496.47s/it]     


Epoch [34/100], Loss: 5.1694


Epoch 35/100: 100%|██████████| 108/108 [00:44<00:00,  2.45it/s]


Epoch [35/100], Loss: 5.1502


Epoch 36/100: 100%|██████████| 108/108 [00:45<00:00,  2.39it/s]


Epoch [36/100], Loss: 5.1464


Epoch 37/100: 100%|██████████| 108/108 [00:39<00:00,  2.73it/s]


Epoch [37/100], Loss: 5.1521


Epoch 38/100: 100%|██████████| 108/108 [00:39<00:00,  2.73it/s]


Epoch [38/100], Loss: 5.1703


Epoch 39/100: 100%|██████████| 108/108 [00:39<00:00,  2.72it/s]


Epoch [39/100], Loss: 5.1565


Epoch 40/100: 100%|██████████| 108/108 [00:42<00:00,  2.51it/s]


Epoch [40/100], Loss: 5.1456


Epoch 41/100: 100%|██████████| 108/108 [00:40<00:00,  2.68it/s]


Epoch [41/100], Loss: 5.2246


Epoch 42/100:  28%|██▊       | 30/108 [00:11<00:30,  2.57it/s]


KeyboardInterrupt: 