In [1]:
# ---- Imports ----
import os
import shutil
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torch.nn.functional as F

from datetime import datetime

In [2]:
from model.gru_audio_model import GRUAudioModel, GRUAudioConfig
from model.audio_dataset import AudioDatasetConfig, MuLawAudioDataset

In [3]:
# ---- Training Settings ----

model_config = GRUAudioConfig(
    num_conditioning_params=1,  # make sure this agrees with data_config.parameter_specs below !!!!!
    hidden_size=40,
    num_layers=4,
    dropout=.1
)

data_config = AudioDatasetConfig(
    data_dir="data/sine_data", #"data/nsynth.64.76.dl", #data_dir="data/MuMeRNN/data/synthodd16",
    sequence_length=256,
    parameter_specs={"oct": (0, 3)}, # {"instID": (1, 2), "p": (64.0, 76.0), "a": (0,1)}, #filenames hjskadhflda_p1xx.xx_p2xx.xx......wav
    add_noise= False,                        # Whether to add white noise
    snr_db= 10.0                           # Desired signal-to-noise ratio (dB)
)

out_dir = "output/sine_data" #"output/nsynth.64.76_"+ datetime.now().strftime("%Y.%m.%d_%H.%M")

learning_rate = 1e-3
batch_size = 256
num_epochs = 100
model_save_every = 100

validation_data_dir = None # "data/val"  # or None
resume_checkpoint = None  # e.g., "output/run1/checkpoints/epoch_10.pt"

# ---- Set Random Seeds ----
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

print("Model config:", model_config)


Using device: cuda
Model config: GRUAudioConfig(num_conditioning_params=1, hidden_size=40, num_layers=4, dropout=0.1)


In [4]:
# ---- Create Output Folders ----
os.makedirs(out_dir, exist_ok=True)
os.makedirs(f"{out_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{out_dir}/tensorboard", exist_ok=True)

# ---- Save Config ----

#machine readable
torch.save({
    "model_config": model_config,
    "data_config": data_config
}, f"{out_dir}/config.pt")

#human readable
with open(f"{out_dir}/config.txt", "w") as f:
    f.write("model_config = " + repr(model_config) + "\\n")
    f.write("data_config = " + repr(data_config) + "\\n")

# ---- Load Datasets ----
train_ds = MuLawAudioDataset(data_config)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

val_dl = None
if validation_data_dir:
    val_config = AudioDatasetConfig(
        data_dir=validation_data_dir,
        sequence_length=data_config.sequence_length,
        parameter_specs=data_config.parameter_specs
    )
    val_ds = MuLawAudioDataset(val_config)
    val_dl = DataLoader(val_ds, batch_size=batch_size)

for sine_oct0.0.wav, normed params are: tensor([0.])
for sine_oct1.0.wav, normed params are: tensor([0.3333])
for sine_oct2.0.wav, normed params are: tensor([0.6667])
for sine_oct3.0.wav, normed params are: tensor([1.])


In [5]:
# ---- Initialize Model ----
model = GRUAudioModel(model_config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = None # torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

In [6]:
start_epoch = 1
if resume_checkpoint:
    checkpoint = torch.load(resume_checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resumed from epoch {start_epoch}")

# ---- TensorBoard ----
writer = SummaryWriter(log_dir=f"{out_dir}/tensorboard")

# ------------------------------
# # SANITY CHECK - save untrained model
# save_path = f"{out_dir}/checkpoints/epoch_{0}.pt"
# torch.save({
#     'epoch': 0,
#     'model': model.state_dict(),
#     'optimizer': optimizer.state_dict(),
# }, save_path)
# print(f"Saved checkpoint: {save_path}")

# ---- Training Loop ----
for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    total_loss = 0.0

    for batch in train_dl:
        x, cond, y = [b.to(device) for b in batch]  # [B,T-1], [B,T-1,p], [B,T-1]
        #print(f"shape of x = {x.shape}, cond = {cond.shape}, and y = {y.shape}")

        optimizer.zero_grad()
        logits = model(x, cond)  # [B,T-1,256]
        loss = F.cross_entropy(logits.view(-1, 256), y.view(-1))
        loss.backward()

        print(model.input_proj.weight.grad.abs().mean())
        
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_dl)
    writer.add_scalar("Loss/train", avg_train_loss, epoch)
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f}")

    # ---- Validation ----
    if val_dl:
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_dl:
                x, cond, y = [b.to(device) for b in batch]
                logits = model(x, cond)
                loss = F.cross_entropy(logits.view(-1, 256), y.view(-1))
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_dl)
        writer.add_scalar("Loss/val", avg_val_loss, epoch)
        print(f"Epoch {epoch} | Val Loss: {avg_val_loss:.4f}")

    if scheduler != None : 
        scheduler.step()

    # ---- Save Checkpoint ----
    if epoch % model_save_every == 0 or epoch == num_epochs:
        save_path = f"{out_dir}/checkpoints/epoch_{epoch}.pt"
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, save_path)
        print(f"Saved checkpoint: {save_path}")

writer.close()

tensor(0.0018, device='cuda:0')
Epoch 1 | Train Loss: 5.5539
tensor(0.0014, device='cuda:0')
Epoch 2 | Train Loss: 5.5442
tensor(0.0013, device='cuda:0')
Epoch 3 | Train Loss: 5.5361
tensor(0.0014, device='cuda:0')
Epoch 4 | Train Loss: 5.5280
tensor(0.0017, device='cuda:0')
Epoch 5 | Train Loss: 5.5191
tensor(0.0022, device='cuda:0')
Epoch 6 | Train Loss: 5.5083
tensor(0.0027, device='cuda:0')
Epoch 7 | Train Loss: 5.4953
tensor(0.0034, device='cuda:0')
Epoch 8 | Train Loss: 5.4794
tensor(0.0040, device='cuda:0')
Epoch 9 | Train Loss: 5.4600
tensor(0.0046, device='cuda:0')
Epoch 10 | Train Loss: 5.4365
tensor(0.0051, device='cuda:0')
Epoch 11 | Train Loss: 5.4085
tensor(0.0055, device='cuda:0')
Epoch 12 | Train Loss: 5.3762
tensor(0.0056, device='cuda:0')
Epoch 13 | Train Loss: 5.3393
tensor(0.0055, device='cuda:0')
Epoch 14 | Train Loss: 5.2978
tensor(0.0052, device='cuda:0')
Epoch 15 | Train Loss: 5.2529
tensor(0.0046, device='cuda:0')
Epoch 16 | Train Loss: 5.2055
tensor(0.0039, de

In [7]:
# %load_ext tensorboard
# %tensorboard --logdir output --port 6011