In [8]:
import librosa
import librosa.display
import torch
import sys
import glob
import time
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import optim
from pytorch_model_summary import summary
import shutil

sys.path.insert(1, '../utils')
sys.path.insert(1, '../models')

from data import StutterData, load_data
from audioCNN import AudioCNN


In [9]:
model = AudioCNN()

In [10]:
print(summary(AudioCNN(), torch.zeros((1, 1, 20, 50)), show_input=True))

---------------------------------------------------------------------------
          Layer (type)         Input Shape         Param #     Tr. Param #
   AdaptiveAvgPool2d-1      [1, 1, 20, 50]               0               0
              Conv2d-2      [1, 1, 32, 32]              80              80
   AdaptiveAvgPool2d-3      [1, 8, 30, 30]               0               0
             Dropout-4      [1, 8, 16, 16]               0               0
              Conv2d-5      [1, 8, 16, 16]           1,168           1,168
   AdaptiveAvgPool2d-6     [1, 16, 14, 14]               0               0
              Linear-7           [1, 1024]         524,800         524,800
              Linear-8            [1, 512]         131,328         131,328
              Linear-9            [1, 256]             257             257
Total params: 657,633
Trainable params: 657,633
Non-trainable params: 0
---------------------------------------------------------------------------


In [11]:
model.forward(torch.zeros((1, 1, 20, 50)))  # Test forward pass

tensor([[-0.0079]], grad_fn=<AddmmBackward>)

In [123]:
lr = 0.001
batch_size = 1
epochs = 25
validation_split=0.2
shuffle_dataset=True
random_seed=42
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")     #Check whether a GPU is present.

# optimizer = optim.SGD(model.parameters(), lr = 0.00001, momentum=0.9, weight_decay=5e-4)
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 200], gamma=0.1)

criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr = lr)

In [124]:
dataset = StutterData('../data/*')
train_loader, validation_loader = load_data(dataset, batch_size, validation_split=0.2, shuffle_dataset=True, random_seed=42)

In [125]:
model.to(device)

AudioCNN(
  (maxpool): AdaptiveAvgPool2d(output_size=32)
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1))
  (maxpool1): AdaptiveAvgPool2d(output_size=16)
  (dropout): Dropout(p=0.2, inplace=False)
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
  (maxpool2): AdaptiveAvgPool2d(output_size=8)
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=1, bias=True)
)

In [24]:
def save_ckp(state, is_best, checkpoint_path, best_model_path):
    """
    state: checkpoint we want to save
    is_best: is this the best checkpoint; min validation loss
    checkpoint_path: path to save checkpoint
    best_model_path: path to save best model
    """
    f_path = checkpoint_path
    # save checkpoint data to the path given, checkpoint_path
    torch.save(state, f_path)
    # if it is a best model, min validation loss
    if is_best:
        best_fpath = best_model_path
        # copy that checkpoint file to best path given, best_model_path
        shutil.copyfile(f_path, best_fpath)

In [25]:
valid_loss_min = float('inf') #init val_loss
checkpoint_path = '../models/audioCNN_ckpt.pth'
best_model_path = '../models/audioCNN_best_ckpt.pth'
for epoch in range(epochs):
    losses=[]
#     scheduler.step()
    
    start = time.time()
    
    for b_idx, x in enumerate(train_loader):
#         print(b_idx)
        print(x['mfcc'].shape)
        inputs, targets = x['mfcc'].to(device), x['label'].to(device)
        
        optimizer.zero_grad()
        
        op = model(inputs).view(-1)
#         print(op[0], targets[0])
#         print(type(op.view(-1)[0]), type(targets[0]))
#         print(b_idx, op)
#         print(targets)
        loss = criterion(op, targets)
        loss.backward()
        
        optimizer.step()
        losses.append(loss.item())
        end = time.time()
        if b_idx % 100 == 0:
            print('Batch Index : %d Loss : %.10f Time : %.3f seconds ' % (b_idx, np.mean(losses), end - start))    
    model.eval()
    total = 0
    correct = 0
    acc = 0
    
    with torch.no_grad():
        for b_idx, x in enumerate(validation_loader):
            inputs, targets = x['mfcc'].to(device), x['label'].to(device)

            outputs = torch.sigmoid(model(inputs))
