In [None]:
from pathlib import Path
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from typing import Optional
from IPython import display
from importlib import reload


os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.dataloader import ThreeRoomDataset, load_dataset, custom_collate, RoomDataset
from diff_gfdn.gain_filters import OneHotEncoding, SinusoidalEncoding
from src.run_model import load_and_validate_config
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset


### Visualise the 3D mesh of the geometry

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_name = 'treble_data_grid_training_full_band_colorless_loss.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)

In [None]:
if "3room_FDTD" in config_dict.room_dataset_path:
    room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)
else:
    room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                          num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                          )

config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})
room_data.plot_2D_meshgrid(room_data.mesh_2D)

### Check if the one-hot encoding works as expected

In [None]:
def plot_2D_receiver_points(room_data: RoomDataset, rec_pos: torch.Tensor, col:str = 'k', title: Optional[str]=None):
    """Plot the 3D receiver points to see if one-hot encoding works"""
    if torch.is_tensor(rec_pos):
        rec_pos = rec_pos.cpu().detach().numpy()
    x_rec = rec_pos[:, 0]
    y_rec = rec_pos[:, 1]

    # Plot using scatter without any additional data for color
    fig = plt.figure()
    ax = fig.add_subplot(111)

    # Plot the X, Y, Z points
    ax.scatter(x_rec, y_rec, color=col, marker='x')

    # Set the limits for all axes
    ax.set_xlim(0,
                room_data.room_dims[-1][0] + room_data.room_start_coord[-1][0] + 0.5)
    ax.set_ylim(0,
                room_data.room_dims[-1][1] + room_data.room_start_coord[-1][1] + 0.5)

    # Labels and title
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    if title is None:
        ax.set_title('Receiver grid')
    else:
        ax.set_title(f'Receiver grid, {title}')
    # Set a top-down view (90-degree elevation)
    # ax.view_init(elev=90, azim=-90)  # Adjust azimuth for better alignment if needed

    # Show the plot
    return fig, ax

In [None]:
# add number of groups to the config dictionary
config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})

if config_dict.sample_rate != room_data.sample_rate:
    logger.warn("Config sample rate does not match data, alterning it")
    config_dict.sample_rate = sample_rate

# get the training config.
trainer_config = config_dict.trainer_config

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, 0.8,
    trainer_config.batch_size, shuffle=True, drop_last=True)

#### Training points

In [None]:
closest_points = []
for data in train_dataset:
    encoder = OneHotEncoding()
    one_hot_encode, closest_point, _ = encoder(data['mesh_2D'], data['listener_position'])
    closest_points.append(closest_point)

closest_points = torch.stack(closest_points)
closest_points = closest_points.view(-1, closest_points.shape[-1])

fig, ax = plot_2D_receiver_points(room_data, closest_points)
plt.show()

#### Validation points

In [None]:
closest_points_valid = []
for data in valid_dataset:
    encoder = OneHotEncoding()
    one_hot_encode, closest_point, _ = encoder(data['mesh_2D'], data['listener_position'])
    closest_points_valid.append(closest_point)

closest_points_valid = torch.stack(closest_points_valid)
closest_points_valid = closest_points_valid.view(-1, closest_points_valid.shape[-1])
fig, ax = plot_2D_receiver_points(room_data, closest_points_valid, col='r')
plt.show()

### Plot the Fourier encodings to see if it can capture spatial variations

In [None]:
from slope2noise.rooms import RoomGeometry
room = RoomGeometry(room_data.sample_rate,
                    room_data.num_rooms,
                    np.array(room_data.room_dims),
                    np.array(room_data.room_start_coord),
                    aperture_coords=room_data.aperture_coords)

npos = 0
num_encoded_features = 3 * 2 * config_dict.output_filter_config.num_fourier_features
encoded_positions = np.zeros((room_data.num_rec, num_encoded_features))
shuffled_positions = np.zeros((room_data.num_rec, 3))

for data in train_dataset:
    encoder = SinusoidalEncoding(num_fourier_features=config_dict.output_filter_config.num_fourier_features)
    encoded_positions[npos*trainer_config.batch_size:(npos+1)*trainer_config.batch_size, :] = encoder(data['norm_listener_position'])
    shuffled_positions[npos*trainer_config.batch_size:(npos+1)*trainer_config.batch_size, :] = data['listener_position']
    npos += 1

In [None]:
for k in range(config_dict.output_filter_config.num_fourier_features):
    room.plot_amps_at_receiver_points(
            shuffled_positions,
            np.squeeze(np.array(room_data.source_position)),
            encoded_positions[:, k*6:k*6 + 3].T,
            scatter_plot=False,
            title=f'Sine Fourier encodings, order = {k}',
            save_path=f'figures/spatial_encodings/sine_encoding_order={k}.png')

    room.plot_amps_at_receiver_points(
                shuffled_positions,
                np.squeeze(np.array(room_data.source_position)),
                encoded_positions[:, k*6 + 3:k*6 + 6].T,
                scatter_plot=False,
                title=f'Cosine Fourier encodings, order = {k}',
                save_path=f'figures/spatial_encodings/cosine_encoding_order={k}.png')

### For training the CNN, 
we need to discretise the grid of receiver locations into square patches, plot the positions in each patch to investigate how uniformly they are distributed - the patches should not overlap

In [None]:
from diff_gfdn.dataloader import InputFeatures, to_device
from spatial_sampling.config import SpatialSamplingConfig
from spatial_sampling.dataloader import (SpatialRoomDataset, parse_room_data, load_dataset, SpatialSamplingDataset, split_dataset_by_resolution)

In [None]:
config_path = 'data/config/spatial_sampling/'
fig_path = 'figures/'
config_name = 'treble_data_grid_training_1000Hz_directional_spatial_sampling_test_cnn.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       SpatialSamplingConfig)
config_dict = config_dict.model_copy(update={"batch_size": 16})


room_data = parse_room_data(
            Path(config_dict.room_dataset_path).resolve())
grid_resolution_m = np.arange(config_dict.num_grid_spacing, 0,
                              -1) * room_data.grid_spacing_m

# create the dataset
dataset = SpatialSamplingDataset(
    config_dict.device,
    room_data,
)
dataset = to_device(dataset, config_dict.device)

In [None]:
import spatial_sampling
reload(spatial_sampling.dataloader)
from spatial_sampling.dataloader import load_dataset

# loop over different grid resolutions
for k in range(len(grid_resolution_m)):

    # split data into training and validation set
    train_set, valid_set = split_dataset_by_resolution(dataset,
                                                       grid_resolution_m[k])
    
    train_dataset, valid_dataset, dataset_ref = load_dataset(
        room_data,
        config_dict.device,
        grid_resolution_m=np.round(grid_resolution_m[k], 1),
        network_type=config_dict.network_type,
        batch_size=config_dict.batch_size)
    
    # plot al points in training subset
    subset_listener_pos_idx = train_set.indices
    train_subset_listener_pos = dataset_ref.listener_positions[subset_listener_pos_idx, :2]
    fig, ax = plot_2D_receiver_points(room_data, train_subset_listener_pos, col='r', title=f'Grid spacing = {np.round(grid_resolution_m[k], 1)}m')
    
    # loop through points in each batch
    for batch in train_dataset:
        cur_rec_pos = batch['listener_position']
        ax.scatter(cur_rec_pos[:, 0], cur_rec_pos[:, 1], color='k', marker='o')
        display.display(fig)  # Display the updated figure
        display.clear_output(wait=True)  # Clear the previous output to keep updates in place
        plt.pause(1.0)
    
    plt.show()