## Audio Spectrogram Transformer模型应用开发

**环境配置：**

1. MindSpore 2.2.14
2. Mindnlp 
3. Python 3.9

**使用华为云 ModelArts 作为AI平台**

在环境搭建部分，使用了AI gallery社区中相关mindnlp项目搭建mindnlp环境的代码。

### 1 环境配置

1. 配置python3.9环境

In [None]:
%%capture captured_output
!/home/ma-user/anaconda3/bin/conda create -n python-3.9.0 python=3.9.0 -y --override-channels --channel https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main
!/home/ma-user/anaconda3/envs/python-3.9.0/bin/pip install ipykernel

In [None]:
import json
import os

data = {
   "display_name": "python-3.9.0",
   "env": {
      "PATH": "/home/ma-user/anaconda3/envs/python-3.9.0/bin:/home/ma-user/anaconda3/envs/python-3.7.10/bin:/modelarts/authoring/notebook-conda/bin:/opt/conda/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/home/ma-user/modelarts/ma-cli/bin:/home/ma-user/modelarts/ma-cli/bin"
   },
   "language": "python",
   "argv": [
      "/home/ma-user/anaconda3/envs/python-3.9.0/bin/python",
      "-m",
      "ipykernel",
      "-f",
      "{connection_file}"
   ]
}

if not os.path.exists("/home/ma-user/anaconda3/share/jupyter/kernels/python-3.9.0/"):
    os.mkdir("/home/ma-user/anaconda3/share/jupyter/kernels/python-3.9.0/")

with open('/home/ma-user/anaconda3/share/jupyter/kernels/python-3.9.0/kernel.json', 'w') as f:
    json.dump(data, f, indent=4)

*注：以上代码执行完成后，需点击左上角或右上角将kernel更换为python-3.9.0*

