In [1]:
#全局变量
hub_token = open('/root/hub_token.txt').read().strip()
repo_id = 'lansinuote/au.1.audio_classification'
push_to_hub = True

In [2]:
from transformers import AutoFeatureExtractor

#音频数据批处理工具类
feature_extractor = AutoFeatureExtractor.from_pretrained(
    'facebook/wav2vec2-base')

print(feature_extractor)

#编码成数据
#input_values是个list，里面每一项是一维的numpy数组，并且长度是确定的16000
feature_extractor(
    #数据,模拟三句话，每句话用4个数字表示
    raw_speech=[[0.1, 0.1, 0.1, 0.1], [0.2, 0.2, 0.2, 0.2],
                [0.3, 0.3, 0.3, 0.3]],
    #采样率，也就是用多少个数组表示1秒的数据，这里16k就意味着是16khz的采样率
    sampling_rate=16000,
    #定义最大长度，16000个数字意味着是1秒的长度
    max_length=8,
    #超出max_length的部分截断
    truncation=True,
    #不足max_length的补充pad
    #padding='max_length',
    #返回attention_mask
    return_attention_mask=True,
)



Wav2Vec2FeatureExtractor {
  "do_normalize": true,
  "feature_extractor_type": "Wav2Vec2FeatureExtractor",
  "feature_size": 1,
  "padding_side": "right",
  "padding_value": 0.0,
  "return_attention_mask": false,
  "sampling_rate": 16000
}



{'input_values': [array([0., 0., 0., 0.], dtype=float32), array([0., 0., 0., 0.], dtype=float32), array([0., 0., 0., 0.], dtype=float32)], 'attention_mask': [array([1, 1, 1, 1], dtype=int32), array([1, 1, 1, 1], dtype=int32), array([1, 1, 1, 1], dtype=int32)]}

In [3]:
from datasets import load_dataset, concatenate_datasets


def get_dataset():
    #pip install soundfile
    #pip install librosa

    #加载数据集
    dataset = load_dataset(path='superb', name='ks')

    #查看label
    print(dataset['train'].features['label'].names)
    #['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', '_silence_', '_unknown_']

    print(dataset['train'].features['label'].num_classes)
    #12

    #查看数据样例
    print(dataset['train'][0])

    #在训练集中各类样本的分布是不均匀的，11太多了，10太少了，干脆放弃这两个类别，也不测试了
    def f(data):
        return [i != 11 and i != 10 for i in data['label']]

    dataset = dataset.filter(function=f,
                             batched=True,
                             batch_size=100,
                             num_proc=4)

    #重新切分
    dataset = concatenate_datasets(list(dataset.values()))
    dataset = dataset.train_test_split(test_size=100, seed=0)

    print(dataset)

    #数据预处理函数
    def f(data):
        #取出数据中的音频数据，每一个是一维的numpy数组
        data = [i['array'] for i in data['audio']]

        #编码成数据
        return feature_extractor(
            #数据
            raw_speech=data,
            #采样率，也就是用多少个数组表示1秒的数据，这里16k就意味着是16khz的采样率
            sampling_rate=16000,
            #定义最大长度，16000个数字意味着是1秒的长度
            max_length=16000,
            #超出max_length的部分截断
            truncation=True,
            #不足max_length的补充pad
            padding='max_length',
            #返回attention_mask
            return_attention_mask=True,
        )

    #预处理数据
    dataset = dataset.map(function=f,
                          remove_columns=['audio', 'file'],
                          batched=True,
                          batch_size=100,
                          num_proc=4)

    return dataset


if push_to_hub:
    dataset = get_dataset()
    dataset.push_to_hub(repo_id=repo_id, token=hub_token)

#直接使用我处理好的数据集
dataset = load_dataset(path=repo_id)

print(dataset)
print(dataset['train'][0]['label'])
print(dataset['train'][0]['input_values'][:20])
print(dataset['train'][0]['attention_mask'][:20])

Found cached dataset superb (/root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e)


  0%|          | 0/3 [00:00<?, ?it/s]

['yes', 'no', 'up', 'down', 'left', 'right', 'on', 'off', 'stop', 'go', '_silence_', '_unknown_']
12
{'file': '/root/.cache/huggingface/datasets/downloads/extracted/05734a36d88019a09725c20cc024e1c4e7982e37d7d55c0c1ca1742ea1cdd47f/_background_noise_/doing_the_dishes.wav', 'audio': {'path': '/root/.cache/huggingface/datasets/downloads/extracted/05734a36d88019a09725c20cc024e1c4e7982e37d7d55c0c1ca1742ea1cdd47f/_background_noise_/doing_the_dishes.wav', 'array': array([ 0.        ,  0.        ,  0.        , ..., -0.00592041,
       -0.00405884, -0.00253296], dtype=float32), 'sampling_rate': 16000}, 'label': 10}
 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-a244d3a271dc796e_00000_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-a244d3a271dc796e_00001_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-a244d3a271dc796e_00002_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-a244d3a271dc796e_00003_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-cacab44f84ace8eb_00000_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-cacab44f84ace8eb_00001_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-cacab44f84ace8eb_00002_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-cacab44f84ace8eb_00003_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-f9a3c81d22a605e1_00000_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-f9a3c81d22a605e1_00001_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-f9a3c81d22a605e1_00002_of_00004.arrow


 

