In [None]:
import logging
from datetime import datetime
import os

import torch
from torch import optim
from tqdm import tqdm

from core.dataset_utils import RIRHDF5Dataset, denormalize
from core.models import CNN1D
from core.training_utils import WeightedMSELoss


In [2]:
from torch.utils.data import DataLoader, random_split

# Full dataset
dataset = RIRHDF5Dataset()

# Split into train/val
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_set, val_set = random_split(dataset, [train_size, val_size])

In [3]:
# Set up log file path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = "training_logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"train_{timestamp}.log")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger()

In [None]:
def train_model(model, train_dataset, val_dataset, num_epochs=20, batch_size=64, lr=1e-3, device='cpu'):
    """
    Trains the model on normalized metrics, using AdamW and validation-based LR scheduler.
    """
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )
    metric_weights = [2.0, 1.0, 0.5, 0.5]
    criterion = WeightedMSELoss(metric_weights)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    best_val_loss = float("inf")
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for rirs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            rirs, targets = rirs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(rirs.unsqueeze(1))  # [B, 1, N]
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_train_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0
        mae_sum = torch.zeros(4, device=device)
        num_batches = 0
        with torch.no_grad():
            for rirs, targets in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
                rirs, targets = rirs.to(device), targets.to(device)
                outputs = model(rirs.unsqueeze(1))
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                # Denormalize for metric reporting
                outputs_real = denormalize(outputs, val_dataset.target_mean, val_dataset.target_std)
                targets_real = denormalize(targets, val_dataset.target_mean, val_dataset.target_std)
                mae_batch = torch.mean(torch.abs(outputs_real - targets_real), dim=0)
                mae_sum += mae_batch
                num_batches += 1

        avg_val_loss = val_loss / num_batches
        avg_mae = (mae_sum / num_batches).cpu().numpy()

        scheduler.step(avg_val_loss)
        best_val_loss = min(best_val_loss, avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "best_cnn_model.pt")
            logger.info(f"New best model saved at epoch {epoch+1} with val loss {avg_val_loss:.4f}")

        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"Real-World MAE: RT60={avg_mae[0]:.4f}s, EDT={avg_mae[1]:.4f}s, C50={avg_mae[2]:.2f}dB, D50={avg_mae[3]:.3f}")
        logger.info(f"Epoch {epoch+1:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        logger.info(f"MAE — RT60: {avg_mae[0]:.4f}s | EDT: {avg_mae[1]:.4f}s | C50: {avg_mae[2]:.2f}dB | D50: {avg_mae[3]:.3f}")

    return model


In [None]:
model = CNN1D()
model = train_model(model, train_set, val_set, num_epochs=20, batch_size=64, lr=1e-3, device='cpu')