# Working with the CHiME-9 ECHI Data and Repository

This notebook is designed to show you show to interact with the data and tools provided as part of the CHiME-9 ECHI Challenge. This script can either:
* Be used as part of the CHiME-9 ECHI repo (i.e. from `CHiME9-ECHI/Quickstart.ipynb`)
* As a standalone script which will clone the CHiME-9 ECHI repo into the current working directory (better for using with Google Colab)

## Setting up your environment

This will install a subset of the dependencies required for the full CHiME-9 ECHI package, and add the required modules to your path.

If this notebook is being used on Google Colab/has been downloaded on its own, the `CHiME9-ECHI` repository will be cloned into the curent working directory.

In [None]:
# Install the packages

import importlib.util

packages = {
    "pystoi": "pystoi",
    "soundfile": "soundfile",
    "soxr": "soxr",
    "torch": "torch",
    "ipywidgets": "ipywidgets",
    "pysepm": "https://github.com/schmiph2/pysepm/archive/master.zip",
    "gdown": "gdown",
}
failures = []

for imp_name, pkg in packages.items():
    %pip install {pkg}

    if importlib.util.find_spec(imp_name) is None:
        failures.append(imp_name)

if failures:
    print(f"❌ Failed to import: {', '.join(failures)}")

else:
    from IPython.display import clear_output

    clear_output()
    print("✅ All packages installed")

In [None]:
# Clone the public repo and add to path

import os
import sys
from pathlib import Path

cwd = Path.cwd()
if cwd.stem == "CHiME9-ECHI":
    # Using this script as part of the main repository, locally on Jupyter
    sys.path.append("src")
    repo_root = Path(".")
else:
    # Using as a standalone script
    if "CHiME9-ECHI" not in os.listdir(cwd):
        !git clone https://github.com/CHiME9-ECHI/CHiME9-ECHI.git
    if str(cwd / "CHiME9-ECHI") not in sys.path:
        print("Adding CHiME9-ECHI to sys.path")
        sys.path.append(str(cwd / "CHiME9-ECHI/src"))
    repo_root = cwd / Path("CHiME9-ECHI")

## Downloading the demo data

A few part-sessions from the development set have been saved into Google Drive
for download. The `tar.gz` file to download is 1.5GB, and when unpacked the data
uses 2.7GB of disk space.

In [None]:
# Download the demo data

cwd = Path.cwd()
data_root = repo_root / "data"
chime9echi_root = data_root / "chime9_echi.demo"
targz_file = data_root / "chime9_echi.demo.tar.gz"


if chime9echi_root.exists():
    print("Data already downloaded!")
else:
    if not targz_file.exists():
        print("No tar.gz found. Downloading...")
        !gdown --fuzzy "https://drive.google.com/file/d/1nDCoLr4NA-CAeHEPsylerQkUeJf_hnCW/view?usp=sharing" -O {targz_file}
    print("Unzipping demo data...")
    !tar -xvzf {targz_file} -C {data_root}

    from IPython.display import clear_output

    clear_output()
    print("✅ Data downloaded and unzipped!")

In [None]:
noisy_ftemplate = str(chime9echi_root / "{device}/dev/{session}.{device}.wav")
ref_ftemplate = str(chime9echi_root / "ref/dev/{session}.{device}.{pid}.wav")
rainbow_ftemplate = str(chime9echi_root / "participant/dev/{pid}.wav")
segments_ftemplate = str(
    chime9echi_root / "metadata/ref/dev/{session}.{device}.{pid}.csv"
)

These variables correspond to:
* `noisy_ftemplate` is the path to the noisy audio, requiring the `session` and `device` to be specified
* `ref_ftemplate` points to the reference conversational speech, requiring `session`, `device` and `pid`
* `rainbow_ftemplate` refers to the clean speech samples for each participant, requiring `pid`
* `segments_ftemplate` gives the path to the CSV file containing the speech segments for each participant, requiring `session`, `device` and `pid`

## Using the data

Now that we have downloaded some data, let's load some of it in and have a look. First, we need to define some helper functions.