#             print(outputs, targets)
            
            
            predicted = torch.round(outputs.data)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()
            valid_loss = criterion(predicted.view(-1), targets.data)
            acc = 100.*correct/total
        print('Epoch : %d Val_Acc : %.3f Val_loss: %.3f' % (epoch, acc, valid_loss))
        print('--------------------------------------------------------------')
    checkpoint = {
            'epoch': epoch + 1,
            'valid_loss_min': valid_loss,
            'valid_acc': acc,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
        }

    save_ckp(checkpoint, False, checkpoint_path, best_model_path)
    
    if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(valid_loss_min,valid_loss))
            # save checkpoint as best model
            save_ckp(checkpoint, True, checkpoint_path, best_model_path)
            valid_loss_min = valid_loss
    model.train()   

torch.Size([1, 20, 6])
Batch Index : 0 Loss : 1.1542206705 Time : 2.679 seconds 
torch.Size([1, 20, 33])
torch.Size([1, 20, 87])
torch.Size([1, 20, 16])
torch.Size([1, 20, 25])
torch.Size([1, 20, 22])
torch.Size([1, 20, 16])
torch.Size([1, 20, 47])
torch.Size([1, 20, 28])
torch.Size([1, 20, 12])
torch.Size([1, 20, 17])
torch.Size([1, 20, 17])
torch.Size([1, 20, 124])
torch.Size([1, 20, 22])
torch.Size([1, 20, 75])
torch.Size([1, 20, 23])
torch.Size([1, 20, 264])
torch.Size([1, 20, 8])
torch.Size([1, 20, 8])
torch.Size([1, 20, 5])
torch.Size([1, 20, 13])
torch.Size([1, 20, 81])
torch.Size([1, 20, 17])
torch.Size([1, 20, 34])
torch.Size([1, 20, 17])
torch.Size([1, 20, 95])
torch.Size([1, 20, 6])
torch.Size([1, 20, 20])
torch.Size([1, 20, 10])
torch.Size([1, 20, 13])
torch.Size([1, 20, 15])
torch.Size([1, 20, 7])
torch.Size([1, 20, 38])
torch.Size([1, 20, 8])
torch.Size([1, 20, 30])
torch.Size([1, 20, 18])
torch.Size([1, 20, 290])
torch.Size([1, 20, 91])
torch.Size([1, 20, 98])
torch.Size




torch.Size([1, 20, 92])
torch.Size([1, 20, 208])
torch.Size([1, 20, 7])
torch.Size([1, 20, 14])
torch.Size([1, 20, 9])
torch.Size([1, 20, 53])
torch.Size([1, 20, 56])
torch.Size([1, 20, 207])
torch.Size([1, 20, 110])
torch.Size([1, 20, 98])
torch.Size([1, 20, 70])
torch.Size([1, 20, 42])
torch.Size([1, 20, 140])
torch.Size([1, 20, 62])
torch.Size([1, 20, 14])
torch.Size([1, 20, 30])
torch.Size([1, 20, 18])
torch.Size([1, 20, 9])
torch.Size([1, 20, 9])
torch.Size([1, 20, 30])
torch.Size([1, 20, 30])
torch.Size([1, 20, 16])
torch.Size([1, 20, 11])
torch.Size([1, 20, 27])
Batch Index : 100 Loss : 1.1972619561 Time : 5.019 seconds 
torch.Size([1, 20, 48])
torch.Size([1, 20, 7])
torch.Size([1, 20, 14])
torch.Size([1, 20, 57])
torch.Size([1, 20, 51])
torch.Size([1, 20, 164])
torch.Size([1, 20, 73])
torch.Size([1, 20, 14])
torch.Size([1, 20, 23])
torch.Size([1, 20, 70])
torch.Size([1, 20, 8])
torch.Size([1, 20, 32])
torch.Size([1, 20, 52])
torch.Size([1, 20, 114])
torch.Size([1, 20, 11])
tor



torch.Size([1, 20, 4])
torch.Size([1, 20, 69])
torch.Size([1, 20, 15])
torch.Size([1, 20, 19])
torch.Size([1, 20, 15])
torch.Size([1, 20, 35])
torch.Size([1, 20, 24])
torch.Size([1, 20, 12])
torch.Size([1, 20, 13])
torch.Size([1, 20, 58])
torch.Size([1, 20, 89])
torch.Size([1, 20, 17])
torch.Size([1, 20, 8])
torch.Size([1, 20, 34])
torch.Size([1, 20, 22])
torch.Size([1, 20, 54])
torch.Size([1, 20, 11])
torch.Size([1, 20, 15])
torch.Size([1, 20, 11])
torch.Size([1, 20, 10])
torch.Size([1, 20, 7])
torch.Size([1, 20, 59])
torch.Size([1, 20, 108])
torch.Size([1, 20, 6])
torch.Size([1, 20, 161])
torch.Size([1, 20, 7])
torch.Size([1, 20, 26])
torch.Size([1, 20, 7])




