In [1]:
import pandas as pd

# df = pd.read_csv("../datasets/ESD_score_list.csv")
# df.columns = ["path", "score"]
# df.head()

In [2]:
import torch
import pickle as pk
import os


class EsdStrengthDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        path = self.df.iloc[idx]["path"]
        score = self.df.iloc[idx]["score"]
        emotion = self.df.iloc[idx]["emotion"]
        emotion_id = self.df.iloc[idx]["emotion_id"]
        mel_path = self.df.iloc[idx]["mel_path"]
        
        with open(mel_path, "rb") as f:
            mel = pk.load(f)
        
        return {
            "path": path,
            "score": score,
            "emotion": emotion,
            "emotion_id": emotion_id,
            "mel": mel,
        }

    @staticmethod
    def get_mel_path_from_audio_path(audio_path, preprocessed_basedir):
        basename = audio_path.split("/")[-1].replace(".wav", "") + "_" + audio_path.split("/")[-2].lower()
        return f"{preprocessed_basedir}/mel/mel_{basename}.pkl"
    
    @classmethod
    def from_csv(cls, path, emotion2idx, preprocessed_basedir, val_speakers, **kwargs):
        df = pd.read_csv(path)
        df.columns = ["path", "score"]
        
        # drop rows without mel file
        df["mel_path"] = df["path"].apply(lambda x: cls.get_mel_path_from_audio_path(x, preprocessed_basedir))
        df = df[df["mel_path"].apply(lambda x: os.path.exists(x))]
        
        df["emotion"] = df["path"].apply(lambda x: x.split("/")[-2].lower())
        df["emotion_id"] = df["emotion"].apply(lambda x: emotion2idx[x])
        
        df["speaker"] = df["path"].apply(lambda x: x.split("/")[-3])
        
        print("Loaded dataset with", len(df), "samples")
        
        # train test split
        val_df = df[df["speaker"].isin(val_speakers)]
        train_df = df[~df["speaker"].isin(val_speakers)]
        print("Train samples:", len(train_df))
        print("Val samples:", len(val_df))

        # reset index
        val_df = val_df.reset_index(drop=True)
        train_df = train_df.reset_index(drop=True)

        return cls(train_df, **kwargs), cls(val_df, **kwargs)

In [3]:
e2id = {
    "angry": 0,
    "happy": 1,
    "neutral": 2,
    "sad": 3,
    "surprise": 4,
}

train_ds, val_ds = EsdStrengthDataset.from_csv("../datasets/ESD_score_list_all.csv", emotion2idx=e2id, preprocessed_basedir="../datasets/esd_processed", val_speakers={"0001", "0005", "0011", "0015"})

train_ds[0]
# train_ds.df.head()

Loaded dataset with 14597 samples
Train samples: 12249
Val samples: 2348


{'path': '0012/Angry/0012_000351.wav',
 'score': 0.52011,
 'emotion': 'angry',
 'emotion_id': 0,
 'mel': array([[ -5.891084 ,  -7.9609103,  -6.724277 , ..., -11.107079 ,
         -11.025017 , -10.827656 ],
        [ -5.492388 ,  -7.182169 ,  -7.0519705, ..., -11.132161 ,
         -11.4717655, -11.510298 ],
        [ -5.143411 ,  -7.0178847,  -8.056166 , ..., -11.138936 ,
         -11.512925 , -11.512925 ],
        ...,
        [ -5.949753 ,  -7.2775326,  -7.0753965, ..., -11.334373 ,
         -11.512925 , -11.512925 ],
        [ -5.7215767,  -6.2575912,  -6.301642 , ...,  -9.436365 ,
          -9.496191 ,  -9.729062 ],
        [ -5.091888 ,  -5.370623 ,  -5.3995743, ...,  -8.487104 ,
          -8.542133 ,  -8.772532 ]], dtype=float32)}

In [4]:
len(train_ds)

12249

In [5]:
import numpy as np
import torch
import math 
from torch.functional import F


def mixup_data(x_emo, x_neu, alpha=1.0, lam=None):
    """Applies mixup augmentation to the data.

    Args:
    x_emo (Tensor): Input data (e.g., features of speech samples).
    x_neu (Tensor): Input data (e.g., features of speech samples).

    Returns:
    mixed_x (Tensor): The mixed input data.
    lam (float): The mixup coefficient.
    """
    if lam is None:
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

    x_mixed = lam * x_emo + (1 - lam) * x_neu

    return x_mixed, lam


