In [None]:
!pip install speechbrain
!pip install online_triplet_loss
!pip install nemo_toolkit["all"]

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
import os
import random
import zipfile

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from huggingface_hub import hf_hub_download
from nemo.collections import asr as nemo_asr
from online_triplet_loss.losses import *
from speechbrain.dataio.dataio import read_audio
from speechbrain.nnet.losses import AdditiveAngularMargin, LogSoftmaxWrapper
from speechbrain.utils.metric_stats import EER
from torch import nn, optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [None]:
class StreamingDataset(Dataset):
    def __init__(self, samples, targets, cut_interval=None, num_copies=1):
        self.samples = [torch.as_tensor(x) for x in samples]
        self.targets = targets
        self.cut_interval = cut_interval
        self.num_copies = num_copies
        self.targets = [torch.as_tensor(x) for x in targets]
        self.samples, self.targets = zip(
            *sorted(zip(self.samples, self.targets), key=lambda pair: len(pair[0]))
        )
        if cut_interval is not None:
            tmp_samples = []
            tmp_targets = []
            for sample, target in zip(self.samples, self.targets):
                for i in range(num_copies):
                    tmp_samples.append(sample)
                    tmp_targets.append(target)
            self.samples = tmp_samples
            self.targets = tmp_targets

    def __getitem__(self, index):
        sample = self.samples[index]
        target = self.targets[index]
        if self.cut_interval is not None:
            duration = round(random.uniform(*self.cut_interval) * 16000)
            start_idx = random.randint(0, max(0, len(sample) - duration))
            end_idx = start_idx + duration
            sample = sample[start_idx:end_idx]
        return sample, target

    def __len__(self):
        return len(self.samples)


def download_data(dest="data"):
    if os.path.exists(os.path.join(dest, "audio", "train")):
        print(
            f"It appears that data is already downloaded. \n"
            f"If you think it should be re-downloaded, remove {dest} directory and re-run"
        )
        return

    # Download data from Huggingface
    for file_name in [
        "metadata.csv",
        "audio.zip",
        "dev_pairs.csv",
        "test_pairs.csv",
        "sample_submission.csv",
    ]:
        hf_hub_download(
            repo_id="Ubenwa/CryCeleb2023",
            filename=file_name,
            local_dir=dest,
            repo_type="dataset",
        )

    with zipfile.ZipFile(os.path.join(dest, "audio.zip"), "r") as zip_ref:
        zip_ref.extractall(dest)

    print(f"Data downloaded to {dest}/ directory")


def get_baby_ids_with_both_periods(manifest_df):
    count_of_periods_per_baby = manifest_df.groupby("baby_id")["period"].count()
    baby_ids_with_both_periods = count_of_periods_per_baby[
        count_of_periods_per_baby == 2
    ].index
    return baby_ids_with_both_periods


def collate_with_padding(batch):
    batch = (batch,) if not isinstance(batch[0], tuple) else tuple(zip(*batch))
    samples = batch[0]
    lengths = torch.as_tensor([x.shape[0] for x in samples])
    max_length = max(lengths)
    lengths = lengths / max_length
    samples = [F.pad(x, [0, max_length - x.shape[0]]) for x in samples]
    return torch.stack(samples), lengths, *(torch.stack(x) for x in batch[1:])


def compute_cosine_similarity_score(row, cry_dict):
    cos = torch.nn.CosineSimilarity(dim=-1)
    similarity_score = cos(
        cry_dict[(row["baby_id_B"], "B")]["cry_encoded"],
        cry_dict[(row["baby_id_D"], "D")]["cry_encoded"],
    )
    return similarity_score.item()


def compute_eer_and_plot_verification_scores(pairs_df):
    # pairs_df must have `score` and `label` columns
    positive_scores = pairs_df.loc[pairs_df["label"] == 1]["score"].values
    negative_scores = pairs_df.loc[pairs_df["label"] == 0]["score"].values
    eer, threshold = EER(torch.tensor(positive_scores), torch.tensor(negative_scores))
    ax = sns.histplot(pairs_df, x="score", hue="label", stat="percent", common_norm=False)
    ax.set_title(f"EER={round(eer, 4)} - Thresh={round(threshold, 4)}")
    plt.axvline(x=[threshold], color="red", ls="--")
    plt.show()
    return eer, threshold

In [None]:
# Download data
dataset_path = "data"
download_data(dataset_path)

# Read manifest
metadata = pd.read_csv(
    f"{dataset_path}/metadata.csv", dtype={"baby_id": str, "chronological_index": str}
)

# Load train data
train_metadata = metadata.loc[metadata["split"] == "train"].copy()
train_metadata["cry"] = train_metadata.apply(
    lambda row: read_audio(f'{dataset_path}/{row["file_name"]}').numpy(), axis=1
)
# Concatenate all segments for each (baby_id, period) group
manifest_df = pd.DataFrame(
    train_metadata.groupby(["baby_id", "period"])["cry"].agg(lambda x: np.concatenate(x.values)),
    columns=["cry"],
).reset_index()

# Load dev data
dev_metadata = metadata.loc[metadata["split"] == "dev"].copy()
dev_pairs = pd.read_csv(
    f"{dataset_path}/dev_pairs.csv", dtype={"baby_id_B": str, "baby_id_D": str}
)
dev_metadata["cry"] = dev_metadata.apply(
    lambda row: read_audio(f'{dataset_path}/{row["file_name"]}').numpy(), axis=1
)
# Concatenate all segments for each (baby_id, period) group
cry_dict = pd.DataFrame(
    dev_metadata.groupby(["baby_id", "period"])["cry"].agg(lambda x: np.concatenate(x.values)),
    columns=["cry"],
).to_dict(orient="index")

