In [1]:
# from signal_processors import *

# # Load audio file
# signal_path = "noises/fire.wav"
# y, fs = torchaudio.load(signal_path)
# y = y[0, :65536]  # Take the first channel and first 65536 samples
# plotter(y, fs)

# # Generate seed
# seed_1 = seed_maker(65536, 44100, 16)
# seed_2 = seed_maker(65536, 44100, 16)
# seed_3 = seed_maker(65536, 44100, 16)

# # Extract parameters
# parameters_real, parameters_imag = textsynth_env_param_extractor(y, fs, 16, 1/8)

# # Resynthesize signal
# resynthesis_1 = textsynth_env(parameters_real, parameters_imag, seed_1, 16, 65536)
# resynthesis_2 = textsynth_env(parameters_real, parameters_imag, seed_2, 16, 65536)
# resynthesis_3 = textsynth_env(parameters_real, parameters_imag, seed_3, 16, 65536)

# # Plot resynthesized signal
# plotter(resynthesis_1, fs)
# plotter(resynthesis_2, fs)
# plotter(resynthesis_3, fs)

In [2]:
from models.model_1 import *
import torch
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model and move it to the appropriate device
hidden_size = 128  # Example hidden size
N_filter_bank = 16  # Example filter bank size
frame_size = 2**15  # Example frame size
sampling_rate = 44100  # Example sampling rate
compression = 8  # Placeholder for compression

# Model initialization
model = DDSP_textenv(hidden_size=128, N_filter_bank=16, deepness=2, compression=8, frame_size=2**15, sampling_rate=44100).to(device)

# Dataset maker
dataset = SoundDataset(audio_path='sounds/fire_long.wav', frame_size=2**15, hop_size=2**10, sampling_rate=44100, N_filter_bank=N_filter_bank)
dataset.compute_dataset()
actual_dataset = dataset.content
dataloader = DataLoader(actual_dataset, batch_size=32, shuffle=True)

# Initialize the optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-2)

# Hyperparameters for multiscale FFT (loss function)
scales = [2048, 1024, 512, 256]  # Example scales
overlap = 0.5                    # Example overlap

3453


In [4]:
# Training loop
num_epochs = 3  # Define the number of epochs
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        # Unpack batch data
        features, segments = batch
        spectral_centroid = features[0].unsqueeze(1).to(device)
        loudness = features[1].to(device)
        segments = segments.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        reconstructed_signal, _ = model(spectral_centroid, loudness)

        # Compute loss
        loss = multispectrogram_loss(segments, reconstructed_signal, scales, overlap)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Accumulate the loss
        running_loss += loss.item()

    # Print average loss for the epoch
    avg_loss = running_loss / len(dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")


    # Save checkpoint
    checkpoint = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_loss': best_loss,
    }
    torch.save(checkpoint, "")

print("Training complete.")

Epoch 1/3: 100%|██████████| 108/108 [00:40<00:00,  2.64it/s]


Epoch [1/3], Loss: 5.2539


Epoch 2/3: 100%|██████████| 108/108 [00:41<00:00,  2.58it/s]


Epoch [2/3], Loss: 5.2187


Epoch 3/3: 100%|██████████| 108/108 [00:40<00:00,  2.66it/s]

Epoch [3/3], Loss: 5.2366
Training complete.



