In [1]:
%load_ext autoreload
%autoreload 2
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
import time
import gc
from itertools import chain
from tqdm import tqdm
import IPython.display as ipd
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
# import torch.multiprocessing as mp

# mp.set_start_method("spawn", force=True)
from torch.nn import functional as F
from torch.utils.tensorboard import SummaryWriter
import torchaudio
import torchmetrics
import webdataset as wds
from braceexpand import braceexpand
from sklearn.metrics import classification_report
from transformers import WhisperFeatureExtractor

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def release_gpu_memory():
    gc.collect()
    torch.cuda.empty_cache()

In [3]:
# train_urls = list(braceexpand("./data/IEMOCAP_audio/data/Session{2..5}.tar"))
# dev_urls = list(braceexpand("./data/IEMOCAP_audio/data/Session1.tar"))

In [4]:
# train_urls = list(braceexpand("./data/RAVDESS_audio/data/Actor_{01..19}.tar"))
# dev_urls = list(braceexpand("./data/RAVDESS_audio/data/Actor_{20..24}.tar"))

In [5]:
train_urls = list(braceexpand("./data/MSP_PODCAST/data/train/train_{01..21}-of-21.tar"))
dev_urls = list(
    braceexpand("./data/MSP_PODCAST/data/development/development_{01..05}-of-05.tar")
)
test_urls = list(braceexpand("./data/MSP_PODCAST/data/test1/test1_{01..08}-of-08.tar"))

In [6]:
from data import (
    apply_augmentation,
    decode,
    crop,
    collate_and_featurize,
    MSP_PODCAST_EMOTIONS,
    IEMOCAP_EMOTION_TO_IX,
    IEMOCAP_EMOTIONS,
    RAVDESS_EMOTIONS,
    RAVDESS_EMOTION_TO_IX,
    Cache,
    FilterLabels,
    MSP_PODCAST_EMOTION_TO_IX,
    SampledShards,
)

In [7]:
MSP_PODCAST_EMOTIONS

['neutral',
 'happy',
 'sad',
 'angry',
 'fear',
 'disgust',
 'surprise',
 'contempt',
 'other',
 'no_agreement']

In [8]:
# from collections import Counter
# counter = Counter()
# for batch in tqdm(dl):
#     counter.update(MSP_PODCAST_EMOTIONS[x] for x in batch["emotion_ix"])
# for i, (label, count) in enumerate(counter.most_common()):
#     print(f"{i+1}) {label:<15}\t{count}\t({count / counter.total():.1%})"

In [9]:
def display_datum(datum=None, batch=None, batch_ix=None):
    if datum is None:
        assert batch is not None and batch_ix is not None
        datum = {key: val[batch_ix] for key, val in batch.items()}
    elif batch is None:
        assert datum is not None

    print(f'Key: {datum["key"]}\tURL: {datum["url"]}')
    maybe_print_keys = [
        ["speaker", "gender"],
        ["transcript"],
        ["valence", "activation", "arousal", "domination"],
        ["emotion_ix", "emotion"],
    ]
    for key_list in maybe_print_keys:
        print_string = ""
        for key in key_list:
            if key in datum:
                print_string += f"{key}: {datum[key]}\t"
        if print_string != "":
            print(print_string)

    wavs = []
    for wav_key in filter(lambda x: x.startswith("wav"), datum.keys()):
        num_samples_key = (
            "num_samples"
            if wav_key == "wav"
            else "num_samples_" + wav_key.removeprefix("wav_")
        )
        num_samples = datum.get(num_samples_key, datum[wav_key].shape[0])
        wav = datum[wav_key][:num_samples]
        wavs.append((wav_key, wav))
    for wav_key, wav in wavs:
        fig, axes = plt.subplots(nrows=2, figsize=(12, 4))
        fig.tight_layout()
        axes[0].plot(wav)
        axes[0].set_title(wav_key)
        axes[1].specgram(wav, Fs=sample_rate)
        ipd.display(ipd.Audio(wav, rate=sample_rate))
        ipd.display(fig)
        plt.close()

In [10]:
from audiomentations import Normalize, Trim, PolarityInversion
from data import PitchShift, apply_augmentation, Compose