Loading cached processed dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-f9a3c81d22a605e1_00003_of_00004.arrow
Loading cached split indices for dataset at /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-162e10e2bdb3b3eb.arrow and /root/.cache/huggingface/datasets/superb/ks/1.9.0/b8183f71eabe8c559d7f3f528ab37a6a21ad1ee088fd3423574cecad8b3ec67e/cache-e7c1d7f49a9ce54d.arrow


DatasetDict({
    train: Dataset({
        features: ['file', 'audio', 'label'],
        num_rows: 23582
    })
    test: Dataset({
        features: ['file', 'audio', 'label'],
        num_rows: 100
    })
})
      

#0:   0%|          | 0/59 [00:00<?, ?ba/s]

 

#1:   0%|          | 0/59 [00:00<?, ?ba/s]

 

#2:   0%|          | 0/59 [00:00<?, ?ba/s]

#3:   0%|          | 0/59 [00:00<?, ?ba/s]

      

#0:   0%|          | 0/1 [00:00<?, ?ba/s]

  

#1:   0%|          | 0/1 [00:00<?, ?ba/s]

#2:   0%|          | 0/1 [00:00<?, ?ba/s]

#3:   0%|          | 0/1 [00:00<?, ?ba/s]

Pushing split train to the Hub.


Pushing dataset shards to the dataset hub:   0%|          | 0/7 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Creating parquet from Arrow format:   0%|          | 0/4 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/7 [00:00<?, ?it/s]

Pushing split test to the Hub.


Pushing dataset shards to the dataset hub:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Deleting unused files from dataset repository:   0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration lansinuote--au.1.audio_classification-e816449103f4e676


Downloading and preparing dataset None/None to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--au.1.audio_classification-e816449103f4e676/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/161M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/161M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/161M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/160M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/162M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/161M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/161M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/23582 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/100 [00:00<?, ? examples/s]

Dataset parquet downloaded and prepared to /root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--au.1.audio_classification-e816449103f4e676/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['label', 'input_values', 'attention_mask'],
        num_rows: 23582
    })
    test: Dataset({
        features: ['label', 'input_values', 'attention_mask'],
        num_rows: 100
    })
})
9
[0.008107735775411129, 0.009123360738158226, 0.009123360738158226, 0.007092110347002745, -1.7265629139728844e-05, -0.005095391534268856, -1.7265629139728844e-05, -1.7265629139728844e-05, -0.002048515947535634, -0.004079766571521759, -0.004079766571521759, -0.0010328907519578934, 0.0020139848347753286, 0.004045234993100166, -1.7265629139728844e-05, 0.0030296097975224257, 0.004045234993100166, 0.00607648491859436, 0.0030296097975224257, 0.0009983595227822661]
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


In [4]:
import torch


#数据整理函数
def collate_fn(data):
    label = [i['label'] for i in data]
    input_values = [i['input_values'] for i in data]
    attention_mask = [i['attention_mask'] for i in data]

    return {
        'labels': torch.LongTensor(label),
        'input_values': torch.FloatTensor(input_values),
        'attention_mask': torch.LongTensor(attention_mask)
    }


#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

for k, v in data.items():
    print(k, v.shape, v[:3])

len(loader)

labels torch.Size([8]) tensor([4, 5, 7])
input_values torch.Size([8, 16000]) tensor([[ 5.9352e-05,  1.1180e-03,  1.1180e-03,  ...,  5.9352e-05,
          5.9352e-05, -9.9932e-04],
        [ 2.9495e-02,  6.0386e-02,  8.0513e-02,  ..., -9.1264e-02,
         -8.5647e-02, -6.5521e-02],
        [ 1.3402e-04,  1.3402e-04,  1.3402e-04,  ...,  2.7440e-04,
         -1.4674e-04,  9.7631e-04]])
attention_mask torch.Size([8, 16000]) tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]])


2947

In [5]:
from transformers import AutoModelForAudioClassification, Wav2Vec2Model, PreTrainedModel, PretrainedConfig

#加载模型
#model = AutoModelForAudioClassification.from_pretrained('facebook/wav2vec2-base', num_labels=10)


#定义下游任务模型
class Model(PreTrainedModel):
    config_class = PretrainedConfig

    def __init__(self, config):
        super().__init__(config)
        self.pretrained = Wav2Vec2Model.from_pretrained(
            'facebook/wav2vec2-base')

        self.fc1 = torch.nn.Linear(768, 256)
        self.fc2 = torch.nn.Linear(256, 10)

        #加载预训练模型的参数
        parameters = AutoModelForAudioClassification.from_pretrained(
            'facebook/wav2vec2-base', num_labels=10)
        self.fc1.load_state_dict(parameters.projector.state_dict())
        self.fc2.load_state_dict(parameters.classifier.state_dict())

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, input_values, attention_mask, labels):
        logits = self.pretrained(input_values=input_values,
                                 attention_mask=attention_mask)
        logits = logits.last_hidden_state

        logits = self.fc1(logits)
        logits = logits.mean(dim=1)
        logits = self.fc2(logits)

        loss = self.criterion(logits, labels)

        return {'loss': loss, 'logits': logits}