torch.Size([1, 20, 4])
torch.Size([1, 20, 28])
torch.Size([1, 20, 101])
torch.Size([1, 20, 38])
torch.Size([1, 20, 7])
torch.Size([1, 20, 87])
torch.Size([1, 20, 22])
torch.Size([1, 20, 11])
torch.Size([1, 20, 22])
torch.Size([1, 20, 24])
torch.Size([1, 20, 7])
torch.Size([1, 20, 7])
torch.Size([1, 20, 9])
torch.Size([1, 20, 13])
torch.Size([1, 20, 36])
torch.Size([1, 20, 66])
torch.Size([1, 20, 156])
torch.Size([1, 20, 35])
torch.Size([1, 20, 25])
torch.Size([1, 20, 30])
torch.Size([1, 20, 50])
torch.Size([1, 20, 4])
torch.Size([1, 20, 11])
torch.Size([1, 20, 13])
torch.Size([1, 20, 44])
torch.Size([1, 20, 21])
torch.Size([1, 20, 26])
torch.Size([1, 20, 119])
torch.Size([1, 20, 7])
torch.Size([1, 20, 18])
torch.Size([1, 20, 10])
torch.Size([1, 20, 46])
torch.Size([1, 20, 8])
torch.Size([1, 20, 81])
torch.Size([1, 20, 69])
torch.Size([1, 20, 6])
torch.Size([1, 20, 28])
torch.Size([1, 20, 1])
torch.Size([1, 20, 29])
torch.Size([1, 20, 66])
torch.Size([1, 20, 61])
torch.Size([1, 20, 54])




Batch Index : 200 Loss : 0.9269539786 Time : 7.416 seconds 
torch.Size([1, 20, 3])
torch.Size([1, 20, 25])
torch.Size([1, 20, 121])
torch.Size([1, 20, 17])
torch.Size([1, 20, 9])
torch.Size([1, 20, 8])




