In [1]:
from transformers import AutoFeatureExtractor

#音频数据批处理工具类
feature_extractor = AutoFeatureExtractor.from_pretrained(
    'facebook/wav2vec2-base')

#facebook/wav2vec2-base
#jonatasgrosman/wav2vec2-large-xlsr-53-english

print(feature_extractor)

#编码成数据
out = feature_extractor(
    #数据,模拟三句话，每句话用4个数字表示
    raw_speech=[[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2],
                [0.3, 0.3, 0.3, 0.3]],
    #采样率，也就是用多少个数组表示1秒的数据，这里16k就意味着是16khz的采样率
    sampling_rate=16000,
    #定义最大长度，16000个数字意味着是1秒的长度
    max_length=8,
    #超出max_length的部分截断
    truncation=True,
    #不足max_length的补充pad
    #padding='max_length',
)

#试编码,编码完是个字典，只有一个键input_values
out = out['input_values']

#input_values是个list，里面每一项是一维的numpy数组，并且长度是确定的16000
out[0], out[0].shape

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "


Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000
}



(array([0., 0., 0., 0.], dtype=float32), (4,))

In [2]:
from datasets import load_dataset, load_from_disk
import torch

#pip install soundfile
#pip install librosa

#加载数据集
#dataset = load_dataset(path='superb', name='ks')
dataset = load_from_disk('./datas/superb/ks')

#查看label
print(dataset['train'].features['label'].names)
#['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', '_silence_', '_unknown_']

print(dataset['train'].features['label'].num_classes)
#12

#查看数据样例
print(dataset['train'][0])

#采样,数据量太大了跑不动
#dataset['train'] = dataset['train'].shuffle(seed=1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(seed=1).select(range(0))
dataset['test'] = dataset['test'].shuffle(seed=1).select(range(200))


#在训练集中各类样本的分布是不均匀的，11太多了，10太少了，干脆放弃这两个类别，也不测试了
def f(data):
    return [i != 11 and i != 10 for i in data['label']]


dataset = dataset.filter(function=f, batched=True, batch_size=100, num_proc=4)

print(torch.LongTensor(dataset['train']['label']).bincount())

dataset

['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', '_silence_', '_unknown_']
12


Loading cached shuffled indices for dataset at datas/superb/ks/validation/cache-fb7bc93e9da9e87d.arrow
Loading cached shuffled indices for dataset at datas/superb/ks/test/cache-de92a67d19ff7904.arrow


{'file': '/root/.cache/huggingface/datasets/downloads/extracted/05734a36d88019a09725c20cc024e1c4e7982e37d7d55c0c1ca1742ea1cdd47f/_background_noise_/doing_the_dishes.wav', 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/05734a36d88019a09725c20cc024e1c4e7982e37d7d55c0c1ca1742ea1cdd47f/_background_noise_/doing_the_dishes.wav', 'array': array([ 0.        ,  0.        ,  0.        , ..., -0.00592041,
       -0.00405884, -0.00253296], dtype=float32), 'sampling_rate': 16000}, 'label': 10}
 

Loading cached processed dataset at datas/superb/ks/train/cache-6877731e6b99c3ac_00000_of_00004.arrow


 

Loading cached processed dataset at datas/superb/ks/train/cache-6877731e6b99c3ac_00001_of_00004.arrow


 

Loading cached processed dataset at datas/superb/ks/train/cache-6877731e6b99c3ac_00002_of_00004.arrow


 

Loading cached processed dataset at datas/superb/ks/train/cache-6877731e6b99c3ac_00003_of_00004.arrow


 

Loading cached processed dataset at datas/superb/ks/test/cache-9f0c351c2fc613e6_00000_of_00004.arrow


 

Loading cached processed dataset at datas/superb/ks/test/cache-9f0c351c2fc613e6_00001_of_00004.arrow


 

Loading cached processed dataset at datas/superb/ks/test/cache-9f0c351c2fc613e6_00002_of_00004.arrow


 

Loading cached processed dataset at datas/superb/ks/test/cache-9f0c351c2fc613e6_00003_of_00004.arrow


tensor([1860, 1853, 1843, 1842, 1839, 1852, 1864, 1839, 1885, 1861])


DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'label'],
        num_rows: 18538
    })
    validation: Dataset({
        features: ['file', 'audio', 'label'],
        num_rows: 0
    })
    test: Dataset({
        features: ['file', 'audio', 'label'],
        num_rows: 170
    })
})