# Load test data
test_metadata = metadata.loc[metadata["split"] == "test"].copy()
test_pairs = pd.read_csv(f"{dataset_path}/test_pairs.csv")
test_metadata["cry"] = test_metadata.apply(
    lambda row: read_audio(f'{dataset_path}/{row["file_name"]}').numpy(), axis=1
)
# Concatenate all segments for each (baby_id, period) group
cry_dict_test = pd.DataFrame(
    test_metadata.groupby(["baby_id", "period"])["cry"].agg(lambda x: np.concatenate(x.values)),
    columns=["cry"],
).to_dict(orient="index")

In [None]:
CUT_INTERVAL = (3, 5)
TRAIN_BATCH_SIZE = 32  #160
NUM_EPOCHS = 200
LR = 1e-4
MARGIN = 0.2  # 0.5
SCALE = 30
EVAL_FREQ = 1


# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Define dataloaders
baby_ids_with_both_periods = get_baby_ids_with_both_periods(manifest_df)
manifest_df = manifest_df.loc[manifest_df["baby_id"].isin(baby_ids_with_both_periods)]
all_targets = list([x for x in manifest_df["baby_id"]])
id2label = {v: k for k, v in enumerate(np.unique(all_targets))}
train_samples = []
train_targets = []
for cry, baby_id, period in zip(manifest_df["cry"], manifest_df["baby_id"], manifest_df["period"]):
    if period == "B":
        train_samples.append(cry)
        train_targets.append(baby_id)
train_targets = [torch.as_tensor(id2label[x]) for x in train_targets]
train_data = StreamingDataset(train_samples, train_targets, CUT_INTERVAL)
train_dataloader = DataLoader(
    train_data, TRAIN_BATCH_SIZE, collate_fn=collate_with_padding, drop_last=True, shuffle=False,
)

# Define model
encoder = nemo_asr.models.EncDecSpeakerLabelModel.from_pretrained(
    "nvidia/speakerverification_en_titanet_large"
)
encoder.requires_grad_(True)
classifier = nn.Linear(192, len(id2label))
model = nn.ModuleList([encoder, classifier]).to(device)

# Load checkpoint
checkpoint_path = None #os.path.join("checkpoints", "titanet_checkpoint_epoch=00_eer=0.41987180709838867.pt")
if checkpoint_path is not None:
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint)

# Define optimizer
optimizer = optim.AdamW(model.parameters(), lr=LR)

# Margin loss
criterion = LogSoftmaxWrapper(AdditiveAngularMargin(margin=MARGIN, scale=SCALE))

In [None]:
train_losses = []
for i in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.
    for samples, lengths, targets in train_dataloader:
        samples, lengths, targets = samples.to(device), lengths.to(device), targets.to(device)
        _, embeddings = encoder(input_signal=samples, input_signal_length=lengths * samples.shape[1])
        logits = classifier(embeddings)
        loss = criterion(logits[:, None, :], targets[:, None], lengths)
        #margin = 0.2
        #loss = batch_all_triplet_loss(targets, embeddings, margin=margin)[0]
        #loss = batch_hard_triplet_loss(targets, embeddings, margin=margin)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    train_losses.append(total_loss / len(train_dataloader))

    if i % EVAL_FREQ == 0:
        ##### VALIDATION #####
        # Compute embeddings
        model.eval()
        with torch.no_grad():
            for (baby_id, period), d in tqdm(cry_dict.items()):
                samples = torch.as_tensor(d["cry"][None], device=device)
                lengths = torch.ones(1, device=device)
                _, embeddings = encoder(input_signal=samples, input_signal_length=lengths * samples.shape[1])
                embedding = embeddings[0].to("cpu")
                d["cry_encoded"] = embedding

        # Compute scores
        dev_pairs["score"] = dev_pairs.apply(
            lambda row: compute_cosine_similarity_score(row=row, cry_dict=cry_dict), axis=1
        )

        eer, threshold = compute_eer_and_plot_verification_scores(pairs_df=dev_pairs)

        print(eer, threshold)
        ##########

        ##### TEST #####
        # Compute embeddings
        model.eval()
        with torch.no_grad():
            for (baby_id, period), d in tqdm(cry_dict_test.items()):
                samples = torch.as_tensor(d["cry"][None], device=device)
                lengths = torch.ones(1, device=device)
                _, embeddings = encoder(input_signal=samples, input_signal_length=lengths * samples.shape[1])
                embedding = embeddings[0].to("cpu")
                d["cry_encoded"] = embedding

        # Compute scores
        test_pairs["score"] = test_pairs.apply(
            lambda row: compute_cosine_similarity_score(row=row, cry_dict=cry_dict_test), axis=1
        )

        # Write submission file
        submission = test_pairs[["id", "score"]]
        os.makedirs("submissions", exist_ok=True)
        submission.to_csv(
            os.path.join("submissions", f"titanet_submission_epoch={i:02d}_eer={eer}.csv"),
            index=False,
        )
        ##########

        # Save model
        os.makedirs("checkpoints", exist_ok=True)
        torch.save(
            model.state_dict(),
            os.path.join("checkpoints", f"titanet_checkpoint_epoch={i:02d}_eer={eer}.pt")
        )

# Plot loss
plt.plot(train_losses)
plt.xlabel("Epoch")
plt.ylabel("Training loss")
plt.show()