In [1]:
%%capture
!pip install torchaudio
!pip install librosa
!pip install jiwer
!pip install ffmpeg-python

In [5]:
import json
import logging
import os
import re
import sys

import torch
import torchaudio
from torch import nn
from torch.nn import functional as F

import transformers
from transformers import (
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor,
    Wav2Vec2Model,
    Wav2Vec2PreTrainedModel,
    Wav2Vec2Processor,
)

class CustomClassificationModel(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        
        self.wav2vec2 = Wav2Vec2Model(config)
        
        self.inner_dim = 512
        self.feature_size = 999
        
        self.tanh = nn.Tanh()
        self.linear1 = nn.Linear(1024, self.inner_dim)
        self.linear2 = nn.Linear(self.inner_dim*self.feature_size, 5)
#         self.linear3 = nn.Linear(256, 5) 
        self.init_weights()
        
    def freeze_feature_extractor(self):
        self.wav2vec2.feature_extractor._freeze_parameters()

    def forward(
        self,
        input_values,
        attention_mask=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        outputs = self.wav2vec2(
            input_values['input_values'],
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        x = self.linear1(outputs[0]) 
        x = self.tanh(x)
        x = self.linear2(x.view(-1, self.inner_dim*self.feature_size))
#         x = self.tanh(x)
#         x = self.linear3(x) # outputshape torch.Size([32, 566, 5])
        
        return x

In [6]:
from transformers import (Wav2Vec2ForCTC,Wav2Vec2Processor)

base_dir = '/workspace/model_dir/dialects-20s/'
model_dir = base_dir+'checkpoint-4400'

model = CustomClassificationModel.from_pretrained(model_dir).to("cuda")
processor = Wav2Vec2Processor.from_pretrained(base_dir)

Special tokens have been added in the vocabulary, make sure the associated word embedding are fine-tuned or trained.


In [8]:
#@title Process Audio
## code created by Eric Lam
from IPython.display import HTML, Audio
from base64 import b64decode
import wave
from scipy.io.wavfile import read as wav_read
import io
import numpy as np
import ffmpeg
import soundfile as sf
import torch
import re
import sys


resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)

def load_file_to_data(file):
    batch = {}
    start = 0 
    stop = 20 
    srate = 16_000
    speech, sampling_rate = sf.read(file, start = start * srate , stop = stop * srate)
    batch["speech"] = speech
    batch["sampling_rate"] = 16_000
    return batch


def predict(data):
    features = processor(data["speech"], 
                        sampling_rate=data["sampling_rate"],
                        max_length=320000,
                        pad_to_multiple_of=320000,
                        padding=True, return_tensors="pt")
    
    input_values = {'input_values':features.input_values.to("cuda")}
    attention_mask = features.attention_mask.to("cuda")
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask)
    pred_id = torch.argmax(logits, dim=-1)[0]
    return pred_id

In [56]:
import IPython
import random
import glob

files = glob.glob('/workspace/data_dir/dev/wav/**/*.wav')
print(len(files))
idx = random.randint(0, len(files))
file = files[idx]
names=['EGY','NOR','GLF','LAV','MSA']
pred_id = predict(load_file_to_data(file))
print('TRUE: ',file.split('/')[5])
print('PRED: ',names[pred_id])
IPython.display.Audio(file)

1566
TRUE:  LAV
PRED:  LAV
