In [1]:
import os
import shutil
from tqdm import tqdm

def organize_files(src_root_dir, dst_root_dir):
    cases = ['case1', 'case2', 'case3', 'case4']

    # Create directories for each case
    for case in cases:
        case_dir = os.path.join(dst_root_dir, case)
        if not os.path.exists(case_dir):
            os.makedirs(case_dir)

    # Traverse the source directory
    for main_dir in tqdm(os.listdir(src_root_dir), desc="Processing directories"):
        main_dir_path = os.path.join(src_root_dir, main_dir)
        if os.path.isdir(main_dir_path):
            for sub_dir in os.listdir(main_dir_path):
                sub_dir_path = os.path.join(main_dir_path, sub_dir)
                if os.path.isdir(sub_dir_path):
                    case_name = sub_dir.split('_')[0]
                    case_dir = os.path.join(dst_root_dir, case_name)

                    # 首先复制原始的case1_1.wav文件
                    if case_name == 'case1':
                        original_wav = os.path.join(sub_dir_path, f"{main_dir}_{sub_dir}.wav")
                        if os.path.exists(original_wav):
                            dst_original_wav = os.path.join(case_dir, f"{main_dir}_{sub_dir}.wav")
                            shutil.copy(original_wav, dst_original_wav)

                    # 然后复制分割后的音频和对应的npy文件
                    for filename in os.listdir(sub_dir_path):
                        if filename.startswith('sample_') and (filename.endswith('.wav') or filename.endswith('.npy')):
                            src_path = os.path.join(sub_dir_path, filename)
                            dst_path = os.path.join(case_dir, f"{main_dir}_{sub_dir}_{filename}")
                            shutil.copy(src_path, dst_path)
src_directory = 'dataset_1k2k3k_withbandpass_extrafeatures_v3'
dst_directory = 'data_1k2k3k_nobandpass_organized_dataset_extrafeatures'
    
print(f"源目录: {src_directory}")
print(f"目标目录: {dst_directory}")

if not os.path.exists(src_directory):
    print(f"源目录不存在: {src_directory}")
    
organize_files(src_directory, dst_directory)
print("文件整理完成。")

源目录: dataset_1k2k3k_withbandpass_extrafeatures_v3
目标目录: data_1k2k3k_nobandpass_organized_dataset_extrafeatures


Processing directories: 100%|██████████| 19/19 [00:04<00:00,  3.88it/s]

文件整理完成。





In [2]:
import os
import shutil

# 定义路径
root_path = 'data_1k2k3k_nobandpass_organized_dataset_extrafeatures'
new_structure_path = 'data_1k2k3k_nobandpass_organized_withoutA3A7_dataset'

# 创建新的目录结构
if not os.path.exists(new_structure_path):
    os.makedirs(new_structure_path)

# 遍历case1到case4
for case in ['case1', 'case2', 'case3', 'case4']:
    case_path = os.path.join(root_path, case)
    for root, dirs, files in os.walk(case_path):
        for file in files:
            if file.endswith('.wav'):
                # 解析文件名获取前缀、case_id和sample_set
                parts = file.split('_')
                prefix = parts[0]  # 获取前缀部分，如A1

                # 跳过prefix为"A3"或"A7"的文件
                if prefix in ['A3', 'A7']:
                    continue

                case_id = f"{parts[1]}_{parts[2]}"
                sample_set = parts[3]
                sample_id = parts[4].split('.')[0]  # 获取sample集编号

                # 创建新的路径
                new_case_path = os.path.join(new_structure_path, case)
                if not os.path.exists(new_case_path):
                    os.makedirs(new_case_path)

                new_prefix_path = os.path.join(new_case_path, prefix)
                if not os.path.exists(new_prefix_path):
                    os.makedirs(new_prefix_path)

                new_case_id_path = os.path.join(new_prefix_path, case_id)
                if not os.path.exists(new_case_id_path):
                    os.makedirs(new_case_id_path)

                new_sample_set_path = os.path.join(new_case_id_path, f'sample_{sample_set}')
                if not os.path.exists(new_sample_set_path):
                    os.makedirs(new_sample_set_path)

                # 移动文件到新的路径
                old_file_path = os.path.join(root, file)
                new_file_path = os.path.join(new_sample_set_path, file)
                shutil.move(old_file_path, new_file_path)

                npy_file = file.replace('.wav', '.npy')
                old_npy_path = os.path.join(root, npy_file)
                if os.path.exists(old_npy_path):  # 检查npy文件是否存在
                    new_npy_path = os.path.join(new_sample_set_path, npy_file)
                    shutil.move(old_npy_path, new_npy_path)

print("文件重新组织完成。")


文件重新组织完成。


In [3]:
import numpy as np
import pandas as pd

from pathlib import Path
from tqdm import tqdm

import torchaudio
from sklearn.model_selection import train_test_split
import os
import sys

In [4]:
import os
from pathlib import Path
import torchaudio
from tqdm import tqdm

data = []

for case in ['case1', 'case2', 'case3', 'case4']:
    case_path = Path(f'data_1k2k3k_nobandpass_organized_withoutA3A7_dataset/{case}')
    for path in tqdm(case_path.glob("**/*.wav")):
        name = path.stem
        # 获取前缀、case_id和sample_set
        parts = path.parts[-4:]  # 获取最后4部分: case文件夹, 前缀文件夹, case_id文件夹, 和文件名
        prefix = parts[1]  # 前缀文件夹
        case_id = parts[2]  # case_id文件夹
        sample_set = parts[3].split('_')[1]  # 从文件名中提取sample_set

        try:
            # 加载文件
            s = torchaudio.load(path)
             # 加载对应的 .npy 文件
            npy_path = path.with_suffix('.npy')
            if npy_path.exists():
                energy_features = np.load(npy_path)
            else:
                energy_features = None
                print(f"Warning: No .npy file found for {path}")
            data.append({
                "name": name,
                "path": str(path),
                "case": case,
                "prefix": prefix,
                "case_id": case_id,
                "sample_set": sample_set,
                "energy_features": energy_features
            })
        except Exception as e:
            # 跳过损坏的文件
            pass

# 显示收集到的数据条目数
print(f"Collected {len(data)} items.")


684it [00:09, 69.99it/s]
684it [00:09, 69.47it/s]
1368it [00:16, 83.85it/s] 
1368it [00:15, 86.43it/s] 

Collected 4104 items.





In [5]:
import pandas as pd
df = pd.DataFrame(data)
df.head()

Unnamed: 0,name,path,case,prefix,case_id,sample_set,energy_features
0,A1_case1_1_sample_10_2,data_1k2k3k_nobandpass_organized_withoutA3A7_d...,case1,case1_1,sample_sample,case1,0.030397265777557
1,A1_case1_1_sample_11_2,data_1k2k3k_nobandpass_organized_withoutA3A7_d...,case1,case1_1,sample_sample,case1,0.0303082682398088
2,A1_case1_1_sample_12_2,data_1k2k3k_nobandpass_organized_withoutA3A7_d...,case1,case1_1,sample_sample,case1,0.0263083679363482
3,A1_case1_1_sample_13_3,data_1k2k3k_nobandpass_organized_withoutA3A7_d...,case1,case1_1,sample_sample,case1,0.0126777785655913
4,A1_case1_1_sample_14_3,data_1k2k3k_nobandpass_organized_withoutA3A7_d...,case1,case1_1,sample_sample,case1,0.0162335163620257


In [6]:
import os
import pandas as pd
import torch
import torchaudio
from scipy.io import wavfile
from scipy.signal import resample
from torchaudio.utils import download_asset
from IPython.display import Audio

# Define function to resample audio
def resample_audio(data, orig_sr, target_sr=16000):
    number_of_samples = round(len(data) * float(target_sr) / orig_sr)
    resampled_data = resample(data, number_of_samples)
    return resampled_data

# Load and process your dataframe `df`
file_path_column = "path"
df["status"] = df[file_path_column].apply(lambda path: True if os.path.exists(path) else None)
df = df.dropna(subset=["status"])
df = df.drop("status", axis=1)

# Shuffle and reset index of dataframe
df = df.sample(frac=1).reset_index(drop=True)
print(df.head())  # Check the first few rows after preprocessing