In [None]:
import soundfile as sf
import soxr
from IPython.display import Audio
from ipywidgets import widgets


def load_session_audio(session, device, target, segment):
    """Load the noisy and reference audio for a session"""
    audio = {}

    noisy_fpath = noisy_ftemplate.format(device=device, session=session)

    model_fs = 16000

    noisy_audio, fs = sf.read(noisy_fpath)

    clip_start = int(segment["start"] * fs / model_fs)
    clip_end = int(segment["end"] * fs / model_fs)

    noisy_audio = noisy_audio[clip_start:clip_end]
    noisy_audio = soxr.resample(noisy_audio, fs, model_fs)

    audio["noisy"] = noisy_audio

    ref_file = ref_ftemplate.format(session=session, device=device, pid=target)
    ref_audio, _ = sf.read(ref_file)
    audio["ref"] = ref_audio[segment["start"] : segment["end"]]

    rainbow_file = rainbow_ftemplate.format(pid=target)
    rainbow_audio, fs = sf.read(rainbow_file)
    rainbow_audio = soxr.resample(rainbow_audio, fs, model_fs)
    audio["rainbow"] = rainbow_audio

    return audio


def labeled_audio(audio_dict, rate=16000):
    """Return an HBox with a label and audio player side by side."""

    players = []
    for label, data in audio_dict.items():
        label_widget = widgets.Label(value=label, layout=widgets.Layout(width="100px"))
        audio_widget = Audio(data.T, rate=16000)._repr_html_()
        audio_widget = widgets.HTML(value=audio_widget)
        players.append(widgets.HBox([label_widget, audio_widget]))
    return widgets.VBox(players)

Now let's load in the metadata file and see what information we get.

In [None]:
# Load in the session information
import csv
import json

with open(chime9echi_root / "metadata/sessions.dev.csv", "r") as file:
    sessions = list(csv.DictReader(file))

sessions = {s["session"]: s for s in sessions}
print(json.dumps(sessions, indent=4))

This gives us information about the sessions that have been downloaded, including:
* Which participants were present in the session
* Which device was in which position. For example, in `dev_10`, the Aria glasses were worn in `pos3` by `p182`, and the hearing aids were worn in `pos4` by `P180`.

In [None]:
session = "dev_10"
session_info = sessions[session]
device = "aria"
all_targets = [
    session_info[f"pos{i}"]
    for i in range(1, 5)
    if str(i) != session_info[f"{device}_pos"]
]
target = all_targets[0]

As well as this session metadata, we also have reference segments, indicating when each participant is speaking, stored in CSV files. These files are separated by session, device and participant.

**NOTE** The reference segments are slightly different on the Aria glasses compared to the hearing aids. This is because they are always worn by different people in the session, so the propogation delays for each person's speech is slightly different.

In [None]:
segments_file = segments_ftemplate.format(session=session, device=device, pid=target)
with open(segments_file, "r") as file:
    segments = [
        {a: int(b) for a, b in seg.items()}
        for seg in csv.DictReader(file, fieldnames=["index", "start", "end"])
    ]

print(json.dumps(segments[:3], indent=4))

The time stamps for these segments correspond the the start/end sample when using a sampling frequency of 16 kHz. So the first segments starts at sample 407679, which is 407679/16000=25.48 s

In [None]:
segment = segments[30]

For all systems, the only input that can be provided to the model is:
* The noisy audio from either the Aria glasses **OR** the hearing aids (never both at once)
* The clean speech sample of the target(s) voice

Then a reference signal containing the clean, conversation speech if also provided to be used as a training target.

In [None]:
audio = load_session_audio(session, device, target, segment)

labeled_audio(audio)

## Enhancing the speech

Now we've loaded the audio, we can use the baseline system to extract the target
speaker's speech from the noisy audio.

The config file and checkpoints for the baseline system are stored in
`CHiME9-ECHI/checkpoints`. The config can be loaded using `omegaconf`, and the
baseline model can be loaded from the checkpoint using the `get_model` function
provided in the repo.