2. 安装mindspore2.2.14，安装指南详见：[MindSpore安装](https://www.mindspore.cn/install/)

3. 安装MindNLP及相关依赖，MindNLP官方仓详见：[MindNLP](https://github.com/mindspore-lab/mindnlp)

In [1]:
%%capture captured_output

!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.2.14/MindSpore/unified/x86_64/mindspore-2.2.14-cp39-cp39-linux_x86_64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple
!wget https://mindspore-demo.obs.cn-north-4.myhuaweicloud.com/mindnlp_install/mindnlp-0.3.1-py3-none-any.whl
!pip install mindnlp-0.3.1-py3-none-any.whl

### 2 Audio Spectrogram Transformer模型
原文地址：https://arxiv.org/abs/2104.01778

代码链接：https://github.com/YuanGongND/ast

#### 2.1 AST 模型

AST 模型是基于 Transformer 框架应用在 Audio 处理的经典模型。这是第一个无卷积（CNN）的、基于注意力的音频分类模型。在此之前相关研究利用卷积神经网络（CNN）作为核心模块构建端到端的音频分类模型。为了捕捉长序列音频内容，相关研究在CNN的基础上添加注意力模块，形成CNN-注意力混合模型。本文作者在各种音频分类基准上评估AST，在数据集AudioSet上获得了0.485 mAP，在数据集ESC-50上准确率为95.6%，在数据集Speech Commands V2上准确率为98.1%。这说明了基于注意力机制的模型同样能够在音频数据集上取得不错的效果。

#### 2.2 mindnlp
可以从中 `dir(mindnlp.transformers)` 查看到关于AST(Audio Spectrogram Transformer)的包，导入相关的AST包

In [None]:
import mindnlp
# 查看相关模型的名称
dir(mindnlp.transformers)

#### 2.3 运用 AST 模型处理音频数据

1. 数据集下载与处理

本次使用的是音频数据集GTZAN，是音乐风格数据集。数据集涵盖了十个的音乐流派，包括Blues、Classical、Country、Disco、Hip-Hop、Jazz、Metal、Pop、Reggae和Rock。

GTZAN数据集出自论文《Musical genre classification of audio signals》，doi: 10.1109/TSA.2002.800560.

[GTZAN数据集](https://aistudio.baidu.com/datasetdetail/121525)


2. 下载并解压数据集

In [3]:
!wget -O GTZANDataset.zip https://bj.bcebos.com/ai-studio-online/de73217a9c3f41e0a8537d4199f8ba74a727831b5cc242ea9f257e044add6816?authorization=bce-auth-v1%2F5cfe9a5e1454405eb2a975c43eace6ec%2F2022-09-04T15%3A27%3A10Z%2F-1%2F%2F91554791dd44205ae1e67f55af9f3561686c2ce934fb4fa7f2c533bd4e6b6948

--2024-08-16 18:46:48--  https://bj.bcebos.com/ai-studio-online/de73217a9c3f41e0a8537d4199f8ba74a727831b5cc242ea9f257e044add6816?authorization=bce-auth-v1%2F5cfe9a5e1454405eb2a975c43eace6ec%2F2022-09-04T15%3A27%3A10Z%2F-1%2F%2F91554791dd44205ae1e67f55af9f3561686c2ce934fb4fa7f2c533bd4e6b6948
Resolving proxy.modelarts.com (proxy.modelarts.com)... 192.168.6.3
Connecting to proxy.modelarts.com (proxy.modelarts.com)|192.168.6.3|:80... connected.
Proxy request sent, awaiting response... 200 OK
Length: 1221298356 (1.1G) [application/octet-stream]
Saving to: ‘GTZANDataset.zip’


2024-08-16 18:47:11 (51.4 MB/s) - ‘GTZANDataset.zip’ saved [1221298356/1221298356]



In [4]:
!unzip -q GTZANDataset.zip

3. 利用 MindSpore 加载和处理 GTZAN 数据集

[mindspore.dataset.GTZANDataset使用说明](https://www.mindspore.cn/docs/zh-CN/r2.3.0/api_python/dataset/mindspore.dataset.GTZANDataset.html#mindspore.dataset.GTZANDataset)

In [5]:
import mindspore.dataset as ds
# 解压后的音频数据集目录
gtzan_dataset_directory = "./genres"

# 从 gtzan 数据集目录中读取500个样本
# dataset = ds.GTZANDataset(gtzan_dataset_directory, usage="all", num_samples=500)

# 从 gtzan 数据集目录中读取所有样本
dataset = ds.GTZANDataset(gtzan_dataset_directory)


In [6]:
import IPython
IPython.display.Audio(gtzan_dataset_directory+"/blues/blues.00000.wav")

In [7]:
# 从数据集中选取一个样本
for index,data in enumerate(dataset):
    sample = data
    # 生成的数据集有三列 [waveform, sample_rate, label] 。 
    # waveform 列的数据类型为float32。 
    # sample_rate 列的数据类型为uint32。 
    # label 列的数据类型为string。
    print(sample)
    print(index)
    break

[Tensor(shape=[1, 661504], dtype=Float32, value=
[[ 2.83822138e-02, -5.44755384e-02,  1.09927669e-01 ... -7.06808716e-02, -1.35837883e-01, -1.64464250e-01]]), Tensor(shape=[], dtype=UInt32, value= 22050), Tensor(shape=[], dtype=String, value= 'pop')]
0


In [8]:
# waveform 列
sample[0]

Tensor(shape=[1, 661504], dtype=Float32, value=
[[ 2.83822138e-02, -5.44755384e-02,  1.09927669e-01 ... -7.06808716e-02, -1.35837883e-01, -1.64464250e-01]])

In [9]:
# sample_rate 列
sample[1]

Tensor(shape=[], dtype=UInt32, value= 22050)

In [10]:
# label 列
print(sample[2])

pop


4. 使用 ASTFeatureExtractor 对音频进行处理

In [11]:
from mindnlp.transformers import ASTFeatureExtractor

feature_extractor = ASTFeatureExtractor()



In [12]:
inputs = feature_extractor(sample[0], return_tensors="ms")
input_values = inputs.input_values
print(input_values.shape)

It is strongly recommended to pass the `sampling_rate` argument to this function. Failing to do so can result in silent errors that might be hard to debug.


(1, 1024, 128)


5. 导入AST模型

In [13]:
from mindnlp.transformers import AutoModelForAudioClassification

model = AutoModelForAudioClassification.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")

26.1kB [00:00, 196kB/s] 
100%|██████████| 330M/330M [00:24<00:00, 14.3MB/s] 


In [14]:
outputs = model(input_values)

6. 预测该音频(GTZANDataset，音乐风格数据集)的类别：音乐

In [15]:
predicted_class_idx = outputs.logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])

Predicted class: Music
