In [None]:
def create_srir_sofa(
    filepath,
    rirs,
    source_pos,
    mic_pos,
    db_name="Default_db",
    room_name="Room_name",
    listener_name="mic",
    sr=24000,
    comment="N/A",
):
    print("Starting create_srir_sofa function")
    
    M = rirs.shape[0]
    R = rirs.shape[1]
    N = rirs.shape[2]
    E = 1
    I = 1
    C = 3

    print(f"Shapes: M={M}, R={R}, N={N}, E={E}, I={I}, C={C}")

    assert rirs.shape == (M, R, N), f"RIRs shape mismatch: expected {(M, R, N)}, got {rirs.shape}"
    assert source_pos.shape == (M, C), f"Source position shape mismatch: expected {(M, C)}, got {source_pos.shape}"

    print(f"Checking if file exists: {filepath}")
    if os.path.exists(filepath):
        print(f"Overwriting {filepath}")
        os.remove(filepath)
    
    print("Creating Dataset")
    rootgrp = Dataset(filepath, "w", format="NETCDF4")

    print("Setting Required Attributes")
    rootgrp.Conventions = "SOFA"
    rootgrp.Version = "2.1"
    rootgrp.SOFAConventions = "SingleRoomSRIR"
    rootgrp.SOFAConventionsVersion = "1.0"
    rootgrp.APIName = "pysofaconventions"
    rootgrp.APIVersion = "0.1.5"
    rootgrp.AuthorContact = "chris.ick@nyu.edu"
    rootgrp.Organization = "Music and Audio Research Lab - NYU"
    rootgrp.License = "Use whatever you want"
    rootgrp.DataType = "FIR"
    rootgrp.DateCreated = time.ctime(time.time())
    rootgrp.DateModified = time.ctime(time.time())
    rootgrp.Title = db_name + " - " + room_name
    rootgrp.RoomType = "shoebox"
    rootgrp.DatabaseName = db_name
    rootgrp.ListenerShortName = listener_name
    rootgrp.RoomShortName = room_name
    rootgrp.Comment = comment

    print("Creating Required Dimensions")
    rootgrp.createDimension("M", M)
    rootgrp.createDimension("N", N)
    rootgrp.createDimension("E", E)
    rootgrp.createDimension("R", R)
    rootgrp.createDimension("I", I)
    rootgrp.createDimension("C", C)

    print("Creating Required Variables")
    print("Creating ListenerPosition")
    listenerPositionVar = rootgrp.createVariable("ListenerPosition", "f8", ("M", "C"))
    listenerPositionVar.Units = "metre"
    listenerPositionVar.Type = "cartesian"
    listenerPositionVar[:] = mic_pos

    print("Creating ListenerUp")
    listenerUpVar = rootgrp.createVariable("ListenerUp", "f8", ("I", "C"))
    listenerUpVar.Units = "metre"
    listenerUpVar.Type = "cartesian"
    listenerUpVar[:] = np.asarray([0, 0, 1])

    print("Creating ListenerView")
    listenerViewVar = rootgrp.createVariable("ListenerView", "f8", ("I", "C"))
    listenerViewVar.Units = "metre"
    listenerViewVar.Type = "cartesian"
    listenerViewVar[:] = np.asarray([1, 0, 0])

    print("Creating EmitterPosition")
    emitterPositionVar = rootgrp.createVariable(
        "EmitterPosition", "f8", ("E", "C", "I")
    )
    emitterPositionVar.Units = "metre"
    emitterPositionVar.Type = "spherical"
    emitterPositionVar[:] = np.zeros((E, C, I))

    print("Creating SourcePosition")
    sourcePositionVar = rootgrp.createVariable("SourcePosition", "f8", ("M", "C"))
    sourcePositionVar.Units = "metre"
    sourcePositionVar.Type = "cartesian"
    sourcePositionVar[:] = source_pos

    print("Creating SourceUp")
    sourceUpVar = rootgrp.createVariable("SourceUp", "f8", ("I", "C"))
    sourceUpVar.Units = "metre"
    sourceUpVar.Type = "cartesian"
    sourceUpVar[:] = np.asarray([0, 0, 1])

    print("Creating SourceView")
    sourceViewVar = rootgrp.createVariable("SourceView", "f8", ("I", "C"))
    sourceViewVar.Units = "metre"
    sourceViewVar.Type = "cartesian"
    sourceViewVar[:] = np.asarray([1, 0, 0])

    print("Creating ReceiverPosition")
    receiverPositionVar = rootgrp.createVariable(
        "ReceiverPosition", "f8", ("R", "C", "I")
    )
    receiverPositionVar.Units = "metre"
    receiverPositionVar.Type = "cartesian"
    receiverPositionVar[:] = np.zeros((R, C, I))

    print("Creating Data.SamplingRate")
    samplingRateVar = rootgrp.createVariable("Data.SamplingRate", "f8", ("I"))
    samplingRateVar.Units = "hertz"
    samplingRateVar[:] = sr

    print("Creating Data.Delay")
    delayVar = rootgrp.createVariable("Data.Delay", "f8", ("I", "R"))
    delay = np.zeros((I, R))
    delayVar[:, :] = delay

    print("Creating Data.IR")
    dataIRVar = rootgrp.createVariable("Data.IR", "f8", ("M", "R", "N"))
    dataIRVar.ChannelOrdering = "acn"  # standard ambi ordering
    dataIRVar.Normalization = "sn3d"
    dataIRVar[:] = rirs

    print("Closing file")
    rootgrp.close()  # Note: Added parentheses here
    print(f"SOFA file saved to {filepath}")
    print("create_srir_sofa function completed")

