In [None]:
import torch, torchaudio

In [None]:
from collections import OrderedDict

def extract_prefix(prefix, weights):
    result = OrderedDict()
    for key in weights:
        if key.find(prefix) == 0:
            result[key[len(prefix):]] = weights[key]
    return result     

In [None]:
from transformers import Wav2Vec2Model, Wav2Vec2Processor

class Wav2Vec2ConvEncoder:

    def __init__(self, device="cuda"):
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").feature_extractor
        self.encoder.eval()
        self.encoder = self.encoder.to(device)
        self.preprocessor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
        self.preprocessor._sample_rate = 16000
        self.device = device

    def __call__(self, x):
        # x - [bs, 1, time]
        x = x[:, 0]
        input_values = (x - x.mean(-1)[:, None]) / (x.std(-1)[:, None] + 1e-6)
        hidden_states = self.encoder(input_values.to(self.device))
        return hidden_states

In [None]:
class Wav2Vec2FullEncoder:

    def __init__(self, device="cuda"):
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self.encoder.eval()
        self.encoder = self.encoder.to(device)
        self.preprocessor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
        self.preprocessor._sample_rate = 16000
        self.device = device

    def __call__(self, x):
        # x - [bs, 1, time]
        x = x[:, 0]
        input_values = (x - x.mean(-1)[:, None]) / (x.std(-1)[:, None] + 1e-6)
        hidden_states = self.encoder(input_values.to(self.device)).last_hidden_state
        return hidden_states.transpose(-2, -1)

In [None]:
from torch import nn
import numpy as np
import librosa
import glob

class Wav2Vec2MOS(nn.Module):
    #def __init__(self, path, freeze=True, cuda=True):
    def __init__(self, freeze=True, cuda=True):
        super().__init__()
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        self.freeze = freeze
        
        self.dense = nn.Sequential(
            nn.Linear(768, 128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 1)
        )
        
        #if freeze == True:
        self.encoder.eval()
        '''
            for p in self.encoder.parameters():
                p.requires_grad_(False)
        self.load_state_dict(extract_prefix('model.', torch.load(path)['state_dict']))
        ''' 
        self.eval()
        self.cuda_flag = cuda
        if cuda:
            self.cuda()
        self.processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")
        
    def forward(self, x):
        x = self.encoder(x)['last_hidden_state'] # [Batch, time, feats]
        x = self.dense(x) # [batch, time, 1]
        x = x.mean(dim=[1,2], keepdims=True) # [batch, 1, 1]
        return x
                
    def train(self, mode):
        super().train(mode)
        if self.freeze == True:
            self.encoder.eval()
            
    def calculate_dir(self, path, mean=True):
        
        pred_mos = []
        for path in tqdm.tqdm(sorted(glob.glob(f"{path}/*.wav"))):
            signal = librosa.load(path, sr=16_000)[0]
            x = self.processor(signal, return_tensors="pt", padding=True, sampling_rate=16000).input_values
            if self.cuda_flag:
                x = x.cuda()
            with torch.no_grad():
                res = self.forward(x).mean()
            pred_mos.append(res.item())
        if mean:
            return np.mean(pred_mos)
        else:
            return pred_mos
        
    def calculate_one(self, path):
        signal = librosa.load(path, sr=16_000)[0]
        x = self.processor(signal, return_tensors="pt", padding=True, sampling_rate=16000).input_values
        with torch.no_grad():
            if self.cuda_flag:
                x = x.cuda()
            res = self.forward(x).mean()
        return res.cpu().item()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cuda = torch.cuda.is_available()

print(device)
print(cuda)

In [None]:
model = Wav2Vec2MOS(cuda=cuda)

In [None]:
import torchaudio

In [None]:
filepath = "/home/fred/Projetos/DATASETS/MOS/BRSPEECH_MOS_DATASET_old/wavs/2959_2564_000001.wav"

data,sr = torchaudio.load(filepath)

In [None]:
data = data.cuda()

In [None]:
output = model(data)

In [None]:
output

In [None]:
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence 

class VCC2018WavDatasetLoad(Dataset):
    def __init__(self, audio_names: list, labels: list):
        self.audio_names = audio_names
        self.labels = labels

        #self.label_to_id = dict((mos,id) for id, mos in enumerate(labels))
        
    def __len__(self):
        return len(self.audio_names)

    def __getitem__(self, idx):
        filename = self.audio_names[idx]
        #waveform, sample_rate = torchaudio.load(filename)
        waveform, sample_rate = torchaudio.load(filename)
        #waveform = torch.from_numpy(waveform)
        #target = self.label_to_id[self.labels[idx]]
        target = self.labels[idx]
        
        return {"data": waveform, "target": target}


def custom_wav_collate_fn(data):
    """
       data: is a list of tuples with (example, label, length)
             where 'example' is a tensor of arbitrary shape
             and label/length are scalars
    """
    features = [torch.tensor(d['data']) for d in data] #(3)
    labels = torch.tensor([d['target']  for d in data]) 
    new_features = pad_sequence([f.T for f in features], batch_first=True).squeeze()
    new_features
    return  {
        'data': new_features.to(device),
        'target': labels.to(device)
    }

In [None]:

def train(model, iterator, optimizer, criterion, scheduler, epoch=0):

    model.train(mode=True)        
    epoch_loss = 0
    total_steps = len(iterator)
    for i, batch in enumerate(iterator):

        data = batch['data'].to(device, dtype=torch.float32)
        labels = batch['target'].to(device, dtype=torch.float32)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        epoch_loss += loss.item()

        if i % ( int(total_steps / 100)) == 0:
          print("Train step {0} / {1}  loss: {2:.5f}".format(i, total_steps, loss.item()))


    return epoch_loss

