In [None]:
from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.dataloader import ThreeRoomDataset, load_dataset, custom_collate, RoomDataset
from diff_gfdn.gain_filters import OneHotEncoding
from pathlib import Path
import torch
import matplotlib.pyplot as plt
import numpy as np

### Visualise the 3D mesh of the geometry

In [None]:
config_dict = DiffGFDNConfig()
room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve())
room_data.plot_3D_meshgrid(room_data.mesh_3D)

### Check if the one-hot encoding works as expected

In [None]:
def plot_3D_receiver_points(room_data: RoomDataset, rec_points: torch.Tensor):
    """Plot the 3D receiver points to see if one-hot encoding works"""
    rec_pos = rec_points.cpu().detach().numpy()
    x_rec = rec_pos[:, 0]
    y_rec = rec_pos[:, 1]
    z_rec = rec_pos[:, 2]

    # Plot using scatter without any additional data for color
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # Plot the X, Y, Z points
    ax.scatter(x_rec, y_rec, z_rec, color='k', marker='x')

    # Set the limits for all axes
    ax.set_xlim(0,
                room_data.room_dims[-1][0] + room_data.room_start_coord[-1][0] + 0.5)
    ax.set_ylim(0,
                room_data.room_dims[-1][1] + room_data.room_start_coord[-1][1] + 0.5)
    ax.set_zlim(0, room_data.room_dims[-1][-1] + 0.5)


    # Labels and title
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')
    ax.set_title('Receiver grid')

    # Show the plot
    plt.show()

In [None]:
# add number of groups to the config dictionary
config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

if config_dict.sample_rate != room_data.sample_rate:
    logger.warn("Config sample rate does not match data, alterning it")
    config_dict.sample_rate = sample_rate

# get the training config
trainer_config = config_dict.trainer_config

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, trainer_config.train_valid_split,
    trainer_config.batch_size)

In [None]:
closest_points = []
for data in train_dataset:
    encoder = OneHotEncoding()
    one_hot_encode, closest_point, _ = encoder(data['mesh_3D'], data['listener_position'])
    closest_points.append(closest_point)

closest_points = torch.stack(closest_points)
closest_points = closest_points.view(-1, closest_points.shape[-1])

plot_3D_receiver_points(room_data, closest_points)