In [None]:
# pip install --upgrade jupyter ipywidgets

import os
import json
import numpy as np
import librosa
from tqdm.notebook import tqdm
from mindnlp.transformers import AutoModelForAudioClassification, AutoFeatureExtractor
from mindspore import nn
import mindspore as ms

In [None]:
# 由于模型配置标签与esc数据集标签不一致，需要创建一个映射
def create_ast_to_esc_mapping():
    """创建AST标签到ESC-50标签的映射"""
    ast_to_esc = {
        # 动物声音
        74: '00',  # Dog -> dog
        75: '00',  # Bark -> dog
        76: '00',  # Yip -> dog
        77: '00',  # Howl -> dog
        78: '00',  # Bow-wow -> dog
        79: '00',  # Growling -> dog
        80: '00',  # Whimper (dog) -> dog
        
        99: '01',  # Chicken, rooster -> rooster
        100: '06', # Cluck -> hen
        101: '01', # Crowing -> rooster
        
        93: '02',  # Pig -> pig
        94: '02',  # Oink -> pig
        
        90: '03',  # Cattle, bovinae -> cow
        91: '03',  # Moo -> cow
        
        132: '04', # Frog -> frog
        133: '04', # Croak -> frog
        
        81: '05',  # Cat -> cat
        82: '05',  # Purr -> cat
        83: '05',  # Meow -> cat
        84: '05',  # Hiss -> cat
        85: '05',  # Caterwaul -> cat
        
        126: '07', # Insect -> insects
        128: '07', # Mosquito -> insects
        129: '07', # Fly, housefly -> insects
        130: '07', # Buzz -> insects
        131: '07', # Bee, wasp, etc. -> insects
        
        97: '08',  # Sheep -> sheep
        96: '08',  # Bleat -> sheep
        
        117: '09', # Crow -> crow
        118: '09', # Caw -> crow
        
        # 环境声音
        289: '10', # Rain -> rain
        290: '10', # Raindrop -> rain
        291: '10', # Rain on surface -> rain
        
        294: '11', # Ocean -> sea_waves
        295: '11', # Waves, surf -> sea_waves
        
        298: '12', # Fire -> crackling_fire
        299: '12', # Crackle -> crackling_fire
        
        127: '13', # Cricket -> crickets
        
        111: '14', # Bird -> chirping_birds
        112: '14', # Bird vocalization, bird call, bird song -> chirping_birds
        113: '14', # Chirp, tweet -> chirping_birds

        448: '15', # Drip -> water_drops
        
        283: '16', # Wind -> wind
        284: '16', # Rustling leaves -> wind
        285: '16', # Wind noise -> wind
        
        449: '17', # Pour -> pouring_water
        450: '17', # Trickle, dribble -> pouring_water
        451: '17', # Gush -> pouring_water
        
        374: '18', # Toilet flush -> toilet_flush
        
        286: '19', # Thunderstorm -> thunderstorm
        287: '19', # Thunder -> thunderstorm
        
        # 人类声音
        23: '20',  # Baby cry, infant cry -> crying_baby
        
        49: '21',  # Sneeze -> sneezing
        
        63: '22',  # Clapping -> clapping
        
        41: '23',  # Breathing -> breathing
        42: '23',  # Wheeze -> breathing
        
        47: '24',  # Cough -> coughing
        
        53: '25',  # Walk, footsteps -> footsteps
        
        16: '26',  # Laughter -> laughing
        17: '26',  # Baby laughter -> laughing
        18: '26',  # Giggle -> laughing
        19: '26',  # Snicker -> laughing
        20: '26',  # Belly laugh -> laughing
        21: '26',  # Chuckle, chortle -> laughing
        
        375: '27', # Toothbrush -> brushing_teeth
        376: '27', # Electric toothbrush -> brushing_teeth
        
        43: '28',  # Snoring -> snoring
        
        54: '29',  # Chewing, mastication -> drinking_sipping
        
        # 室内声音
        358: '30', # Knock -> door_wood_knock
        359: '30', # Knock -> door_wood_knock
        
        491: '31', # Clicking -> mouse_click
        
        386: '32', # Computer keyboard -> keyboard_typing
        
        361: '33', # Squeak -> door_wood_creaks
        
        364: '34', # Dishes, pots, and pans -> can_opening

        377: '36', # Vacuum cleaner -> vacuum_cleaner
        
        395: '37', # Alarm clock -> clock_alarm
        
        407: '38', # Tick -> clock_tick
        408: '38', # Tick-tock -> clock_tick
        
        443: '39', # Shatter -> glass_breaking
        
        # 交通工具声音
        339: '40', # Helicopter -> helicopter
        
        347: '41', # Chainsaw -> chainsaw
        
        396: '42', # Siren -> siren
        397: '42', # Civil defense siren -> siren
        
        308: '43', # Vehicle horn, car horn, honking -> car_horn
        309: '43', # Toot -> car_horn
        
        343: '44', # Engine -> engine
        344: '44', # Light engine -> engine
        348: '44', # Medium engine -> engine
        349: '44', # Heavy engine -> engine
        
        329: '45', # Train -> train
        330: '45', # Train whistle -> train
        331: '45', # Train horn -> train
        
        201: '46', # Church bell -> church_bells
        
        340: '47', # Fixed-wing aircraft, airplane -> airplane
        
        432: '48', # Fireworks -> fireworks
        433: '48', # Firecracker -> fireworks
        
        421: '49', # Sawing -> hand_saw
    }
    return ast_to_esc

