In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/HRTF_Interpolation_Project

/content/drive/MyDrive/HRTF_Interpolation_Project


In [3]:
!pip install sofa scipy torch matplotlib

Collecting sofa
  Downloading sofa-0.7.6.tar.gz (26 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyramid (from sofa)
  Downloading pyramid-2.0.2-py3-none-any.whl.metadata (20 kB)
Collecting transaction (from sofa)
  Downloading transaction-5.0-py3-none-any.whl.metadata (14 kB)
Collecting validate_email (from sofa)
  Downloading validate_email-1.3.tar.gz (4.7 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pyDNS (from sofa)
  Downloading pydns-2.3.6.tar.gz (28 kB)
  [1;31merror[0m: [1msubprocess-exited-with-error[0m
  
  [31m×[0m [32mpython setup.py egg_info[0m did not run successfully.
  [31m│[0m exit code: [1;36m1[0m
  [31m╰─>[0m See above for output.
  
  [1;35mnote[0m: This error originates from a subprocess, and is likely not a problem with pip.
  Preparing metadata (setup.py) ... [?25l[?25herror
[1;31merror[0m: [1mmetadata-generation-failed[0m

[31m×[0m Encountered error while generating package metadata.
[31m╰─>

In [4]:
!pip install python-sofa


Collecting python-sofa
  Downloading python_sofa-0.2.0-py3-none-any.whl.metadata (1.5 kB)
Collecting netcdf4 (from python-sofa)
  Downloading netCDF4-1.7.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.8 kB)
Collecting datetime (from python-sofa)
  Downloading DateTime-5.5-py3-none-any.whl.metadata (33 kB)
Collecting zope.interface (from datetime->python-sofa)
  Downloading zope.interface-7.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting cftime (from netcdf4->python-sofa)
  Downloading cftime-1.6.4.post1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (8.7 kB)
Downloading python_sofa-0.2.0-py3-none-any.whl (49 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.0/50.0 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloadi

# DATASET.PY

In [5]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset
import sofa
from scipy.fft import rfft, rfftfreq
from scipy.special import sph_harm

# --- CONFIG ---
SH_ORDER = 16
NUM_COEFF = (SH_ORDER + 1) ** 2
FREQ_RANGE = (172, 16000)
N_KNOWN = 120
N_TOTAL = 480
N_FREQS = 93

# --- HELPERS ---

def real_spherical_harmonics(N, theta, phi):
    P = len(theta)
    C = (N + 1) ** 2
    Y = np.zeros((P, C))
    idx = 0
    for n in range(N + 1):
        for m in range(-n, n + 1):
            Ynm = sph_harm(m, n, phi, theta)
            if m < 0:
                Y[:, idx] = np.sqrt(2) * (-1)**m * Ynm.imag
            elif m == 0:
                Y[:, idx] = Ynm.real
            else:
                Y[:, idx] = np.sqrt(2) * (-1)**m * Ynm.real
            idx += 1
    return Y

def compute_area_weights(Y_known):
    return np.ones(Y_known.shape[0]) / Y_known.shape[0]

def process_sofa_file(file_path):
    db = sofa.Database.open(file_path)
    ir = db.Data.IR.get_values()  # (M, R, N)
    fs = db.Data.SamplingRate.get_values(indices={"M": 0})

    # Get source positions (azimuth, elevation) in radians
    pos = db.Source.Position.get_values(system="spherical")[:, :2]
    azimuth = np.deg2rad(pos[:, 0])
    elevation = np.pi/2 - np.deg2rad(pos[:, 1])

    # SH basis matrix for all 480 directions
    Y = real_spherical_harmonics(SH_ORDER, elevation, azimuth)

    # Extract magnitude HRTFs from HRIRs
    n_measurements, n_receivers, n_samples = ir.shape
    freqs = rfftfreq(n_samples, d=1/fs)
    valid_idx = np.where((freqs >= FREQ_RANGE[0]) & (freqs <= FREQ_RANGE[1]))[0]

    HRTF_mag = np.zeros((n_measurements, n_receivers, len(valid_idx)))
    for m in range(n_measurements):
        for r in range(n_receivers):
            hrir = ir[m, r, :]
            hrtf = np.abs(rfft(hrir))[valid_idx]
            HRTF_mag[m, r, :] = 20 * np.log10(hrtf + 1e-8)

    return HRTF_mag, Y

# --- MAIN DATASET CLASS ---

class HRTFDataset(Dataset):
    def __init__(self, sofa_dir, known_idx_path="known_directions_idx.npy"):
        self.file_paths = [os.path.join(sofa_dir, f) for f in os.listdir(sofa_dir)
                           if f.endswith(".sofa") and "measured" in f]

        if os.path.exists(known_idx_path):
            self.known_idx = np.load(known_idx_path)
        else:
            raise FileNotFoundError("Missing known_directions_idx.npy")

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        file_path = self.file_paths[idx]
        hrtf_mag, Y = process_sofa_file(file_path)

        # Left ear
        x = hrtf_mag[self.known_idx, 0, :].T  # [93, 120]
        y = hrtf_mag[:, 0, :].T               # [93, 480]

        Y_known = Y[self.known_idx, :]       # [120, 289]
        Y_inv = np.linalg.pinv(Y_known)      # [289, 120]
        area = compute_area_weights(Y_known) # [120]

        return {
            "x": torch.tensor(x, dtype=torch.float32),
            "y": torch.tensor(y, dtype=torch.float32),
            "Y_inv": torch.tensor(Y_inv.T[:, :81], dtype=torch.float32),  # L=8 for first SHT
            "area": torch.tensor(area, dtype=torch.float32),
            "Y": torch.tensor(Y, dtype=torch.float32)
        }


120 known_directions (once)

In [6]:
import numpy as np
import sofa
import os

path = "HRIRs/pp1_HRIRs_measured.sofa"
db = sofa.Database.open(path)

n_dirs = db.Dimensions.M  # should be 440
np.random.seed(42)
known_idx = np.sort(np.random.choice(n_dirs, 120, replace=False))
np.save("known_directions_idx.npy", known_idx)

print("Saved known_directions_idx.npy with shape:", known_idx.shape)


Saved known_directions_idx.npy with shape: (120,)


# MODEL.PY

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math  # also needed for stdv in SHConv

# layers

class SHT(nn.Module):
    def __init__(self, L, Y_inv, area):
        """
        Input shape  : [batch, n_ch, 120]
        Output shape : [batch, n_ch, (8+1)**2]
        """

        super().__init__()

        self.Y_inv = Y_inv[:, : (L + 1) ** 2]
        self.area = area

    def forward(self, x):
        x = torch.mul(self.area, x)
        x = torch.matmul(x, self.Y_inv)

        return x

class SHConv(nn.Module):
    def __init__(self, in_ch, out_ch, L):
        """
        Input shape  : [batch, in_ch, (L+1)**2]
        Output shape : [batch, out_ch, (L+1)**2]
        """

        super().__init__()

        ncpt = L + 1

        self.weight = nn.Parameter(torch.empty(in_ch, out_ch, ncpt))
        self.repeats = nn.Parameter(torch.tensor([(2 * l + 1) for l in range(L + 1)]), requires_grad=False)

        stdv = 1.0 / math.sqrt(in_ch * (L + 1))
        self.weight.data.uniform_(-stdv, stdv)

    def forward(self, x):

        w = torch.repeat_interleave(self.weight, self.repeats, dim=2)
        x = torch.mul(w.unsqueeze(0), x.unsqueeze(2)).sum(1)

        return x


class ISHT1(nn.Module):
    def __init__(self, Y):
        """
        Input shape  : [batch, n_ch, (L+1)**2]
        Output shape : [batch, n_ch, 120]
        """

        super().__init__()

        self.Y = Y

    def forward(self, x):
        x = torch.matmul(x, self.Y[: x.shape[-1], :])

        return x

class ISHT2(nn.Module):
    def __init__(self, Y480_289):
        """
        Input shape  : [batch, n_ch, (L+1)**2]
        Output shape : [batch, n_ch, 480]
        """

        super().__init__()

        self.Y480_289 = Y480_289

    def forward(self, x):
        x = torch.matmul(x, self.Y480_289)

        return x


# model

class SCNN(nn.Module):
    def __init__(self, Y, Y_inv, area,Y480_289, in_ch, out_ch, L,nonlinear=None, fullband=True, bn=True):
        """
        In channel shape  : [batch, in_ch, n_vertex]
        Out channel shape : [batch, out_ch, n_vertex]
        """
        super().__init__()

        self.first = nn.Sequential(SHT(8, Y_inv, area),ISHT1(Y))
        self.shconv1 = nn.Sequential(SHT(L, Y_inv, area), SHConv(in_ch, 93*2, L), ISHT1(Y))
        self.shconv2 = nn.Sequential(SHT(16, Y_inv, area), SHConv(93*2, out_ch, 16), ISHT1(Y))
        self.final = nn.Sequential(SHT(16, Y_inv, area),ISHT2(Y480_289))

        self.impulse1 = nn.Conv1d(in_ch, 93*2, kernel_size=1, stride=1, bias=not bn) if fullband else lambda _: 0
        self.impulse2 = nn.Conv1d(93*2, in_ch, kernel_size=1, stride=1, bias=not bn) if fullband else lambda _: 0
        self.nonlinear = F.relu if nonlinear is not None else nn.Identity()

    def forward(self, x):
        x = self.first(x)

        x = self.shconv1(x) + self.impulse1(x)
        x = self.nonlinear(x)

        x = self.shconv2(x) + self.impulse2(x)
        x = self.nonlinear(x)

        x = self.final(x)

        return x

# TRAIN.PY

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# from dataset import HRTFDataset
# from model import SCNN
import os

# --- Hyperparameters ---
EPOCHS = 50
BATCH_SIZE = 1
LEARNING_RATE = 1e-3
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SH_ORDER = 16

# --- LSD LOSS FUNCTION ---
def lsd_loss(y_pred, y_true):
    return torch.sqrt(torch.mean((y_pred - y_true) ** 2))

# --- Load dataset ---
dataset = HRTFDataset(sofa_dir="HRIRs", known_idx_path="known_directions_idx.npy")
dataset.file_paths = dataset.file_paths[:10]  # Limit to first 10 files for quick test
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# --- Sample a batch to get input/output shapes ---
sample = next(iter(loader))
in_ch = sample["x"].shape[1]      # 93 frequency bins
out_ch = sample["y"].shape[1]

# --- Init model ---
model = SCNN(
    Y=sample["Y"].to(DEVICE),
    Y_inv=sample["Y_inv"].to(DEVICE),
    area=sample["area"].to(DEVICE),
    Y480_289=sample["Y"].to(DEVICE),
    in_ch=in_ch,
    out_ch=out_ch,
    L=SH_ORDER,
    nonlinear=True
).to(DEVICE)

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- Training loop ---
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for batch in loader:
        x = batch["x"].to(DEVICE)               # [B, 93, 120]
        y = batch["y"].to(DEVICE)               # [B, 93, 480]
        model.Y_inv = batch["Y_inv"].to(DEVICE) # dynamically assign if needed
        model.area = batch["area"].to(DEVICE)
        model.Y = batch["Y"].to(DEVICE)

        optimizer.zero_grad()
        y_pred = model(x)                       # → [B, 93, 480]
        loss = lsd_loss(y_pred, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(loader)
    print(f"Epoch {epoch+1}/{EPOCHS} | LSD Loss: {avg_loss:.4f}")

# --- Save the model ---
torch.save(model.state_dict(), "scnn_model_test10.pth")
print("✅ Model saved as scnn_model_test10.pth")


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [1, 120] but got: [1, 81].