In [11]:
sample_rate = 16_000
max_duration = 10  # has to be int and such that 3000 / (max_duration / 30) is an int
assert 3000 * (max_duration / 30 / 2) % 1 == 0
batch_size = 64
dataloader_workers = 20
emotions = ["neutral", "happy", "sad", "angry", "fear", "surprise", "disgust"]
ix_to_emotion = {i: emotion for i, emotion in enumerate(emotions)}
msp_labels_wanted = ["neutral", "happy", "sad", "angry", "fear", "surprise", "disgust"]
ravdess_labels_wanted = [
    "neutral",
    "happy",
    "sad",
    "angry",
    "fearful",
    "surprised",
    "disgust",
]
iemocap_labels_wanted = [
    "neutral",
    "happy",
    "sad",
    "angry",
    "fearful",
    "surprised",
    "disgust",
]
assert set(msp_labels_wanted).issubset(MSP_PODCAST_EMOTIONS)
assert len(emotions) == len(msp_labels_wanted)
assert set(ravdess_labels_wanted).issubset(RAVDESS_EMOTIONS)
assert len(emotions) == len(ravdess_labels_wanted)
assert set(iemocap_labels_wanted).issubset(IEMOCAP_EMOTIONS)
assert len(emotions) == len(iemocap_labels_wanted)

In [12]:
# duration = 0
# for batch in tqdm(train_dl):
#     pass
#     duration += batch["num_samples"].sum() / sample_rate

