In [1]:
%load_ext autoreload
%autoreload 2
import os

# os.environ['TORCH_LOGS']="+dynamo"
# os.environ['TORCHDYNAMO_VERBOSE']="1"
import torch
from torch import nn
from torch.nn import functional as F
import torchaudio
import torchaudio.functional as Fa
from pytorch_metric_learning import losses
import torchmetrics
import webdataset as wds
from tqdm import tqdm
from huggingface_hub import get_token
from accelerate import Accelerator
from transformers import WhisperConfig, WhisperModel, WhisperFeatureExtractor
import nltk
from nltk.tokenize import RegexpTokenizer
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
from audiomentations import (
    Compose,
    AddGaussianNoise,
    TimeStretch,
    PitchShift,
    PolarityInversion,
    AdjustDuration,
    Normalize,
)
from itertools import chain
from sklearn.manifold import TSNE
from sklearn.metrics import classification_report

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# import os
# os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"

# from huggingface_hub import HfApi
# hf = HfApi()

# hf.snapshot_download(
#     "columbiaslp/MSP_PODCAST",
#     repo_type="dataset",
#     local_dir_use_symlinks=False,
#     local_dir="./data/MSP_PODCAST",
# )

In [3]:
train_url = "./data/MSP_PODCAST/data/train/train_{01..21}-of-21.tar"
dev_url = "./data/MSP_PODCAST/data/development/development_{01..05}-of-05.tar"
test_url = "./data/MSP_PODCAST/data/test1/test1_{01..08}-of-08.tar"

In [5]:
sample_rate = 16_000
max_duration = 6  # has to be int and such that 3000 / (max_duration / 30) is an int
assert 3000 * (max_duration / 30 / 2) % 1 == 0
crop_duration = 3
shift_offset = 1
feature_extractor = WhisperFeatureExtractor(chunk_length=max_duration)
batch_size = 128
epoch_length = 500
dataloader_workers = 32
augmentation = Compose(
    [
        AdjustDuration(duration_seconds=crop_duration, padding_mode="reflect", p=0.5),
        PolarityInversion(p=0.5),
        Normalize(p=1.0),
        AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.015, p=1.0),
        TimeStretch(min_rate=0.8, max_rate=1.2, leave_length_unchanged=True, p=0.5),
        PitchShift(min_semitones=-1, max_semitones=1, p=0.5),
    ]
)

In [6]:
from data import (
    collation_fn,
    decode,
    crop,
    apply_augmentation,
    MSP_PODCAST_EMOTION_TO_IX,
    MSP_PODCAST_EMOTIONS,
)

In [9]:
# This is the basic WebDataset definition: it starts with a URL and add shuffling,
# decoding, and augmentation. Note `resampled=True`; this is essential for
# distributed training to work correctly.
trainset = (
    wds.WebDataset(train_url, shardshuffle=True, resampled=True)
    .shuffle(8192)
    .map(decode())
    .map(apply_augmentation(augmentation=augmentation))
    .map(crop(crop_duration=max_duration, random=True))
    .batched(
        batchsize=batch_size,
        collation_fn=collation_fn(feature_extractor=feature_extractor),
        partial=False,
    )
)
trainloader = wds.WebLoader(
    trainset,
    batch_size=None,
    num_workers=dataloader_workers,
    prefetch_factor=4,
    pin_memory=True,
    persistent_workers=True,
)

# A resampled dataset is infinite size, but we can recreate a fixed epoch length.
trainloader = trainloader.with_epoch(epoch_length)

In [10]:
devset = (
    wds.WebDataset(dev_url, shardshuffle=False, resampled=False)
    .map(decode())
    .map(crop(crop_duration=max_duration, random=False))
    .batched(
        batchsize=batch_size,
        collation_fn=collation_fn(feature_extractor=feature_extractor),
        partial=False,
    )
)
devloader = wds.WebLoader(
    devset,
    batch_size=None,
    num_workers=dataloader_workers,
    prefetch_factor=4,
    pin_memory=True,
    persistent_workers=True,
)

In [25]:
batch = next(iter(trainset))

In [18]:
from model import WhisperBackbone, Classifier

In [27]:
model = WhisperBackbone(
    "openai/whisper-tiny.en", max_duration=max_duration, pooling="mean"
).cuda()

In [20]:
classifier = Classifier(
    "openai/whisper-tiny.en",
    max_duration=max_duration,
    projection_dim=128,
    num_classes=len(MSP_PODCAST_EMOTION_TO_IX),
)

In [78]:
model.load_state_dict(whisper_weights)

<All keys matched successfully>

