In [5]:
# ---- Imports ----
import os
import shutil
import random
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torch.nn.functional as F

from datetime import datetime

In [6]:
from model.gru_audio_model import GRUAudioModel, GRUAudioConfig
from model.audio_dataset import AudioDatasetConfig, MuLawAudioDataset

In [9]:
# ---- Training Settings ----
out_dir = "output/"+ datetime.now().strftime("%Y.%m.%d_%H.%M")
model_save_every = 2
learning_rate = 1e-3
batch_size = 32
num_epochs = 2
validation_data_dir = None # "data/val"  # or None
resume_checkpoint = None  # e.g., "output/run1/checkpoints/epoch_10.pt"

# ---- Set Random Seeds ----
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# ---- Device ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ---- Model + Dataset Config ----
model_config = GRUAudioConfig(
    num_conditioning_params=2,
    embedding_dim=16,
    hidden_size=24, #128,
    num_layers=2,
    dropout=0.1
)
print("Model config:", model_config)

data_config = AudioDatasetConfig(
    data_dir="data/PoissonGeiger68_16k",
    sequence_length=1024,
    parameter_specs={"r": (0, 3)} #, "mod": (0.0, 1.0)}
)

Using device: cuda
Model config: GRUAudioConfig(num_conditioning_params=2, embedding_dim=16, hidden_size=128, num_layers=2, dropout=0.1)


In [8]:
# ---- Create Output Folders ----
os.makedirs(out_dir, exist_ok=True)
os.makedirs(f"{out_dir}/checkpoints", exist_ok=True)
os.makedirs(f"{out_dir}/tensorboard", exist_ok=True)

# ---- Save Config ----
with open(f"{out_dir}/config.txt", "w") as f:
    f.write(str(model_config))
    f.write("\n")
    f.write(str(data_config))

# ---- Load Datasets ----
train_ds = MuLawAudioDataset(data_config)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

val_dl = None
if validation_data_dir:
    val_config = AudioDatasetConfig(
        data_dir=validation_data_dir,
        sequence_length=data_config.sequence_length,
        parameter_specs=data_config.parameter_specs
    )
    val_ds = MuLawAudioDataset(val_config)
    val_dl = DataLoader(val_ds, batch_size=batch_size)

ValueError: num_samples should be a positive integer value, but got num_samples=0

In [None]:
# ---- Initialize Model ----
model = GRUAudioModel(model_config).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

start_epoch = 1
if resume_checkpoint:
    checkpoint = torch.load(resume_checkpoint, map_location=device)
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resumed from epoch {start_epoch}")

# ---- TensorBoard ----
writer = SummaryWriter(log_dir=f"{out_dir}/tensorboard")

# ---- Training Loop ----
for epoch in range(start_epoch, num_epochs + 1):
    model.train()
    total_loss = 0.0

    for batch in train_dl:
        x, cond, y = [b.to(device) for b in batch]  # [B,T-1], [B,T-1,p], [B,T-1]

        optimizer.zero_grad()
        logits = model(x, cond)  # [B,T-1,256]
        loss = F.cross_entropy(logits.view(-1, 256), y.view(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_dl)
    writer.add_scalar("Loss/train", avg_train_loss, epoch)
    print(f"Epoch {epoch} | Train Loss: {avg_train_loss:.4f}")

    # ---- Validation ----
    if val_dl:
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in val_dl:
                x, cond, y = [b.to(device) for b in batch]
                logits = model(x, cond)
                loss = F.cross_entropy(logits.view(-1, 256), y.view(-1))
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_dl)
        writer.add_scalar("Loss/val", avg_val_loss, epoch)
        print(f"Epoch {epoch} | Val Loss: {avg_val_loss:.4f}")

    scheduler.step()

    # ---- Save Checkpoint ----
    if epoch % model_save_every == 0 or epoch == num_epochs:
        save_path = f"{out_dir}/checkpoints/epoch_{epoch}.pt"
        torch.save({
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }, save_path)
        print(f"Saved checkpoint: {save_path}")

writer.close()