In [None]:
import os
from pathlib import Path

from torch.utils.data import DataLoader
import webdataset as wds

In [None]:
hf_token = os.environ['CMERAKI_HF_TOKEN']
url = "https://huggingface.co/datasets/cmeraki/audiofolder_webdataset/resolve/main/en__gs__{{000000..000100}}.tar"
url = f"pipe:curl -s -L {url} -H 'Authorization:Bearer {hf_token}'"

cache_dir = Path('~/.cache/wds/tmp/').expanduser()
os.makedirs(cache_dir, exist_ok=True)

### Read webds

In [None]:
def get_sample(item):
    txt = item['json']['raw_text']
    audio = item['wav']

    return txt, audio

In [None]:
dataset = wds.WebDataset(url, shardshuffle=False, cache_dir=cache_dir).decode().map(get_sample)

In [None]:
idx = 10
for i, elem in enumerate(dataset):
    if i >= idx:
        break

print(elem)

In [None]:
# Create DataLoader for batched loading
dataloader = DataLoader(
    dataset.batched(256),
    batch_size=None,
    num_workers=2
)

for elem in dataloader:
    break

print(len(elem[0]), len(elem[1]))

### Transform webds

In [None]:
import os
import io
import torch

from tqdm import tqdm
from pathlib import Path
import webdataset as wds
import torchaudio
from torch.utils.data import DataLoader

from transformers import MimiModel, AutoFeatureExtractor

In [None]:
tokenizer = MimiModel.from_pretrained("kyutai/mimi")

device = 'cpu'

if torch.cuda.is_available():
    device = 'cuda:0'
    tokenizer.to(device)

tokenizer.eval()
n_codebooks = 8

In [None]:
def tokenize(item):
    byte_io = io.BytesIO(item['wav'])
    audio, sr = torchaudio.load(byte_io)
    audio = torchaudio.functional.resample(audio, orig_freq=sr, new_freq=24000)

    encoder_outputs = tokenizer.encode(
        audio.unsqueeze(0).to(device),
        None,
        num_quantizers=n_codebooks
    ).audio_codes.to(torch.int16).detach().cpu().numpy()

    item['mimi.npy'] = encoder_outputs
    item.pop('wav')

    return item

In [None]:
dataset = wds.WebDataset(url, shardshuffle=False, cache_dir=cache_dir).decode().map(tokenize)
dataloader = DataLoader(dataset, batch_size=None, num_workers=4)

In [None]:
with wds.ShardWriter(f"transform_out__%06d.tar", maxsize=1e9) as sink:
    for sample in tqdm(dataset, desc='Tokenizing audio...'):
        sink.write(sample)

### Reading transformed webds

In [None]:
transform_dataset = wds.WebDataset(
    str(Path('./transform_out__{000000..000006}.tar').absolute()),
    shardshuffle=False
).decode()

In [None]:
total_len = 0
total_fls = 0

for elem in tqdm(transform_dataset):
    total_fls += 1
    total_len += elem['mimi.npy'].shape[-1]