In [28]:
with torch.inference_mode():
    out = model(batch["feats"]["orig"].cuda(), batch["attention_mask"]["orig"].cuda())

In [29]:
plt.matshow()

torch.Size([128, 384])

In [27]:
emotions = []
pre_contrastive_embeddings = []
for batch in tqdm(devloader):
    with torch.inference_mode():
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            embeddings = model(
                batch["feats"]["orig"].cuda(), batch["attention_mask"]["orig"].cuda()
            )
        emotions.append(batch["emotion_ix"])
        pre_contrastive_embeddings.append(embeddings.cpu())
pre_contrastive_embeddings = torch.cat(pre_contrastive_embeddings).numpy()
emotions = torch.cat(emotions).numpy()
pre_contrastive_tsne = TSNE(n_jobs=32).fit_transform(pre_contrastive_embeddings)

77it [00:22,  3.36it/s]


In [28]:
colors = {
    "neutral": "gray",
    "happy": "yellow",
    "sad": "blue",
    "angry": "red",
    "fear": "purple",
    "disgust": "darkgreen",
    "surprise": "pink",
    "contempt": "brown",
}

In [29]:
def plot_tsne(title, embeddings):
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.set_title(title)
    for emotion in [
        "neutral",
        "happy",
        "sad",
        "angry",
        "fear",
        "disgust",
        "surprise",
        "contempt",
    ]:
        y = embeddings[emotions == MSP_PODCAST_EMOTION_TO_IX[emotion]]
        ax.scatter(y[:, 0], y[:, 1], alpha=0.6, s=2, c=colors[emotion], label=emotion)
    legend = ax.legend()
    for handle in legend.legend_handles:
        handle._sizes = [30]
    plt.close()
    return fig

In [31]:
conrastive_loss_func = losses.SelfSupervisedLoss(losses.NTXentLoss())
optimizer = torch.optim.AdamW(model.parameters())

In [32]:
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    with tqdm(total=500) as pbar:
        tot_train_loss = 0
        for i, batch in enumerate(trainloader):
            optimizer.zero_grad()
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                embeddings_orig = model(
                    batch["feats"]["orig"].cuda(),
                    batch["attention_mask"]["orig"].cuda(),
                )
                embeddings_aug = model(
                    batch["feats"]["aug"].cuda(), batch["attention_mask"]["aug"].cuda()
                )
                loss = conrastive_loss_func(embeddings_orig, embeddings_aug)
            loss.backward()
            optimizer.step()
            tot_train_loss += loss.detach()
            pbar.update(1)
            if i % 10 == 0:
                pbar.set_description(f"loss: {tot_train_loss / i:.2f}")

loss: 0.59: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [03:23<00:00,  2.45it/s]


In [33]:
post_contrastive_embeddings = []
for batch in tqdm(devloader):
    with torch.inference_mode():
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            embeddings = model(
                batch["feats"]["orig"].cuda(), batch["attention_mask"]["orig"].cuda()
            )
        post_contrastive_embeddings.append(embeddings.cpu())
post_contrastive_embeddings = torch.cat(post_contrastive_embeddings).numpy()
post_contrastive_tsne = TSNE(n_jobs=32).fit_transform(post_contrastive_embeddings)

77it [00:22,  3.36it/s]


In [36]:
model = model.cpu()
contrastive_weights = model.state_dict()
model.load_state_dict(whisper_weights)
model = model.cuda()

In [37]:
supervised_contrastive_loss_func = losses.NTXentLoss()
optimizer = torch.optim.AdamW(model.parameters())

In [44]:
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    with tqdm(total=500) as pbar:
        tot_train_loss = 0
        for i, batch in enumerate(trainloader):
            optimizer.zero_grad()
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                embeddings_orig = model(
                    batch["feats"]["orig"].cuda(),
                    batch["attention_mask"]["orig"].cuda(),
                )
                loss = supervised_contrastive_loss_func(
                    embeddings_orig, labels=batch["emotion_ix"].cuda()
                )
            loss.backward()
            optimizer.step()
            tot_train_loss += loss.detach()
            pbar.update(1)
            if i % 10 == 0:
                pbar.set_description(f"loss: {tot_train_loss / i:.2f}")

loss: 4.51: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [01:12<00:00,  6.94it/s]


In [46]:
post_supervised_contrastive_embeddings = []
for batch in tqdm(devloader):
    with torch.inference_mode():
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            embeddings = model(
                batch["feats"]["orig"].cuda(), batch["attention_mask"]["orig"].cuda()
            )
        post_supervised_contrastive_embeddings.append(embeddings.cpu())
