In [1]:
# ---- Imports ----
import os
import shutil
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torch.nn.functional as F

from datetime import datetime

In [2]:
from model.gru_audio_model import GRUAudioModel, GRUAudioConfig
from model.audio_dataset import AudioDatasetConfig, MuLawAudioDataset

In [3]:
# ---- Training Settings ----
out_dir = "output/synthodd_"+ datetime.now().strftime("%Y.%m.%d_%H.%M")
learning_rate = 1e-4
batch_size = 24
num_epochs = 2000
model_save_every = 250

validation_data_dir = None # "data/val"  # or None
resume_checkpoint = None  # e.g., "output/run1/checkpoints/epoch_10.pt"

# ---- Set Random Seeds ----
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---- Model + Dataset Config ----
model_config = GRUAudioConfig(
    num_conditioning_params=1,  #make sure this is correct!
    embedding_dim=8,
    hidden_size=48,
    num_layers=4,
    dropout=0.1
)
print("Model config:", model_config)

data_config = AudioDatasetConfig(
    data_dir="data/MuMeRNN/data/synthodd16",
    sequence_length=256,
    parameter_specs={"mp": (52, 76)} #, "mod": (0.0, 1.0)} #filenames hjskadhflda_p1xx.xx_p2xx.xx......wav
)

Using device: cuda
Model config: GRUAudioConfig(num_conditioning_params=1, embedding_dim=8, hidden_size=48, num_layers=3, dropout=0.1)


In [4]:
# ---- Create Output Folders ----
os.makedirs(out_dir, exist_ok=True)
os.makedirs(f"{out_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{out_dir}/tensorboard", exist_ok=True)

# ---- Save Config ----

#machine readable
torch.save({
    "model_config": model_config,
    "data_config": data_config
}, f"{out_dir}/config.pt")

#human readable
with open(f"{out_dir}/config.txt", "w") as f:
    f.write("model_config = " + repr(model_config) + "\\n")
    f.write("data_config = " + repr(data_config) + "\\n")

# ---- Load Datasets ----
train_ds = MuLawAudioDataset(data_config)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

val_dl = None
if validation_data_dir:
    val_config = AudioDatasetConfig(
        data_dir=validation_data_dir,
        sequence_length=data_config.sequence_length,
        parameter_specs=data_config.parameter_specs
    )
    val_ds = MuLawAudioDataset(val_config)
    val_dl = DataLoader(val_ds, batch_size=batch_size)

In [5]:
# ---- Initialize Model ----
model = GRUAudioModel(model_config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

start_epoch = 1
if resume_checkpoint:
    checkpoint = torch.load(resume_checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resumed from epoch {start_epoch}")

# ---- TensorBoard ----
writer = SummaryWriter(log_dir=f"{out_dir}/tensorboard")

# ---- Training Loop ----
for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    total_loss = 0.0

    for batch in train_dl:
        x, cond, y = [b.to(device) for b in batch]  # [B,T-1], [B,T-1,p], [B,T-1]
        #print(f"shape of x = {x.shape}, cond = {cond.shape}, and y = {y.shape}")

        optimizer.zero_grad()
        logits = model(x, cond)  # [B,T-1,256]
        loss = F.cross_entropy(logits.view(-1, 256), y.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_dl)
    writer.add_scalar("Loss/train", avg_train_loss, epoch)
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f}")

    # ---- Validation ----
    if val_dl:
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_dl:
                x, cond, y = [b.to(device) for b in batch]
                logits = model(x, cond)
                loss = F.cross_entropy(logits.view(-1, 256), y.view(-1))
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_dl)
        writer.add_scalar("Loss/val", avg_val_loss, epoch)
        print(f"Epoch {epoch} | Val Loss: {avg_val_loss:.4f}")

    scheduler.step()

    # ---- Save Checkpoint ----
    if epoch % model_save_every == 0 or epoch == num_epochs:
        save_path = f"{out_dir}/checkpoints/epoch_{epoch}.pt"
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, save_path)
        print(f"Saved checkpoint: {save_path}")

writer.close()

Epoch 1 | Train Loss: 5.5013
Epoch 2 | Train Loss: 5.1230
Epoch 3 | Train Loss: 4.8488
Epoch 4 | Train Loss: 4.7754
Epoch 5 | Train Loss: 4.7472
Epoch 6 | Train Loss: 4.7342
Epoch 7 | Train Loss: 4.7274
Epoch 8 | Train Loss: 4.7234
Epoch 9 | Train Loss: 4.7210
Epoch 10 | Train Loss: 4.7195
Epoch 11 | Train Loss: 4.7186
Epoch 12 | Train Loss: 4.7181
Epoch 13 | Train Loss: 4.7177
Epoch 14 | Train Loss: 4.7174
Epoch 15 | Train Loss: 4.7171
Epoch 16 | Train Loss: 4.7168
Epoch 17 | Train Loss: 4.7165
Epoch 18 | Train Loss: 4.7163
Epoch 19 | Train Loss: 4.7161
Epoch 20 | Train Loss: 4.7160
Epoch 21 | Train Loss: 4.7158
Epoch 22 | Train Loss: 4.7157
Epoch 23 | Train Loss: 4.7156
Epoch 24 | Train Loss: 4.7155
Epoch 25 | Train Loss: 4.7154
Epoch 26 | Train Loss: 4.7153
Epoch 27 | Train Loss: 4.7152
Epoch 28 | Train Loss: 4.7150
Epoch 29 | Train Loss: 4.7148
Epoch 30 | Train Loss: 4.7145
Epoch 31 | Train Loss: 4.7140
Epoch 32 | Train Loss: 4.7134
Epoch 33 | Train Loss: 4.7124
Epoch 34 | Train Lo

In [6]:
%load_ext tensorboard
%tensorboard --logdir output --port 6011