model = Model(PretrainedConfig())

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

out = model(**data)

out['loss'], out['logits'].shape

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.bias', 'project_q.bias', 'project_hid.weight', 'quantizer.codevectors', 'quantizer.weight_proj.bias', 'project_q.weight', 'quantizer.weight_proj.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForSequenceClassification: ['project_hid.bias', 'project_q.bias', 'project_hid.weight', 'quantizer.codevectors', 'quantizer.weight_proj.bias', 'project_q.weig

9457.1146


(tensor(2.2947, grad_fn=<NllLossBackward0>), torch.Size([8, 10]))

In [6]:
from IPython.display import Audio, display


def show(audio, out, label):
    display(Audio(audio, rate=16000))
    print('%s / %s' % (out, label))


show(data['input_values'][0], 'a', 'b')

a / b


In [7]:
#测试
def test():
    model.eval()

    names = dataset['test'].features['label'].names

    #数据加载器
    loader_test = torch.utils.data.DataLoader(
        dataset=dataset['test'],
        batch_size=16,
        collate_fn=collate_fn,
        shuffle=False,
        drop_last=True,
    )

    correct = 0
    total = 0
    for i, data in enumerate(loader_test):
        #计算
        with torch.no_grad():
            out = model(**data)

        out = out['logits'].argmax(dim=1)
        correct += (out == data['labels']).sum().item()
        total += 16

        if i % 1 == 0:
            show(data['input_values'][0], names[out[0]],
                 names[data['labels'][0]])

        if i == 8:
            break

    print(correct / total)


test()

on / yes


on / go


up / no


up / no


off / up


up / yes
0.125


In [8]:
from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    optimizer = AdamW(model.parameters(),
                      lr=1e-5,
                      betas=(0.9, 0.999),
                      eps=1e-8)

    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.train()
    model.to(device)
    for i, data in enumerate(loader):
        for k in data.keys():
            data[k] = data[k].to(device)

        out = model(**data)

        loss = out['loss']
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if i % 100 == 0:
            out = out['logits'].argmax(dim=1)
            accuracy = (data['labels'] == out).sum().item() / 8
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), lr, accuracy)

    model.to('cpu')


if push_to_hub:
    train()
    model.push_to_hub(repo_id=repo_id, use_auth_token=hub_token)



0 2.301429033279419 9.996606718696981e-06 0.125
100 2.229736566543579 9.657278588394979e-06 0.375
200 1.7464560270309448 9.317950458092976e-06 1.0
300 1.209434986114502 8.978622327790974e-06 1.0
400 1.277097225189209 8.639294197488971e-06 0.75
500 0.843075692653656 8.29996606718697e-06 1.0
600 0.764723539352417 7.960637936884968e-06 0.875
700 1.1249730587005615 7.621309806582966e-06 0.75
800 0.47292712330818176 7.281981676280965e-06 1.0
900 0.4129190444946289 6.9426535459789625e-06 1.0
1000 0.6378757357597351 6.60332541567696e-06 0.875
1100 0.26956620812416077 6.263997285374958e-06 1.0
1200 0.33857256174087524 5.924669155072956e-06 1.0
1300 0.4451664090156555 5.585341024770954e-06 0.875
1400 0.16815416514873505 5.246012894468951e-06 1.0
1500 0.16068410873413086 4.9066847641669495e-06 1.0
1600 0.33739838004112244 4.567356633864948e-06 0.875
1700 0.13309237360954285 4.228028503562946e-06 1.0
1800 0.12347780913114548 3.888700373260944e-06 1.0
1900 0.4648444950580597 3.549372242958942e-06 

Upload 1 LFS files:   0%|          | 0/1 [00:00<?, ?it/s]

pytorch_model.bin:   0%|          | 0.00/378M [00:00<?, ?B/s]

In [10]:
#直接使用我训练好的模型
model = Model.from_pretrained(repo_id)
test()

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/378M [00:00<?, ?B/s]

Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2Model: ['project_hid.bias', 'project_q.bias', 'project_hid.weight', 'quantizer.codevectors', 'quantizer.weight_proj.bias', 'project_q.weight', 'quantizer.weight_proj.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at facebook/wav2vec2-base were not used when initializing Wav2Vec2ForSequenceClassification: ['project_hid.bias', 'project_q.bias', 'project_hid.weight', 'quantizer.codevectors', 'quantizer.weight_proj.bias', 'project_q.weig

yes / yes


go / go


no / no


no / no


up / up


yes / yes
1.0