# Process each audio file in the dataframe
for index, row in df.iterrows():
    file_path = row[file_path_column]
    orig_sample_rate, data = wavfile.read(file_path)

    # Resample audio to target sample rate (16000 Hz)
    resampled_data = resample_audio(data, orig_sample_rate, target_sr=16000)

    # Save the resampled data back to a file
    resampled_path = file_path.replace(".wav", "_resampled.wav")
    torchaudio.save(resampled_path, torch.from_numpy(resampled_data).unsqueeze(0), sample_rate=16000)

    # Update the path in the dataframe
    df.at[index, file_path_column] = resampled_path

    # Print progress or any other processing steps
    print(f"Processed {file_path} and saved resampled audio to {resampled_path}")

# Optionally, save the updated dataframe with processed data paths
df.to_csv("processed_audio_data.csv", index=False)


                      name                                               path  \
0   A9_case4_9_sample_52_3  data_1k2k3k_nobandpass_organized_withoutA3A7_d...   
1    A8_case4_1_sample_4_1  data_1k2k3k_nobandpass_organized_withoutA3A7_d...   
2    E8_case3_1_sample_4_1  data_1k2k3k_nobandpass_organized_withoutA3A7_d...   
3  A10_case4_3_sample_18_3  data_1k2k3k_nobandpass_organized_withoutA3A7_d...   
4   A9_case4_6_sample_31_3  data_1k2k3k_nobandpass_organized_withoutA3A7_d...   

    case   prefix        case_id sample_set      energy_features  
0  case4  case4_9  sample_sample      case4    0.575738765581533  
1  case4  case4_1  sample_sample      case4  0.14253931506919848  
2  case3  case3_1  sample_sample      case3   1.1833151451022026  
3  case4  case4_3  sample_sample      case4   0.1667274870129316  
4  case4  case4_6  sample_sample      case4    0.708377408278764  
Processed data_1k2k3k_nobandpass_organized_withoutA3A7_dataset\case4\A9\case4_9\sample_sample\A9_case4_9_sample

Let's explore how many labels (emotions) are in the dataset with what distribution.

In [7]:
# Filter broken and non-existent paths

print(f"Step 0: {len(df)}")

df["status"] = df["path"].apply(lambda path: True if os.path.exists(path) else None)
df = df.dropna(subset=["status"])
df = df.drop("status", axis=1)
print(f"Step 1: {len(df)}")

df = df.sample(frac=1)
df = df.reset_index(drop=True)

# Print unique emotions and count
print("Labels: ", df["case"].unique())
print()
print(df.groupby("case").count()[["path"]])


Step 0: 4104
Step 1: 4104
Labels:  ['case4' 'case3' 'case1' 'case2']

       path
case       
case1   684
case2   684
case3  1368
case4  1368


For training purposes, we need to split data into train test sets; in this specific example, we break with a `20%` rate for the test set.

In [9]:
import os
import pandas as pd
import torchaudio
import librosa
import numpy as np
import torch
from sklearn.model_selection import train_test_split
from scipy.ndimage import maximum_filter1d, uniform_filter1d
from tqdm import tqdm
from pathlib import Path

save_path = "data_1k2k3k_nobandpass_organized_withoutA3A7_dataset"

# 添加一列来标识每个参与者
df['participant'] = df['name'].apply(lambda x: x.split('_')[0])

# 确保路径存在
df["status"] = df["path"].apply(lambda path: True if os.path.exists(path) else None)
df = df.dropna(subset=["status"])
df = df.drop("status", axis=1)

# 获取所有参与者的唯一列表
participants = df['participant'].unique()

# 随机划分参与者
train_participants, eval_participants = train_test_split(participants, test_size=0.2, random_state=101)

# 根据参与者划分数据集
train_df = df[df['participant'].isin(train_participants)].reset_index(drop=True)
eval_df = df[df['participant'].isin(eval_participants)].reset_index(drop=True)

# 打印参与者信息以确认划分
print("Unique participants in training dataset:", train_df['participant'].unique())
print("Unique participants in evaluation dataset:", eval_df['participant'].unique())

# 保存为 CSV 文件
save_path = "data_1k2k3k_nobandpass_organized_withoutA3A7_dataset"
train_df.to_csv(f"{save_path}/train.csv", sep="\t", encoding="utf-8", index=False)
eval_df.to_csv(f"{save_path}/test.csv", sep="\t", encoding="utf-8", index=False)
# 打印数据集的形状
print(train_df.shape)
print(eval_df.shape)

# 加载数据集
from datasets import load_dataset

data_files = {
    "train": f"{save_path}/train.csv",
    "validation": f"{save_path}/test.csv",
}

dataset = load_dataset("csv", data_files={"train": f"{save_path}/train.csv", "validation": f"{save_path}/test.csv"}, delimiter="\t")
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]


# 检查加载后的数据集
print("Loaded training dataset size:", len(train_dataset))
print("Loaded validation dataset size:", len(eval_dataset))

# 指定输入和输出列
input_column = "path"
output_column = "case"

# 加载能量特征
def load_energy_features(example):
    try:
        # 从 wav 文件路径获取对应的 npy 文件路径
        wav_path = example['path']
        npy_path = wav_path.replace('_resampled.wav', '.npy')
        
        # 加载 npy 文件
        features = np.load(npy_path)
        example['energy_features'] = features
        return example
    except Exception as e:
        print(f"Error loading energy features from {example}: {e}")
        example['energy_features'] = None
        return example

train_dataset = train_dataset.map(load_energy_features)
eval_dataset = eval_dataset.map(load_energy_features)

# 检查能量特征加载情况
print("Train dataset with energy features:")
print(train_dataset[:5])
print("Validation dataset with energy features:")
print(eval_dataset[:5])

# 打印数据集的前几行，检查数据完整性
print("Train dataset preview:")
print(train_dataset[:10])
print("Validation dataset preview:")
print(eval_dataset[:10])

# 打印每个 case 的样本数
print("Sample count per case in training dataset:")
print(train_dataset.to_pandas()[output_column].value_counts())
print("Sample count per case in validation dataset:")
print(eval_dataset.to_pandas()[output_column].value_counts())

# 识别和排序标签列表
label_list = train_dataset.unique(output_column)
label_list.sort()  # Let's sort it for determinism
num_labels = len(label_list)
print(f"A classification problem with {num_labels} classes: {label_list}")


Unique participants in training dataset: ['A1' 'E10' 'A2' 'A9' 'E1' 'A4' 'A8' 'A10' 'E7' 'E5' 'E4' 'E3' 'E6' 'A6'
 'E11']
Unique participants in evaluation dataset: ['E9' 'A5' 'E8' 'E2']
(3240, 8)
(864, 8)


  from .autonotebook import tqdm as notebook_tqdm
Generating train split: 3240 examples [00:00, 215857.82 examples/s]
Generating validation split: 864 examples [00:00, 123834.02 examples/s]


Loaded training dataset size: 3240
Loaded validation dataset size: 864


Map: 100%|██████████| 3240/3240 [00:00<00:00, 3861.30 examples/s]
Map: 100%|██████████| 864/864 [00:00<00:00, 3987.58 examples/s]