torch.Size([1, 20, 24])
torch.Size([1, 20, 11])
torch.Size([1, 20, 14])
torch.Size([1, 20, 69])
torch.Size([1, 20, 28])
torch.Size([1, 20, 21])
torch.Size([1, 20, 157])
torch.Size([1, 20, 18])
torch.Size([1, 20, 7])
torch.Size([1, 20, 40])
torch.Size([1, 20, 26])
torch.Size([1, 20, 13])
torch.Size([1, 20, 12])
torch.Size([1, 20, 98])
torch.Size([1, 20, 42])
torch.Size([1, 20, 9])
torch.Size([1, 20, 20])
torch.Size([1, 20, 13])
torch.Size([1, 20, 360])
torch.Size([1, 20, 16])
torch.Size([1, 20, 38])
torch.Size([1, 20, 38])
torch.Size([1, 20, 23])
torch.Size([1, 20, 30])
torch.Size([1, 20, 47])
torch.Size([1, 20, 55])
torch.Size([1, 20, 16])
torch.Size([1, 20, 6])
torch.Size([1, 20, 14])
torch.Size([1, 20, 49])
torch.Size([1, 20, 36])
torch.Size([1, 20, 28])
torch.Size([1, 20, 19])
torch.Size([1, 20, 42])
torch.Size([1, 20, 6])
torch.Size([1, 20, 162])
torch.Size([1, 20, 166])
torch.Size([1, 20, 42])
torch.Size([1, 20, 85])
torch.Size([1, 20, 14])
torch.Size([1, 20, 8])
torch.Size([1, 20



torch.Size([1, 20, 58])
torch.Size([1, 20, 77])
torch.Size([1, 20, 14])
torch.Size([1, 20, 27])
torch.Size([1, 20, 96])
torch.Size([1, 20, 29])
torch.Size([1, 20, 59])
torch.Size([1, 20, 11])
torch.Size([1, 20, 46])
torch.Size([1, 20, 10])
torch.Size([1, 20, 73])
torch.Size([1, 20, 84])
torch.Size([1, 20, 65])
torch.Size([1, 20, 31])
torch.Size([1, 20, 16])
torch.Size([1, 20, 11])
torch.Size([1, 20, 10])
torch.Size([1, 20, 28])
torch.Size([1, 20, 18])
torch.Size([1, 20, 12])
torch.Size([1, 20, 15])
torch.Size([1, 20, 17])
torch.Size([1, 20, 105])
torch.Size([1, 20, 26])
torch.Size([1, 20, 30])
torch.Size([1, 20, 18])
torch.Size([1, 20, 54])
torch.Size([1, 20, 8])
torch.Size([1, 20, 102])
torch.Size([1, 20, 35])
torch.Size([1, 20, 17])
torch.Size([1, 20, 81])
torch.Size([1, 20, 71])
torch.Size([1, 20, 79])
torch.Size([1, 20, 6])
torch.Size([1, 20, 19])
torch.Size([1, 20, 46])
torch.Size([1, 20, 14])
torch.Size([1, 20, 9])
torch.Size([1, 20, 17])
torch.Size([1, 20, 53])
torch.Size([1, 20

KeyboardInterrupt: 

In [285]:
model.to('cuda')

AudioCNN(
  (maxpool): AdaptiveAvgPool2d(output_size=32)
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1))
  (maxpool1): AdaptiveAvgPool2d(output_size=16)
  (dropout): Dropout(p=0.2, inplace=False)
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
  (maxpool2): AdaptiveAvgPool2d(output_size=8)
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=1, bias=True)
)

In [283]:
validation_loader, _ = load_data(dataset, batch_size, validation_split=0.0, shuffle_dataset=True, random_seed=42)

In [287]:
model.eval()
total = 0
correct = 0
acc = 0
for b_idx, x in enumerate(validation_loader):
            inputs, targets = x['mfcc'].to(device), x['label'].to(device)

            outputs = torch.sigmoid(model(inputs))
#             print(outputs, targets)
            
            
            predicted = torch.round(outputs.data)
#             print(outputs.data[0], predicted.data, targets)
            total += targets.size(0)
            correct += predicted.eq(targets.data).cpu().sum()
            valid_loss = criterion(predicted.view(-1), targets.data)
            acc = 100.*correct/total
print('Epoch : %d Val_Acc : %.3f Val_loss: %.3f' % (epoch, acc, valid_loss))
print('--------------------------------------------------------------')



Epoch : 0 Val_Acc : 90.629 Val_loss: 0.313
--------------------------------------------------------------


In [118]:
checkpoint_fpath = '../models/audioCNN_best_ckpt.pth'

In [42]:
for i in train_loader:
    print(i['mfcc'].shape)
    break

torch.Size([1, 20, 71])


In [14]:
def load_ckp(checkpoint_fpath, model, optimizer):
    """
    checkpoint_path: path to save checkpoint
    model: model that we want to load checkpoint parameters into       
    optimizer: optimizer we defined in previous training
    """
    # load check point
    checkpoint = torch.load(checkpoint_fpath)
    # initialize state_dict from checkpoint to model
    model.load_state_dict(checkpoint['state_dict'])
    # initialize optimizer from checkpoint to optimizer
    optimizer.load_state_dict(checkpoint['optimizer'])
    # initialize valid_loss_min from checkpoint to valid_loss_min
    valid_loss_min = checkpoint['valid_loss_min']
    # return model, optimizer, epoch value, min validation loss 
    return model, optimizer, checkpoint['epoch'], valid_loss_min.item()

In [119]:
model,optimizer, chkpt, v_loass = load_ckp(checkpoint_fpath, model, optimizer)

TypeError: forward() missing 1 required positional argument: 'x'

In [376]:
def correct_audio_in_chunks(filename, seg_length=1, zero_fill = False):
    '''
    Correct stuttered speech in chunks of duration seg_length

    Correct each chunk of audio and merge the corrected chunks together. Effectiveness depends on the audio in question and the size of the chunks.

    Parameters:
    filename (string): The path of the stuttered audio file
    seg_length (number): Length of the chunks in seconds
    zero_fill (boolean): (Optional) Whether the stuttered bits are removed or replaced with zeroes

    Returns:
    y (numpy ndarray): The sampled amplitude of the corrected audio
    sr (number): The sampling rate of the corrected audio
    '''
    y, sr = librosa.load(filename)
    
    # one second long chunks (num_secs * sr) = chunk_length
    num_chunks = len(y) // sr
    extra = len(y) % sr
    corrected_audio = []

    for i in range(num_chunks):
      audio_segment = y[int(sr * i * seg_length) : int(sr * (i + 1) * seg_length)]
      corr_segment, _ = correct_audio_segment(audio_segment, sr, zero_fill=zero_fill)
      corrected_audio.extend(corr_segment)
    corrected_audio = np.array(corrected_audio)
    print('correction complete')
    return corrected_audio, sr

In [377]:
def correct_audio_segment(y, sr, zero_fill = False):
    '''
    Correct the audio specified audio chunk

    Corrects the audio clip by predicting the threshold corresponding to the maximum amplitude of the chunk, and removing the stuttered clips if zero_fill is set to False, or replacing said clips with zeroes if zero_fill is set to True.

    Parameters:
    y (list): The sampled amplitude of the soundwave
    sr (number): The sampling rate in hertz
    zero_fill (boolean): (Optional) Whether the stuttered bits are removed or replaced with zeroes
    
    Returns:
    y (list): The sampled amplitude of the corrected audio segment
    sr (number): The sampling rate of the corrected audio segment
    '''
    if len(y) == 0:
      return y, sr
    
    maxv = max(y)
    mfcc = librosa.feature.mfcc(y,sr)
#     print(type(mfcc))
#     print(.shape)
    mfcc = torch.from_numpy(mfcc).view((-1,mfcc.shape[0],mfcc.shape[1]))
    st = torch.sigmoid(model(mfcc))
    print(st, torch.round(st.data)==0, torch.round(st))
    
# outputs = torch.sigmoid(model(inputs))
# #             print(outputs, targets)


#         predicted = torch.round(outputs.data)
    if torch.round(st) == 1 :
        print('hereee',y.shape, type(y))
        return y, sr
    else:
        return np.array([]), sr
#     pred_thresh = model.predict()
#     frame_duration = 0.3
#     frame_len = int(frame_duration * sr)
#     n = len(y)
#     num_frames = int(n // frame_len)

#     corrected_audio_signal = []

#     count = -1

#     for i in range(num_frames):
#       frame = y[(i) * frame_len : frame_len * (i+1)]
#       frame_max = max(frame)
#       if (frame_max > pred_thresh):
#         count += 1
#         corrected_audio_signal[(count)*frame_len:frame_len*(count+1)] = frame
#       else:
#         if (zero_fill):
#           count += 1
#           # Zero fill
#           corrected_audio_signal[(count)*frame_len:frame_len*(count+1)] = np.repeat([0], frame_len)
#         else:
#           print('skipped frame with max {}'.format(frame_max))
#     corrected_audio_signal = np.array(corrected_audio_signal)
    return corrected_audio_signal, sr

In [378]:
i = StutterData('./F_0050_10y9m_1.wav', single_file=True)

In [379]:
model.to('cpu')

AudioCNN(
  (maxpool): AdaptiveAvgPool2d(output_size=32)
  (conv1): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1))
  (maxpool1): AdaptiveAvgPool2d(output_size=16)
  (dropout): Dropout(p=0.2, inplace=False)
  (conv2): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1))
  (maxpool2): AdaptiveAvgPool2d(output_size=8)
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=1, bias=True)
)

