In [None]:
! git clone https://github.com/credwood/rir_ml.git

In [None]:
! pip install -r rir_ml/requirements.txt

In [3]:
import sys
sys.path.append('/content/rir_ml')


In [4]:
import logging
from datetime import datetime
import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
from torch import optim
from tqdm import tqdm

from core.dataset_utils import RIRHDF5Dataset, denormalize
from core.models import CNN1D
from core.training_utils import WeightedMSELoss


In [5]:
# Mount Google Drive to access your data
from google.colab import drive
drive.mount('/content/drive')

# Now you can use your HDF5 files like this:
rir_path = '/content/drive/MyDrive/rir_data/rir_dataset.h5'
metrics_path = '/content/drive/MyDrive/rir_data/rir_metrics.h5'

Mounted at /content/drive


In [6]:
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from sklearn.model_selection import train_test_split

# Full dataset
full_dataset = RIRHDF5Dataset(rir_path, metrics_path)

indices = list(range(len(full_dataset)))
train_idx, val_idx = train_test_split(indices, test_size=0.2, random_state=42)

# Get train targets
train_targets = np.stack([full_dataset[i][1].numpy() for i in train_idx])
target_mean = train_targets.mean(axis=0)
target_std = train_targets.std(axis=0)

train_set = RIRHDF5Dataset(
    rir_path, metrics_path,
    normalize_targets=True,
    target_mean=target_mean,
    target_std=target_std,
    subset_indices=train_idx
)

val_set = RIRHDF5Dataset(
    rir_path, metrics_path,
    normalize_targets=True,
    target_mean=target_mean,
    target_std=target_std,
    subset_indices=val_idx
)


In [7]:
# Set up log file path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = "training_logs"
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"train_{timestamp}.log")

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)s | %(message)s",
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger()

In [8]:
import matplotlib.pyplot as plt
from matplotlib.ticker import MaxNLocator
import seaborn as sns
from IPython.display import clear_output, display
from collections import defaultdict


def init_weights_kaiming(m):
    if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
        init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            nn.init.zeros_(m.bias)


