In [1]:
import torch
import torchaudio
import numpy as np
import os
import tqdm as tqdm
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset,DataLoader
from transformers import Wav2Vec2Model, Wav2Vec2PreTrainedModel, AutoConfig

In [2]:
if torch.cuda.is_available():
    device='cuda'
else:
    device='cpu'

In [3]:
audio_path = './speech_res/'
name_set = set()
for file in os.listdir(audio_path):
    if file.endswith('mp3'):
        name_set.add(file)

In [6]:
name_set

{'diaid0_uid21_iid246_20-30_women_269.mp3',
 'diaid1627_uid46_iid167_over 30_women_318.mp3',
 'diaid5072_uid53_iid193_over 30_women_318.mp3'}

In [13]:
datalist = list(name_set)
audio_name = datalist[2]
name_list = audio_name.split('_')
age_labels = {'under 20':0, '20-30':1, 'over 30':2}
gender_labels = {'women':0, 'men':1}
age = age_labels[name_list[-3]]
gender = gender_labels[name_list[-2]]

In [16]:
float(gender)

0.0

In [4]:
class AudioDataset(Dataset):
    def __init__(self,audio_path, target_sample_rate, device, transformation=None):
        name_set = set()
        for file in os.listdir(audio_path):
            if file.endswith('mp3'):
                name_set.add(file)
        self.datalist = list(name_set)
        self.audio_path = audio_path
        self.device = device
        self.target_sample_rate = target_sample_rate
        self.transformation = None
        if transformation:
            self.transformation = transformation.to(device)
            
    def __len__(self):
        return len(self.datalist)
    
    def __getitem__(self,idx):
        audio_file_path = os.path.join(self.audio_path, self.datalist[idx])
        audio_name = self.datalist[idx]
        age_label, gender_label = self._get_label(audio_name)
        waveform, sample_rate = torchaudio.load(audio_file_path)
        waveform = waveform.to(self.device)
        if sample_rate != self.target_sample_rate:
            waveform = self._resample(waveform, sample_rate)
            
        return waveform, age_label, gender_label
        
    
    def _get_label(self, audio_name):
        name_list = audio_name.split('_')
        age_labels = {'under 20':0, '20-30':1, 'over 30':2}
        gender_labels = {'women':0, 'men':1}
        
        age = age_labels[name_list[-3]]
        gender = gender_labels[name_list[-2]]
        
        return age, gender
    
    def _resample(self, waveform, sample_rate):
        resampler = torchaudio.transforms.Resample(sample_rate,self.target_sample_rate)
        
        return resampler(waveform)

In [5]:
train_set = AudioDataset('./speech_res/', 16000, device)

In [100]:
train_set[0][0].size()

torch.Size([259321])

In [6]:
def pad_sequence(batch):
    # Make all tensor in a batch the same length by padding with zeros
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.)
    return batch.permute(0, 2, 1)

def collate_fn(batch):

    audios, age_labels, gender_labels = [], [], []

    for waveform, age, gender in batch:
        audios += [waveform]
        age_labels += [torch.tensor(age)]
        gender_labels += [torch.tensor(gender)]

    audios = pad_sequence(audios)
    age_labels = torch.stack(age_labels)
    gender_labels = torch.stack(gender_labels)

    return audios.squeeze(dim=1), age_labels, gender_labels

if device == "cuda":
    num_workers = 1
    pin_memory = True
else:
    num_workers = 0
    pin_memory = False

batch_size = 2

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=num_workers,
    pin_memory=pin_memory,
)

In [12]:
a = iter(train_loader)

In [13]:
train_data = next(a)

In [15]:
train_data[1].size()

torch.Size([2])

In [16]:
len(train_loader.dataset)

3

