## 0. Imports

In [9]:
import torch
import matplotlib.pyplot as plt
import numpy as np

from hydra.core.global_hydra import GlobalHydra
from hydra import compose, initialize

import sys

sys.path.append("../")

from di_nn.datasets import DistributedSSLDataset
from di_nn.di_ssl_net import DISSLNET
from di_nn.trainer import DISSLNETLightniningModule

GlobalHydra.instance().clear()
initialize(config_path="../config")
config = compose("config")

MODEL_CHECKPOINT_PATH = "./pretrained_weights.ckpt"
DEMO_DATASET_PATH = "./testing_samples"

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path="../config")


## 1. Load model and dataset

In [20]:
dataset = DistributedSSLDataset(DEMO_DATASET_PATH,
                                metadata_microphone_std_in_m=0.0,
                                metadata_rt60_std_in_ms=1000)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    pin_memory=True,
    drop_last=False,
    num_workers=1
)

dataloader = iter(dataloader)

model = DISSLNETLightniningModule(config)
model.eval()
torch.set_grad_enabled(False)
checkpoint = torch.load(MODEL_CHECKPOINT_PATH, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["state_dict"])


<All keys matched successfully>

## 2. Evaluate dataset samples on model

In [21]:
while True:
    try:
        batch = next(dataloader)
    except StopIteration:
        break
    model_output = model(batch[0])[0].numpy()

    true_coords = batch[1]["source_coordinates"][0].numpy()
    
    print("True vs estimated coordinates:", true_coords, model_output)
    #print("Predicted coordinates:", model_output)
    print("Error (meters):", np.linalg.norm(true_coords - model_output))
    print("\n")

True vs estimated coordinates: [2.7 1.2] [2.7244563 1.2100601]
Error (meters): 0.026444543


True vs estimated coordinates: [4.3 1.2] [4.3423724 1.1823423]
Error (meters): 0.04590427


True vs estimated coordinates: [7.5 1.2] [7.461317  1.1922424]
Error (meters): 0.03945315


True vs estimated coordinates: [7.5 1.2] [6.9783673 1.2218285]
Error (meters): 0.5220892