def plot_metrics(history):
    """
    Plots training and validation loss, and real-world MAE curves for all metrics.
    Ensures integer epoch ticks.
    """
    epochs = list(range(1, len(history['train_loss']) + 1))

    # --- Plot Training and Validation Loss ---
    plt.figure(figsize=(8, 5))
    ax = sns.lineplot(x=epochs, y=history['train_loss'], label='Train Loss')
    sns.lineplot(x=epochs, y=history['val_loss'], label='Val Loss')
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.title("Training and Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # --- Plot Real-World MAE (Seconds) ---
    plt.figure(figsize=(10, 5))
    ax = sns.lineplot(x=epochs, y=history['mae_rt60'], label='RT60 (s)')
    sns.lineplot(x=epochs, y=history['mae_edt'], label='EDT (s)')
    sns.lineplot(x=epochs, y=history['mae_d50'], label='D50')
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.title("Real-World MAE — Time-Based Metrics")
    plt.xlabel("Epoch")
    plt.ylabel("MAE (seconds)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()

    # --- Plot Real-World MAE (Decibels) ---
    plt.figure(figsize=(6, 5))
    ax = sns.lineplot(x=epochs, y=history['mae_c50'], label='C50 (dB)', color='orange')
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    plt.title("Real-World MAE — C50")
    plt.xlabel("Epoch")
    plt.ylabel("MAE (dB)")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.show()


def train_model(model, train_dataset, val_dataset, num_epochs=20, batch_size=64, lr=1e-4, device='cuda'):
    """
    Trains the model on normalized metrics, using AdamW and validation-based LR scheduler.
    """
    model = model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3
    )
    metric_weights = torch.tensor([0.5, 1.0, 2.0, 0.5])
    criterion = WeightedMSELoss(metric_weights).to(device)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    best_val_loss = float("inf")
    history = defaultdict(list)
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for rirs, targets in tqdm(train_loader, desc=f"Epoch {epoch+1} Training"):
            rirs, targets = rirs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(rirs.unsqueeze(1))  # [B, 1, N]
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        avg_train_loss = running_loss / len(train_loader)

        # Validation
        model.eval()
        val_loss = 0.0
        mae_sum = torch.zeros(4, device=device)
        num_batches = 0
        with torch.no_grad():
            for rirs, targets in tqdm(val_loader, desc=f"Epoch {epoch+1} Validation"):
                rirs, targets = rirs.to(device), targets.to(device)
                outputs = model(rirs.unsqueeze(1))
                loss = criterion(outputs, targets)
                val_loss += loss.item()

                # Denormalize for metric reporting
                # Denormalize to linear space
                outputs_real = denormalize(outputs, val_dataset.target_mean, val_dataset.target_std)
                targets_real = denormalize(targets, val_dataset.target_mean, val_dataset.target_std)

                # Compute MAE in real-world scale
                mae_batch = torch.mean(torch.abs(outputs_real - targets_real), dim=0)

                mae_report = mae_batch.clone()

                mae_sum += mae_report
                num_batches += 1

        avg_val_loss = val_loss / num_batches
        avg_mae = mae_sum / num_batches
        avg_mae[2] = 10 * torch.log10(avg_mae[2])
        avg_mae = avg_mae.cpu().numpy()

        scheduler.step(avg_val_loss)

        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), "best_cnn_model.pt")
            logger.info(f"New best model saved at epoch {epoch+1} with val loss {avg_val_loss:.4f}")

        logger.info(f"Epoch {epoch+1:02d} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        logger.info(f"MAE — RT60: {avg_mae[0]:.4f}s | EDT: {avg_mae[1]:.4f}s | C50: {avg_mae[2]:.2f}dB | D50: {avg_mae[3]:.3f}")
        history["epoch"].append(epoch + 1)
        history["train_loss"].append(avg_train_loss)
        history["val_loss"].append(avg_val_loss)
        history["mae_rt60"].append(avg_mae[0])
        history["mae_edt"].append(avg_mae[1])
        history["mae_c50"].append(avg_mae[2])
        history["mae_d50"].append(avg_mae[3])
        print(f"Epoch {epoch+1} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"Real-World MAE: RT60={avg_mae[0]:.4f}s, EDT={avg_mae[1]:.4f}s, C50={avg_mae[2]:.2f}dB, D50={avg_mae[3]:.3f}")


        #plot_metrics(history)

    return model


In [9]:
model = CNN1D()
model.apply(init_weights_kaiming)
model = train_model(model, train_set, val_set, num_epochs=20, batch_size=1024, lr=1e-4, device='cuda')

  self.register_buffer("weights", torch.tensor(weights, dtype=torch.float32))
Epoch 1 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 1 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Epoch 1 | Train Loss: 1.0083 | Val Loss: 1.1102
Real-World MAE: RT60=0.0809s, EDT=0.1527s, C50=32.70dB, D50=0.146


Epoch 2 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 2 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.55it/s]


Epoch 2 | Train Loss: 1.0017 | Val Loss: 1.1066
Real-World MAE: RT60=0.0800s, EDT=0.1470s, C50=32.88dB, D50=0.136


Epoch 3 Training: 100%|██████████| 39/39 [01:36<00:00,  2.48s/it]
Epoch 3 Validation: 100%|██████████| 10/10 [00:07<00:00,  1.41it/s]


Epoch 3 | Train Loss: 1.0012 | Val Loss: 1.1082
Real-World MAE: RT60=0.0804s, EDT=0.1457s, C50=33.01dB, D50=0.136


Epoch 4 Training: 100%|██████████| 39/39 [01:36<00:00,  2.48s/it]
Epoch 4 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Epoch 4 | Train Loss: 1.0006 | Val Loss: 1.1086
Real-World MAE: RT60=0.0802s, EDT=0.1432s, C50=32.66dB, D50=0.134


Epoch 5 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 5 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Epoch 5 | Train Loss: 1.0006 | Val Loss: 1.1109
Real-World MAE: RT60=0.0798s, EDT=0.1417s, C50=33.22dB, D50=0.133


Epoch 6 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 6 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Epoch 6 | Train Loss: 1.0004 | Val Loss: 1.1070
Real-World MAE: RT60=0.0804s, EDT=0.1456s, C50=32.76dB, D50=0.133


Epoch 7 Training: 100%|██████████| 39/39 [01:36<00:00,  2.48s/it]
Epoch 7 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Epoch 7 | Train Loss: 1.0001 | Val Loss: 1.1070
Real-World MAE: RT60=0.0804s, EDT=0.1426s, C50=32.91dB, D50=0.135


Epoch 8 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 8 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.55it/s]


