In [None]:
from pathlib import Path
import torch
import matplotlib.pyplot as plt
import numpy as np
import os

os.chdir('..')  # This changes the working directory to DiffGFDN

from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.dataloader import ThreeRoomDataset, load_dataset, custom_collate, RoomDataset
from diff_gfdn.gain_filters import OneHotEncoding, SinusoidalEncoding
from src.run_model import load_and_validate_config
from diff_gfdn.solver import convert_common_slopes_rir_to_room_dataset


### Visualise the 3D mesh of the geometry

In [None]:
config_path = 'data/config/'
fig_path = 'figures/'
config_name = 'treble_data_grid_training_1000Hz_colorless_loss.yml'
config_file = config_path + config_name
config_dict = load_and_validate_config(config_file,
                                       DiffGFDNConfig)

In [None]:
if "3room_FDTD" in config_dict.room_dataset_path:
    room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), config_dict)
else:
    room_data = convert_common_slopes_rir_to_room_dataset(config_dict.room_dataset_path, 
                                                          num_freq_bins=config_dict.trainer_config.num_freq_bins,
                                                          )

config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})
room_data.plot_3D_meshgrid(room_data.mesh_3D)

### Check if the one-hot encoding works as expected

In [None]:
def plot_3D_receiver_points(room_data: RoomDataset, rec_points: torch.Tensor):
    """Plot the 3D receiver points to see if one-hot encoding works"""
    rec_pos = rec_points.cpu().detach().numpy()
    x_rec = rec_pos[:, 0]
    y_rec = rec_pos[:, 1]
    z_rec = rec_pos[:, 2]

    # Plot using scatter without any additional data for color
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot the X, Y, Z points
    ax.scatter(x_rec, y_rec, z_rec, color='k', marker='x')

    # Set the limits for all axes
    ax.set_xlim(0,
                room_data.room_dims[-1][0] + room_data.room_start_coord[-1][0] + 0.5)
    ax.set_ylim(0,
                room_data.room_dims[-1][1] + room_data.room_start_coord[-1][1] + 0.5)
    ax.set_zlim(0, room_data.room_dims[-1][-1] + 0.5)


    # Labels and title
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')
    ax.set_title('Receiver grid')

    # Show the plot
    plt.show()

In [None]:
# add number of groups to the config dictionary
config_dict = config_dict.model_copy(update={"num_groups": room_data.num_rooms})

if config_dict.sample_rate != room_data.sample_rate:
    logger.warn("Config sample rate does not match data, alterning it")
    config_dict.sample_rate = sample_rate

# get the training config.
trainer_config = config_dict.trainer_config

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, 1.0,
    trainer_config.batch_size, shuffle=False)

In [None]:
closest_points = []
for data in train_dataset:
    encoder = OneHotEncoding()
    one_hot_encode, closest_point, _ = encoder(data['mesh_3D'], data['listener_position'])
    closest_points.append(closest_point)

closest_points = torch.stack(closest_points)
closest_points = closest_points.view(-1, closest_points.shape[-1])

plot_3D_receiver_points(room_data, closest_points)

### Plot the Fourier encodings to see if it can capture spatial variations

In [None]:
from slope2noise.rooms import RoomGeometry
room = RoomGeometry(room_data.sample_rate,
                    room_data.num_rooms,
                    np.array(room_data.room_dims),
                    np.array(room_data.room_start_coord),
                    aperture_coords=room_data.aperture_coords)

npos = 0
num_encoded_features = 3 * 2 * config_dict.output_filter_config.num_fourier_features
encoded_positions = np.zeros((room_data.num_rec, num_encoded_features))
shuffled_positions = np.zeros((room_data.num_rec, 3))

for data in train_dataset:
    encoder = SinusoidalEncoding(num_fourier_features=config_dict.output_filter_config.num_fourier_features)
    encoded_positions[npos*trainer_config.batch_size:(npos+1)*trainer_config.batch_size, :] = encoder(data['norm_listener_position'])
    shuffled_positions[npos*trainer_config.batch_size:(npos+1)*trainer_config.batch_size, :] = data['listener_position']
    npos += 1

In [None]:
for k in range(config_dict.output_filter_config.num_fourier_features):
    room.plot_amps_at_receiver_points(
            shuffled_positions,
            np.squeeze(np.array(room_data.source_position)),
            encoded_positions[:, k*6:k*6 + 3].T,
            scatter_plot=False,
            cur_freq_hz=None,
            title=f'Sine Fourier encodings, order = {k}',
            save_path=f'figures/spatial_encodings/sine_encoding_order={k}.png')

    room.plot_amps_at_receiver_points(
                shuffled_positions,
                np.squeeze(np.array(room_data.source_position)),
                encoded_positions[:, k*6 + 3:k*6 + 6].T,
                scatter_plot=False,
                cur_freq_hz=None,
                title=f'Cosine Fourier encodings, order = {k}',
                save_path=f'figures/spatial_encodings/cosine_encoding_order={k}.png')