In [None]:
def load_data(json_path):
    """加载ESC-50数据集的json文件"""
    with open(json_path, 'r') as fp:
        data = json.load(fp)
    return data['data']

def preprocess_function(audio_path, feature_extractor):
    """预处理音频文件"""
    # 使用librosa读取音频文件
    audio, sr = librosa.load(audio_path, sr=16000)
    # audio, sr = librosa.load(audio_path)
    # 使用特征提取器处理音频
    inputs = feature_extractor(
        audio,
        sampling_rate=16000,
        return_tensors="ms",
        padding=True
    )
    
    # 获取输入值
    input_values = inputs.input_values
    
    # 打印形状以便调试
    # print(f"Input shape before processing: {input_values.shape}")
    
    # 不要squeeze，保持原始形状
    return input_values

def evaluate(model, data, feature_extractor):
    """评估函数"""
    model.set_train(False)
    correct = 0
    total = 0
    
    # 获取AST到ESC的映射
    ast_to_esc = create_ast_to_esc_mapping()
    
    for item in tqdm(data):
        # 获取音频路径和标签
        audio_path = item['wav']
        # audio_path=audio_path.replace('_16k','')
        # 从标签中提取数字部分（去掉'/m/07rwj'前缀）
        esc_label = item['labels'].replace('/m/07rwj', '')
        
        try:
            # 预处理音频
            audio = preprocess_function(audio_path, feature_extractor)
            
            # 模型推理
            outputs = model(audio)
            ast_pred = outputs.logits.argmax(axis=-1).asnumpy()[0]
            
            # 将AST预测映射到ESC标签
            esc_pred = ast_to_esc.get(ast_pred, None)
            # print('################################')
            # print(ast_pred)
            # print(esc_pred)
            # print(esc_label)
            if esc_pred is None:
                continue
            
            total += 1
            if esc_pred == esc_label:
                correct += 1
            
        except Exception as e:
            print(f"处理文件 {audio_path} 时出错: {str(e)}")
            print(f"错误类型: {type(e)}")
            import traceback
            print(traceback.format_exc())
            continue

    # 计算acc
    accuracy = correct / total if total > 0 else 0
    return accuracy

In [None]:
def main():
    # 设置模型下载镜像地址
    os.environ["HF_HOME"] = "https://hf-mirror.com/"
    # 设置设备
    ms.set_context(device_target='Ascend')
    
    # 加载AST模型和特征提取器
    # 使用论文中的ast-p模型
    model_name = "MIT/ast-finetuned-audioset-10-10-0.4593"
    model = AutoModelForAudioClassification.from_pretrained(model_name)
    feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
    
    # 打印模型配置
    # print(f"模型配置：{model.config}")
    
    # 测试单个样本以验证预处理
    test_fold = 1
    test_data = load_data(f'./data/datafiles/esc_train_data_{test_fold}.json')
    if test_data:
        test_item = test_data[0]
        print(f"\n测试单个样本：{test_item['wav']}")
        try:
            audio = preprocess_function(test_item['wav'], feature_extractor)
            print(f"测试样本形状：{audio.shape}")
            outputs = model(audio)
            print("测试样本处理成功")
        except Exception as e:
            print(f"测试样本处理失败：{str(e)}")
    
    total_ac=0
    # 对每个fold进行评估
    for fold in range(1, 6):
        print(f"\n处理 Fold {fold}...")
        
        # 加载训练集和测试集
        # train_data = load_data(f'./data/datafiles/esc_train_data_{fold}.json')
        eval_data = load_data(f'./data/datafiles/esc_eval_data_{fold}.json')
        
        # 评估
        try:
            # train_accuracy = evaluate(model, train_data, feature_extractor)
            eval_accuracy = evaluate(model, eval_data, feature_extractor)
            
            print(f"Fold {fold}:")
            # print(f"Training Accuracy: {train_accuracy:.4f}")
            print(f"Test Accuracy: {eval_accuracy:.4f}")
            total_ac+=eval_accuracy
            print("-" * 50)
        except Exception as e:
            print(f"评估 Fold {fold} 时出错: {str(e)}")
            continue
    print(f"average Accuracy is {total_ac/5}")

In [2]:
if __name__ == "__main__":
    main() 




测试单个样本：./data/ESC-50-master/audio_16k/2-100648-A-43.wav
测试样本形状：(1, 1024, 128)
测试样本处理成功

处理 Fold 1...




  0%|          | 0/400 [00:00<?, ?it/s]

Fold 1:
Test Accuracy: 0.9287
--------------------------------------------------

处理 Fold 2...


  0%|          | 0/400 [00:00<?, ?it/s]

Fold 2:
Test Accuracy: 0.9833
--------------------------------------------------

处理 Fold 3...


  0%|          | 0/400 [00:00<?, ?it/s]

Fold 3:
Test Accuracy: 0.9662
--------------------------------------------------

处理 Fold 4...


  0%|          | 0/400 [00:00<?, ?it/s]

Fold 4:
Test Accuracy: 0.9511
--------------------------------------------------

处理 Fold 5...


  0%|          | 0/400 [00:00<?, ?it/s]

Fold 5:
Test Accuracy: 0.9484
--------------------------------------------------
average Accuracy is 0.9555199380842737
