In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
from fastai.audio import *

In [3]:
sg_cfg= SpectrogramConfig(hop=480, n_mels=128, n_fft=960, top_db=80, f_min=20.0, f_max=22050)
config_split= AudioConfig(resample_to = 8000, remove_silence = "all", silence_padding=200, silence_threshold=20, sg_cfg=sg_cfg, duration=10000)

In [4]:
train_df = pd.read_csv(Path("./data/train.tsv"),sep="\t")

In [5]:
class DeepSpeechAudioList(AudioList):
    _bunch = DataBunch
    def __init__(self, items, path, config=AudioConfig(), context=5, **kwargs):
        super().__init__(items=items, path=path, config=config, **kwargs)
        self.context = context
        
    def get(self,i):
        one_spectro = super().get(i).spectro
        one_spectro = F.pad(one_spectro, pad=(self.context,self.context))
        return torch.stack([one_spectro[:,:,i:i+(self.context*2+1)].squeeze() for i in range(one_spectro.shape[-1]-(self.context*2))])

In [6]:
class SentenceCharList(ItemList):
    itoc = ["a","b","c","d","e","f","g","h","i","j","k","l","m","n","o","p","q","r","s","t","u","v","w","x","y","z"," ","'",""]
    ctoi = defaultdict(lambda: 28)
    def __init__(self, items, path, **kwargs):
        super().__init__(items, **kwargs)
        for i,char in enumerate(self.itoc):
            self.ctoi[char]=i
        
    def get(self,i):
#        return torch.tensor([self.ctoi[j] for j in self.items[i].lower()])
        return F.pad(torch.tensor([self.ctoi[j] for j in self.items[i].lower()]), pad=(0,200), mode='constant', value=28)[:100]
#        return torch.tensor([self.ctoi[j] for j in self.items[i].lower()])

In [7]:
#My solution to having to look through training files to make sure they are in training and not validation. 
train_file = defaultdict(bool)
for i in train_df.path:
    train_file[i] = True

In [8]:
data = (DeepSpeechAudioList.from_folder("data/clips", config=config_split)#, processor=[rs])
        .use_partial_data(0.05, seed=42)
        .filter_by_func(lambda x: train_file[x.name])
        .split_by_rand_pct(0.2, seed=42)
        .label_from_func(lambda x: train_df[train_df.path==str(x).split("/")[-1]]["sentence"].iloc[0], label_cls=SentenceCharList)
        .databunch(bs=32)
       )

Preprocessing: Resampling to 8000


Preprocessing: Removing Silence


Preprocessing: Resampling to 8000


Preprocessing: Removing Silence


In [9]:
class DeepSpeech(nn.Module):
    def __init__(self, context=5, bs=64):
        super(DeepSpeech, self).__init__()
        self.bs = bs
        self.context = context
        self.h = None#(torch.zeros((2,5,2048)).cuda(),torch.zeros((2,5,2048)).cuda())#None
        self.flatten = nn.Flatten()#lambda x: torch.reshape(x,(-1,1,2432))
        self.h1 = nn.Linear(128*2*self.context+128,2048)
        self.h2 = nn.Linear(2048,2048)
        self.h3 = nn.Linear(2048,2048)
        self.h4 = nn.LSTM(2048,2048,bidirectional=False,batch_first=True)
        #self.h4 = nn.RNN(2048,2048, nonlinearity="relu",bidirectional=True)
        self.h5 = nn.Linear(2048,29) #ct ∈ {a,b,c, . . . , z, space, apostrophe, blank}
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = nn.Flatten(-2,-1)(x)#torch.stack(x,dim=1)
        x = self.h1(x).clamp(min=0, max=20)
        x = self.h2(x).clamp(min=0, max=20)
        x = self.h3(x).clamp(min=0, max=20)
        if self.h is None:
            ;
        elif self.h[0].shape[1]>x.size(0):
            self.h=tuple([each[:,:x.size(0),:] for each in model.h])
        elif self.h[0].shape[1]<x.size(0):
            self.h=None#tuple([each.expand(-1,x.size(0),-1) for each in model.h])
        x,h = self.h4(x, self.h)
        self.h = to_detach(h, cpu=False)
        x = x.view(-1,166,1,2048)
        x = x.sum(dim=2)
        x = self.h5(x).clamp(min=0, max=20)
        x = self.softmax(x)
        return x

In [10]:
def ctc_loss(input, target, bs=64):
    r"""Loss function that makes CTC Loss easier to use especially for putting into a fastai Learner"""
    i_length = torch.tensor([sum([1 for i in t.sum(1) if i != 0]) for t in input])
    input = input.permute(1,0,2).detach().requires_grad_()
    ctc = nn.CTCLoss(blank=28, zero_infinity=True, reduction="sum")
    t_length = torch.tensor([sum([1 for i in t if i != 0])for t in target])
    ret = ctc(input, target, input_lengths=i_length, target_lengths=t_length)#torch.tensor([100]*input.size(1)))
    return ret

In [22]:
model = DeepSpeech(context=5, bs=32)

In [23]:
learn = Learner(data, model, loss_func=partial(ctc_loss,bs=32), opt_func=AdamW)

In [24]:
learn.fit_one_cycle(10,1e-3)

epoch,train_loss,valid_loss,time
0,22815.015625,22393.337891,00:23
1,22768.455078,22393.097656,00:24
2,22731.625,22394.027344,00:24
3,22616.173828,22393.060547,00:24
4,22804.589844,22393.794922,00:24
5,22659.431641,22393.273438,00:24
6,22736.007812,22394.404297,00:24
7,22791.597656,22392.853516,00:24
8,22724.533203,22393.158203,00:24
9,22799.511719,22393.740234,00:24


In [19]:
preds,targs = learn.get_preds()

In [15]:
preds.shape

torch.Size([652, 166, 29])

In [16]:
targs.shape

torch.Size([652, 100])

In [20]:
preds.argmax(dim=2)[phrase]

tensor([ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2, 24, 28, 28, 28, 28, 18, 18, 18, 18, 18, 18, 18, 18,
        18, 18, 15, 18, 18, 18, 18, 18, 18, 18, 18, 18, 18,  6,  6, 18, 18, 18,
        18, 18, 18, 18, 18, 18, 18, 18, 18, 18, 11, 18, 18, 18, 18, 18, 14, 27,
        27, 27, 27, 27, 27,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,
         2,  2,  2,  2])

In [21]:
phrase = 0
for i in preds.argmax(dim=2)[phrase]: print(learn.data.itoc[i], end='')
print("")
for i in targs[phrase]: print(learn.data.itoc[i], end='')

cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccysssssssssspssssssssssggssssssssssssslssssso''''''ccccccccccccccccccccccccccccccccccccccccccccccccccccc
stop fooling damn you