In [None]:

def evaluate(model, iterator, criterion, epoch):

    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(iterator):

            data = batch['data'].to(device, dtype=torch.float32)
            batch_size = data.shape[0]
            labels = batch['target'].to(device, dtype=torch.float32)
            output = model(data)
            loss = criterion(output, labels)
            
            epoch_loss += loss.item()
            
            print("Test Epoch {0}   loss: {2:.5f}".format(epoch, loss.item()))

    return epoch_loss

In [None]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs


In [None]:
import pandas as pd

metadata_train = "/home/fred/Projetos/MOS/custom_mos_prediction/data/train.csv"
metadata_test = "/home/fred/Projetos/MOS/custom_mos_prediction/data/test.csv"

train_data = pd.read_csv(metadata_train)
test_data = pd.read_csv(metadata_test)

scores_train = train_data['score'].to_list()
audio_names_train = train_data['filepath'].to_list()

scores_test = test_data['score'].to_list()
audio_names_test = test_data['filepath'].to_list()

dataset_train = VCC2018WavDatasetLoad(audio_names_train, scores_train)
loader_train = DataLoader(dataset_train, batch_size=1, shuffle=True, collate_fn=custom_wav_collate_fn)

dataset_test = VCC2018WavDatasetLoad(audio_names_test, scores_test)
loader_test = DataLoader(dataset_test, batch_size=1, shuffle=True, collate_fn=custom_wav_collate_fn)

In [None]:
for data in dataset_train:
    print(data['data'].shape)
    output = model(data['data'].to(device))
    print(output.shape)
    break

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion =  nn.MSELoss()

lambda2 = lambda epoch: epoch * 0.95
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,  lr_lambda=[lambda2])

train_losses, val_losses = [], []
train_accuracies, val_accuracies = [], []
train_f1s, val_f1s = [], []

best_valid_loss = float('inf')

In [None]:
import time

N_EPOCHS = 10

for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss = train(model, loader_train, optimizer, criterion, scheduler, epoch)
    val_loss = evaluate(model, loader_test, criterion, epoch)   
    end_time = time.time()

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    torch.save(model.state_dict(), 'best-val-model.pt')

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins} m {epoch_secs} s')
    print(f'\tTrain Loss: {train_loss}')
    print(f'\t Val. Loss: {val_loss}')


# Extract Embeddings

In [9]:
import torch, torchaudio

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cuda = torch.cuda.is_available()


In [11]:
from transformers import Wav2Vec2Model, Wav2Vec2Processor

model = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base").feature_extractor
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base")

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_q.weight', 'project_hid.weight', 'project_hid.bias', 'quantizer.weight_proj.weight', 'quantizer.codevectors', 'project_q.bias', 'quantizer.weight_proj.bias']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
filepath = "/home/fred/Projetos/DATASETS/MOS/BRSPEECH_MOS_DATASET_old/wavs/2959_2564_000001.wav"

audio_data, sr = torchaudio.load(filepath)

In [None]:
output = model(audio_data)

In [None]:
output.shape

In [12]:
from tqdm import tqdm
import os

metadata_filepath = '/home/fred/Projetos/DATASETS/MOS/VCC2018_MOS_preprocessed/mos_list.txt'
wavs_filepath = '/home/fred/Projetos/DATASETS/MOS/VCC2018_MOS_preprocessed/wav'
with open(metadata_filepath, encoding="utf-8") as f:
  content_file = f.readlines()

output_dir = "./VCC2018_wav2vec_embeddings"
os.makedirs(output_dir, exist_ok=True)

for line in tqdm(content_file):
    #filepath, mos, condition, database = line.split(',')
    filepath, mos = line.split(',')
    filename = os.path.basename(filepath)
    filepath = os.path.join(wavs_filepath, filepath)
    if not os.path.exists(filepath):
      continue
    audio_data, sr = torchaudio.load(filepath)
    audio_data = audio_data.to(device).squeeze()

    audio_data = processor(audio_data, return_tensors="pt", padding=True, sampling_rate=16000).input_values

    # Extract Embedding    
    file_embedding = model(audio_data)#['last_hidden_state']
    
    # Saving embedding
    output_filename = filename.split(".")[0] + ".pt"
    output_filepath = os.path.join(output_dir, output_filename)

    torch.save(file_embedding, output_filepath)  

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20580/20580 [15:53<00:00, 21.59it/s]


In [14]:
model = model.to(device)

# Extract Embeddings from BRSpeech

In [16]:
metadata_filepath = '/home/fred/Projetos/DATASETS/MOS/BRSPEECH_MOS_DATASET/metadata.csv'
wavs_filepath = "/home/fred/Projetos/DATASETS/MOS/BRSPEECH_MOS_DATASET/"
with open(metadata_filepath, encoding="utf-8") as f:
  content_file = f.readlines()

output_dir = "./brspeech_mos_wav2vec_embeddings"

for line in tqdm(content_file):
    filepath, score, condition, database = line.split(",")
    filename = os.path.basename(filepath)
    complete_filepath = os.path.join(wavs_filepath, filepath)
    if not os.path.exists(complete_filepath):
      continue
    audio_data, sr = torchaudio.load(complete_filepath)
    audio_data = audio_data.to(device)
    # Extract Embedding    
    file_embedding = model(audio_data)#['last_hidden_state']
    
    # Saving embedding
    output_filename = filename.split(".")[0] + ".pt"
    output_filepath = os.path.join(output_dir, os.path.dirname(filepath), output_filename)

    os.makedirs(os.path.dirname(output_filepath), exist_ok=True)
    torch.save(file_embedding, output_filepath)  

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 2428/2428 [00:23<00:00, 104.29it/s]
