In [1]:
import torch
import skorch

In [2]:
import numpy as np

In [3]:
from deepspeech.model import DeepSpeech
from deepspeech.data.data_loader import SpectrogramParser
from deepspeech.data.data_loader import BucketingSampler

In [4]:
from data import dataset

In [5]:
ds = dataset.SwearDataset(dataset.DEFAULT_PROVIDERS)

In [6]:
X, y = ds.load()

In [7]:
class SwearAudioDataset:
    def __init__(self, X, y, parser):
        self.X = X
        self.y = y
        assert len(X) == len(y)
        self.parser = parser
        
    def table(self):
        for i, (word, fpath) in enumerate(self.X):
            frames = self.parser.parse_audio(fpath)
            y = np.zeros(frames.shape[1])
            y[-1] = int(self.y[i])
            yield frames, y
            
    def load(self):
        Xy = list(self.table())
        return (
            [n[0] for n in Xy],
            [n[1] for n in Xy],
        )

In [8]:
base_model = DeepSpeech.load_model('models/librispeech_pretrained.pth')

In [9]:
audio_conf = DeepSpeech.get_audio_conf(base_model)
parser = SpectrogramParser(audio_conf)

In [10]:
ds = SwearAudioDataset(X, y, parser)

In [11]:
X, y = ds.load()

In [94]:
seq_lens = np.array([x.shape[1] for x in X])
max_seq_len = max(seq_lens)
max_seq_len, np.mean(seq_lens), np.median(seq_lens)

(663, 233.85714285714286, 173.0)

In [95]:
X_pad = np.zeros((len(X), X[0].shape[0], max_seq_len), dtype='float32')
y_pad = np.zeros((len(y), max_seq_len), dtype='float32') - 999

In [96]:
for i, _ in enumerate(X):
    X_pad[i, :, :seq_lens[i]] = X[i]
    y_pad[i, :seq_lens[i]] = y[i]

In [97]:
X_comb = {'lens': seq_lens, 'X': X_pad}

In [98]:
from deepspeech.model import SequenceWise

In [64]:
class RNNExtractor(torch.nn.Module):
    def forward(self, x):
        assert type(x) == tuple
        return x[0]

In [65]:
def get_model(model):
    model.rnns = torch.nn.GRU(672, 10, bias=False)
    model.fc = torch.nn.Sequential(
        RNNExtractor(),
        SequenceWise(torch.nn.Linear(10, 1)),
    )
    model.inference_softmax = torch.nn.Sigmoid()
    return model

In [66]:
model = get_model(base_model)

In [67]:
model

DeepSpeech(
  (conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(0, 10))
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): Hardtanh(min_val=0, max_val=20, inplace)
    (3): Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): Hardtanh(min_val=0, max_val=20, inplace)
  )
  (rnns): GRU(672, 10, bias=False)
  (fc): Sequential(
    (0): RNNExtractor()
    (1): SequenceWise (
    Linear(in_features=10, out_features=1, bias=True))
  )
  (inference_softmax): Sigmoid()
)

In [132]:
class SpeechNet(skorch.NeuralNet):
    def get_loss(self, y_pred, y_true, **kwargs):
        y_true = y_true[:, :y_pred.shape[1]]
        y_pred = y_pred.squeeze()
        
        # remove padding
        nz = torch.nonzero(y_true != -999)
        y_true = y_true[nz[:, 0], nz[:, 1]]
        y_pred = y_pred[nz[:, 0], nz[:, 1]]
        
        return super().get_loss(y_pred, y_true, **kwargs)

In [133]:
def bucket(Xi, yi):
    X, max_len = Xi['X'], max(Xi['lens'])
    return X[:, :, :max_len], yi[:, :max_len]

In [134]:
def bucketing_dataloader(ds, **kwargs):
    dl = torch.utils.data.DataLoader(ds, **kwargs)
    for Xi, yi in dl:
        Xi, yi = bucket(Xi, yi)
        Xi = Xi[:, None, :, :]
        yield Xi, yi