def mixup_criterion(pred, y_emo, y_neu, lam) -> torch.Tensor:
    w_emo = math.sqrt(1)
    w_neu = math.sqrt(1)
    l_emo = sum([F.cross_entropy(pred[i], y_emo[i]) * lam[i] for i in range(len(pred))]) / len(pred)
    l_neu = sum([F.cross_entropy(pred[i], y_neu[i]) * (1 - lam[i]) for i in range(len(pred))]) / len(pred)
    
    loss = (w_emo * l_emo + w_neu * l_neu) / (w_emo + w_neu)
    return torch.mean(loss)


def rank_loss(ri, rj, lam_diff) -> torch.Tensor:
    p_hat_ij = F.sigmoid(ri - rj)
    rank_loss = torch.mean(- lam_diff @ torch.log(p_hat_ij) - (1 - lam_diff) @ torch.log(1 - p_hat_ij)) / lam_diff.size(0)
    return rank_loss


In [6]:
from torch import nn

class StrengthNetMixupLoss(nn.Module):
    alpha = 0.1
    beta = 1.0

    def forward(
        self, 
        predi, predj, y_emo, y_neu, lam_i, lam_j
    ):
        hi, ri = predi
        hj, rj = predj
        lam_diff = (lam_i - lam_j) / 2 + 0.5

        losses = {}
        mixup_loss_i = mixup_criterion(hi, y_emo, y_neu, lam_i)
        mixup_loss_j = mixup_criterion(hj, y_emo, y_neu, lam_j)
        mixup_loss = mixup_loss_i + mixup_loss_j
        losses.update({
            "mi": mixup_loss_i.item(), 
            "mj": mixup_loss_j.item(),
        })
        
        ranking_loss = rank_loss(ri, rj, lam_diff)
        losses.update({
            "rank": ranking_loss.item(),
        })

        total_loss = self.alpha * mixup_loss + self.beta * ranking_loss
        losses.update({
            "total": total_loss.item(),
        })

        return total_loss, losses

In [7]:
import random


class RankMixupDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, emo_ids, neu_ids, alpha=1.0):
        self.dataset = dataset
        self.alpha = alpha
        self.emo_ids = emo_ids
        self.neu_ids = neu_ids

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample_emo = self.dataset[self.emo_ids[np.random.randint(0, len(self.emo_ids))]]
        sample_neu = self.dataset[self.neu_ids[np.random.randint(0, len(self.neu_ids))]]

        x_emo, y_emo = sample_emo["mel"], sample_emo["emotion_id"]
        x_neu, y_neu = sample_neu["mel"], sample_neu["emotion_id"]

        x_emo = torch.from_numpy(x_emo).float()
        x_neu = torch.from_numpy(x_neu).float()
        y_emo = torch.tensor(y_emo, requires_grad=False)
        y_neu = torch.tensor(y_neu, requires_grad=False)

        return x_emo, x_neu, y_emo, y_neu

    def collate_fn(self, batch):
        mel_emo, mel_neu, y_emo, y_neu = zip(*batch)
        batch_size = len(mel_neu)
        
        lam_i = np.random.beta(self.alpha, self.alpha) if self.alpha != 0 else 0
        lam_j = np.random.beta(self.alpha, self.alpha) if self.alpha != 0 else 0
        if lam_i > 0.5:
            lam_i = 1 - lam_i
        if lam_j < 0.5:
            lam_j = 1 - lam_j
        lam_i = [lam_i] * batch_size
        lam_j = [lam_j] * batch_size
        
        xis = []
        xjs = []
        xi_lens = []
        xj_lens = []
        for i in range(batch_size):
            # if self.rand_lam_per_batch is False:
            _lam_i = np.random.beta(self.alpha, self.alpha)
            _lam_j = np.random.beta(self.alpha, self.alpha)
            lam_i[i] = _lam_i
            lam_j[i] = _lam_j
            
            neu_len = mel_neu[i].shape[0]
            emo_len = mel_emo[i].shape[0]
            min_mel_len = min(neu_len, emo_len)
            neu_start = random.randint(0, neu_len - min_mel_len)
            emo_start = random.randint(0, emo_len - min_mel_len)
            neu_mel = mel_neu[i][neu_start:neu_start+min_mel_len]
            emo_mel = mel_emo[i][emo_start:emo_start+min_mel_len]
            xi, _ = mixup_data(emo_mel, neu_mel, lam_i[i])
            xj, _ = mixup_data(emo_mel, neu_mel, lam_j[i])
            xis.append(xi)
            xjs.append(xj)
        xi_lens = [len(x) for x in xis]
        xj_lens = [len(x) for x in xjs]
        max_xi_len = max(xi_lens)
        max_xj_len = max(xj_lens)
        xi = [F.pad(x, (0, 0, 0, max_xi_len - len(x))) for x in xis]
        xj = [F.pad(x, (0, 0, 0, max_xj_len - len(x))) for x in xjs]
        xi = torch.stack(xi)
        xj = torch.stack(xj)
        y_emo = torch.tensor(y_emo, requires_grad=False)
        y_neu = torch.tensor(y_neu, requires_grad=False)
        lam_i = torch.tensor(lam_i, requires_grad=False)
        lam_j = torch.tensor(lam_j, requires_grad=False)

        return xi, xj, y_emo, y_neu, lam_i, lam_j, xi_lens, xj_lens

