In [27]:
from speechbrain.nnet.losses import LogSoftmaxWrapper, AdditiveAngularMargin
from speechbrain.pretrained import EncoderClassifier
from data import create_dataloader
import torch
import numpy as np
from sklearn.metrics import classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier
from sklearn.neural_network import MLPClassifier

In [28]:
def compute_features(model, loader, criterion):
    features=[]
    labels=[]
    lab2ind=model.hparams.label_encoder.lab2ind
    for i, data in enumerate(loader):
        x, label, y = data
        x = x.cuda()
        y = torch.tensor([[lab2ind[lab]] for lab in y]).cuda()
        with torch.inference_mode():
            pred, score, _, _ = model.classify_batch(x)
            loss = criterion(pred, y)
        features.append([loss,score])
        labels.append(label)
    return torch.tensor(features).numpy(), torch.tensor(labels).numpy()


In [29]:
criterion=LogSoftmaxWrapper(AdditiveAngularMargin(margin=0.2, scale=30))
model=EncoderClassifier.from_hparams(source="speechbrain/spkrec-ecapa-voxceleb", run_opts={"device": "cuda"})
model.eval()
train_loader, val_loader, test_loader = create_dataloader(
    root='split',
    data_root='/home/baole/cmu/2023spring/11785/data/vox1/vox1_all',
    batch_size=1
)

In [30]:
train_x, train_y=compute_features(model, train_loader, criterion)
val_x, val_y=compute_features(model, val_loader, criterion)
test_x, test_y=compute_features(model, test_loader, criterion)

In [31]:
train=np.column_stack((train_x,train_y))
np.save('train.npy',train)
val=np.column_stack((val_x,val_y))
np.save('val.npy',val)
test=np.column_stack((test_x,test_y))
np.save('test.npy',test)

In [32]:
def classify(classifier, train, test):
    classifier.fit(train[:,:-1], train[:,-1])
    pred=classifier.predict(test[:,:-1])
    print(classification_report(test[:,-1],pred))

In [33]:
classify(LogisticRegression(),train,test)

              precision    recall  f1-score   support

         0.0       0.55      0.40      0.46      2000
         1.0       0.53      0.68      0.59      2000

    accuracy                           0.54      4000
   macro avg       0.54      0.54      0.53      4000
weighted avg       0.54      0.54      0.53      4000



In [34]:
classify(SVC(),train,test)

              precision    recall  f1-score   support

         0.0       0.57      0.28      0.38      2000
         1.0       0.52      0.79      0.63      2000

    accuracy                           0.54      4000
   macro avg       0.55      0.54      0.50      4000
weighted avg       0.55      0.54      0.50      4000



In [35]:
classify(RandomForestClassifier(),train,test)

              precision    recall  f1-score   support

         0.0       0.50      0.51      0.51      2000
         1.0       0.50      0.50      0.50      2000

    accuracy                           0.50      4000
   macro avg       0.50      0.50      0.50      4000
weighted avg       0.50      0.50      0.50      4000



In [36]:
classify(AdaBoostClassifier(),train,test)

              precision    recall  f1-score   support

         0.0       0.56      0.32      0.40      2000
         1.0       0.52      0.75      0.62      2000

    accuracy                           0.53      4000
   macro avg       0.54      0.53      0.51      4000
weighted avg       0.54      0.53      0.51      4000



In [37]:
classify(GradientBoostingClassifier(),train,test)

              precision    recall  f1-score   support

         0.0       0.56      0.36      0.44      2000
         1.0       0.53      0.71      0.60      2000

    accuracy                           0.54      4000
   macro avg       0.54      0.54      0.52      4000
weighted avg       0.54      0.54      0.52      4000



In [38]:
classify(MLPClassifier(),train,test)

              precision    recall  f1-score   support

         0.0       0.57      0.28      0.38      2000
         1.0       0.52      0.79      0.63      2000

    accuracy                           0.54      4000
   macro avg       0.55      0.54      0.50      4000
weighted avg       0.55      0.54      0.50      4000