We also load in a STFT wrapper, based on the requirements of the baseline model.

In [None]:
import torch
from omegaconf import OmegaConf

from shared.core_utils import get_model
from shared.signal_utils import STFTWrapper

cfg = OmegaConf.load(repo_root / f"checkpoints/{device}_config.yaml")

stft = STFTWrapper(**cfg.input.stft)
model = get_model(cfg, repo_root / f"checkpoints/{device}_baseline.pt")  # type: ignore
model.eval()

noisy_audio = torch.from_numpy(audio["noisy"].T)
rainbow_audio = torch.from_numpy(audio["rainbow"]).unsqueeze(0)

The baseline system takes three inputs:
* The STFT of the noisy audio, in shape `[batch, channels, time, freqs]`
* The STFT of the rainbow audio, in shape `[batch, time, freqs]`
* The lengths of the rainbow audio, in shape `[batch]`
    * When the batch size is greater than one, the rainbow passages may be
    zero-padded to match the lengths
    * Providing the rainbow lengths means only the speech is processed, and not
    the zero-padding

In [None]:
noisy_stft = stft(noisy_audio).unsqueeze(0).to(torch.float)
rainbow_stft = stft(rainbow_audio).to(torch.float)
rainbow_len = torch.tensor([rainbow_stft.shape[2]]).unsqueeze(0)

with torch.no_grad():
    output_stft = model(noisy_stft, rainbow_stft, rainbow_len).squeeze(1)

In [None]:
output_audio = stft.inverse(output_stft).squeeze(0, 1).detach().numpy()

audio = {"noisy": audio["noisy"], "output": output_audio, "ref": audio["ref"]}
labeled_audio(audio)

## Evaluating the enhanced audio

The `CHiME9-ECHI` repository relies on the
[WavLab Versa toolkit](https://github.com/wavlab-speech/versa) to compute the
scores from a wide variety of metrics. For simplicity, in this notebook we will
use a handful of these metrics directly:
* STOI: Short-Time Objective Intelligibility which assesses how easy a signal is
to understand (from [pystoi](https://github.com/mpariente/pystoi))
* FWSegSNR: Frequency-Weighted Segmental SNR, which scores the SNR of the signal
with a weighting applied to the frequencies to more accurately reflect human
hearing (from [pysepm](https://github.com/schmiph2/pysepm))
* The Composite metrics, designed to evaluate the quality of the signal
(Csig), intrusiveness of the background (Cbak), and the overall quality (Covl),
(also from [pysepm](https://github.com/schmiph2/pysepm))

Note that these metrics are only to be used as a guide, and final evaluation of
system will be down to listening tests.

In [None]:
from pysepm.qualityMeasures import composite, fwSNRseg
from pystoi import stoi

if device == "aria":
    noisy_mono = noisy_audio[3, :]  # The fourth channel from the Aria glasses
else:
    noisy_mono = noisy_audio[:2, :].sum(
        dim=0
    )  # Sum of the left-front and right-fron channels of the hearing aids

noisy_mono = noisy_mono.detach().cpu().numpy()
ref_audio = audio["ref"]

scores = {}
scores["STOI"] = [
    stoi(ref_audio, noisy_mono, 16000),
    stoi(ref_audio, output_audio, 16000),
]
scores["FWSegSNR"] = [
    fwSNRseg(ref_audio, noisy_mono, 16000),
    fwSNRseg(ref_audio, output_audio, 16000),
]
csig_noisy, cbak_noisy, covl_noisy = composite(ref_audio, noisy_mono, 16000)
csig_output, cbak_output, covl_output = composite(ref_audio, output_audio, 16000)

scores["Csig"] = [csig_noisy, csig_output]
scores["Cbak"] = [cbak_noisy, cbak_output]
scores["Covl"] = [covl_noisy, covl_output]

print("{:<10}{:<10}{:<10}".format("Metric", "Noisy", "Baseline"))

for metric, pair in scores.items():
    print("{:<10}{:<10.2f}{:<10.2f}".format(metric, *pair))