Train dataset with energy features:
{'name': ['A1_case4_9_sample_52_3', 'A1_case4_4_sample_22_1', 'E10_case1_1_sample_9_2', 'A2_case1_1_sample_6_1', 'A9_case4_8_sample_43_2'], 'path': ['data_1k2k3k_nobandpass_organized_withoutA3A7_dataset\\case4\\A1\\case4_9\\sample_sample\\A1_case4_9_sample_52_3_resampled.wav', 'data_1k2k3k_nobandpass_organized_withoutA3A7_dataset\\case4\\A1\\case4_4\\sample_sample\\A1_case4_4_sample_22_1_resampled.wav', 'data_1k2k3k_nobandpass_organized_withoutA3A7_dataset\\case1\\E10\\case1_1\\sample_sample\\E10_case1_1_sample_9_2_resampled.wav', 'data_1k2k3k_nobandpass_organized_withoutA3A7_dataset\\case1\\A2\\case1_1\\sample_sample\\A2_case1_1_sample_6_1_resampled.wav', 'data_1k2k3k_nobandpass_organized_withoutA3A7_dataset\\case4\\A9\\case4_8\\sample_sample\\A9_case4_8_sample_43_2_resampled.wav'], 'case': ['case4', 'case4', 'case1', 'case1', 'case4'], 'prefix': ['case4_9', 'case4_4', 'case1_1', 'case1_1', 'case4_8'], 'case_id': ['sample_sample', 'sample_sample', '




In [10]:
# 统计包含能量特征的样本数
train_with_features = sum(1 for item in train_dataset if item['energy_features'] is not None)
eval_with_features = sum(1 for item in eval_dataset if item['energy_features'] is not None)
print(f"Train samples with energy features: {train_with_features} out of {len(train_dataset)}")
print(f"Validation samples with energy features: {eval_with_features} out of {len(eval_dataset)}")

Train samples with energy features: 3240 out of 3240
Validation samples with energy features: 864 out of 864


## Prepare Data for Training

In [11]:
# Loading the created dataset using datasets
from datasets import load_dataset

# 定义数据文件路径
data_files = {
    "train": "data_1k2k3k_nobandpass_organized_withoutA3A7_dataset/train.csv",
    "validation": "data_1k2k3k_nobandpass_organized_withoutA3A7_dataset/test.csv",
}

# 加载数据集
dataset = load_dataset("csv", data_files=data_files, delimiter="\t")
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]

# 打印数据集信息
print(train_dataset)
print(eval_dataset)


Dataset({
    features: ['name', 'path', 'case', 'prefix', 'case_id', 'sample_set', 'energy_features', 'participant'],
    num_rows: 3240
})
Dataset({
    features: ['name', 'path', 'case', 'prefix', 'case_id', 'sample_set', 'energy_features', 'participant'],
    num_rows: 864
})


In [12]:
input_column = "path"
output_column = "case"

In [13]:
# we need to distinguish the unique labels in our SER dataset
label_list = train_dataset.unique(output_column)
label_list.sort()  # Let's sort it for determinism
num_labels = len(label_list)
print(f"A classification problem with {num_labels} classes: {label_list}")

A classification problem with 4 classes: ['case1', 'case2', 'case3', 'case4']


In order to preprocess the audio into our classification model, we need to set up the relevant Wav2Vec2 assets regarding our language in this case `lighteternal/wav2vec2-large-xlsr-53-greek` fine-tuned by [Dimitris Papadopoulos](https://huggingface.co/lighteternal/wav2vec2-large-xlsr-53-greek). To handle the context representations in any audio length we use a merge strategy plan (pooling mode) to concatenate that 3D representations into 2D representations.

There are three merge strategies `mean`, `sum`, and `max`. In this example, we achieved better results on the mean approach. In the following, we need to initiate the config and the feature extractor from the Dimitris model.

In [14]:
from transformers import AutoConfig, Wav2Vec2Processor

In [15]:
model_name_or_path = "c3f9d884181a224a6ac87bf8885c84d1cff3384f"
pooling_mode = "mean"

In [16]:
# config
config = AutoConfig.from_pretrained(
    model_name_or_path,
    num_labels=num_labels,
    label2id={label: i for i, label in enumerate(label_list)},
    id2label={i: label for i, label in enumerate(label_list)},
    finetuning_task="wav2vec2_clf",
)
setattr(config, 'pooling_mode', pooling_mode)

setattr(config, 'use_energy_features', True)
setattr(config, 'energy_feature_dim', 1)

In [17]:
# processor = Wav2Vec2Processor.from_pretrained(model_name_or_path,)
from transformers import Wav2Vec2FeatureExtractor
processor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
# target_sampling_rate = processor.feature_extractor.sampling_rate
target_sampling_rate = 16000
print(f"The target sampling rate: {target_sampling_rate}")

The target sampling rate: 16000


# Preprocess Data

So far, we downloaded, loaded, and split the SER dataset into train and test sets. The instantiated our strategy configuration for using context representations in our classification problem SER. Now, we need to extract features from the audio path in context representation tensors and feed them into our classification model to determine the emotion in the speech.

Since the audio file is saved in the `.wav` format, it is easy to use **[Librosa](https://librosa.org/doc/latest/index.html)** or others, but we suppose that the format may be in the `.mp3` format in case of generality. We found that the **[Torchaudio](https://pytorch.org/audio/stable/index.html)** library works best for reading in `.mp3` data.

An audio file usually stores both its values and the sampling rate with which the speech signal was digitalized. We want to store both in the dataset and write a **map(...)** function accordingly. Also, we need to handle the string labels into integers for our specific classification task in this case, the **single-label classification** you may want to use for your **regression** or even **multi-label classification**.

In [18]:
def speech_file_to_array_fn(path):
    speech_array, sampling_rate = torchaudio.load(path)
    resampler = torchaudio.transforms.Resample(sampling_rate, target_sampling_rate)
    speech = resampler(speech_array).squeeze().numpy()
    return speech

def label_to_id(label, label_list):

    if len(label_list) > 0:
        return label_list.index(label) if label in label_list else -1

    return label

# def preprocess_function(examples):
#     speech_list = [speech_file_to_array_fn(path) for path in examples[input_column]]
#     target_list = [label_to_id(label, label_list) for label in examples[output_column]]

#     result = processor(speech_list, sampling_rate=target_sampling_rate)
#     result["labels"] = list(target_list)

#     return result

In [19]:
df.head(2)

Unnamed: 0,name,path,case,prefix,case_id,sample_set,energy_features,participant
0,A1_case4_9_sample_52_3,data_1k2k3k_nobandpass_organized_withoutA3A7_d...,case4,case4_9,sample_sample,case4,0.0152214764462648,A1
1,E9_case4_5_sample_27_2,data_1k2k3k_nobandpass_organized_withoutA3A7_d...,case4,case4_5,sample_sample,case4,0.0977085709969269,E9


In [20]:
from transformers import Wav2Vec2FeatureExtractor

processor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
# 定义输入和输出列
input_column = "path"
output_column = "case"

# 定义预处理函数
def preprocess_function(examples):
    # 处理音频数据
    audio = [torchaudio.load(path)[0].numpy().squeeze() for path in examples["path"]]
    result = processor(audio, sampling_rate=16000, return_tensors="pt", padding=True)
    result["labels"] = [label_list.index(label) for label in examples["emotion"]]
    
    # 加载能量特征
    energy_features = []
    for wav_path in examples["path"]:
        try:
            # 从wav文件路径构造npy文件路径
            npy_path = wav_path.replace('_resampled.wav', '.npy')
            if os.path.exists(npy_path):
                features = np.load(npy_path)
                energy_features.append(features)
            else:
                print(f"Warning: No .npy file found at {npy_path}")
                energy_features.append(None)
        except Exception as e:
            print(f"Error loading energy features from {wav_path}: {e}")
            energy_features.append(None)
            
    result["energy_features"] = energy_features
    
    return result

# # 加载数据集
# data_files = {
#     "train": "data_1k2k3k_nobandpass_organized_withoutA3A7_dataset/train.csv",
#     "validation": "data_1k2k3k_nobandpass_organized_withoutA3A7_dataset/test.csv",
# }
# dataset = load_dataset("csv", data_files=data_files, delimiter="\t")
# train_dataset = dataset["train"]
# eval_dataset = dataset["validation"]

# 重命名 case 字段为 emotion
train_dataset = train_dataset.rename_column("case", "emotion")
eval_dataset = eval_dataset.rename_column("case", "emotion")

# 获取独特的标签
label_list = train_dataset.unique("emotion")
label_list.sort()  # 排序以确保确定性
num_labels = len(label_list)
print(f"A classification problem with {num_labels} classes: {label_list}")

# 预处理和映射数据集
train_dataset = train_dataset.map(
    preprocess_function,
    batch_size=100,
    batched=True,
    num_proc=1,
    remove_columns=["case_id", "sample_set", "prefix"]
)
eval_dataset = eval_dataset.map(
    preprocess_function,
    batch_size=100,
    batched=True,
    num_proc=1,
    remove_columns=["case_id", "sample_set", "prefix"]
)

print("Train dataset:")
print(train_dataset)
print("\nEval dataset:")
print(eval_dataset)

A classification problem with 4 classes: ['case1', 'case2', 'case3', 'case4']


Map: 100%|██████████| 3240/3240 [00:23<00:00, 138.81 examples/s]
Map: 100%|██████████| 864/864 [00:05<00:00, 144.33 examples/s]

Train dataset:
Dataset({
    features: ['name', 'path', 'emotion', 'energy_features', 'participant', 'input_values', 'attention_mask', 'labels'],
    num_rows: 3240
})

Eval dataset:
Dataset({
    features: ['name', 'path', 'emotion', 'energy_features', 'participant', 'input_values', 'attention_mask', 'labels'],
    num_rows: 864
})





In [21]:
# 打印训练集中的唯一参与者
train_participants = train_dataset.unique("participant")
print("Unique participants in training dataset:", train_participants)

# 打印验证集中的唯一参与者
eval_participants = eval_dataset.unique("participant")
print("Unique participants in evaluation dataset:", eval_participants)

Unique participants in training dataset: ['A1', 'E10', 'A2', 'A9', 'E1', 'A4', 'A8', 'A10', 'E7', 'E5', 'E4', 'E3', 'E6', 'A6', 'E11']
Unique participants in evaluation dataset: ['E9', 'A5', 'E8', 'E2']


In [22]:
idx = 0
print(f"Training input_values: {train_dataset[idx]['input_values']}")
print(f"Training attention_mask: {train_dataset[idx]['attention_mask']}")
print(f"Training labels: {train_dataset[idx]['labels']} - {train_dataset[idx]['emotion']}")

Training input_values: [-0.41851750016212463, -0.551163911819458, -0.46926823258399963, 0.7831283807754517, 1.2626723051071167, 0.3006206452846527, -0.4494817852973938, 0.15936148166656494, 0.5248208045959473, -0.01692991331219673, -0.6339441537857056, -0.38629257678985596, 0.5214669108390808, 0.7892333269119263, 0.2575990557670593, 0.15485253930091858, 0.35282137989997864, 0.29341840744018555, 0.06536584347486496, -0.4802039861679077, -0.7313829660415649, -0.09916476905345917, 0.9635612964630127, 0.8484675288200378, -0.04378727823495865, -0.519194483757019, -0.38555410504341125, 0.05971831828355789, 0.35143905878067017, 0.31855508685112, -0.16467133164405823, -0.049604497849941254, 0.49306535720825195, 0.33177292346954346, -0.25170543789863586, -0.11917348206043243, 0.5526325106620789, 0.4570348858833313, -0.11410888284444809, -0.227482870221138, 0.1943720132112503, 0.2140822857618332, -0.14007568359375, 0.03729008138179779, 0.47915467619895935, 0.23227721452713013, -0.203168854117393

Great, now we've successfully read all the audio files, resampled the audio files to 16kHz, and mapped each audio to the corresponding label.

## Model

Before diving into the training part, we need to build our classification model based on the merge strategy.

In [23]:
from dataclasses import dataclass
from typing import Optional, Tuple
import torch
from transformers.file_utils import ModelOutput


@dataclass
class SpeechClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


In [24]:
import torch
import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.models.wav2vec2.modeling_wav2vec2 import (
    Wav2Vec2PreTrainedModel,
    Wav2Vec2Model
)

class Wav2Vec2ClassificationHead(nn.Module):
    """Head for wav2vec classification task."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size + config.energy_feature_dim, config.hidden_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, energy_features=None, **kwargs):
        x = features
        if energy_features is not None:
            x = torch.cat([x, energy_features], dim=-1)
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.pooling_mode = config.pooling_mode
        self.config = config

        self.wav2vec2 = Wav2Vec2Model(config)
        self.classifier = Wav2Vec2ClassificationHead(config)

        self.init_weights()

    def freeze_feature_extractor(self):
        self.wav2vec2.feature_extractor._freeze_parameters()

    def merged_strategy(self, hidden_states, mode="mean"):
        if mode == "mean":
            outputs = torch.mean(hidden_states, dim=1)
        elif mode == "sum":
            outputs = torch.sum(hidden_states, dim=1)
        elif mode == "max":
            outputs = torch.max(hidden_states, dim=1)[0]
        else:
            raise Exception(
                "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
        return outputs

    def forward(
            self,
            input_values,
            attention_mask=None,
            energy_features=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            labels=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        outputs = self.wav2vec2(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]
        # print(f"Hidden states shape before: {hidden_states.shape}")
        hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
        # print(f"Hidden states shape after: {hidden_states.shape}")

        logits = self.classifier(hidden_states, energy_features)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SpeechClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [25]:
# import torch
# import torch.nn as nn
# from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

# from transformers.models.wav2vec2.modeling_wav2vec2 import (
#     Wav2Vec2PreTrainedModel,
#     Wav2Vec2Model
# )

# class Wav2Vec2ClassificationHeadWithEnergy(nn.Module):
#     def __init__(self, config):
#         super().__init__()
#         # 修改这里的维度
#         self.energy_dense = nn.Linear(config.energy_feature_dim, 128)  # 改为固定输出128维
#         self.energy_activation = nn.ReLU()

#         # 修改输入维度：wav2vec2输出维度(1024) + energy特征维度(128)
#         self.dense = nn.Linear(config.hidden_size + 128, config.hidden_size)  # wav2vec2默认输出是768维

#         self.dropout = nn.Dropout(config.final_dropout)
#         self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

#     def forward(self, wav_features, energy_features=None):
#         if energy_features is not None:
#             e = self.energy_dense(energy_features)   # 映射到128维
#             e = self.energy_activation(e)
#             x = torch.cat([wav_features, e], dim=-1) 
#         else:
#             x = wav_features

#         x = self.dropout(x)
#         x = self.dense(x)
#         x = torch.tanh(x)
#         x = self.dropout(x)
#         x = self.out_proj(x)

#         return x

# class Wav2Vec2ForSpeechClassificationWithEnergy(Wav2Vec2PreTrainedModel):
#     def __init__(self, config):
#         super().__init__(config)
#         self.num_labels = config.num_labels
#         self.pooling_mode = config.pooling_mode
#         self.config = config

#         # 主体：Wav2Vec2 base
#         self.wav2vec2 = Wav2Vec2Model(config)
#         # 分头：使用带 energy MLP 的分类头
#         self.classifier = Wav2Vec2ClassificationHeadWithEnergy(config)

#         self.init_weights()

#     def freeze_feature_extractor(self):
#         self.wav2vec2.feature_extractor._freeze_parameters()

#     def merged_strategy(self, hidden_states, mode="mean"):
#         if mode == "mean":
#             outputs = torch.mean(hidden_states, dim=1)
#         elif mode == "sum":
#             outputs = torch.sum(hidden_states, dim=1)
#         elif mode == "max":
#             outputs = torch.max(hidden_states, dim=1)[0]
#         else:
#             raise ValueError("pooling_mode must be one of ['mean', 'sum', 'max']")
#         return outputs

#     def forward(
#         self,
#         input_values,
#         attention_mask=None,
#         energy_features=None,           # <--- 新增参数
#         output_attentions=None,
#         output_hidden_states=None,
#         return_dict=None,
#         labels=None,
#     ):
#         return_dict = return_dict if return_dict is not None else self.config.use_return_dict

#         # 1) 先跑 Wav2Vec2
#         outputs = self.wav2vec2(
#             input_values,
#             attention_mask=attention_mask,
#             output_attentions=output_attentions,
#             output_hidden_states=output_hidden_states,
#             return_dict=return_dict,
#         )

#         # 2) 取最后一层隐藏状态 (B, T, hidden_size)
#         hidden_states = outputs[0]

#         # 3) 在时间维度上做 pooling (mean/sum/max)
#         hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
#         # 此时 hidden_states: [batch_size, hidden_size]

#         # 4) 把 pooled hidden + energy_features (若有) 做融合并分类
#         logits = self.classifier(hidden_states, energy_features)

#         # 5) 算 loss (可选)
#         loss = None
#         if labels is not None:
#             if self.config.problem_type is None:
#                 if self.num_labels == 1:
#                     self.config.problem_type = "regression"
#                 elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
#                     self.config.problem_type = "single_label_classification"
#                 else:
#                     self.config.problem_type = "multi_label_classification"

#             if self.config.problem_type == "regression":
#                 loss_fct = MSELoss()
#                 loss = loss_fct(logits.view(-1, self.num_labels), labels)
#             elif self.config.problem_type == "single_label_classification":
#                 loss_fct = CrossEntropyLoss()
#                 loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
#             elif self.config.problem_type == "multi_label_classification":
#                 loss_fct = BCEWithLogitsLoss()
#                 loss = loss_fct(logits, labels)

#         if not return_dict:
#             output = (logits,) + outputs[2:]
#             return ((loss,) + output) if loss is not None else output

#         return SpeechClassifierOutput(
#             loss=loss,
#             logits=logits,
#             hidden_states=outputs.hidden_states,
#             attentions=outputs.attentions,
#         )

## Training

The data is processed so that we are ready to start setting up the training pipeline. We will make use of 🤗's [Trainer](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer) for which we essentially need to do the following:

- Define a data collator. In contrast to most NLP models, XLSR-Wav2Vec2 has a much larger input length than output length. *E.g.*, a sample of input length 50000 has an output length of no more than 100. Given the large input sizes, it is much more efficient to pad the training batches dynamically meaning that all training samples should only be padded to the longest sample in their batch and not the overall longest sample. Therefore, fine-tuning XLSR-Wav2Vec2 requires a special padding data collator, which we will define below

- Evaluation metric. During training, the model should be evaluated on the word error rate. We should define a `compute_metrics` function accordingly

- Load a pretrained checkpoint. We need to load a pretrained checkpoint and configure it correctly for training.

- Define the training configuration.

After having fine-tuned the model, we will correctly evaluate it on the test data and verify that it has indeed learned to correctly transcribe speech.

### Set-up Trainer

Let's start by defining the data collator. The code for the data collator was copied from [this example](https://github.com/huggingface/transformers/blob/9a06b6b11bdfc42eea08fa91d0c737d1863c99e3/examples/research_projects/wav2vec2/run_asr.py#L81).

Without going into too many details, in contrast to the common data collators, this data collator treats the `input_values` and `labels` differently and thus applies to separate padding functions on them (again making use of XLSR-Wav2Vec2's context manager). This is necessary because in speech input and output are of different modalities meaning that they should not be treated by the same padding function.
Analogous to the common data collators, the padding tokens in the labels with `-100` so that those tokens are **not** taken into account when computing the loss.

In [26]:
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import torch

import transformers
from transformers import Wav2Vec2Processor


@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
        max_length (:obj:`int`, `optional`):
            Maximum length of the ``input_values`` of the returned list and optionally padding length (see above).
        max_length_labels (:obj:`int`, `optional`):
            Maximum length of the ``labels`` returned list and optionally padding length (see above).
        pad_to_multiple_of (:obj:`int`, `optional`):
            If set will pad the sequence to a multiple of the provided value.
            This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
            7.5 (Volta).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True
    max_length: Optional[int] = None
    max_length_labels: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    pad_to_multiple_of_labels: Optional[int] = None

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [feature["labels"] for feature in features]
        energy_features = [feature.get("energy_features") for feature in features]
        d_type = torch.long if isinstance(label_features[0], int) else torch.float

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch["labels"] = torch.tensor(label_features, dtype=d_type)
        batch["energy_features"] = torch.tensor(np.array(energy_features), dtype=torch.float)

        return batch

In [27]:
from transformers import Wav2Vec2FeatureExtractor

processor = Wav2Vec2FeatureExtractor.from_pretrained('c3f9d884181a224a6ac87bf8885c84d1cff3384f')
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)
config.energy_feature_dim = 1

Next, the evaluation metric is defined. There are many pre-defined metrics for classification/regression problems, but in this case, we would continue with just **Accuracy** for classification and **MSE** for regression. You can define other metrics on your own.

In [28]:
is_regression = False

In [29]:
import numpy as np
from transformers import EvalPrediction


def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)

    if is_regression:
        return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
    else:
        return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

Now, we can load the pretrained XLSR-Wav2Vec2 checkpoint into our classification model with a pooling strategy.

In [30]:
model = Wav2Vec2ForSpeechClassification.from_pretrained(
    model_name_or_path,
    config=config,
)

Some weights of Wav2Vec2ForSpeechClassification were not initialized from the model checkpoint at c3f9d884181a224a6ac87bf8885c84d1cff3384f and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original0', 'wav2vec2.encoder.pos_conv_embed.conv.parametrizations.weight.original1']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The first component of XLSR-Wav2Vec2 consists of a stack of CNN layers that are used to extract acoustically meaningful - but contextually independent - features from the raw speech signal. This part of the model has already been sufficiently trained during pretraining and as stated in the [paper](https://arxiv.org/pdf/2006.13979.pdf) does not need to be fine-tuned anymore.
Thus, we can set the `requires_grad` to `False` for all parameters of the *feature extraction* part.

In [31]:
model.freeze_feature_extractor()

In a final step, we define all parameters related to training.
To give more explanation on some of the parameters:
- `learning_rate` and `weight_decay` were heuristically tuned until fine-tuning has become stable. Note that those parameters strongly depend on the Common Voice dataset and might be suboptimal for other speech datasets.

For more explanations on other parameters, one can take a look at the [docs](https://huggingface.co/transformers/master/main_classes/trainer.html?highlight=trainer#trainingarguments).

**Note**: If one wants to save the trained models in his/her google drive the commented-out `output_dir` can be used instead.

In [32]:
# from google.colab import drive

# drive.mount('/gdrive')

In [33]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="train_result/123k_extrafeature",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    evaluation_strategy="steps",
    num_train_epochs=1,
    fp16=True,
    save_steps=10,
    eval_steps=10,
    logging_steps=10,
    learning_rate=1e-4,
    save_total_limit=2,
    seed=2003,
)



In [34]:
# !pip uninstall transformers[torch] -y
# !pip install transformers[torch]

In [35]:
# !pip uninstall transformers -y
# !pip install transformers


For future use we can create our training script, we do it in a simple way. You can add more on you own.

In [36]:
from typing import Any, Dict, Union

import torch
from packaging import version
from torch import nn

from transformers import (
    Trainer,
    is_apex_available,
)

if is_apex_available():
    from apex import amp

if version.parse(torch.__version__) >= version.parse("1.6"):
    _is_native_amp_available = True
    from torch.cuda.amp import autocast


class CTCTrainer(Trainer):
    def training_step(self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]]) -> torch.Tensor:
        """
        Perform a training step on a batch of inputs.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to train.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.

        Return:
            :obj:`torch.Tensor`: The tensor with training loss on this batch.
        """

        model.train()
        inputs = self._prepare_inputs(inputs)

        if self.use_cuda_amq:
            with autocast():
                loss = self.compute_loss(model, inputs)
        else:
            loss = self.compute_loss(model, inputs)

        if self.args.gradient_accumulation_steps > 1:
            loss = loss / self.args.gradient_accumulation_steps

        if self.use_cuda_amq:
            self.scaler.scale(loss).backward()
        elif self.use_apex:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        elif self.deepspeed:
            self.deepspeed.backward(loss)
        else:
            loss.backward()

        return loss.detach()


Now, all instances can be passed to Trainer and we are ready to start training!

In [37]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    # tokenizer=processor.feature_extractor,
)


### Training

Training will take between 10 and 60 minutes depending on the GPU allocated to this notebook.

In case you want to use this google colab to fine-tune your model, you should make sure that your training doesn't stop due to inactivity. A simple hack to prevent this is to paste the following code into the console of this tab (right mouse click -> inspect -> Console tab and insert code).

\\```javascript
function ConnectButton(){
    console.log("Connect pushed");
    document.querySelector("#top-toolbar > colab-connect-button").shadowRoot.querySelector("#connect").click()
}
setInterval(ConnectButton,60000);
```

In [38]:
trainer.train()

  return F.conv1d(input, weight, bias, self.stride,
  attn_output = torch.nn.functional.scaled_dot_product_attention(
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  2%|▏         | 10/405 [00:09<03:06,  2.11it/s]

{'loss': 1.35, 'grad_norm': 3.4573113918304443, 'learning_rate': 9.753086419753087e-05, 'epoch': 0.02}


                                                
  2%|▏         | 10/405 [00:31<03:06,  2.11it/s] 

{'eval_loss': 1.3656243085861206, 'eval_accuracy': 0.3333333432674408, 'eval_runtime': 21.5819, 'eval_samples_per_second': 40.034, 'eval_steps_per_second': 10.008, 'epoch': 0.02}


  5%|▍         | 20/405 [00:38<04:19,  1.48it/s]

{'loss': 1.3598, 'grad_norm': 5.1076979637146, 'learning_rate': 9.530864197530865e-05, 'epoch': 0.05}


                                                
  5%|▍         | 20/405 [00:59<04:19,  1.48it/s] 

{'eval_loss': 1.3451560735702515, 'eval_accuracy': 0.3333333432674408, 'eval_runtime': 21.2482, 'eval_samples_per_second': 40.662, 'eval_steps_per_second': 10.166, 'epoch': 0.05}


  7%|▋         | 30/405 [01:07<04:19,  1.44it/s]

{'loss': 1.2106, 'grad_norm': 3.3813815116882324, 'learning_rate': 9.28395061728395e-05, 'epoch': 0.07}


                                                
  7%|▋         | 30/405 [01:29<04:19,  1.44it/s] 

{'eval_loss': 1.2704964876174927, 'eval_accuracy': 0.34837964177131653, 'eval_runtime': 22.187, 'eval_samples_per_second': 38.942, 'eval_steps_per_second': 9.735, 'epoch': 0.07}


 10%|▉         | 40/405 [01:36<04:14,  1.44it/s]

{'loss': 1.1504, 'grad_norm': nan, 'learning_rate': 9.06172839506173e-05, 'epoch': 0.1}


                                                
 10%|▉         | 40/405 [01:59<04:14,  1.44it/s] 

{'eval_loss': 1.2334085702896118, 'eval_accuracy': 0.4618055522441864, 'eval_runtime': 22.8885, 'eval_samples_per_second': 37.748, 'eval_steps_per_second': 9.437, 'epoch': 0.1}


 12%|█▏        | 50/405 [02:06<04:11,  1.41it/s]

{'loss': 1.0981, 'grad_norm': 8.72940444946289, 'learning_rate': 8.814814814814815e-05, 'epoch': 0.12}


                                                
 12%|█▏        | 50/405 [02:29<04:11,  1.41it/s] 

{'eval_loss': 1.469560980796814, 'eval_accuracy': 0.22337962687015533, 'eval_runtime': 23.0275, 'eval_samples_per_second': 37.52, 'eval_steps_per_second': 9.38, 'epoch': 0.12}


 15%|█▍        | 60/405 [02:36<04:11,  1.37it/s]

{'loss': 1.0392, 'grad_norm': 5.389920711517334, 'learning_rate': 8.592592592592593e-05, 'epoch': 0.15}


                                                
 15%|█▍        | 60/405 [02:59<04:11,  1.37it/s] 

{'eval_loss': 1.1284172534942627, 'eval_accuracy': 0.42592594027519226, 'eval_runtime': 23.3029, 'eval_samples_per_second': 37.077, 'eval_steps_per_second': 9.269, 'epoch': 0.15}


 17%|█▋        | 70/405 [03:06<03:53,  1.43it/s]

{'loss': 0.8729, 'grad_norm': 5.75370979309082, 'learning_rate': 8.34567901234568e-05, 'epoch': 0.17}


                                                
 17%|█▋        | 70/405 [03:29<03:53,  1.43it/s] 

{'eval_loss': 1.2956132888793945, 'eval_accuracy': 0.35300925374031067, 'eval_runtime': 22.7338, 'eval_samples_per_second': 38.005, 'eval_steps_per_second': 9.501, 'epoch': 0.17}


 20%|█▉        | 80/405 [03:36<03:48,  1.42it/s]

{'loss': 0.8359, 'grad_norm': 9.853349685668945, 'learning_rate': 8.098765432098767e-05, 'epoch': 0.2}


                                                
 20%|█▉        | 80/405 [03:59<03:48,  1.42it/s] 

{'eval_loss': 0.8620266318321228, 'eval_accuracy': 0.5023148059844971, 'eval_runtime': 22.9293, 'eval_samples_per_second': 37.681, 'eval_steps_per_second': 9.42, 'epoch': 0.2}


 22%|██▏       | 90/405 [04:06<03:44,  1.40it/s]

{'loss': 0.7748, 'grad_norm': 3.2499732971191406, 'learning_rate': 7.851851851851852e-05, 'epoch': 0.22}


                                                
 22%|██▏       | 90/405 [04:29<03:44,  1.40it/s] 

{'eval_loss': 0.7648174166679382, 'eval_accuracy': 0.5462962985038757, 'eval_runtime': 23.2537, 'eval_samples_per_second': 37.155, 'eval_steps_per_second': 9.289, 'epoch': 0.22}


 25%|██▍       | 100/405 [04:37<03:41,  1.37it/s]

{'loss': 0.8105, 'grad_norm': 3.709836959838867, 'learning_rate': 7.60493827160494e-05, 'epoch': 0.25}


                                                 
 25%|██▍       | 100/405 [05:00<03:41,  1.37it/s]

{'eval_loss': 0.7700876593589783, 'eval_accuracy': 0.5879629850387573, 'eval_runtime': 23.1706, 'eval_samples_per_second': 37.289, 'eval_steps_per_second': 9.322, 'epoch': 0.25}


 27%|██▋       | 110/405 [05:07<03:31,  1.39it/s]

{'loss': 0.8221, 'grad_norm': 4.657902240753174, 'learning_rate': 7.358024691358025e-05, 'epoch': 0.27}


                                                 
 27%|██▋       | 110/405 [05:30<03:31,  1.39it/s]

{'eval_loss': 0.9933385252952576, 'eval_accuracy': 0.49189814925193787, 'eval_runtime': 22.7281, 'eval_samples_per_second': 38.015, 'eval_steps_per_second': 9.504, 'epoch': 0.27}


 30%|██▉       | 120/405 [05:37<03:21,  1.41it/s]

{'loss': 0.7394, 'grad_norm': 4.600498676300049, 'learning_rate': 7.111111111111112e-05, 'epoch': 0.3}


                                                 
 30%|██▉       | 120/405 [05:59<03:21,  1.41it/s]

{'eval_loss': 1.139278531074524, 'eval_accuracy': 0.43865740299224854, 'eval_runtime': 22.578, 'eval_samples_per_second': 38.267, 'eval_steps_per_second': 9.567, 'epoch': 0.3}


 32%|███▏      | 130/405 [06:09<03:24,  1.34it/s]

{'loss': 0.6723, 'grad_norm': 5.716556549072266, 'learning_rate': 6.864197530864198e-05, 'epoch': 0.32}


                                                 
 32%|███▏      | 130/405 [06:32<03:24,  1.34it/s]

{'eval_loss': 0.66737961769104, 'eval_accuracy': 0.6215277910232544, 'eval_runtime': 22.7724, 'eval_samples_per_second': 37.941, 'eval_steps_per_second': 9.485, 'epoch': 0.32}


 35%|███▍      | 140/405 [06:41<03:11,  1.38it/s]

{'loss': 0.7941, 'grad_norm': 60.13768768310547, 'learning_rate': 6.617283950617285e-05, 'epoch': 0.35}


                                                 
 35%|███▍      | 140/405 [07:03<03:11,  1.38it/s]

{'eval_loss': 0.8375363945960999, 'eval_accuracy': 0.5092592835426331, 'eval_runtime': 22.4385, 'eval_samples_per_second': 38.505, 'eval_steps_per_second': 9.626, 'epoch': 0.35}


 37%|███▋      | 150/405 [07:12<03:06,  1.36it/s]

{'loss': 0.7723, 'grad_norm': 2.4577317237854004, 'learning_rate': 6.37037037037037e-05, 'epoch': 0.37}


                                                 
 37%|███▋      | 150/405 [07:35<03:06,  1.36it/s]

{'eval_loss': 0.7266174554824829, 'eval_accuracy': 0.5763888955116272, 'eval_runtime': 22.778, 'eval_samples_per_second': 37.931, 'eval_steps_per_second': 9.483, 'epoch': 0.37}


 40%|███▉      | 160/405 [07:44<03:03,  1.34it/s]

{'loss': 0.6586, 'grad_norm': 4.377495765686035, 'learning_rate': 6.123456790123457e-05, 'epoch': 0.4}


                                                 
 40%|███▉      | 160/405 [08:06<03:03,  1.34it/s]

{'eval_loss': 1.2929784059524536, 'eval_accuracy': 0.38078704476356506, 'eval_runtime': 22.2756, 'eval_samples_per_second': 38.787, 'eval_steps_per_second': 9.697, 'epoch': 0.4}


 42%|████▏     | 170/405 [08:15<02:52,  1.36it/s]

{'loss': 0.6666, 'grad_norm': 4.290340423583984, 'learning_rate': 5.8765432098765437e-05, 'epoch': 0.42}


                                                 
 42%|████▏     | 170/405 [08:39<02:52,  1.36it/s]

{'eval_loss': 0.8536517024040222, 'eval_accuracy': 0.46759259700775146, 'eval_runtime': 24.6202, 'eval_samples_per_second': 35.093, 'eval_steps_per_second': 8.773, 'epoch': 0.42}


 44%|████▍     | 180/405 [08:50<03:00,  1.24it/s]

{'loss': 0.6167, 'grad_norm': 9.656787872314453, 'learning_rate': 5.62962962962963e-05, 'epoch': 0.44}


                                                 
 44%|████▍     | 180/405 [09:13<03:00,  1.24it/s]

{'eval_loss': 0.7595200538635254, 'eval_accuracy': 0.5775462985038757, 'eval_runtime': 23.3863, 'eval_samples_per_second': 36.945, 'eval_steps_per_second': 9.236, 'epoch': 0.44}


 47%|████▋     | 190/405 [09:21<02:40,  1.34it/s]

{'loss': 0.653, 'grad_norm': 5.566981315612793, 'learning_rate': 5.382716049382717e-05, 'epoch': 0.47}


                                                 
 47%|████▋     | 190/405 [09:44<02:40,  1.34it/s]

{'eval_loss': 0.6795032024383545, 'eval_accuracy': 0.6192129850387573, 'eval_runtime': 23.1209, 'eval_samples_per_second': 37.369, 'eval_steps_per_second': 9.342, 'epoch': 0.47}


 49%|████▉     | 200/405 [09:56<02:40,  1.28it/s]

{'loss': 0.7312, 'grad_norm': 5.8433685302734375, 'learning_rate': 5.135802469135803e-05, 'epoch': 0.49}


                                                 
 49%|████▉     | 200/405 [10:19<02:40,  1.28it/s]

{'eval_loss': 1.4204734563827515, 'eval_accuracy': 0.4409722089767456, 'eval_runtime': 23.0849, 'eval_samples_per_second': 37.427, 'eval_steps_per_second': 9.357, 'epoch': 0.49}


 52%|█████▏    | 210/405 [10:31<02:35,  1.25it/s]

{'loss': 0.5666, 'grad_norm': 4.787693023681641, 'learning_rate': 4.888888888888889e-05, 'epoch': 0.52}


                                                 
 52%|█████▏    | 210/405 [10:54<02:35,  1.25it/s]

{'eval_loss': 0.7265880107879639, 'eval_accuracy': 0.6018518805503845, 'eval_runtime': 22.8885, 'eval_samples_per_second': 37.748, 'eval_steps_per_second': 9.437, 'epoch': 0.52}


 54%|█████▍    | 220/405 [11:02<02:13,  1.38it/s]

{'loss': 0.6653, 'grad_norm': 3.6731107234954834, 'learning_rate': 4.641975308641975e-05, 'epoch': 0.54}


                                                 
 54%|█████▍    | 220/405 [11:26<02:13,  1.38it/s]

{'eval_loss': 0.8260233402252197, 'eval_accuracy': 0.5798611044883728, 'eval_runtime': 24.5839, 'eval_samples_per_second': 35.145, 'eval_steps_per_second': 8.786, 'epoch': 0.54}


 57%|█████▋    | 230/405 [11:34<02:10,  1.34it/s]

{'loss': 0.6017, 'grad_norm': 3.953308582305908, 'learning_rate': 4.3950617283950617e-05, 'epoch': 0.57}


                                                 
 57%|█████▋    | 230/405 [11:58<02:10,  1.34it/s]

{'eval_loss': 0.7730549573898315, 'eval_accuracy': 0.5983796119689941, 'eval_runtime': 24.6644, 'eval_samples_per_second': 35.03, 'eval_steps_per_second': 8.758, 'epoch': 0.57}


 59%|█████▉    | 240/405 [12:07<02:16,  1.20it/s]

{'loss': 0.5646, 'grad_norm': 1.5794059038162231, 'learning_rate': 4.17283950617284e-05, 'epoch': 0.59}


                                                 
 59%|█████▉    | 240/405 [12:31<02:16,  1.20it/s]

{'eval_loss': 0.7502283453941345, 'eval_accuracy': 0.5925925970077515, 'eval_runtime': 24.2902, 'eval_samples_per_second': 35.57, 'eval_steps_per_second': 8.892, 'epoch': 0.59}


 62%|██████▏   | 250/405 [12:38<01:53,  1.37it/s]

{'loss': 0.6603, 'grad_norm': 3.872318744659424, 'learning_rate': 3.925925925925926e-05, 'epoch': 0.62}


                                                 
 62%|██████▏   | 250/405 [13:01<01:53,  1.37it/s]

{'eval_loss': 0.6628624200820923, 'eval_accuracy': 0.6215277910232544, 'eval_runtime': 22.4923, 'eval_samples_per_second': 38.413, 'eval_steps_per_second': 9.603, 'epoch': 0.62}


 64%|██████▍   | 260/405 [13:08<01:52,  1.28it/s]

{'loss': 0.6262, 'grad_norm': 7.652580738067627, 'learning_rate': 3.6790123456790125e-05, 'epoch': 0.64}


                                                 
 64%|██████▍   | 260/405 [13:33<01:52,  1.28it/s]

{'eval_loss': 0.8088181018829346, 'eval_accuracy': 0.5775462985038757, 'eval_runtime': 24.7266, 'eval_samples_per_second': 34.942, 'eval_steps_per_second': 8.736, 'epoch': 0.64}


 67%|██████▋   | 270/405 [13:41<01:47,  1.25it/s]

{'loss': 0.5964, 'grad_norm': 2.9723801612854004, 'learning_rate': 3.432098765432099e-05, 'epoch': 0.67}


                                                 
 67%|██████▋   | 270/405 [14:08<01:47,  1.25it/s]

{'eval_loss': 0.8361343741416931, 'eval_accuracy': 0.5625, 'eval_runtime': 27.5513, 'eval_samples_per_second': 31.36, 'eval_steps_per_second': 7.84, 'epoch': 0.67}


 69%|██████▉   | 280/405 [14:17<01:46,  1.17it/s]

{'loss': 0.5397, 'grad_norm': 8.098444938659668, 'learning_rate': 3.209876543209876e-05, 'epoch': 0.69}


                                                 
 69%|██████▉   | 280/405 [14:43<01:46,  1.17it/s]

{'eval_loss': 0.7155152559280396, 'eval_accuracy': 0.5914351940155029, 'eval_runtime': 26.6518, 'eval_samples_per_second': 32.418, 'eval_steps_per_second': 8.105, 'epoch': 0.69}


 72%|███████▏  | 290/405 [14:50<01:28,  1.30it/s]

{'loss': 0.5685, 'grad_norm': 19.149568557739258, 'learning_rate': 2.962962962962963e-05, 'epoch': 0.72}


                                                 
 72%|███████▏  | 290/405 [15:15<01:28,  1.30it/s]

{'eval_loss': 0.6024020314216614, 'eval_accuracy': 0.6157407164573669, 'eval_runtime': 24.4503, 'eval_samples_per_second': 35.337, 'eval_steps_per_second': 8.834, 'epoch': 0.72}


 74%|███████▍  | 300/405 [15:23<01:19,  1.32it/s]

{'loss': 0.713, 'grad_norm': 20.049728393554688, 'learning_rate': 2.7160493827160493e-05, 'epoch': 0.74}


                                                 
 74%|███████▍  | 300/405 [15:50<01:19,  1.32it/s]

{'eval_loss': 0.5996790528297424, 'eval_accuracy': 0.6412037014961243, 'eval_runtime': 27.5807, 'eval_samples_per_second': 31.326, 'eval_steps_per_second': 7.832, 'epoch': 0.74}


 77%|███████▋  | 310/405 [15:58<01:13,  1.30it/s]

{'loss': 0.4938, 'grad_norm': 2.5125513076782227, 'learning_rate': 2.4691358024691357e-05, 'epoch': 0.77}


                                                 
 77%|███████▋  | 310/405 [16:24<01:13,  1.30it/s]

{'eval_loss': 0.5873221158981323, 'eval_accuracy': 0.6331018805503845, 'eval_runtime': 26.6159, 'eval_samples_per_second': 32.462, 'eval_steps_per_second': 8.115, 'epoch': 0.77}


 79%|███████▉  | 320/405 [16:33<01:14,  1.15it/s]

{'loss': 0.652, 'grad_norm': 1.832817792892456, 'learning_rate': 2.2222222222222223e-05, 'epoch': 0.79}


                                                 
 79%|███████▉  | 320/405 [16:59<01:14,  1.15it/s]

{'eval_loss': 0.5546272397041321, 'eval_accuracy': 0.6284722089767456, 'eval_runtime': 25.7911, 'eval_samples_per_second': 33.5, 'eval_steps_per_second': 8.375, 'epoch': 0.79}


 81%|████████▏ | 330/405 [17:06<00:57,  1.31it/s]

{'loss': 0.561, 'grad_norm': 7.611238956451416, 'learning_rate': 1.9753086419753087e-05, 'epoch': 0.81}


                                                 
 81%|████████▏ | 330/405 [17:33<00:57,  1.31it/s]

{'eval_loss': 0.5575839877128601, 'eval_accuracy': 0.6388888955116272, 'eval_runtime': 26.7222, 'eval_samples_per_second': 32.333, 'eval_steps_per_second': 8.083, 'epoch': 0.81}


 84%|████████▍ | 340/405 [17:41<00:58,  1.11it/s]

{'loss': 0.5704, 'grad_norm': 16.253633499145508, 'learning_rate': 1.728395061728395e-05, 'epoch': 0.84}


                                                 
 84%|████████▍ | 340/405 [18:04<00:58,  1.11it/s]

{'eval_loss': 0.5772150754928589, 'eval_accuracy': 0.6412037014961243, 'eval_runtime': 23.4847, 'eval_samples_per_second': 36.79, 'eval_steps_per_second': 9.197, 'epoch': 0.84}


 86%|████████▋ | 350/405 [18:13<00:45,  1.21it/s]

{'loss': 0.5021, 'grad_norm': 2.6168558597564697, 'learning_rate': 1.4814814814814815e-05, 'epoch': 0.86}


                                                 
 86%|████████▋ | 350/405 [18:37<00:45,  1.21it/s]

{'eval_loss': 0.5797354578971863, 'eval_accuracy': 0.6458333134651184, 'eval_runtime': 24.0401, 'eval_samples_per_second': 35.94, 'eval_steps_per_second': 8.985, 'epoch': 0.86}


 89%|████████▉ | 360/405 [18:45<00:39,  1.14it/s]

{'loss': 0.4226, 'grad_norm': 6.3968682289123535, 'learning_rate': 1.2345679012345678e-05, 'epoch': 0.89}


                                                 
 89%|████████▉ | 360/405 [19:10<00:39,  1.14it/s]

{'eval_loss': 0.5775889158248901, 'eval_accuracy': 0.6412037014961243, 'eval_runtime': 24.4466, 'eval_samples_per_second': 35.342, 'eval_steps_per_second': 8.836, 'epoch': 0.89}


 91%|█████████▏| 370/405 [19:20<00:29,  1.17it/s]

{'loss': 0.7349, 'grad_norm': 10.677435874938965, 'learning_rate': 9.876543209876543e-06, 'epoch': 0.91}


                                                 
 91%|█████████▏| 370/405 [19:48<00:29,  1.17it/s]

{'eval_loss': 0.5661423206329346, 'eval_accuracy': 0.6446759104728699, 'eval_runtime': 28.4331, 'eval_samples_per_second': 30.387, 'eval_steps_per_second': 7.597, 'epoch': 0.91}


 94%|█████████▍| 380/405 [19:57<00:24,  1.01it/s]

{'loss': 0.5809, 'grad_norm': 2.1108407974243164, 'learning_rate': 7.4074074074074075e-06, 'epoch': 0.94}


                                                 
 94%|█████████▍| 380/405 [20:23<00:24,  1.01it/s]

{'eval_loss': 0.558044970035553, 'eval_accuracy': 0.6458333134651184, 'eval_runtime': 25.6556, 'eval_samples_per_second': 33.677, 'eval_steps_per_second': 8.419, 'epoch': 0.94}


 96%|█████████▋| 390/405 [20:35<00:14,  1.07it/s]

{'loss': 0.6468, 'grad_norm': 1.6151989698410034, 'learning_rate': 4.938271604938272e-06, 'epoch': 0.96}


                                                 
 96%|█████████▋| 390/405 [20:59<00:14,  1.07it/s]

{'eval_loss': 0.5513487458229065, 'eval_accuracy': 0.6469907164573669, 'eval_runtime': 24.4319, 'eval_samples_per_second': 35.364, 'eval_steps_per_second': 8.841, 'epoch': 0.96}


 99%|█████████▉| 400/405 [21:09<00:03,  1.27it/s]

{'loss': 0.5647, 'grad_norm': 7.272134780883789, 'learning_rate': 2.469135802469136e-06, 'epoch': 0.99}


                                                 
 99%|█████████▉| 400/405 [21:33<00:03,  1.27it/s]

{'eval_loss': 0.5529109239578247, 'eval_accuracy': 0.6469907164573669, 'eval_runtime': 23.4135, 'eval_samples_per_second': 36.902, 'eval_steps_per_second': 9.225, 'epoch': 0.99}


100%|██████████| 405/405 [21:44<00:00,  3.22s/it]

{'train_runtime': 1304.0308, 'train_samples_per_second': 2.485, 'train_steps_per_second': 0.311, 'train_loss': 0.733257022904761, 'epoch': 1.0}





TrainOutput(global_step=405, training_loss=0.733257022904761, metrics={'train_runtime': 1304.0308, 'train_samples_per_second': 2.485, 'train_steps_per_second': 0.311, 'total_flos': 7.8753696657408e+16, 'train_loss': 0.733257022904761, 'epoch': 1.0})

In [39]:
# trainer.evaluate()
trainer.evaluate(eval_dataset=eval_dataset)

100%|██████████| 216/216 [00:23<00:00,  9.27it/s]


{'eval_loss': 0.550628125667572,
 'eval_accuracy': 0.6469907164573669,
 'eval_runtime': 23.4553,
 'eval_samples_per_second': 36.836,
 'eval_steps_per_second': 9.209,
 'epoch': 1.0}

In [40]:
trainer.evaluate(eval_dataset=train_dataset)

100%|██████████| 810/810 [01:27<00:00,  9.25it/s]


{'eval_loss': 0.5462911128997803,
 'eval_accuracy': 0.6478394865989685,
 'eval_runtime': 87.7605,
 'eval_samples_per_second': 36.919,
 'eval_steps_per_second': 9.23,
 'epoch': 1.0}

In [41]:
import torch
from collections import Counter
from transformers import Trainer, EvalPrediction
import numpy as np

class WeightedVoteTrainer(Trainer):
    def __init__(self, *args, num_votes=5, **kwargs):
        super().__init__(*args, **kwargs)
        self.num_votes = num_votes

    def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None):
        # 进行多次预测
        predictions = []
        confidences = []
        model.train()  # 启用 dropout
        for i in range(self.num_votes):
            torch.manual_seed(i)  # 为每次预测设置不同的随机种子
            inputs_copy = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in inputs.items()}
            
            # 添加小的随机噪声到输入
            if 'input_values' in inputs_copy:
                noise = torch.randn_like(inputs_copy['input_values']) * 0.01
                inputs_copy['input_values'] += noise
            
            _, logits, labels = super().prediction_step(model, inputs_copy, prediction_loss_only, ignore_keys)
            probs = torch.softmax(logits, dim=-1)
            confidence, pred = torch.max(probs, dim=-1)
            predictions.append(pred)
            confidences.append(confidence)
        
        model.eval()  # 恢复评估模式
        
        # 将预测结果和置信度堆叠起来
        stacked_preds = torch.stack(predictions, dim=0)
        stacked_confidences = torch.stack(confidences, dim=0)
        
        # 计算加权投票结果
        num_classes = logits.shape[-1]
        weighted_votes = torch.zeros((stacked_preds.shape[1], num_classes), device=stacked_preds.device)
        for i in range(self.num_votes):
            weighted_votes.scatter_add_(1, stacked_preds[i].unsqueeze(1), stacked_confidences[i].unsqueeze(1))
        
        weighted_vote_result = torch.argmax(weighted_votes, dim=1)

        return None, weighted_vote_result, labels

def compute_metrics_with_weighted_vote(eval_pred: EvalPrediction):
    predictions, labels = eval_pred.predictions, eval_pred.label_ids
    
    # 确保预测和标签都是 NumPy 数组
    if isinstance(predictions, torch.Tensor):
        predictions = predictions.cpu().numpy()
    if isinstance(labels, torch.Tensor):
        labels = labels.cpu().numpy()
    
    # 计算准确率
    accuracy = (predictions == labels).mean()
    return {"accuracy": accuracy}

# 创建新的 Trainer 实例
weighted_vote_trainer = WeightedVoteTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics_with_weighted_vote,
    num_votes=6  # 设置投票次数
)

# 进行评估
eval_results = weighted_vote_trainer.evaluate(eval_dataset=eval_dataset)
print(eval_results)

100%|██████████| 216/216 [00:41<00:00,  5.23it/s]

{'eval_model_preparation_time': 0.004, 'eval_accuracy': 0.6863425925925926, 'eval_runtime': 41.6993, 'eval_samples_per_second': 20.72, 'eval_steps_per_second': 5.18}