In [13]:
train_ds = wds.DataPipeline(
    *[
        SampledShards(
            urls=train_urls,
            nshards=min(
                max(1, (len(train_urls) // dataloader_workers)) * dataloader_workers,
                len(train_urls),
            ),
        ),
        wds.split_by_worker,
        wds.tarfile_to_samples(),
        FilterLabels(
            labels_wanted=msp_labels_wanted,
            labels_to_ix=MSP_PODCAST_EMOTION_TO_IX,
            label_key="emotion",
        ),
        wds.shuffle(batch_size * 64),
        wds.map(
            decode(dataset="MSP_PODCAST", ix_to_label=ix_to_emotion, decode_json=False)
        ),
        wds.map(
            apply_augmentation(
                augmentation=Compose([Normalize(p=1.0), Trim(p=1.0)]),
                out_key="wav",
                sample_rate=sample_rate,
            )
        ),
        wds.map(crop(crop_duration=max_duration, random=True, keys=["wav"])),
        # wds.map(
        #     apply_augmentation(
        #         augmentation=Compose([PitchShift(sample_rate=sample_rate, p=1.0, device='cuda:1')]),
        #         out_key="wav",
        #         p=1.,
        #     )
        # ),
    ]
)
train_ds_batched = train_ds.compose(
    wds.batched(
        batchsize=batch_size,
        collation_fn=collate_and_featurize(
            list_keys=["key", "url", "emotion", "transcript"],
            tensor_keys=["emotion_ix"],
            wav_keys=["wav"],
            feature_keys=["wav"],
            feature_extractor=WhisperFeatureExtractor,
            feature_kwargs={"chunk_length": max_duration},
        ),
        partial=False,
    )
)
train_dl = wds.WebLoader(
    train_ds_batched,
    batch_size=None,
    num_workers=dataloader_workers,
    shuffle=False,
    pin_memory=True,
    persistent_workers=True,
    # worker_init_fn=worker_init_fn,
)

In [14]:
from utils import JupyterProfiler

In [15]:
# with JupyterProfiler():
#     for i, batch in enumerate(train_ds_batched):
#         if i == 1:
#             break

In [16]:
for batch in tqdm(train_dl):
    pass

292it [00:13, 21.40it/s]Exception ignored from cffi callback <function SoundFile._init_virtual_io.<locals>.vio_tell at 0x7fce6cc2de40>Exception ignored from cffi callback <function SoundFile._init_virtual_io.<locals>.vio_tell at 0x7fce6cc31800>Exception ignored from cffi callback <function SoundFile._init_virtual_io.<locals>.vio_tell at 0x7fce6cc44400>:
:
:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/kd2939/dev/miniconda3/envs/emo/lib/python3.11/site-packages/soundfile.py", line 1264, in vio_tell
  File "/home/kd2939/dev/miniconda3/envs/emo/lib/python3.11/site-packages/soundfile.py", line 1264, in vio_tell
Exception ignored from cffi callback <function SoundFile._init_virtual_io.<locals>.vio_tell at 0x7fce6cde7060>  File "/home/kd2939/dev/miniconda3/envs/emo/lib/python3.11/site-packages/soundfile.py", line 1264, in vio_tell
        :
    @_ffi.callback("sf_vio_tell")@_ffi.callback("sf_vio_tell")@_ffi.callback("s

In [27]:
dev_pipeline = [
    wds.SimpleShardList(dev_urls),
    wds.split_by_worker,
    wds.tarfile_to_samples(),
    FilterLabels(
        labels_wanted=msp_labels_wanted,
        labels_to_ix=MSP_PODCAST_EMOTION_TO_IX,
        label_key="emotion",
    ),
    wds.map(
        decode(dataset="MSP_PODCAST", ix_to_label=ix_to_emotion, decode_json=False)
    ),
    wds.map(
        apply_augmentation(
            augmentation=Compose([Normalize(p=1.0)]),
            out_key="wav",
            sample_rate=sample_rate,
        )
    ),
    wds.map(crop(crop_duration=max_duration, random=True, keys=["wav"])),
    wds.batched(
        batchsize=batch_size,
        collation_fn=collate_and_featurize(
            list_keys=["key", "url", "emotion"],
            tensor_keys=["emotion_ix"],
            wav_keys=["wav"],
            feature_keys=["wav"],
            feature_extractor=WhisperFeatureExtractor,
            feature_kwargs={"chunk_length": max_duration},
        ),
        partial=False,
    ),
    Cache(shuffle=False),
]
dev_dl = wds.WebLoader(
    wds.DataPipeline(*dev_pipeline),
    batch_size=None,
    num_workers=8,
    prefetch_factor=4,
    shuffle=False,
    pin_memory=True,
    persistent_workers=True,
)

In [None]:
_= next(iter(dev_dl))

In [28]:
# training_data_duration = 0.
# for i, batch in enumerate(tqdm(train_dl)):
#     training_data_duration += batch['num_samples'].sum() / sample_rate
# num_train_batches = i + 1
# train_dl.length = num_train_batches
# for i, batch in enumerate(tqdm(dev_dl)):
#     pass
# num_dev_batches = i + 1
# dev_dl.length = num_dev_batches

In [29]:
# print(f'Total training data duration: {training_data_duration/3600:.2f} hrs')

In [30]:
from model import Classifier

In [31]:
model = Classifier(
    # "openai/whisper-tiny.en",
    "openai/whisper-small.en",
    # "openai/whisper-medium.en",
    max_duration=max_duration,
    projection_dim=256,
    num_classes=len(emotions),
    freeze_conv_pos=True,
)
model = model.cuda()
cross_entropy_criterion = nn.CrossEntropyLoss().cuda()
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-4)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.97)
grad_scaler = torch.cuda.amp.GradScaler()

In [32]:
def train(
    model: nn.Module,
    dataloader: wds.WebLoader,
    criterion,
    optimizer: torch.optim.Optimizer,
    lr_scheduler: torch.optim.lr_scheduler.LRScheduler,
    grad_scaler: torch.cuda.amp.GradScaler,
    forward_pass_iters: int,
    grad_accum_iters: int,
    grad_norm: float,
    num_epochs: int,
    print_every: int,
):
    model.train().cuda()
    epoch_losses = []
    for epoch_ix in range(num_epochs):
        total = None if dataloader.length == -1 else dataloader.length
        with tqdm(total=total) as pbar:
            tot_epoch_loss = torch.tensor(0.0).cuda()
            for batch_idx, batch in enumerate(dataloader):

                # FORWARD PASS
                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                    out = model(batch["feats"].cuda(), batch["attention_mask"].cuda())
                    loss = criterion(out, batch["emotion_ix"].cuda())

                # BACKWARD PASS
                grad_scaler.scale(loss).backward()

                # GRADIENT UPDATE
                if (batch_idx + 1) % grad_accum_iters == 0:
                    if grad_norm is not None:
                        grad_scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(
                            model.parameters(), max_norm=grad_norm
                        )
                    grad_scaler.step(optimizer)
                    grad_scaler.update()
                    optimizer.zero_grad()

                # PBAR UPDATE
                with torch.no_grad():
                    tot_epoch_loss += loss
                pbar.update(1)
                if ((batch_idx + 1) % print_every) == 0:
                    pbar.set_description(
                        f"train_loss: {tot_epoch_loss.cpu().item() / batch_idx:.2f}"
                    )
            # FINAL PBAR UPDATE
            tot_epoch_loss = tot_epoch_loss.cpu().item() / batch_idx
            pbar.set_description(f"train_loss: {tot_epoch_loss:.2f}")
        epoch_losses.append(tot_epoch_loss)

        # LR SCHEDULER
        if lr_scheduler is not None:
            lr_scheduler.step()

        if dataloader.length == -1:
            dataloader.length = batch_idx + 1
    return epoch_losses

In [33]:
def validate(
    model: nn.Module,
    dataloader: wds.WebLoader,
    criterion,
    forward_pass_iters: int,
    print_every: int,
    target_names: list[str],
):
    model.eval().cuda()
    total = None if dataloader.length == -1 else dataloader.length
    with tqdm(total=total) as pbar:
        tot_epoch_loss = torch.tensor(0.0).cuda()
        preds = []
        truth = []
        for batch_idx, batch in enumerate(dataloader):
            truth.append(batch["emotion_ix"])

            # FORWARD PASS
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                with torch.no_grad():
                    out = model(
                        batch["feats"].cuda(),
                        batch["attention_mask"].cuda(),
                    )
                    loss = criterion(out, batch["emotion_ix"].cuda())
            preds.append(out.argmax(-1).cpu())

            # PBAR UPDATE
            with torch.no_grad():
                tot_epoch_loss += loss
            pbar.update(1)
            if ((batch_idx + 1) % print_every) == 0:
                pbar.set_description(
                    f"dev_loss: {tot_epoch_loss.cpu().item() / batch_idx:.2f}"
                )
        # FINAL PBAR UPDATE
        tot_epoch_loss = tot_epoch_loss.cpu().item() / batch_idx
        pbar.set_description(f"dev_loss: {tot_epoch_loss:.2f}")
    preds = torch.cat(preds).numpy()
    truth = torch.cat(truth).numpy()
    print(
        classification_report(
            y_true=truth,
            y_pred=preds,
            labels=list(range(len(target_names))),
            target_names=target_names,
            zero_division=0.0,
        )
    )
    if dataloader.length == -1:
        dataloader.length = batch_idx + 1
    return tot_epoch_loss

In [34]:
def train_and_validate(
    model: nn.Module,
    train_dataloader: wds.WebLoader,
    dev_dataloader: wds.WebLoader,
    criterion,
    optimizer,
    lr_scheduler,
    grad_scaler: torch.cuda.amp.GradScaler,
    forward_pass_iters: int,
    grad_accum_iters: int,
    grad_norm: float,
    num_epochs: int,
    print_every: int,
    validate_every: int,
    target_names: list[str],
):
    train_losses = []
    dev_losses = []
    for epoch_ix in range(num_epochs):
        train_losses.append(
            (
                epoch_ix + 1,
                train(
                    model=model,
                    dataloader=train_dataloader,
                    criterion=cross_entropy_criterion,
                    optimizer=optimizer,
                    lr_scheduler=lr_scheduler,
                    grad_scaler=grad_scaler,
                    forward_pass_iters=forward_pass_iters,
                    grad_accum_iters=grad_accum_iters,
                    grad_norm=grad_norm,
                    num_epochs=1,
                    print_every=print_every,
                )[0],
            )
        )
        if (epoch_ix + 1) % validate_every == 0:
            dev_losses.append(
                (
                    epoch_ix + 1,
                    validate(
                        model=model,
                        dataloader=dev_dataloader,
                        criterion=cross_entropy_criterion,
                        forward_pass_iters=forward_pass_iters,
                        print_every=print_every,
                        target_names=target_names,
                    ),
                )
            )
    return train_losses, dev_losses

In [35]:
results = train_and_validate(
    model=model,
    train_dataloader=train_dl,
    dev_dataloader=dev_dl,
    criterion=cross_entropy_criterion,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    grad_scaler=grad_scaler,
    forward_pass_iters=1,
    grad_accum_iters=1,
    grad_norm=None,
    num_epochs=10,
    print_every=10,
    validate_every=1,
    target_names=emotions,
)

train_loss: 1.13: : 326it [01:38,  3.32it/s]

KeyboardInterrupt



In [26]:
train_results, dev_results = results