In [1]:
# I'm going to try on the whole dataset without streaming
from datasets import load_dataset
import torch
from torch import nn
from transformers import Wav2Vec2Model, Wav2Vec2Processor

torch.manual_seed(42)

<torch._C.Generator at 0x75e392f70690>

In [2]:
ds = load_dataset(
    'facebook/voxpopuli', 'sl', split='train'
).shuffle(seed=42)

In [3]:
processor = Wav2Vec2Processor.from_pretrained('facebook/wav2vec2-base-960h')
model = Wav2Vec2Model.from_pretrained('facebook/wav2vec2-base-960h')

# Padding doesn't really matter - regardless of padding, the number of frames, the embedding dimension, and the features themselves are all identical
def feat(sample, padding=True):
    enc = processor(sample['audio']['array'], sampling_rate=16_000, return_tensors='pt', padding=padding)
    with torch.no_grad():
        output = model(**enc)
    return {'features': output.last_hidden_state}

features = ds.map(feat)

Some weights of Wav2Vec2Model were not initialized from the model checkpoint at facebook/wav2vec2-base-960h and are newly initialized: ['wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1', 'wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Map:   0%|          | 0/2099 [00:00<?, ? examples/s]

In [10]:
# minimum and maximum feature lengths

In [11]:
features = features.map(lambda x: {'n_frames': len(x['features'][0])}, num_proc=9)
min(features['n_frames']), max(features['n_frames'])

Map (num_proc=9):   0%|          | 0/2099 [00:00<?, ? examples/s]

(27, 1948)

## 5. Data Filtering and Label Mapping
* Filtering Short Samples: Exclude samples with fewer than 200 frames.
* Label Mapping: Assign each unique speaker an integer label. Sort the speaker list using the `sorted` function before mapping and store mappings in a dictionary.

In [13]:
ds_filtered = features.filter(lambda x: x['n_frames'] >= 200, num_proc=9)

Filter (num_proc=9):   0%|          | 0/2099 [00:00<?, ? examples/s]

In [14]:
len(ds_filtered)

1937

In [15]:
# First speaker ID of the filtered dataset
ds_filtered[0]

{'audio_id': '20150908-0900-PLENARY-3-sl_20150908-10:38:22_5',
 'language': 13,
 'audio': {'path': None,
  'array': array([-0.05892944, -0.05279541, -0.01495361, ..., -0.0057373 ,
         -0.00714111, -0.0105896 ]),
  'sampling_rate': 16000},
 'raw_text': 'Ali je na vrsti za uspešno metodo na živalih potem tudi kloniranje na ljudeh?',
 'normalized_text': 'ali je na vrsti za uspešno metodo na živalih potem tudi kloniranje na ljudeh?',
 'gender': 'male',
 'speaker_id': '125004',
 'is_gold_transcript': True,
 'accent': 'None',
 'features': [[[-0.14782148599624634,
    -0.0040324958972632885,
    0.05790264159440994,
    -0.0076219444163143635,
    0.3669835925102234,
    0.017433855682611465,
    -0.04617757350206375,
    -0.0925062894821167,
    -0.05182841420173645,
    -0.39585620164871216,
    0.013623909093439579,
    0.004006413742899895,
    0.038226064294576645,
    0.09433609247207642,
    -0.03560760244727135,
    -0.17768095433712006,
    -0.30195948481559753,
    0.3375417888

In [19]:
speakers = list(set(ds_filtered['speaker_id']))
speakers = sorted(speakers)
speakers

['125003',
 '125004',
 '125103',
 '125104',
 '197446',
 '197447',
 '197452',
 '23693',
 '28294',
 '96911',
 '96933',
 '97019']

In [20]:
spk_mapping = {spk: i for i, spk in enumerate(speakers)}
spk_mapping

{'125003': 0,
 '125004': 1,
 '125103': 2,
 '125104': 3,
 '197446': 4,
 '197447': 5,
 '197452': 6,
 '23693': 7,
 '28294': 8,
 '96911': 9,
 '96933': 10,
 '97019': 11}

In [21]:
# number of unique speakers
len(spk_mapping), len(set(ds_filtered['speaker_id']))

(12, 12)

In [23]:
import warnings

In [52]:
class SpeakerClassification(nn.Module):

    def __init__(self, n_classes, *args, **kwargs):
        super(SpeakerClassification, self).__init__(*args, **kwargs)
        
        self.relu = nn.ReLU()
        
        # Conv1: In channels: 768, Out channels: 256, Kernel size: 3
        # BatchNorm1: Features: 256
        # ReLU1: Applied after BatchNorm1
        # MaxPool1: Kernel size: 2
        self.conv1 = nn.Conv1d(in_channels=768, out_channels=256, kernel_size=3)
        self.bnorm1 = nn.BatchNorm1d(256)
        self.maxpool1 = nn.MaxPool1d(kernel_size=2)
        
        # Conv2: In channels: 256, Out channels: 128, Kernel size: 3
        # BatchNorm2: Features: 128
        # ReLU2: Applied after BatchNorm2
        # MaxPool2: Kernel size: 2
        self.conv2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3)
        self.bnorm2 = nn.BatchNorm1d(128)
        self.maxpool2 = nn.MaxPool1d(kernel_size=2)
        
        # Conv3: In channels: 128, Out channels: 32, Kernel size: 3
        # BatchNorm3: Features: 32
        # ReLU3: Applied after BatchNorm3
        # MaxPool3: Kernel size: 2
        self.conv3 = nn.Conv1d(in_channels=128,  out_channels=32, kernel_size=3)
        self.bnorm3 = nn.BatchNorm1d(32)
        self.maxpool3 = nn.MaxPool1d(kernel_size=2)
        
        # Global Average Pooling: Averaged across the temporal dimension
        # FC1: Input features: 32, Output features: 128
        # ReLU4: Applied after FC1
        # FC2: Input features: 128, Output features: Number of classes
        self.fc1 = nn.Linear(in_features=32, out_features=128)
        self.fc2 = nn.Linear(in_features=128, out_features=n_classes)
        
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bnorm1(x)
        x = self.relu(x)
        x = self.maxpool1(x)

        x = self.conv2(x)
        x = self.bnorm2(x)
        x = self.relu(x)
        x = self.maxpool2(x)

        x = self.conv3(x)
        x = self.bnorm3(x)
        x = self.relu(x)
        x = self.maxpool3(x)

        x = torch.mean(x, axis=-1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        
        return x
model = SpeakerClassification(n_classes=len(spk_mapping))

In [36]:
x = torch.tensor(ds_filtered[1]['features'])
x.shape

torch.Size([1, 413, 768])

In [53]:
with torch.no_grad():
    y = model(x.transpose(2, 1))

In [54]:
y

tensor([[-0.1705,  0.2408, -0.3702, -0.2456, -0.0187, -0.3471,  0.0137, -0.0910,
         -0.2151,  0.0485, -0.3126,  0.1476]])

In [56]:
y.shape

torch.Size([1, 12])

### Data Splitting:
* Split the data into training, validation, and test sets using sklearn.
* Divide the dataset by taking 80% for train and 20% for test.
* Then take another 10% of the train data for valid set. Set random state to 42 and use sklearn for this stage.

In [59]:
from sklearn.model_selection import train_test_split
import numpy as np

In [60]:
ix = np.arange(len(ds_filtered))
trix, tsix = train_test_split(ix, test_size=0.2, random_state=42)
trix, valix = train_test_split(trix, test_size=0.1, random_state=42)

In [61]:
train_dataset = ds_filtered.select(trix)
test_dataset = ds_filtered.select(tsix)
val_dataset = ds_filtered.select(valix)

In [62]:
len(train_dataset)

1394

In [63]:
len(test_dataset)

388

In [64]:
len(val_dataset)

155

## DataLoader:
Use PyTorch’s DataLoader with a custom collate function to convert features and labels to tensors, truncating features to the first 200 frames. Batch sizes:
* Train: 100
* Validation: 10
* Test: 1

In [66]:
from torch.utils.data import DataLoader

In [97]:
def collate(batch):
    feats = torch.stack(
        [torch.tensor(sample['features'][0])[:200] for sample in batch]
    ).transpose(2, 1)
    labels = torch.tensor([spk_mapping[sample['speaker_id']] for sample in batch])
    return {'batch': feats, 'labels': labels}
    

train_dataloader = DataLoader(train_dataset, batch_size=100, collate_fn=collate)
test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=collate)
val_dataloader = DataLoader(val_dataset, batch_size=10, collate_fn=collate)

In [88]:
# Training / validation loop

In [89]:
len(train_dataset)

1394

In [90]:
n_batches_train = 0
for i in train_dataloader:
    n_batches_train += 1
print(n_batches_train)

14


In [108]:
from torch.optim import Adam
from tqdm import tqdm

In [115]:
from sklearn.metrics import accuracy_score

In [124]:
model = SpeakerClassification(n_classes=len(spk_mapping))
model.to(device)

losser = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
device = torch.device('cuda')


n_epochs = 100

for i in tqdm(range(n_epochs), desc='Epoch'):
    model.train()
    epoch_train_loss = 0
    
    y_true_train = []
    y_pred_train = []

    # for batch in tqdm(train_dataloader, total=len(train_dataloader), desc='Train batch', leave=False):
    for batch in train_dataloader:
        optimizer.zero_grad()
        outputs = model(batch['batch'].to(device))
        loss = losser(outputs, batch['labels'].to(device))
        loss.backward()
        optimizer.step()
        batch_loss = loss.item()
        epoch_train_loss += batch_loss

        y_true = batch['labels'].tolist()
        y_true_train.extend(y_true)
        y_pred = torch.argmax(nn.functional.softmax(outputs, dim=1), axis=1).tolist()
        y_pred_train.extend(y_pred)
        # print(f'\tBatch train loss: {batch_loss}; Batch train accuracy: {round(accuracy_score(y_true, y_pred), 2)}')
    epoch_train_loss = epoch_train_loss / len(train_dataloader)
    
    # Validation
    y_true_val = []
    y_pred_val = []
    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        # for batch in tqdm(val_dataloader, total=len(val_dataloader), desc='Val batch', leave=False):
        for batch in val_dataloader:
            output = model(batch['batch'].to(device))
            loss = losser(output, batch['labels'].to(device))
            batch_val_loss = loss.item()
            epoch_val_loss += batch_val_loss
            y_true = batch['labels'].tolist()
            y_true_val.extend(y_true)
            y_pred = torch.argmax(nn.functional.softmax(output, dim=1), axis=1).tolist()
            y_pred_val.extend(y_pred)
            # print(f'\tBatch val loss: {batch_val_loss}; Batch val accuracy: {round(accuracy_score(y_true, y_pred), 2)}')

    train_acc = round(accuracy_score(y_true_train, y_pred_train), 2)
    val_acc = round(accuracy_score(y_true_val, y_pred_val), 2)
    epoch_val_loss = epoch_val_loss / len(val_dataloader)
    print(f'Epoch {i}: Train loss: {round(epoch_train_loss, 3)}; Train acc: {train_acc}; Val loss: {round(epoch_val_loss, 3)}; Val acc: {val_acc}')

Epoch:   1%|█▋                                                                                                                                                                    | 1/100 [03:48<6:17:14, 228.64s/it]

Epoch 0: Train loss: 2.282; Train acc: 0.32; Val loss: 2.337; Val acc: 0.27


Epoch:   2%|███▎                                                                                                                                                                  | 2/100 [07:37<6:13:25, 228.63s/it]

Epoch 1: Train loss: 1.958; Train acc: 0.43; Val loss: 2.231; Val acc: 0.25


Epoch:   3%|████▉                                                                                                                                                                 | 3/100 [11:26<6:09:48, 228.74s/it]

Epoch 2: Train loss: 1.714; Train acc: 0.48; Val loss: 2.098; Val acc: 0.35


Epoch:   4%|██████▋                                                                                                                                                               | 4/100 [15:16<6:07:14, 229.53s/it]

Epoch 3: Train loss: 1.494; Train acc: 0.53; Val loss: 1.932; Val acc: 0.36


Epoch:   5%|████████▎                                                                                                                                                             | 5/100 [19:12<6:06:38, 231.56s/it]

Epoch 4: Train loss: 1.291; Train acc: 0.61; Val loss: 1.526; Val acc: 0.44


Epoch:   6%|█████████▉                                                                                                                                                            | 6/100 [23:08<6:05:13, 233.12s/it]

Epoch 5: Train loss: 1.112; Train acc: 0.66; Val loss: 1.59; Val acc: 0.44


Epoch:   7%|███████████▌                                                                                                                                                          | 7/100 [27:03<6:02:34, 233.92s/it]

Epoch 6: Train loss: 0.971; Train acc: 0.7; Val loss: 1.428; Val acc: 0.55


Epoch:   8%|█████████████▎                                                                                                                                                        | 8/100 [30:55<5:57:50, 233.37s/it]

Epoch 7: Train loss: 0.812; Train acc: 0.76; Val loss: 1.667; Val acc: 0.41


Epoch:   9%|██████████████▉                                                                                                                                                       | 9/100 [34:53<5:56:00, 234.74s/it]

Epoch 8: Train loss: 0.707; Train acc: 0.8; Val loss: 1.453; Val acc: 0.57


Epoch:  10%|████████████████▌                                                                                                                                                    | 10/100 [38:49<5:52:26, 234.96s/it]

Epoch 9: Train loss: 0.595; Train acc: 0.83; Val loss: 1.3; Val acc: 0.6


Epoch:  11%|██████████████████▏                                                                                                                                                  | 11/100 [42:42<5:47:36, 234.34s/it]

Epoch 10: Train loss: 0.467; Train acc: 0.89; Val loss: 1.15; Val acc: 0.65


Epoch:  12%|███████████████████▊                                                                                                                                                 | 12/100 [46:34<5:42:58, 233.84s/it]

Epoch 11: Train loss: 0.368; Train acc: 0.93; Val loss: 1.492; Val acc: 0.54


Epoch:  13%|█████████████████████▍                                                                                                                                               | 13/100 [50:27<5:38:29, 233.45s/it]

Epoch 12: Train loss: 0.289; Train acc: 0.95; Val loss: 1.137; Val acc: 0.66


Epoch:  14%|███████████████████████                                                                                                                                              | 14/100 [54:24<5:36:18, 234.63s/it]

Epoch 13: Train loss: 0.222; Train acc: 0.97; Val loss: 0.978; Val acc: 0.69


Epoch:  15%|████████████████████████▊                                                                                                                                            | 15/100 [58:12<5:29:30, 232.60s/it]

Epoch 14: Train loss: 0.167; Train acc: 0.98; Val loss: 1.302; Val acc: 0.57


Epoch:  16%|██████████████████████████                                                                                                                                         | 16/100 [1:01:54<5:21:03, 229.33s/it]

Epoch 15: Train loss: 0.127; Train acc: 0.99; Val loss: 1.933; Val acc: 0.5


Epoch:  17%|███████████████████████████▋                                                                                                                                       | 17/100 [1:05:35<5:13:57, 226.95s/it]

Epoch 16: Train loss: 0.105; Train acc: 0.99; Val loss: 1.512; Val acc: 0.55


Epoch:  18%|█████████████████████████████▎                                                                                                                                     | 18/100 [1:09:22<5:10:05, 226.89s/it]

Epoch 17: Train loss: 0.102; Train acc: 0.99; Val loss: 1.528; Val acc: 0.54


Epoch:  19%|██████████████████████████████▉                                                                                                                                    | 19/100 [1:13:05<5:04:35, 225.62s/it]

Epoch 18: Train loss: 0.094; Train acc: 0.99; Val loss: 1.309; Val acc: 0.63


Epoch:  20%|████████████████████████████████▌                                                                                                                                  | 20/100 [1:16:50<5:00:39, 225.50s/it]

Epoch 19: Train loss: 0.097; Train acc: 0.99; Val loss: 1.157; Val acc: 0.66


Epoch:  21%|██████████████████████████████████▏                                                                                                                                | 21/100 [1:20:36<4:57:08, 225.68s/it]

Epoch 20: Train loss: 0.069; Train acc: 1.0; Val loss: 1.112; Val acc: 0.7


Epoch:  22%|███████████████████████████████████▊                                                                                                                               | 22/100 [1:24:26<4:55:02, 226.95s/it]

Epoch 21: Train loss: 0.066; Train acc: 0.99; Val loss: 1.198; Val acc: 0.63


Epoch:  23%|█████████████████████████████████████▍                                                                                                                             | 23/100 [1:28:23<4:55:07, 229.96s/it]

Epoch 22: Train loss: 0.061; Train acc: 0.99; Val loss: 2.033; Val acc: 0.51


Epoch:  24%|███████████████████████████████████████                                                                                                                            | 24/100 [1:32:21<4:54:26, 232.46s/it]

Epoch 23: Train loss: 0.058; Train acc: 0.99; Val loss: 1.45; Val acc: 0.62


Epoch:  25%|████████████████████████████████████████▊                                                                                                                          | 25/100 [1:36:17<4:51:43, 233.38s/it]

Epoch 24: Train loss: 0.039; Train acc: 1.0; Val loss: 1.356; Val acc: 0.63


Epoch:  26%|██████████████████████████████████████████▍                                                                                                                        | 26/100 [1:40:13<4:48:55, 234.26s/it]

Epoch 25: Train loss: 0.023; Train acc: 1.0; Val loss: 0.942; Val acc: 0.72


Epoch:  26%|██████████████████████████████████████████▍                                                                                                                        | 26/100 [1:41:24<4:48:38, 234.03s/it]


KeyboardInterrupt: 