In [3]:
import torch
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import os
import codecs
import pickle

In [4]:
dataset_path = "./dataset/data_thchs30"
mfcc_mat = np.load(os.path.join(dataset_path, "mfcc_vec_680x26.npy"))
mfcc_mat = torch.tensor(mfcc_mat)
mfcc_mat.shape

torch.Size([13388, 680, 26])

In [5]:
with codecs.open(os.path.join(dataset_path, "all_texts.txt"), encoding="utf-8") as file_read:
    text_lines = file_read.readlines()
len(text_lines), text_lines[:3]

(13388,
 ['绿是阳春烟景大块文章的底色四月的林峦更是绿得鲜活秀媚诗意盎然\n',
  '他仅凭腰部的力量在泳道上下翻腾蛹动蛇行状如海豚一直以一头的优势领先\n',
  '炮眼打好了炸药怎么装岳正才咬了咬牙倏地脱去衣服光膀子冲进了水窜洞\n'])

In [6]:
token_set = set(list(''.join(text_lines).replace("\n","")))
token_map = dict((j,i+1) for i,j in enumerate(token_set))
print(len(token_map))
seq_lines = [list(map(lambda x: token_map[x], text_line.replace("\n",""))) for text_line in text_lines]
len(seq_lines), len(seq_lines[0])

2883


(13388, 30)

In [8]:
pad_lines = [(seq_line + [0]*48)[:48] for seq_line in seq_lines]

In [9]:
class MyDataset(torch.utils.data.Dataset):
    def __init__(self):
        dataset_path = "./dataset/data_thchs30"
        self.mfcc_mat = np.load(os.path.join(dataset_path, "mfcc_vec_680x26.npy"))
        with codecs.open(os.path.join(dataset_path, "all_texts.txt"), encoding="utf-8") as file_read:
            text_lines = file_read.readlines()
        token_set = set(list(''.join(text_lines).replace("\n","")))
        token_map = dict((j,i+1) for i,j in enumerate(token_set))
        seq_lines = [list(map(lambda x: token_map[x], text_line.replace("\n",""))) for text_line in text_lines]
        self.pad_lines = [(seq_line + [0]*48)[:48] for seq_line in seq_lines]
        self.pad_lines = torch.tensor(self.pad_lines).unsqueeze(-1)
    
    def __len__(self):
        return len(self.mfcc_mat)

    def __getitem__(self,idx):
        return self.mfcc_mat[idx], self.pad_lines[idx]

In [10]:
my_dataset = MyDataset()
my_dataset

<__main__.MyDataset at 0x28de0ff25c8>

In [11]:
data_train, data_test = torch.utils.data.random_split(my_dataset, [13000, 388])

In [12]:
class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.stage = torch.nn.Sequential(
            torch.nn.Conv1d(26, 50, kernel_size=5, stride=1, padding=2),
            torch.nn.Tanh(),
            torch.nn.Conv1d(50, 2884, kernel_size=5, stride=1, padding=2)
        )

    def forward(self, x):
        batch_size = x.size(0)
        x = self.stage(x)
        return x
    

In [13]:
net = MyModel()
tmp = torch.randn((4,26,680))
print(net(tmp).shape)
torch.transpose(torch.transpose(net(tmp), 0, 2), 1, 2).shape

torch.Size([4, 2884, 680])


torch.Size([680, 4, 2884])

In [7]:
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [29]:
# this is for demo
batch_size = 8
ctc_loss = torch.nn.CTCLoss()
log_probs = torch.randn(680, batch_size, 2884).log_softmax(2).detach().requires_grad_()
targets = torch.randint(1, 2884, (batch_size, 48), dtype=torch.long)
input_lengths = torch.full((batch_size,), 680, dtype=torch.long)
target_lengths = torch.randint(24, 48, (batch_size,), dtype=torch.long)
loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
loss.backward()
loss