post_supervised_contrastive_embeddings = torch.cat(
    post_supervised_contrastive_embeddings
).numpy()
post_supervised_contrastive_tsne = TSNE(n_jobs=32).fit_transform(
    post_supervised_contrastive_embeddings
)

77it [00:23,  3.34it/s]


In [53]:
model = model.cpu()
supervised_contrastive_weights = model.state_dict()
model = model.cuda()

In [160]:
fig = plot_tsne(
    "TSNE on embeddings pre-contrastive (Whisper weights)",
    embeddings=pre_contrastive_tsne,
)
fig.savefig("Whisper")

In [161]:
fig = plot_tsne(
    "TSNE on embeddings post-contrastive (new weights)",
    embeddings=post_contrastive_tsne,
)
fig.savefig("Contrastive")

In [162]:
fig = plot_tsne(
    "TSNE on embeddings post-supervised-contrastive (using labels)",
    embeddings=post_supervised_contrastive_tsne,
)
fig.savefig("Supervised Contrastive")

In [54]:
model.load_state_dict(contrastive_weights)

<All keys matched successfully>

In [121]:
class Classifier(nn.Module):
    def __init__(self, pretrained_model_name_or_path, max_duration):
        super().__init__()
        model = WhisperModel.from_pretrained(pretrained_model_name_or_path)
        state_dict = model.state_dict()
        offset = int(3000 * (max_duration / 30 / 2))
        state_dict["encoder.embed_positions.weight"] = state_dict[
            "encoder.embed_positions.weight"
        ][:offset, :]
        config = WhisperConfig.from_pretrained(
            pretrained_model_name_or_path, max_source_positions=offset
        )
        model = WhisperModel(config)
        model.load_state_dict(state_dict)
        self.encoder = model.get_encoder()
        self.projection = nn.Linear(in_features=384, out_features=128)
        self.classification_head = nn.Linear(
            in_features=128, out_features=len(MSP_PODCAST_EMOTION_TO_IX)
        )

    def forward(
        self, input_values: torch.Tensor, attention_mask: torch.Tensor
    ) -> torch.Tensor:
        out = self.encoder(input_values).last_hidden_state
        output_attention_mask = attention_mask[
            :, ::2, None
        ]  # second conv has stride of 2, so drop half
        out = out * output_attention_mask
        out = out.sum(dim=1, keepdim=True) / output_attention_mask.sum(
            dim=1, keepdim=True
        )
        out = out.squeeze()
        out = self.projection(out)
        out = F.gelu(out)
        out = self.classification_head(out)
        return out

In [180]:
model = Classifier("openai/whisper-tiny.en", max_duration=max_duration)

In [181]:
model.load_state_dict(contrastive_weights, strict=False)

_IncompatibleKeys(missing_keys=['projection.weight', 'projection.bias', 'classification_head.weight', 'classification_head.bias'], unexpected_keys=[])

In [182]:
model = model.cuda()
# for params in model.encoder.parameters():
#     params.requires_grad = False

In [187]:
cross_entropy_loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.005)
lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)

In [188]:
with torch.inference_mode():
    preds = model(
        batch["feats"]["orig"].cuda(),
        batch["attention_mask"]["orig"].cuda(),
    )

In [189]:
num_epochs = 2
for epoch in range(num_epochs):
    model.train()
    with tqdm(total=500) as pbar:
        tot_train_loss = 0
        for i, batch in enumerate(trainloader):
            optimizer.zero_grad()
            with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
                preds = model(
                    batch["feats"]["orig"].cuda(),
                    batch["attention_mask"]["orig"].cuda(),
                )
                loss = cross_entropy_loss_func(preds, batch["emotion_ix"].cuda())
            loss.backward()
            optimizer.step()
            tot_train_loss += loss.detach()
            pbar.update(1)
            if i % 10 == 0:
                pbar.set_description(f"loss: {tot_train_loss / i:.2f}")
    lr_scheduler.step()

loss: 1.76: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [02:26<00:00,  3.41it/s]
loss: 1.73: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [02:31<00:00,  3.29it/s]


In [None]:
predictions = []
truth = []
for batch in tqdm(devloader):
    with torch.inference_mode():
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            preds = (
                model(
                    batch["feats"]["orig"].cuda(),
                    batch["attention_mask"]["orig"].cuda(),
                )
                .argmax(-1)
                .cpu()
            )
        predictions.append(preds)
        truth.append(batch["emotion_ix"])

predictions = torch.cat(predictions).numpy()
truth = torch.cat(truth).numpy()

In [None]:
print(classification_report(truth, predictions, target_names=MSP_PODCAST_EMOTIONS))