# EchoMark: Training and Inference

Thank you for the time reviewing the code. Please follow the instructions below to reconstruct and reproduce the EchoMark system.

---

##  Requirements

Please ensure the following Python packages are installed:

- `torch`
- `torchaudio`
- `speechbrain`
- `soundfile`
- `asteroid`
- `scipy`
- `numpy`
- `pyacoustics`
- `tqdm`
- `logging`
- `argparse`
- `icecream`


## Dataset

EchoMark uses a combination of clean speech and real/simulated RIR datasets for training and evaluation.

### Required Datasets

- **AIR / RVB / RWCP**  
  Download from [OpenSLR SLR28]

- **BUT ReverbDB**  
  Publicly available

- **LibriSpeech**  
  Automatically downloaded by `torchaudio` in the code

>  It is possible to train without the BUT dataset, but this may result in slightly reduced generalization performance.

### Data Setup

1. Download all datasets into the `../data/` directory.
2. Ensure that the dataset list files point to the correct local file paths.
3. Modify the `rootpath` variable in `new_dataloader.py` to match your local data path.


## Training
multi-gpu training:

```
python -m torch.distributed.launch --nproc_per_node=N train_ddp.py
```
single-gpu training:
```
python train.py
```

## Inference
see below

In [1]:
import torch
from model import *
from new_dataloader import RIRS_Dataset, add_noise, Noise_Dataset, MyLibriSpeech, get_dataloader

device = "cuda"

separator = ReEcho_Separator().to(device)
generator = ReEcho_Generator().to(device)
watermarker = ReEcho_WM(msg_len=5).to(device)
spec_transform = SpectrogramTransform().to(device)

checkpoint = torch.load(f"save/trained_WM_detector.pth", map_location=device)
separator.load_state_dict(checkpoint['separator_state_dict'])
generator.load_state_dict(checkpoint['generator_state_dict'])
watermarker.load_state_dict(checkpoint['watermarker_state_dict'])

separator.eval()
generator.eval()
watermarker.eval()


  WeightNorm.apply(module, name, dim)


ReEcho_WM(
  (wm_net): WMEmbedderExtractor(
    (msg_codebook): Embedding(32, 128)
    (embedder): Sequential(
      (0): Linear(in_features=128, out_features=128, bias=True)
      (1): Swish(
        (silu): SiLU()
      )
      (2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
    )
    (encoder): ConformerEncoder(
      (input_proj): Sequential(
        (0): Linear(in_features=513, out_features=256, bias=True)
        (1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (2): Swish(
          (silu): SiLU()
        )
      )
      (encoder): ConformerEncoder(
        (layers): ModuleList(
          (0-11): 12 x ConformerEncoderLayer(
            (mha_layer): RelPosMHAXL(
              (dropout_att): Dropout(p=0.1, inplace=False)
              (out_proj): Linear(in_features=256, out_features=256, bias=True)
              (linear_pos): Linear(in_features=256, out_features=256, bias=False)
            )
            (convolution_module): ConvolutionModule(
        

In [None]:
import torchaudio
import numpy as np
from IPython.display import Audio, display
from icecream import ic
ic.disable()

device = "cuda"

### load a reverberant speech
rs, sr = torchaudio.load("wav_example/target_reverb/2.wav")

### apply it to clean speech
clean, sr = torchaudio.load("wav_example/clean/2.wav")


### generate a watermarked rir
with torch.no_grad():
    rs, clean = rs.to(device), clean.to(device)
    
    ### generate a message
    msg = torch.tensor([[1, 1, 1, 1, 1]], device=device)  


    ### generate a watermarked rir
    _, rir_emb = separator(rs)
    rir_emb_wm = watermarker.embedding(msg, rir_emb)
    rir_est = generator(rir_emb_wm).squeeze(0)

    rs_aem = torchaudio.functional.fftconvolve(clean, rir_est, mode='full')
    rs_aem = rs_aem[..., :rs.shape[-1]]

rs_cpu = rs.cpu().numpy().squeeze(0)
rs_aem_cpu = rs_aem.cpu().numpy().squeeze(0)

### compare the original and the resynthesized speech
print("Target Environment")
display(Audio(rs_cpu, rate=sr))
print("Transferred Environment with watermark")
display(Audio(rs_aem_cpu, rate=sr))

Target Environment


Transferred Environment with watermark


In [3]:
### Decode the watermarked speech
with torch.no_grad():
    rs_aem_spec = spec_transform(rs_aem)
    wm_logit, msg_logit = watermarker.extraction(rs_aem_spec, mode='test')
    print(wm_logit, msg_logit)
    print(msg==msg_logit)

tensor([1.], device='cuda:0') tensor([[1., 1., 1., 1., 1.]], device='cuda:0')
tensor([[True, True, True, True, True]], device='cuda:0')


In [4]:
### Decode non-watermarked speech
with torch.no_grad():
    rs_spec = spec_transform(rs)
    wm_logit, msg_logit = watermarker.extraction(rs_spec, mode='test')
    print(wm_logit)

tensor([0.], device='cuda:0')
