## Test time-invariance of Encodec

In [1]:
from encodec import EncodecModel
import torchaudio

bandwidth = 3.0
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(bandwidth)



In [2]:
import os
from pathlib import Path

in_root = Path("data/raw/birdclef-2023/train_audio")
out_dir = Path("data/encodec_invariance")
out_whole_dir = out_dir / "whole"
out_split_dir = out_dir / "split"

for dirs in [out_dir, out_whole_dir, out_split_dir]:
    os.makedirs(dirs, exist_ok=True)

In [40]:
import torch
from encodec.utils import convert_audio

def convert(file):
    # Load and pre-process the audio waveform
    wav, sr = torchaudio.load(in_root / file)
    wav = convert_audio(wav, sr, model.sample_rate, model.channels)
    wav = wav.unsqueeze(0)
    return wav

def encode(wav):
    # Extract discrete codes from EnCodec
    with torch.no_grad():
        encoded_frames = model.encode(wav)
    return encoded_frames

def decode(encoded_frames):
    with torch.no_grad():
        decoded_frames = model.decode(encoded_frames)
    decoded = decoded_frames[0]
    return decoded

def save(file, decoded, parent_dir):
    path = parent_dir / (file[:-4] + ".wav")
    os.makedirs(path.parent, exist_ok=True)
    torchaudio.save(path, decoded, model.sample_rate)

In [43]:
def reconstruct_whole(file):
    wav = convert(file)
    reconstructed = decode(encode(wav))
    save(file, reconstructed, out_whole_dir)
    return reconstructed

def reconstruct_split(file, chunk_size=5):
    wav = convert(file)
    true_chunk_size = chunk_size * model.sample_rate
    chunks = torch.split(wav, true_chunk_size, dim=-1)
    reconstructed = [decode(encode(chunk)) for chunk in chunks]
    for i, re_chunk in enumerate(reconstructed):
        save(file[:-4] + f"_{i}.wav", re_chunk, out_split_dir)
    reconstructed = torch.cat(reconstructed, dim=1)
    save(file, reconstructed, out_split_dir)
    return reconstructed

In [9]:
file = "abethr1/XC128013.ogg"

In [41]:
recon_whole = reconstruct_whole(file)

In [44]:
recon_split = reconstruct_split(file)

In [85]:
(recon_whole - recon_split).square().mean()

tensor(0.0001)

In [50]:
recon_whole == recon_split

tensor([[ True,  True,  True,  ..., False, False, False]])

In [87]:
(recon_whole == recon_split).sum() / recon_whole.size()[1]

tensor(0.1087)

In [75]:
(recon_whole - recon_split).abs().sort(descending=True).values

tensor([[0.4122, 0.4025, 0.3914,  ..., 0.0000, 0.0000, 0.0000]])

In [84]:
raw = convert(file)[0]
(recon_whole[:, :raw.size(1)] - raw).square().mean()

tensor(0.0011)