In [3]:
#数据预处理函数
def f(data):
    #取出数据中的音频数据，每一个是一维的numpy数组
    data = [i['array'] for i in data['audio']]

    #编码成数据
    return feature_extractor(
        #数据
        raw_speech=data,
        #采样率，也就是用多少个数组表示1秒的数据，这里16k就意味着是16khz的采样率
        sampling_rate=16000,
        #定义最大长度，16000个数字意味着是1秒的长度
        max_length=16000,
        #超出max_length的部分截断
        truncation=True,
        #不足max_length的补充pad
        padding='max_length',
    )


#预处理数据
dataset = dataset.map(function=f,
                      remove_columns=['audio', 'file'],
                      batched=True,
                      batch_size=100,
                      num_proc=4)

#dataset.save_to_disk('datas/superb/ks_maped')
#dataset = load_from_disk('datas/superb/ks_maped')

print(dataset['train'][0]['label'])
print(dataset['train'][0]['input_values'][:20])

dataset

        

#0:   0%|          | 0/47 [00:00<?, ?ba/s]

#2:   0%|          | 0/47 [00:00<?, ?ba/s]

#3:   0%|          | 0/47 [00:00<?, ?ba/s]

#1:   0%|          | 0/47 [00:00<?, ?ba/s]

        

#0:   0%|          | 0/1 [00:00<?, ?ba/s]

#2:   0%|          | 0/1 [00:00<?, ?ba/s]

#3:   0%|          | 0/1 [00:00<?, ?ba/s]

#1:   0%|          | 0/1 [00:00<?, ?ba/s]

3
[-0.0017543029971420765, -0.0013663222780451179, -0.0013663222780451179, -0.0017543029971420765, -0.0013663222780451179, -0.0036942060105502605, -0.005634109023958445, -0.000978341675363481, 0.002125503495335579, 0.00018560023454483598, -0.004082186613231897, -0.002918244805186987, 0.0009615615126676857, 0.0029014647006988525, -0.00253026420250535, -0.0095139155164361, -0.00330622517503798, 0.0013495421735569835, -0.0017543029971420765, -0.0021422835998237133]


DatasetDict({
    train: Dataset({
        features: ['label', 'input_values'],
        num_rows: 18538
    })
    validation: Dataset({
        features: ['label'],
        num_rows: 0
    })
    test: Dataset({
        features: ['label', 'input_values'],
        num_rows: 170
    })
})

In [4]:
import torch


#数据整理函数
def collate_fn(data):
    label = [i['label'] for i in data]
    input_values = [i['input_values'] for i in data]
    #attention_mask = [i['attention_mask'] for i in data]

    return {
        'labels': torch.LongTensor(label),
        'input_values': torch.FloatTensor(input_values),
        #'attention_mask': torch.LongTensor(attention_mask),
    }


#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

for k, v in data.items():
    print(k, v.shape, v[:3])

len(loader)

labels torch.Size([8]) tensor([9, 3, 3])
input_values torch.Size([8, 16000]) tensor([[ 0.0047,  0.0047,  0.0047,  ...,  0.0089,  0.0121,  0.0142],
        [-0.0019, -0.0002, -0.0019,  ..., -0.0036, -0.0136, -0.0152],
        [-0.0830, -0.0911, -0.0830,  ...,  0.0991,  0.0991,  0.1031]])


2317

In [5]:
from transformers import AutoModelForAudioClassification, Wav2Vec2Model

#加载模型
#model = AutoModelForAudioClassification.from_pretrained('facebook/wav2vec2-base', num_labels=12)

#不加这个参数可能导致loss为nan
#model.config.ctc_zero_infinity = True


#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = Wav2Vec2Model.from_pretrained(
            'facebook/wav2vec2-base')

        #不训练pretrained,这个模型有bug，更新参数的话可能会导致参数全部为nan
        for param in self.pretrained.parameters():
            param.requires_grad_(False)

        self.fc1 = torch.nn.Linear(768, 256)
        self.fc2 = torch.nn.Linear(256, 10)

        #加载预训练模型的参数
        parameters = AutoModelForAudioClassification.from_pretrained(
            'facebook/wav2vec2-base', num_labels=10)
        self.fc1.load_state_dict(parameters.projector.state_dict())
        self.fc2.load_state_dict(parameters.classifier.state_dict())

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, input_values, labels):
        with torch.no_grad():
            logits = self.pretrained(input_values=input_values)
        logits = logits.last_hidden_state

        logits = self.fc1(logits)
        logits = logits.mean(dim=1)
        logits = self.fc2(logits)

        loss = self.criterion(logits, labels)

        return {'loss': loss, 'logits': logits}


model = Model()

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

out = model(**data)