In [120]:
class Wav2Vec2ClassificationModel(Wav2Vec2PreTrainedModel):
    def __init__(self, config, hidden_size, dropout):
        super().__init__(config)
        
        self.wav2vec2 = Wav2Vec2Model(config)
        self.hidden_size = hidden_size
        self.fc = nn.Linear(config.hidden_size, self.hidden_size)
        self.dropout = nn.Dropout(dropout)
        self.gender_fc = nn.Linear(self.hidden_size, 2)
        self.age_fc = nn.Linear(self.hidden_size, 3)
        self.tanh = nn.Tanh()
        
        self.init_weights()
        
    def freeze_feature_extractor(self):
        self.wav2vec2.feature_extractor._freeze_parameters()
    
    def merged_strategy(self, hidden_states):
        outputs = torch.mean(hidden_states, dim=1)
        
        return outputs
    
    def forward(
        self,
        input_values,
        attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        labels=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        
        hidden_states = outputs[0]
        x = self.merged_strategy(hidden_states)
        x = self.dropout(x)
        x = self.fc(x)
        x = self.tanh(x)
        x = self.dropout(x)
        gender_logits = self.gender_fc(x)
        gender_logits = F.log_softmax(gender_logits, dim=-1)
        age_logits = self.age_fc(x)
        age_logits = F.log_softmax(age_logits, dim=-1)
        
        return age_logits, gender_logits

In [79]:
model_name_or_path = "lighteternal/wav2vec2-large-xlsr-53-greek"
config = AutoConfig.from_pretrained(
    model_name_or_path,
    finetuning_task="wav2vec2_clf",
)

Downloading:   0%|          | 0.00/1.56k [00:00<?, ?B/s]



In [125]:
clf_model = Wav2Vec2ClassificationModel.from_pretrained(
    model_name_or_path,
    config=config,
    hidden_size=1024,
    dropout=0.1,
)

Some weights of the model checkpoint at lighteternal/wav2vec2-large-xlsr-53-greek were not used when initializing Wav2Vec2ClassificationModel: ['lm_head.weight', 'lm_head.bias']
- This IS expected if you are initializing Wav2Vec2ClassificationModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2ClassificationModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of Wav2Vec2ClassificationModel were not initialized from the model checkpoint at lighteternal/wav2vec2-large-xlsr-53-greek and are newly initialized: ['gender_fc.bias', 'fc.bias', 'fc.weight', 'age_fc.weight', 'age_fc.bias', 'gender_fc.weight']
You should probably TRAIN this model on a down-stream task to be able to us

In [83]:
print(model)

Wav2Vec2ClassificationModel(
  (wav2vec2): Wav2Vec2Model(
    (feature_extractor): Wav2Vec2FeatureEncoder(
      (conv_layers): ModuleList(
        (0): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(1, 512, kernel_size=(10,), stride=(5,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (1): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (2): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
          (activation): GELUActivation()
        )
        (3): Wav2Vec2LayerNormConvLayer(
          (conv): Conv1d(512, 512, kernel_size=(3,), stride=(2,))
          (layer_norm): LayerNorm((512,), eps=

In [123]:
def train_single_epoch(model, dataloader, optimizer, device):
    model.train()
    for waveform, age_label, gender_label in tqdm.tqdm(dataloader):
        waveform = waveform.to(device)
        age_label = age_label.to(device)
        gender_label = gender_label.to(device)
        
        age_logits, gender_logits = model(waveform)
        
        age_loss = F.nll_loss(age_logits, age_label)
        gender_loss = F.nll_loss(gender_logits, gender_label)
        loss = age_loss + gender_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"loss:{loss.item()}")

In [124]:
def train(model, dataloader, optimizer, device, epochs):
    for i in tqdm.tqdm(range(epochs)):
        print(f"epoch:{i+1}")
        train_single_epoch(model, dataloader, optimizer, device)
        print('-------------------------------------------')
    print('Finished Training')

In [129]:
clf_model.freeze_feature_extractor()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

In [135]:
epochs=1
train(clf_model, train_loader, optimizer, device, epochs)

  0%|                                                     | 0/1 [00:00<?, ?it/s]

epoch:1



  0%|                                                     | 0/2 [00:00<?, ?it/s][A
 50%|██████████████████████▌                      | 1/2 [01:31<01:31, 91.69s/it][A
100%|█████████████████████████████████████████████| 2/2 [02:12<00:00, 66.35s/it][A
100%|████████████████████████████████████████████| 1/1 [02:12<00:00, 132.71s/it]

loss:1.8445322513580322
-------------------------------------------
Finished Training





In [130]:
device

'cpu'