tensor(153.8136, grad_fn=<MeanBackward0>)

In [30]:
dataloader_train = torch.utils.data.DataLoader(data_train, batch_size = batch_size, shuffle=True)
dataloader_test = torch.utils.data.DataLoader(data_test, batch_size = batch_size, shuffle=False)

In [11]:
for i,j in dataloader_train:
    print(i,j)
    break

tensor([[[ 2.0944e+01, -4.1253e+01,  6.7133e-01,  ..., -1.1126e-01,
           3.9241e-01,  1.4940e-01],
         [ 2.0943e+01, -4.1742e+01,  3.5420e-01,  ..., -9.3708e-02,
           4.7631e-01,  5.7791e-01],
         [ 1.3963e+01, -2.5509e+01, -1.0801e+01,  ..., -5.8580e-02,
           4.5159e-01, -8.9573e-02],
         ...,
         [ 1.7983e+01,  3.9722e+00, -1.8970e+01,  ..., -4.4753e-01,
           1.1611e+00,  2.9819e+00],
         [ 1.8745e+01,  3.0303e+00, -2.1659e+01,  ..., -5.2191e-01,
           3.2556e+00,  4.6692e+00],
         [ 1.9063e+01,  1.8797e+00, -2.2774e+01,  ..., -6.0087e-01,
           3.0301e+00,  5.0095e+00]],

        [[ 1.1203e+01, -8.8239e+00, -1.2761e+01,  ..., -4.1323e-01,
           2.8512e-01, -7.1445e-01],
         [ 1.0919e+01, -7.9935e+00, -1.2961e+01,  ..., -3.7978e-01,
          -1.0755e-01, -1.8388e+00],
         [ 1.0632e+01, -8.7431e+00, -1.2541e+01,  ..., -4.3447e-01,
          -6.0046e-01, -6.2409e-01],
         ...,
         [ 1.5245e+01,  2

In [31]:
for epoch in range(10):
    for batch_idx, (x, y_true) in enumerate(dataloader_train):
        x = torch.transpose(x, 1, 2)
        x = x.type(torch.FloatTensor)
        y_true = y_true.squeeze(-1)
        x, y_true = x.cuda(), y_true.cuda()
        logits = model(x).log_softmax(2)
        ctc_loss = torch.nn.CTCLoss().cuda()
        log_probs = torch.transpose(torch.transpose(
            logits, 0, 2), 1, 2).requires_grad_()
        targets = y_true
        input_lengths = torch.full((batch_size,), 680, dtype=torch.long)
        target_lengths = torch.tensor(
            [sum([1 for j in i if j > 0]) for i in y_true], dtype=torch.long)
        loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
        loss.backward()
        print(batch_idx, loss.item())


0 130.65274047851562
1 123.19671630859375
2 128.10986328125
3 129.64865112304688
4 128.28305053710938
5 120.76300048828125
6 126.61337280273438
7 120.30543518066406
8 113.39134216308594
9 126.79279327392578
10 125.12100219726562
11 136.17364501953125
12 134.31430053710938
13 132.18272399902344
14 126.92364501953125
15 127.32121276855469
16 144.0885772705078
17 124.84217834472656
18 126.72932434082031
19 129.77639770507812
20 124.54583740234375
21 126.73545837402344
22 130.92990112304688
23 126.32046508789062
24 144.81375122070312
25 131.81141662597656
26 123.07147979736328
27 118.42012786865234
28 138.7265625
29 129.5321044921875
30 121.78547668457031
31 130.73707580566406
32 122.19229888916016
33 144.44847106933594
34 134.82272338867188
35 134.0120849609375
36 132.3302001953125
37 126.5694580078125
38 121.38441467285156
39 132.16123962402344
40 132.57183837890625
41 124.88327026367188
42 129.2086181640625
43 132.24630737304688
44 139.4387664794922
45 129.31007385253906
46 127.56576538

KeyboardInterrupt: 