### Load dataset and save IRs

In [None]:
from diff_gfdn.dataloader import ThreeRoomDataset, load_dataset
from diff_gfdn.config.config import DiffGFDNConfig
from diff_gfdn.model import DiffGFDN
from diff_gfdn.utils import get_response, db
from diff_gfdn.losses import get_stft_torch, get_edr_from_stft

from pathlib import Path
from typing import Tuple, Optional
from numpy.typing import ArrayLike
import matplotlib.pyplot as plt
import numpy as np
import os
import torch
import torchaudio
import soundfile as sf

### Helper functions

In [None]:
def plot_spectrogram(S: torch.tensor, freqs: ArrayLike, time_frames: ArrayLike, title:Optional[str]=None):
    plt.figure()
    plt.imshow(db(np.abs(S)).cpu().detach().numpy(), aspect='auto', origin='lower',
    extent=[time_frames.min(), time_frames.max(), freqs.min(), freqs.max()])
    plt.xlabel('Time (s)')
    plt.ylabel('Frequency (Hz)')
    cbar = plt.colorbar()
    cbar.set_label('dB')
    if title is not None:
        plt.title(title)
    plt.show()


def plot_edr(h: torch.tensor, 
             fs: float, 
             win_size:int=2**9, 
             hop_size:int=2**8, 
             title:Optional[str]=None) -> Tuple[torch.tensor, ArrayLike, ArrayLike]:
    S, freqs, time_frames = get_stft_torch(h, fs, win_size=win_size, hop_size=hop_size, nfft=win_size, freq_axis=0)
    edr = get_edr_from_stft(S)
    plot_spectrogram(edr, freqs, time_frames,title)
    return edr

### Read config files and dataset

In [None]:
config_dict = DiffGFDNConfig()
room_data = ThreeRoomDataset(Path(config_dict.room_dataset_path).resolve(), save_irs=True)

# add number of groups to the config dictionary
config_dict = config_dict.copy(update={"num_groups": room_data.num_rooms})

if config_dict.sample_rate != room_data.sample_rate:
    logger.warn("Config sample rate does not match data, alterning it")
    config_dict.sample_rate = sample_rate

# get the training config
trainer_config = config_dict.trainer_config

# prepare the training and validation data for DiffGFDN
train_dataset, valid_dataset = load_dataset(
    room_data, trainer_config.device, trainer_config.train_valid_split,
    trainer_config.batch_size)

### Check output data and compare with true IR

In [None]:
# initialise the model
model = DiffGFDN(room_data.sample_rate, room_data.num_rooms,
                 config_dict.delay_length_samps,
                 room_data.absorption_coeffs, room_data.room_dims,
                 trainer_config.device, config_dict.feedback_loop_config,
                 config_dict.output_filter_config)

print(model)

audio_directory  = Path("../audio/")
checkpoint_dir = Path("../output/mps/checkpoints")
max_epochs = 3
save_ir_dir = audio_directory/'mps'
save_ir = False
plot_ir = not save_ir

In [None]:
for epoch in range(max_epochs):
    # load the trained weights for the particular epoch
    checkpoint = torch.load(Path(checkpoint_dir/f'model_e{epoch}.pt').resolve(), weights_only=True)                        
    # Load the trained model state
    opt_params = model.load_state_dict(checkpoint)
    # in eval mode, no gradients are calculated
    model.eval()
    
    for data in train_dataset:
        position = data['listener_position']
        H, h = get_response(data, model)
        
        for num_pos in range(position.shape[0]):
            filename = f'ir_({position[num_pos,0]:.2f}, {position[num_pos, 1]:.2f}, {position[num_pos, 2]:.2f}).wav'
            
            # find the true IR corresponding to this position
            filepath_true = os.path.join(Path(audio_directory/'true').resolve(), filename)
            h_true = torch.from_numpy(sf.read(filepath_true)[0])

            if plot_ir:
                # plot the EDRs of the true and estimated
                plot_edr(h_true, model.sample_rate, title=f'True RIR EDR, epoch={epoch}')
                plot_edr(h[num_pos, :], model.sample_rate, title=f'Estimated RIR EDR, epoch={epoch}')
        
                plt.figure()
                plt.plot(torch.stack((h_true, h[num_pos, :len(h_true)]), dim=-1))
                plt.show()

            if save_ir and (epoch == max_epochs - 1):
                outer_loop_break = False
                filepath = os.path.join(save_ir_dir, filename)
                torchaudio.save(filepath,
                            torch.stack((h[num_pos, :], h[num_pos, :]),
                                        dim=1).cpu(),
                            int(model.sample_rate),
                            bits_per_sample=32,
                            channels_first=False)
            else:
                outer_loop_break = True
                break
                
        if outer_loop_break:
            break

### Observation
- There is no high frequency content in the estimated signal