In [135]:
net = SpeechNet(
    model, 
    criterion=torch.nn.MSELoss,
    iterator_train=bucketing_dataloader,
    iterator_valid=bucketing_dataloader,
    batch_size=8,
    callbacks=[
        skorch.callbacks.Freezer(lambda n: not n.startswith('rnns') and not n.startswith('fc')),
        skorch.callbacks.ProgressBar(),
    ]
)

In [136]:
%pdb on
net.fit(X_comb, y_pad)

Automatic pdb calling has been turned ON


HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

  epoch    train_loss    valid_loss      dur
-------  ------------  ------------  -------
      1        [36m0.0111[0m        [32m0.0102[0m  28.8133


HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

      2        [36m0.0101[0m        [32m0.0092[0m  28.5385


HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

      3        [36m0.0092[0m        [32m0.0086[0m  28.4663


HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

      4        [36m0.0085[0m        [32m0.0081[0m  28.3448


HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

      5        [36m0.0079[0m        [32m0.0077[0m  28.3073


HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

      6        [36m0.0074[0m        [32m0.0073[0m  28.5440


HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

      7        [36m0.0070[0m        [32m0.0070[0m  28.3965


HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

<class '__main__.SpeechNet'>[initialized](
  module_=DeepSpeech(
    (conv): Sequential(
      (0): Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(0, 10))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): Hardtanh(min_val=0, max_val=20, inplace)
      (3): Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1))
      (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): Hardtanh(min_val=0, max_val=20, inplace)
    )
    (rnns): GRU(672, 10, bias=False)
    (fc): Sequential(
      (0): RNNExtractor()
      (1): SequenceWise (
      Linear(in_features=10, out_features=1, bias=True))
    )
    (inference_softmax): Sigmoid()
  ),
)

In [137]:
from skorch.helper import SliceDict

In [165]:
targets = y_pad[list(range(len(seq_lens))), seq_lens - 1]

In [190]:
where_true = np.where(targets == 1)[0][:2]
where_false = np.where(targets == 0)[0][:2]

In [197]:
pred_where_true = net.predict_proba(SliceDict(**X_comb)[where_true])
pred_where_false = net.predict_proba(SliceDict(**X_comb)[where_false])

In [198]:
SliceDict(**X_comb)[where_true]

SliceDict(**{'lens': array([112, 150]), 'X': array([[[ 9.039093 , 13.191983 , 12.980769 , ...,  0.       ,
          0.       ,  0.       ],
        [ 8.538684 , 12.893001 , 14.551824 , ...,  0.       ,
          0.       ,  0.       ],
        [ 7.3704205, 12.335937 , 15.226274 , ...,  0.       ,
          0.       ,  0.       ],
        ...,
        [ 4.7000933,  8.631623 ,  8.778911 , ...,  0.       ,
          0.       ,  0.       ],
        [ 4.435214 ,  8.660207 ,  8.742351 , ...,  0.       ,
          0.       ,  0.       ],
        [ 3.977667 ,  8.559978 ,  8.88186  , ...,  0.       ,
          0.       ,  0.       ]],

       [[ 5.1243763, 10.000149 , 11.343134 , ...,  0.       ,
          0.       ,  0.       ],
        [ 6.7059717, 10.578775 , 12.60751  , ...,  0.       ,
          0.       ,  0.       ],
        [ 5.8380523, 10.350276 , 12.807459 , ...,  0.       ,
          0.       ,  0.       ],
        ...,
        [ 3.9761453,  7.5290995,  7.8088303, ...,  0.       ,
 

In [200]:
pred_where_true.shape

(2, 70, 1)

In [199]:
pred_where_true[0, seq_lens[where_true[0]] - 1]

IndexError: index 111 is out of bounds for axis 1 with size 70

> [0;32m<ipython-input-199-0b6892b490ef>[0m(1)[0;36m<module>[0;34m()[0m
[0;32m----> 1 [0;31m[0mpred_where_true[0m[0;34m[[0m[0;36m0[0m[0;34m,[0m [0mseq_lens[0m[0;34m[[0m[0mwhere_true[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m][0m [0;34m-[0m [0;36m1[0m[0;34m][0m[0;34m[0m[0m
[0m
ipdb> q


In [189]:
pred_where_false.sum(1)

array([[4.1872063],
       [1.5879295],
       [8.0309925],
       [7.169367 ],
       [2.3060193]], dtype=float32)