In [1]:
import torch
import torchaudio
import torchaudio.transforms as T
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
import torch.nn as nn
import hf00
import os
import random
from IPython.display import Audio, display
import numpy as np
import librosa
import matplotlib.pyplot as plt

In [2]:
musan_dir = '../audioData/MUSAN/MUSAN/musan/'
speech_dir = os.listdir(musan_dir+'speech')
speech_files = []

for i in speech_dir:
    spech_dir = os.path.join((musan_dir+'speech'), i)
    if os.path.isdir(spech_dir):
        speech_files.extend(os.listdir(spech_dir))

In [6]:
#speech_files.sort()
speech_files.pop(0)

'LICENSE'

In [7]:
for i in range(0, len(speech_files)):
    folder = speech_files[i].split('-')
    #print(folder[1])
    if folder[1] == 'librivox':
        fname = musan_dir+'speech/librivox/'+speech_files[i]
        speech_files[i] = fname
    else:
        fname = musan_dir+'speech/us-gov/'+speech_files[i]
        speech_files[i] = fname

len(speech_files)

426

In [9]:
audio, noise = hf00.get_random_audio_sec(random.choice(speech_files))

(tensor(-0.0003), tensor(-0.0003))

In [16]:
wav, sr = torchaudio.load('../../LibriVox_Kaggle/achtgesichterambiwasse/achtgesichterambiwasse_0007.wav')

encoder = torchaudio.models.wav2vec2_base()

feats, _ = encoder.feature_extractor(audio, 2)
feats1, _ = encoder.extract_features(audio)
feats2, _ = encoder.extract_features(wav)

feats[0].dtype, feats1[0].shape, feats2[0].shape, len(feats2)

(torch.float32, torch.Size([1, 99, 768]), torch.Size([1, 607, 768]), 12)

In [120]:
class audioDataset(Dataset):

    def __init__(self, speech_list):
        #self.audio_df = pd.read_csv(audio_csvfile)
        self.speech_list = speech_list
        #self.audio_dir = audio_dir

    def __len__(self):
        return len(self.speech_list)
    
    def __getitem__(self, index):
    
        audio_path = self.speech_list[index]
        audio, label = hf00.get_random_audio_sec(audio_path)
        
        audio = (audio - audio.min())/(audio.max() - audio.min())
        label = (label - label.min())/(label.max() - label.min())

        audio_feats, _ = encoder.feature_extractor(audio, 2)
        audio_feats = audio_feats.squeeze().reshape(512,99)
        return audio_feats, label

In [121]:
train_dataset = audioDataset(speech_files)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [122]:
torch.manual_seed(13)
torch.cuda.manual_seed(13)

class decoder00(nn.Module):

    def __init__(self):
        super(decoder00, self).__init__()

        self.decode = nn.Sequential(
            nn.ConvTranspose1d(512, 256, kernel_size=10, stride=5, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(256, 128, kernel_size=7, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(128, 64, kernel_size=5, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(64, 32, kernel_size=5, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(32, 16, kernel_size=5, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(16, 8, kernel_size=5, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose1d(8, 1, kernel_size=6, stride=2, padding=1),
            nn.ReLU(),
            
        )

    def forward(self, x):
        x = self.decode(x)

        return x

In [123]:
torch.manual_seed(13)
torch.cuda.manual_seed(13)

model = decoder00()
#model.load_state_dict(torch.load('models/model00_conf03.pt'))
device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

model = model.to(device)

loss_fn = nn.L1Loss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [124]:
import statistics
ll_plot = []
model = model.to(device)
for i in range(0,1):
    epochs = 10

    for epoch in range(0,epochs):
        loss_list = []
        for data in train_dataloader:
        
            model.train()
            inputs, labels = data
            inputs = inputs.to(device, dtype=torch.float)
            labels = labels.to(device, dtype=torch.float)
            #print(inputs.dtype, labels.dtype)
            # Forward pass
            outputs = model(inputs)
        
            # Compute loss
            loss = loss_fn(outputs, labels)
            #loss = loss_fn(outputs, inputs)

            # BP and optim
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loss_list.append(loss.item())
        
        ll_plot.append(statistics.mean(loss_list))

        print(f"Round: {i} Epoch [{epoch + 1}/{epochs}] Loss: {statistics.mean(loss_list)}")

Round: 0 Epoch [1/10] Loss: 0.5032425920168558
Round: 0 Epoch [2/10] Loss: 0.5023684159473136
Round: 0 Epoch [3/10] Loss: 0.5042374321707973
Round: 0 Epoch [4/10] Loss: 0.5034969702914909
Round: 0 Epoch [5/10] Loss: 0.5004061802669808
Round: 0 Epoch [6/10] Loss: 0.5003851663183283
Round: 0 Epoch [7/10] Loss: 0.5037527581055959
Round: 0 Epoch [8/10] Loss: 0.5019526216718886
Round: 0 Epoch [9/10] Loss: 0.5030947669788643
Round: 0 Epoch [10/10] Loss: 0.5021384457747141


In [129]:
sample = random.choice(speech_files)

audio, noise = hf00.get_random_audio_sec(sample)

audio = (audio - audio.min()) / (audio.max() - audio.min())
audio_feats, _ = encoder.feature_extractor(audio, 2)
audio_feats = audio_feats.squeeze().reshape(512,99)

model = model.cpu()
model.eval()
with torch.inference_mode():
    y_preds = model(audio_feats)

display(Audio(noise, rate=16000))
display(Audio(y_preds, rate=16000))

In [130]:
y_preds, noise

(tensor([[0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([[-0.0253, -0.0289, -0.0259,  ..., -0.0128, -0.0114, -0.0110]]))

In [111]:
noise = (noise - noise.min()) / (noise.max() - noise.min())
display(Audio(noise, rate=16000))


In [4]:
bundle = torchaudio.pipelines.WAV2VEC2_BASE

model = bundle.get_model()

In [5]:
aud = '../audioData/MUSAN/MUSAN/musan/speech/librivox/speech-librivox-0004.wav'

wav, sr = torchaudio.load(aud)

feats, _ = model.extract_features(wav[:,0:2*sr])

In [9]:
feats[3].shape, type(feats[0]), wav.shape[1]/sr, len(feats)

(torch.Size([1, 99, 768]), torch.Tensor, 214.047375, 12)

In [24]:
vv, _ = model2.feature_extractor(wav[:,0:2*sr], 2)

In [35]:
vv.shape, _, wav.shape

(torch.Size([1, 99, 512]), tensor(0), torch.Size([1, 3424758]))

In [80]:
reshape = vv.squeeze().reshape(512,99)

dec = nn.Sequential(
    nn.ConvTranspose1d(512, 256, kernel_size=10, stride=5, padding=1),
    nn.ConvTranspose1d(256, 128, kernel_size=7, stride=2, padding=1),
    nn.ConvTranspose1d(128, 64, kernel_size=5, stride=2, padding=1),
    nn.ConvTranspose1d(64, 32, kernel_size=5, stride=2, padding=1),
    nn.ConvTranspose1d(32, 16, kernel_size=5, stride=2, padding=1),
    nn.ConvTranspose1d(16, 8, kernel_size=5, stride=2, padding=1),
    nn.ConvTranspose1d(8, 1, kernel_size=6, stride=2, padding=1)
)

dec(reshape).shape, vv.shape, wav[:,0:2*sr].shape

(torch.Size([1, 32000]), torch.Size([1, 99, 512]), torch.Size([1, 32000]))