In [1]:
import torch
import skorch

In [133]:
torch.cuda.is_available()

False

In [2]:
import numpy as np

In [None]:
from sklearn.model_selection import StratifiedKFold

In [3]:
from deepspeech.model import DeepSpeech
from deepspeech.model import SequenceWise
from deepspeech.data.data_loader import SpectrogramParser
from deepspeech.data.data_loader import BucketingSampler

In [4]:
from data import dataset
from utils import RNNValueExtractor
from utils import Identity
from utils import bucketing_dataloader

In [5]:
base_model = DeepSpeech.load_model(
    'models/librispeech_pretrained.pth'
)
audio_conf = DeepSpeech.get_audio_conf(base_model)
parser = SpectrogramParser(audio_conf)

In [6]:
ds = dataset.SwearDataset(dataset.DEFAULT_PROVIDERS)

In [7]:
X, y = ds.load()

In [8]:
ds = dataset.SwearBinaryAudioDataset(X, y, parser)

In [9]:
X, y = ds.load()

In [10]:
seq_lens = np.array([x.shape[1] for x in X])
max_seq_len = max(seq_lens)
max_seq_len, np.mean(seq_lens), np.median(seq_lens)

(663, 233.85714285714286, 173.0)

In [11]:
X_pad = np.zeros(
    (len(X), X[0].shape[0], max_seq_len), 
    dtype='float32'
)
for i, _ in enumerate(X):
    X_pad[i, :, :seq_lens[i]] = X[i]

In [120]:
split = StratifiedKFold(n_splits=5)

In [121]:
train_idcs, test_idcs = next(split.split(y, y=y))

In [122]:
X_train = {'lens': seq_lens[train_idcs], 'X': X_pad[train_idcs]}
y_train = np.array(y)[train_idcs]

X_test = {'lens': seq_lens[test_idcs], 'X': X_pad[test_idcs]}
y_test = np.array(y)[test_idcs]

In [125]:
len(X_train['X']), len(X_test['X'])

(173, 44)

In [126]:
class NoSwearModel(torch.nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model
        self.base_model.rnns = Identity()
        self.base_model.lookahead = Identity()
        self.base_model.fc = Identity()
        self.base_model.inference_softmax = Identity()
        
        self.rnn = torch.nn.GRU(672, 10,
                                bias=False, batch_first=True)
        self.clf = torch.nn.Linear(10, 2)
        
    def forward(self, X, lens):
        # run base model, output is NxTxH with
        # T=Time, N=samples, H=hidden.
        with torch.no_grad():
            y_pre = self.base_model(X)
        
        # run RNN over sequence and extract last item
        y, _ = self.rnn(y_pre)
        # presumably we cannot use lens since cnns reduce that too
        #y = y[:, lens - 1]
        y = y[:, -1]
        
        # run classifier
        y = self.clf(y)
        y = torch.softmax(y, dim=-1)
        return y

In [127]:
def bucket(Xi, yi):
    Xi['X'] = Xi['X'][:, :, :max(Xi['lens'])]
    return Xi, yi

In [128]:
net = skorch.NeuralNetClassifier(
    NoSwearModel(base_model), 
    
    iterator_train=bucketing_dataloader,
    iterator_train__bucket_fn=bucket,
    iterator_valid=bucketing_dataloader,
    iterator_valid__bucket_fn=bucket,
    
    batch_size=2,
    
    callbacks=[
        skorch.callbacks.Freezer('base_model.*'),
        skorch.callbacks.ProgressBar(),
    ]
)

In [129]:
%pdb on
net.fit(X_comb, y)

Automatic pdb calling has been turned ON


HBox(children=(IntProgress(value=0, max=109), HTML(value='')))

<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=NoSwearModel(
    (base_model): DeepSpeech(
      (conv): Sequential(
        (0): Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(0, 10))
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Hardtanh(min_val=0, max_val=20, inplace)
        (3): Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1))
        (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): Hardtanh(min_val=0, max_val=20, inplace)
      )
      (rnns): Identity()
      (fc): Identity()
      (inference_softmax): Identity()
      (lookahead): Identity()
    )
    (rnn): GRU(672, 10, bias=False, batch_first=True)
    (clf): Linear(in_features=10, out_features=2, bias=True)
  ),
)

In [131]:
from sklearn.metrics import accuracy_score

In [132]:
accuracy_score(y_test, net.predict(X_test))

0.5227272727272727