In [8]:
emo_ids = train_ds.df[train_ds.df["emotion"] != "neutral"].index
neu_ids = train_ds.df[train_ds.df["emotion"] == "neutral"].index
train_ds[0]

{'path': '0012/Angry/0012_000351.wav',
 'score': 0.52011,
 'emotion': 'angry',
 'emotion_id': 0,
 'mel': array([[ -5.891084 ,  -7.9609103,  -6.724277 , ..., -11.107079 ,
         -11.025017 , -10.827656 ],
        [ -5.492388 ,  -7.182169 ,  -7.0519705, ..., -11.132161 ,
         -11.4717655, -11.510298 ],
        [ -5.143411 ,  -7.0178847,  -8.056166 , ..., -11.138936 ,
         -11.512925 , -11.512925 ],
        ...,
        [ -5.949753 ,  -7.2775326,  -7.0753965, ..., -11.334373 ,
         -11.512925 , -11.512925 ],
        [ -5.7215767,  -6.2575912,  -6.301642 , ...,  -9.436365 ,
          -9.496191 ,  -9.729062 ],
        [ -5.091888 ,  -5.370623 ,  -5.3995743, ...,  -8.487104 ,
          -8.542133 ,  -8.772532 ]], dtype=float32)}

In [9]:
mix_train_ds = RankMixupDataset(train_ds, emo_ids, neu_ids, alpha=1.0)
# mix_train_ds[0]

In [10]:
loader = torch.utils.data.DataLoader(mix_train_ds, batch_size=4, shuffle=True, collate_fn=mix_train_ds.collate_fn)

In [11]:
for batch in loader:
    xi, xj, y_emo, y_neu, lam_i, lam_j, xi_lens, xj_lens = batch
    print(xi.shape, xj.shape, y_emo, y_neu, lam_i, lam_j, xi_lens, xj_lens)
    break

torch.Size([4, 205, 80]) torch.Size([4, 205, 80]) tensor([4, 0, 4, 4]) tensor([2, 2, 2, 2]) tensor([0.8583, 0.6731, 0.4974, 0.7996]) tensor([0.1412, 0.0053, 0.1834, 0.3923]) [148, 182, 205, 169] [148, 182, 205, 169]


In [12]:
# # train_ds.df["emotion"].value_counts()
# esd_basedir = "../datasets/esd_processed"

# lst = []
# for f in os.listdir(f"{esd_basedir}/mel"):
#     if "neutral" in f:
#         speaker = f.split("_")[1]
#         basename = f.split("_")[2]
#         emotion = f.split("_")[3].replace(".pkl", "")
#         # build audio path like 0011/Angry/0011_000352.wav
#         emotion_upper = emotion[0].upper() + emotion[1:]
#         audio_path = f"{speaker}/{emotion_upper}/{speaker}_{basename}.wav"
#         lst.append({
#             "path": audio_path,
#             "score": 0,
#         })
# import pandas as pd

# df = pd.DataFrame(lst)
# df.to_csv("../datasets/ESD_score_list_neu.csv", index=False, header=False)

In [13]:
# df1 = pd.read_csv("../datasets/ESD_score_list.csv")
# df2 = pd.read_csv("../datasets/ESD_score_list_neu.csv")

# df1.columns = ["path", "score"]
# df2.columns = ["path", "score"]

# # concat 2 df
# df = pd.concat([df1, df2])
# df.to_csv("../datasets/ESD_score_list_all.csv", index=False, header=False)