In [1]:
from utils.datasets import InTheWildDataset
from utils.metrics import ABXAccuracy
from utils.training import SpeechCLRTrainerVanilla

from models import SpeechEmbedder

import torch
import torch.nn as nn
from torch.nn.functional import cosine_similarity

2024-11-24 17:57:25.280959: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1732467445.303868 2489434 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1732467445.310418 2489434 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-24 17:57:25.337865: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
INFO:speechbrain.utils.quirks:Applied quirks (see `speechbrain.utils.quirks`): [allow_tf32, disable_jit_profiling]
INFO:speec

In [2]:
from torch.utils.data import DataLoader
from utils.datasets import InTheWildDataset


train_dataset = InTheWildDataset(
        root_dir="/home/infres/amathur-23/DADA/datasets/InTheWild",
        metadata_file='meta.csv',
        include_spoofs=False,
        bonafide_label="bona-fide",
        filename_col="file",
        sampling_rate=16000,
        max_duration=4,
        split="train",
        config='configs/data/inthewild_toy.yaml',
        mode="triplet",
    )

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [3]:
import os 

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SpeechEmbedder().to(device)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at ehcalabres/wav2vec2-lg-xlsr-en-speech-emotion-recognition and are newly initialized: ['lm_head.bias', 'lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
class AdaTriplet(nn.Module):
    """
    Adaptive Triplet Loss from
    Nyugen et al. 'AdaTriplet: Adaptive Gradient Triplet Loss with Automatic Margin Learning 
                   for Forensic Medical Image Matching'
    """

    def __init__(self, K_d=2, K_an=2, eps=0, beta=0, lambda_=1):
        super(AdaTriplet, self).__init__()
        self.K_d = K_d  
        self.K_an = K_an
        self.eps = eps
        self.beta = beta
        self.lambda_ = lambda_

        # stats, init?
        self.mu_d = 0
        self.mu_an = 0
        self.counter = 0

    def reset(self):
        self.mu_d = 0
        self.mu_an = 0
        self.counter = 0

    def update_stats(self, phi_ap, phi_an):
        delta = phi_ap - phi_an
        self.mu_d = (self.counter * self.mu_d + delta.mean()) / (self.counter + 1)
        self.mu_an = (self.counter * self.mu_an + phi_an.mean()) / (self.counter + 1)
        self.counter = self.counter + 1

    def update_margins(self):
        self.eps = self.mu_d / self.K_d
        self.beta = self.mu_an / self.K_an

    def __repr__(self):
        return f"AdaTriplet(K_d={self.K_d}, K_an={self.K_an}, eps={self.eps}, beta={self.beta}, lambda_={self.lambda_})"

    def forward(self, anchor, positive, negative):
        phi_ap = cosine_similarity(anchor, positive)
        phi_an = cosine_similarity(anchor, negative)

        with torch.no_grad():
            self.update_stats(phi_ap, phi_an)
            self.update_margins()

        loss = torch.clamp_min(phi_an - phi_ap + self.eps, 0) 
        loss = loss + self.lambda_ * torch.clamp_min(phi_an - self.beta, 0)
        loss = torch.mean(loss)
        return loss


In [5]:
criterion = AdaTriplet(K_d=2, K_an=2, eps=0, beta=0, lambda_=1)

In [6]:
criterion

AdaTriplet(K_d=2, K_an=2, eps=0, beta=0, lambda_=1)

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [8]:
from tqdm import tqdm
n = 0
for _ in tqdm(range(len(train_loader))[:5]):
    batch = next(iter(train_loader))

    optimizer.zero_grad()
    model.train()
    output = model(batch)
    loss = criterion(output["anchor"], output["positive"], output["negative"])
    try:
        # assert not torch.isnan(loss), "loss is nan"
        loss.backward()
    except Exception as e:
        n+=1
        print(e)
    optimizer.step()

100%|██████████| 5/5 [00:11<00:00,  2.29s/it]


In [9]:
print(n)

0


In [10]:
print(criterion)

AdaTriplet(K_d=2, K_an=2, eps=0.04331747442483902, beta=0.360461950302124, lambda_=1)