In [None]:
from pathlib import Path
import numpy as np
import os
from netCDF4 import Dataset
import time
import trimesh
import matplotlib.pyplot as plt
from PIL import Image
import csv 
from rlr_audio_propagation import Config, Context, ChannelLayout, ChannelLayoutType
import matplotlib.pyplot as plt

GIBSON_DB_NAME = "GIBSON"
DATASET_DIR = "/datasets/soundspaces/scene_datasets/gibson_copy"
dest_path_sofa = Path("/datasets/soundspaces/ss_rooms_tetra")
dest_path_sofa.mkdir(parents=True, exist_ok=True)
audio_fmts = ["mic"]


def sofa_file_exists(glb_file, dest_path_sofa, audio_fmts):
    for fmt in audio_fmts:
        filepath = dest_path_sofa / f"soundspaces_{fmt}_{os.path.splitext(glb_file.name)[0]}.sofa"
        if filepath.exists():
            return True
    return False


def prepare_soundspaces(glb_file, dest_path_sofa, audio_fmts=["mic"]):
    global source_spheres, cfg, ctx, scene, mic_positions, source_positions, adjusted_source_positions

    # Reset the variables
    source_spheres = []
    cfg = Config()
    ctx = Context(cfg)
    scene = trimesh.Scene()
    mic_positions = []
    source_positions = []
    adjusted_source_positions = []
    
    Image.MAX_IMAGE_PIXELS = None
    mesh = trimesh.load(glb_file, force='mesh')
    
    # MESH REPAIR PROCESS 
    vertices = mesh.vertices.copy()
    faces = mesh.faces.copy()
    new_mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    broken_faces = trimesh.repair.broken_faces(new_mesh)
    print(f"Number of broken faces: {len(broken_faces)}")
    trimesh.repair.fix_inversion(new_mesh)
    trimesh.repair.fix_normals(new_mesh)
    trimesh.repair.fix_winding(new_mesh)
    new_mesh.fill_holes()
    new_mesh.visual.face_colors = np.ones((len(new_mesh.faces), 4)) * 255
    new_mesh.visual.face_colors[broken_faces] = [255, 0, 0, 255]
    broken_faces = trimesh.repair.broken_faces(new_mesh)
    print(f"Number of broken faces: {len(broken_faces)}")
    
    scene = trimesh.Scene()
    scene.add_geometry(new_mesh)
    
    cfg = Config()
    
    source_spheres = []
    
    def add_sphere(scene, pos, color=[0,0,0], r=0.2):
        sphere = trimesh.creation.uv_sphere(radius=r)
        sphere.apply_translation(pos)
        sphere.visual.face_colors = color
        scene.add_geometry(sphere)
        return sphere
    
    def is_point_inside_mesh(mesh, point):
        return mesh.contains([point])[0]
    
    def get_random_point_inside_mesh(mesh, min_distance_from_surface=0.2):
        while True:
            point = np.random.uniform(mesh.bounds[0], mesh.bounds[1])
            if is_point_inside_mesh(mesh, point):
                # Check distance from surface
                _, distance, _ = mesh.nearest.on_surface([point])
                if distance[0] >= min_distance_from_surface:
                    return point
    
    def calculate_weighted_average_ray_length(mesh, point, num_rays=100):
        angles = np.random.uniform(0, 2*np.pi, num_rays)
        elevations = np.random.uniform(-np.pi/2, np.pi/2, num_rays)
        directions = np.column_stack([
            np.cos(elevations) * np.cos(angles),
            np.cos(elevations) * np.sin(angles),
            np.sin(elevations)
        ])
        origins = np.tile(point, (num_rays, 1))
        distances = trimesh.proximity.longest_ray(mesh, origins, directions)
        
        # Apply weights to the distances (square the distances)
        weights = distances ** 2
        
        # Calculate weighted average
        weighted_average = np.sum(distances * weights) / np.sum(weights)
        
        return weighted_average
    
    # Find a suitable microphone position
    min_avg_ray_length = 3.0  # we can adjust this value as needed 
    max_attempts = 100
    for attempt in range(max_attempts):
        mic_center = get_random_point_inside_mesh(new_mesh)
        avg_ray_length = calculate_weighted_average_ray_length(new_mesh, mic_center)
        
        if avg_ray_length >= min_avg_ray_length:
            print(f"Found suitable microphone position after {attempt+1} attempts")
            break
    else:
        print(f"Could not find a suitable position after {max_attempts} attempts. Using the last attempted position.")

    mic_radius = 0.06
    mic_positions = [
        (55, 45),
        (125, 315),
        (125, 135),
        (55, 225)
    ]
    
    def spherical_to_cartesian(r, theta, phi):
        theta_rad = np.radians(theta)
        phi_rad = np.radians(phi)
        x = r * np.sin(theta_rad) * np.cos(phi_rad)
        y = r * np.sin(theta_rad) * np.sin(phi_rad)
        z = r * np.cos(theta_rad)
        return x, y, z
    
    mic_cartesian = [spherical_to_cartesian(mic_radius, theta, phi) for theta, phi in mic_positions]
    mic_absolute_positions = [mic_center + np.array(pos) for pos in mic_cartesian]
    
    # Add microphone spheres to the scene
    for mic_pos in mic_absolute_positions:
        add_sphere(scene, mic_pos, [255, 0, 0], r=0.02)  # Red color for microphones

    ctx = Context(cfg)
    ctx.add_object()
    ctx.set_object_position(0, [0, 0, 0])
    ctx.add_mesh_vertices(new_mesh.vertices.flatten().tolist())
    ctx.add_mesh_indices(new_mesh.faces.flatten().tolist(), 3, "default")
    ctx.finalize_object_mesh(0)
    
    # Add listeners (microphones)
    for i, mic_pos in enumerate(mic_absolute_positions):
        ctx.add_listener(ChannelLayout(ChannelLayoutType.Mono, 1))
        ctx.set_listener_position(i, mic_pos.tolist())
    
    # First sample a large number of evenly spaced angles
    num_initial_rays = 200 #1500 
    initial_angles = np.linspace(0, 2*np.pi, num_initial_rays, endpoint=False)
    initial_ray_directions = np.column_stack((np.cos(initial_angles), np.sin(initial_angles), np.zeros_like(initial_angles)))
    
    # Get distances for these initial rays
    ray_origins = np.tile(mic_center, (num_initial_rays, 1))
    distances = trimesh.proximity.longest_ray(new_mesh, ray_origins, initial_ray_directions)
    
    max_distance = 10.0
    distances = np.minimum(distances, max_distance)
    
    # Calculate the number of rays to keep based on distances
    num_rays_to_keep = 100 #1000 
    probabilities = distances / np.sum(distances) # longer distances get higher probabilities
    selected_indices = np.random.choice(num_initial_rays, size=num_rays_to_keep, replace=False, p=probabilities)
    
    # Sort the selected indices to maintain the order
    selected_indices.sort()
    
    # Only use the selected rays
    ray_directions = initial_ray_directions[selected_indices]
    distances = distances[selected_indices]
    
    for i, direction in enumerate(ray_directions):
        ray_end = mic_center + direction * distances[i]
        ray_points = np.vstack((mic_center, ray_end))
        ray_path = trimesh.load_path(ray_points)
        scene.add_geometry(ray_path)
    
    # Sample points along the rays
    num_sources = 1000 
    d = distances**2  # squaring makes it more likely to choose longer rays to sample from 
    idx_rays = np.random.choice(np.arange(len(distances)), size=num_sources, replace=True, p=d/d.sum())
    dist_proportion = np.sqrt(np.random.uniform(0, 1, size=num_sources))
    source_dist = distances[idx_rays] * dist_proportion
    
    min_distance = 0.2 
    min_distance_from_mic = 0.1 
    source_positions = []
    for i, idx in enumerate(idx_rays):
        attempts = 0
        while attempts < 10: 
            new_pos = mic_center + ray_directions[idx] * source_dist[i]
            if (not source_positions or all(np.linalg.norm(new_pos - pos) >= min_distance for pos in source_positions)) and all(np.linalg.norm(new_pos - mic_pos) >= min_distance_from_mic for mic_pos in mic_absolute_positions):
                source_positions.append(new_pos)
                sphere = add_sphere(scene, new_pos, [0, 0, 255], r=0.05)
                source_spheres.append(sphere)
                break
            else:
                source_dist[i] = distances[idx] * np.sqrt(np.random.uniform(0, 1))
            attempts += 1
    
    # Add sources
    for i, position in enumerate(source_positions):
        ctx.add_source()
        ctx.set_source_position(i, position.tolist())
    
    def adjust_source_elevation(mesh, position):
        # Calculate the total height of the mesh
        mesh_height = mesh.bounds[1][2] - mesh.bounds[0][2]
        max_elevation_change = mesh_height / 2
    
        elevation_change = np.random.uniform(-max_elevation_change, max_elevation_change)
        elevation_vector = np.array([0, 0, elevation_change])
        new_position = position + elevation_vector
        
        if is_point_inside_mesh(mesh, new_position):
            return new_position
        else:
            # If outside, try to find valid position within the mesh
            for _ in range(10):  # Try up to 10 times
                elevation_change = np.random.uniform(-max_elevation_change, max_elevation_change)
                elevation_vector = np.array([0, 0, elevation_change])
                new_position = position + elevation_vector
                if is_point_inside_mesh(mesh, new_position):
                    return new_position
            # If we couldn't find valid position, return the original
            return position
    
    # Adjust elevations of source positions
    adjusted_source_positions = []
    for i, position in enumerate(source_positions):
        new_position = adjust_source_elevation(new_mesh, position)
        adjusted_source_positions.append(new_position)
        
        # Update the sphere in the scene
        source_spheres[i].apply_translation(new_position - position)
        # Update source position in the simulation context
        ctx.set_source_position(i, new_position.tolist())
    
    # Replace the original source_positions with the adjusted ones
    source_positions = adjusted_source_positions
    
    print(f"Adjusted {len(source_positions)} source positions for elevation")
    
    # Run simulation
    ctx.simulate()
    efficiency = ctx.get_indirect_ray_efficiency()
    print(f"Overall Indirect Ray Efficiency = {efficiency}")
    scene.show()

    # Generate and save the plots
    room_name = os.path.splitext(os.path.basename(glb_file))[0]
    
    # Create a figure with two subplots side by side
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
    
    # Top-down view
    vertices = new_mesh.vertices
    ax1.scatter(vertices[:, 0], vertices[:, 1], c='gray', alpha=0.1, s=1)
    ax1.scatter(mic_center[0], mic_center[1], c='red', s=100, label='Microphone')
    new_sources = np.array(source_positions)
    ax1.scatter(new_sources[:, 0], new_sources[:, 1], c='blue', s=25, alpha=0.5, label='Sound Sources')
    ax1.set_xlabel('X')
    ax1.set_ylabel('Y')
    ax1.set_title(f'Top-down view of {room_name}')
    ax1.legend()
    ax1.axis('equal')
    ax1.grid(True)
    
    # Side view
    ax2.scatter(vertices[:, 0], vertices[:, 2], c='gray', alpha=0.1, s=1)  # X vs Z
    ax2.scatter(mic_center[0], mic_center[2], c='red', s=100, label='Microphone')
    ax2.scatter(new_sources[:, 0], new_sources[:, 2], c='blue', s=25, alpha=0.5, label='Sound Sources')
    ax2.set_xlabel('X')
    ax2.set_ylabel('Z')
    ax2.set_title(f'Side view of {room_name}')
    ax2.legend()
    ax2.axis('equal')
    ax2.grid(True)
    
    plt.tight_layout()
    
    # Save the plot
    plot_path = dest_path_sofa / f"{room_name}_plots.png"
    plt.savefig(plot_path)
    plt.close()  # Close the figure to free up memory

    print(f"Plots saved as {plot_path}")

    # Prepare SOFA file
    prepare_sofa(cfg, ctx, source_positions, mic_center, mic_absolute_positions, dest_path_sofa, audio_fmts)
    print(f"Processing completed for {glb_file}")