In [380]:
aud = librosa.core.load('./F_0050_10y9m_1.wav')

In [381]:
aud[0].shape, aud[1]

((2950848,), 22050)

In [382]:
x[0].shape[0]/aud[0].shape[0]

0.7920773960569979

In [387]:
x[0].shape

(2337300,)

In [384]:
x = correct_audio_in_chunks(filename='./F_0050_10y9m_1.wav')

tensor([[0.0111]], grad_fn=<SigmoidBackward>) tensor([[True]]) tensor([[0.]], grad_fn=<RoundBackward>)
tensor([[0.4276]], grad_fn=<SigmoidBackward>) tensor([[True]]) tensor([[0.]], grad_fn=<RoundBackward>)
tensor([[0.9696]], grad_fn=<SigmoidBackward>) tensor([[False]]) tensor([[1.]], grad_fn=<RoundBackward>)
hereee (22050,) <class 'numpy.ndarray'>
tensor([[0.0494]], grad_fn=<SigmoidBackward>) tensor([[True]]) tensor([[0.]], grad_fn=<RoundBackward>)
tensor([[0.9922]], grad_fn=<SigmoidBackward>) tensor([[False]]) tensor([[1.]], grad_fn=<RoundBackward>)
hereee (22050,) <class 'numpy.ndarray'>
tensor([[0.9553]], grad_fn=<SigmoidBackward>) tensor([[False]]) tensor([[1.]], grad_fn=<RoundBackward>)
hereee (22050,) <class 'numpy.ndarray'>
tensor([[0.0639]], grad_fn=<SigmoidBackward>) tensor([[True]]) tensor([[0.]], grad_fn=<RoundBackward>)
tensor([[0.9972]], grad_fn=<SigmoidBackward>) tensor([[False]]) tensor([[1.]], grad_fn=<RoundBackward>)
hereee (22050,) <class 'numpy.ndarray'>
tensor([[0.7

In [385]:
import soundfile

In [386]:
soundfile.write('./cor.wav', x[0], x[1])