Epoch 8 | Train Loss: 0.9997 | Val Loss: 1.1068
Real-World MAE: RT60=0.0799s, EDT=0.1443s, C50=32.94dB, D50=0.135


Epoch 9 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 9 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.51it/s]


Epoch 9 | Train Loss: 0.9980 | Val Loss: 1.1084
Real-World MAE: RT60=0.0800s, EDT=0.1409s, C50=33.02dB, D50=0.136


Epoch 10 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 10 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Epoch 10 | Train Loss: 0.9997 | Val Loss: 1.1067
Real-World MAE: RT60=0.0808s, EDT=0.1428s, C50=32.79dB, D50=0.134


Epoch 11 Training: 100%|██████████| 39/39 [01:36<00:00,  2.48s/it]
Epoch 11 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Epoch 11 | Train Loss: 0.9997 | Val Loss: 1.1069
Real-World MAE: RT60=0.0799s, EDT=0.1422s, C50=32.95dB, D50=0.134


Epoch 12 Training: 100%|██████████| 39/39 [01:36<00:00,  2.48s/it]
Epoch 12 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Epoch 12 | Train Loss: 0.9993 | Val Loss: 1.1071
Real-World MAE: RT60=0.0799s, EDT=0.1440s, C50=32.97dB, D50=0.134


Epoch 13 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 13 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


Epoch 13 | Train Loss: 0.9991 | Val Loss: 1.1066
Real-World MAE: RT60=0.0801s, EDT=0.1425s, C50=32.89dB, D50=0.135


Epoch 14 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 14 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


Epoch 14 | Train Loss: 0.9991 | Val Loss: 1.1065
Real-World MAE: RT60=0.0802s, EDT=0.1448s, C50=32.88dB, D50=0.134


Epoch 15 Training: 100%|██████████| 39/39 [01:36<00:00,  2.48s/it]
Epoch 15 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.53it/s]


Epoch 15 | Train Loss: 0.9987 | Val Loss: 1.1065
Real-World MAE: RT60=0.0803s, EDT=0.1438s, C50=32.85dB, D50=0.135


Epoch 16 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 16 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Epoch 16 | Train Loss: 0.9991 | Val Loss: 1.1066
Real-World MAE: RT60=0.0801s, EDT=0.1431s, C50=32.91dB, D50=0.135


Epoch 17 Training: 100%|██████████| 39/39 [01:36<00:00,  2.48s/it]
Epoch 17 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Epoch 17 | Train Loss: 0.9991 | Val Loss: 1.1065
Real-World MAE: RT60=0.0802s, EDT=0.1431s, C50=32.88dB, D50=0.135


Epoch 18 Training: 100%|██████████| 39/39 [01:36<00:00,  2.48s/it]
Epoch 18 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.55it/s]


Epoch 18 | Train Loss: 0.9987 | Val Loss: 1.1066
Real-World MAE: RT60=0.0804s, EDT=0.1432s, C50=32.88dB, D50=0.136


Epoch 19 Training: 100%|██████████| 39/39 [01:36<00:00,  2.47s/it]
Epoch 19 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]


Epoch 19 | Train Loss: 0.9992 | Val Loss: 1.1065
Real-World MAE: RT60=0.0802s, EDT=0.1431s, C50=32.90dB, D50=0.135


Epoch 20 Training: 100%|██████████| 39/39 [01:36<00:00,  2.48s/it]
Epoch 20 Validation: 100%|██████████| 10/10 [00:06<00:00,  1.54it/s]

Epoch 20 | Train Loss: 0.9990 | Val Loss: 1.1065
Real-World MAE: RT60=0.0802s, EDT=0.1441s, C50=32.91dB, D50=0.135