def prepare_sofa(cfg, ctx, source_positions, mic_center, mic_absolute_positions, dest_path_sofa, audio_fmts=["mic"]):
    sr = int(cfg.sample_rate)
    print(f"Total number of source positions: {len(source_positions)}")
    room_name = os.path.splitext(glb_file.name)[0]
    csv_filepath = dest_path_sofa / f"{room_name}_relative_positions.csv"

    with open(csv_filepath, 'w', newline='') as csvfile:
        csv_writer = csv.writer(csvfile)
        csv_writer.writerow(['Source_Index', 'X', 'Y', 'Z'])  # Write header
        for fmt in audio_fmts:
            IRs = []
            coords = []
            max_length = 0
            for source_index, source_position in enumerate(source_positions):
                ir_channels = []
                
                # Calculate relative position using the microphone array center
                relative_position = np.array(source_position) - mic_center
                x, y, z = relative_position
                x, y, z = round(x, 3), round(y, 3), round(z, 3)
                coords.append([x, y, z])

                csv_writer.writerow([source_index, x, y, z])
                
                max_ir_length = 0
                for listener_index, mic_pos in enumerate(mic_absolute_positions):
                    ir_sample_count = ctx.get_ir_sample_count(listener_index, source_index)
                    ir_channel_count = ctx.get_ir_channel_count(listener_index, source_index)
                    
                    ir = np.zeros((ir_channel_count, ir_sample_count))
                    for i in range(ir_channel_count):
                        channel = np.array(ctx.get_ir_channel(listener_index, source_index, i))
                        ir[i] = channel
                    ir_channels.append(ir[0])  # mono channel for each microphone
                    max_ir_length = max(max_ir_length, ir_sample_count)
                
                # Pad all IR channels to the same length
                padded_ir_channels = []
                for ir in ir_channels:
                    padded_ir = np.pad(ir, (0, max_ir_length - len(ir)), mode='constant')
                    padded_ir_channels.append(padded_ir)
                
                combined_ir = np.array(padded_ir_channels)
                if combined_ir.shape[1] > max_length:
                    max_length = combined_ir.shape[1]
                IRs.append(combined_ir)
                
                print(f"IR {source_index}:")
                print(f"  Position: ({x}, {y}, {z})")
                print(f"  Channels: {combined_ir.shape[0]}")
                print(f"  Samples: {combined_ir.shape[1]}")
                print(f"  Shape: {combined_ir.shape}")
            
            # Pad IRs to max_length
            padded_IRs = []
            for ir in IRs:
                if ir.shape[1] < max_length:
                    padded = np.pad(ir, ((0, 0), (0, max_length - ir.shape[1])), mode='constant')
                    padded_IRs.append(padded)
                else:
                    padded_IRs.append(ir[:, :max_length])
            
            filepath = dest_path_sofa / f"soundspaces_{fmt}_{os.path.splitext(glb_file.name)[0]}.sofa"
            rirs = np.array(padded_IRs)
            source_pos = np.array(coords) 
            
            # Repeat mic_center to match the number of source positions
            mic_pos = np.tile([0,0,0], (len(source_positions), 1))
            
            create_srir_sofa(
                filepath,
                rirs,
                source_pos,
                mic_pos,
                db_name=GIBSON_DB_NAME,
                room_name="soundspaces_mic_{os.path.splitext(glb_file.name)[0]}",
                listener_name="mic",
                sr=sr,
            )
    print("SOFA file has been created.")
    print(f"Final IR array shape: {rirs.shape}")
          
if __name__ == "__main__":
    dataset_dir = Path(DATASET_DIR)
    for glb_file in dataset_dir.glob("*.glb"):
        if sofa_file_exists(glb_file, dest_path_sofa, audio_fmts):
            print(f"Skipping {glb_file.name} - SOFA file already exists")
            continue
        print(f"Processing {glb_file.name}")
        prepare_soundspaces(str(glb_file), dest_path_sofa)