out['loss'], out['logits'].shape

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.bias', 'quantizer.codevectors', 'project_hid.weight', 'project_q.bias', 'quantizer.weight_proj.bias', 'quantizer.weight_proj.weight', 'project_q.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForSequenceClassification: ['project_hid.bias', 'quan

9457.1146


(tensor(2.2916, grad_fn=<NllLossBackward0>), torch.Size([8, 10]))

In [6]:
from IPython.display import Audio, display


def show(audio, out, label):
    display(Audio(audio, rate=16000))
    print('%s / %s' % (out, label))


show(data['input_values'][0], 'a', 'b')

a / b


In [7]:
#测试
def test():
    model.eval()

    names = dataset['test'].features['label'].names

    #数据加载器
    loader_test = torch.utils.data.DataLoader(
        dataset=dataset['test'],
        batch_size=16,
        collate_fn=collate_fn,
        shuffle=False,
        drop_last=True,
    )

    correct = 0
    total = 0
    for i, data in enumerate(loader_test):
        #计算
        with torch.no_grad():
            out = model(**data)

        out = out['logits'].argmax(dim=1)
        correct += (out == data['labels']).sum().item()
        total += 16

        if i % 1 == 0:
            show(data['input_values'][0], names[out[0]],
                 names[data['labels'][0]])

        if i == 8:
            break

    print(correct / total)


test()

yes / yes


right / on


on / right


right / off


on / right


on / yes


yes / stop


yes / off


on / yes
0.0763888888888889


In [8]:
from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    optimizer = AdamW(model.parameters(),
                      lr=5e-4,
                      betas=(0.9, 0.999),
                      eps=1e-8)

    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    model.train()
    for i, data in enumerate(loader):
        out = model(**data)

        loss = out['loss']
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if i % 10 == 0:
            out = out['logits'].argmax(dim=1)
            accuracy = (data['labels'] == out).sum().item() / 8
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), lr, accuracy)

    torch.save(model, 'models/10.语音分类.model')


train()



0 2.318608283996582 0.0004997842037116962 0.125
10 2.2279653549194336 0.0004976262408286577 0.25
20 2.2289209365844727 0.0004954682779456193 0.0
30 2.135363817214966 0.0004933103150625809 0.375
40 2.2534003257751465 0.0004911523521795425 0.0
50 2.127800703048706 0.0004889943892965041 0.375
60 2.043365240097046 0.0004868364264134657 0.25
70 2.0175156593322754 0.00048467846353042725 0.25
80 1.7809906005859375 0.0004825205006473889 0.625
90 2.009760618209839 0.0004803625377643505 0.375
100 1.783545732498169 0.000478204574881312 0.375
110 1.7011560201644897 0.00047604661199827366 0.75
120 1.9074225425720215 0.0004738886491152352 0.5
130 1.6788890361785889 0.00047173068623219684 0.5
140 1.5558550357818604 0.0004695727233491584 0.75
150 1.6103097200393677 0.00046741476046611996 0.5
160 1.8042408227920532 0.0004652567975830816 0.5
170 1.575299859046936 0.00046309883470004314 0.75
180 1.393979787826538 0.0004609408718170048 0.5
190 1.8946713209152222 0.00045878290893396637 0.375
200 1.90949869

1610 1.0958058834075928 0.00015235217954251187 0.625
1620 0.6424765586853027 0.00015019421665947346 0.625
1630 0.8516159057617188 0.00014803625377643502 0.75
1640 0.6461579203605652 0.00014587829089339664 0.875
1650 0.6041488647460938 0.00014372032801035823 0.875
1660 0.863368809223175 0.00014156236512731982 0.75
1670 0.6324833631515503 0.0001394044022442814 0.875
1680 0.8206380009651184 0.000137246439361243 0.75
1690 1.1162978410720825 0.00013508847647820458 0.75
1700 0.31985193490982056 0.00013293051359516617 0.875
1710 0.5460153818130493 0.00013077255071212776 0.75
1720 0.24905581772327423 0.00012861458782908932 1.0
1730 0.5569127798080444 0.00012645662494605093 0.75
1740 0.4469282329082489 0.00012429866206301252 0.875
1750 0.9005675911903381 0.0001221406991799741 0.875
1760 0.9380244016647339 0.0001199827362969357 0.875
1770 0.29966840147972107 0.00011782477341389729 0.875
1780 0.5130868554115295 0.00011566681053085887 0.75
1790 0.5179408192634583 0.00011350884764782046 0.75
1800 1

In [9]:
model = torch.load('models/10.语音分类.model')
test()

yes / yes


right / on


on / right


off / off


right / right


yes / yes


yes / stop


off / off


yes / yes
0.6875
