In [1]:
import os
import shutil
from tqdm import tqdm

def organize_files(src_root_dir, dst_root_dir):
    cases = ['case1', 'case2', 'case3', 'case4']

    # 为每个 case 创建目录
    for case in cases:
        case_dir = os.path.join(dst_root_dir, case)
        if not os.path.exists(case_dir):
            os.makedirs(case_dir)

    # 遍历源目录
    for main_dir in tqdm(os.listdir(src_root_dir), desc="处理目录"):
        main_dir_path = os.path.join(src_root_dir, main_dir)
        if os.path.isdir(main_dir_path):
            for sub_dir in os.listdir(main_dir_path):
                sub_dir_path = os.path.join(main_dir_path, sub_dir)
                if os.path.isdir(sub_dir_path):
                    case_name = sub_dir.split('_')[0]
                    case_dir = os.path.join(dst_root_dir, case_name)

                    for filename in os.listdir(sub_dir_path):
                        if filename.endswith('.npy'):
                            src_npy_path = os.path.join(sub_dir_path, filename)
                            dst_npy_path = os.path.join(case_dir, f"{main_dir}_{sub_dir}_{filename}")
                            shutil.copy(src_npy_path, dst_npy_path)
                            # print(f"已复制 {src_npy_path} 到 {dst_npy_path}")

src_directory = 'dataset_1k2k3k_withbandpass_extrafeatures_v3'
dst_directory = 'data_1k2k3k_nobandpass_organized_dataset'

print(f"源目录: {src_directory}")
print(f"目标目录: {dst_directory}")

if not os.path.exists(src_directory):
    print(f"源目录不存在: {src_directory}")
else:
    organize_files(src_directory, dst_directory)
    print("文件整理完成。")

源目录: dataset_1k2k3k_withbandpass_extrafeatures_v3
目标目录: data_1k2k3k_nobandpass_organized_dataset


处理目录: 100%|██████████| 19/19 [00:27<00:00,  1.43s/it]

文件整理完成。





In [2]:
import numpy as np
import pandas as pd

from pathlib import Path
from tqdm import tqdm

import torchaudio
from sklearn.model_selection import train_test_split
import os
import sys

In [3]:
import os
from pathlib import Path
import numpy as np
from tqdm import tqdm

data = []

for case in ['case1', 'case2', 'case3', 'case4']:
    case_path = Path(f'data_1k2k3k_nobandpass_organized_dataset/{case}')
    for path in tqdm(case_path.glob("*.npy"), desc=f"处理 {case}"):
        name = path.stem
        # 获取文件名的各个部分
        parts = path.name.split('_')
        prefix = parts[0]  # 前缀
        case_id = '_'.join(parts[1:-2])  # case_id（可能包含多个部分）
        sample_set = parts[-2]  # sample_set

        try:
            # 加载 .npy 文件
            npy_path = path.with_suffix('.npy')
            if npy_path.exists():
                energy_features = np.load(npy_path)

            data.append({
                "name": name,
                "path": str(path),
                "case": case,
                "prefix": prefix,
                "case_id": case_id,
                "sample_set": sample_set,
                "energy_features": energy_features
            })
        except Exception as e:
            print(f"处理文件 {path} 时出错: {str(e)}")
            # 跳过损坏的文件
            pass

# 显示收集到的数据条目数
print(f"收集了 {len(data)} 个项目。")

# 数据统计
case_counts = {case: sum(1 for item in data if item['case'] == case) for case in ['case1', 'case2', 'case3', 'case4']}
print("\n数据分布:")
for case, count in case_counts.items():
    print(f"{case}: {count} 个项目")

# 检查是否所有项目都有能量特征
items_with_features = sum(1 for item in data if item['energy_features'] is not None)
print(f"\n具有能量特征的项目: {items_with_features} / {len(data)}")

处理 case1: 684it [00:03, 200.75it/s]
处理 case2: 684it [00:04, 149.60it/s]
处理 case3: 1368it [00:10, 134.72it/s]
处理 case4: 1368it [00:09, 141.56it/s]

收集了 4104 个项目。

数据分布:
case1: 684 个项目
case2: 684 个项目
case3: 1368 个项目
case4: 1368 个项目

具有能量特征的项目: 4104 / 4104





In [4]:
import pandas as pd
df = pd.DataFrame(data)
df.head()

Unnamed: 0,name,path,case,prefix,case_id,sample_set,energy_features
0,A10_case1_1_sample_10_2,data_1k2k3k_nobandpass_organized_dataset\case1...,case1,A10,case1_1_sample,10,4.071372111649566
1,A10_case1_1_sample_11_2,data_1k2k3k_nobandpass_organized_dataset\case1...,case1,A10,case1_1_sample,11,3.197134666321631
2,A10_case1_1_sample_12_2,data_1k2k3k_nobandpass_organized_dataset\case1...,case1,A10,case1_1_sample,12,3.637433915029988
3,A10_case1_1_sample_13_3,data_1k2k3k_nobandpass_organized_dataset\case1...,case1,A10,case1_1_sample,13,0.4854435119452205
4,A10_case1_1_sample_14_3,data_1k2k3k_nobandpass_organized_dataset\case1...,case1,A10,case1_1_sample,14,0.2141085064245373


In [5]:
import pandas as pd
df = pd.DataFrame(data)
df.head()
df = df.sample(frac=1).reset_index(drop=True)

Let's display some random sample of the dataset and run it a couple of times to get a feeling for the audio and the emotional label.

In [8]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset
from transformers import TrainingArguments, Trainer
import torch
from torch import nn

# 处理 energy_features 列的函数
def process_energy_features(df):
    new_values = []
    for v in df["energy_features"]:
        if isinstance(v, np.ndarray):
            new_values.append(float(v.item()))  # 使用 item() 来获取数组中的单个值
        else:
            new_values.append(float(v))
    df["energy_float"] = new_values
    return df

# 处理数据
df = process_energy_features(df)
df = df.drop(columns=["energy_features"])
df = df.rename(columns={"energy_float": "energy_features"})

# 定义辅助函数
def label_to_id(label, label_list):
    return label_list.index(label)

# 数据集分割
train_df, test_df = train_test_split(df, test_size=0.2, random_state=101, stratify=df["case"])

# 重置索引
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

# 打印数据集的形状
print("训练数据集形状:", train_df.shape)
print("测试数据集形状:", test_df.shape)

# 创建 Dataset 对象
train_dataset = Dataset.from_pandas(train_df)
eval_dataset = Dataset.from_pandas(test_df)

# 指定输出列
output_column = "case"

# 打印每个 case 的样本数
print("\n训练数据集中每个 case 的样本数:")
print(train_df[output_column].value_counts())
print("\n验证数据集中每个 case 的样本数:")
print(test_df[output_column].value_counts())

# 识别和排序标签列表
label_list = sorted(df["case"].unique())
num_labels = len(label_list)
print(f"\n这是一个有 {num_labels} 个类别的分类问题: {label_list}")

# 预处理函数
def preprocess_function(examples):
    labels = [label_to_id(label, label_list) for label in examples["case"]]
    energy_features = examples["energy_features"]
    return {
        "labels": labels, 
        "energy_features": energy_features
    }

# 应用预处理
train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(preprocess_function, batched=True, remove_columns=eval_dataset.column_names)

# 检查能量特征加载情况
print("\n带有能量特征的训练数据集:")
print(train_dataset[:5])
print("\n带有能量特征的验证数据集:")
print(eval_dataset[:5])

# 统计包含能量特征的样本数
train_with_features = sum(1 for item in train_dataset if item['energy_features'] is not None)
eval_with_features = sum(1 for item in eval_dataset if item['energy_features'] is not None)
print(f"\n训练样本中包含能量特征的数量: {train_with_features} / {len(train_dataset)}")
print(f"验证样本中包含能量特征的数量: {eval_with_features} / {len(eval_dataset)}")

# 定义模型
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, num_labels):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_labels)
    
    def forward(self, x):
        return self.linear(x)

# 数据整理函数
def collate_fn(batch):
    energy_features = torch.tensor([item['energy_features'] for item in batch], dtype=torch.float)
    labels = torch.tensor([item['labels'] for item in batch], dtype=torch.long)
    return {'energy_features': energy_features, 'labels': labels}

# 计算指标
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return {'accuracy': (predictions == labels).mean()}

# 初始化模型
model = SimpleClassifier(input_dim=1, num_labels=num_labels)

# 打印一些数据检查信息
print("\n数据类型检查:")
print("energy_features 的数据类型:", df["energy_features"].dtype)
print("\n前5个样本的 energy_features:")
print(df["energy_features"].head())

训练数据集形状: (3283, 7)
测试数据集形状: (821, 7)

训练数据集中每个 case 的样本数:
case
case3    1095
case4    1094
case1     547
case2     547
Name: count, dtype: int64

验证数据集中每个 case 的样本数:
case
case4    274
case3    273
case1    137
case2    137
Name: count, dtype: int64

这是一个有 4 个类别的分类问题: ['case1', 'case2', 'case3', 'case4']


Map: 100%|██████████| 3283/3283 [00:00<00:00, 247544.31 examples/s]
Map: 100%|██████████| 821/821 [00:00<00:00, 399017.80 examples/s]


带有能量特征的训练数据集:
{'energy_features': [0.07321549042228526, 33.11227428622621, 3.6374339150299875, 0.013138705499983416, 175.39704312275228], 'labels': [0, 3, 0, 0, 3]}

带有能量特征的验证数据集:
{'energy_features': [11.063609211383412, 245.78904389559597, 0.12018817243463786, 11.255839661838182, 993.8337578568035], 'labels': [3, 2, 0, 1, 3]}

训练样本中包含能量特征的数量: 3283 / 3283
验证样本中包含能量特征的数量: 821 / 821

数据类型检查:
energy_features 的数据类型: float64

前5个样本的 energy_features:
0      2.595416
1      1.882986
2    885.436748
3     93.589134
4     11.226658
Name: energy_features, dtype: float64





In [9]:
df.head(5)

Unnamed: 0,name,path,case,prefix,case_id,sample_set,energy_features
0,A4_case1_1_sample_7_2,data_1k2k3k_nobandpass_organized_dataset\case1...,case1,A4,case1_1_sample,7,2.595416
1,A9_case1_2_sample_8_2,data_1k2k3k_nobandpass_organized_dataset\case1...,case1,A9,case1_2_sample,8,1.882986
2,A4_case2_2_sample_10_1,data_1k2k3k_nobandpass_organized_dataset\case2...,case2,A4,case2_2_sample,10,885.436748
3,E4_case4_3_sample_14_3,data_1k2k3k_nobandpass_organized_dataset\case4...,case4,E4,case4_3_sample,14,93.589134
4,E9_case4_12_sample_71_3,data_1k2k3k_nobandpass_organized_dataset\case4...,case4,E9,case4_12_sample,71,11.226658


## Model

Before diving into the training part, we need to build our classification model based on the merge strategy.

In [11]:
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, List, Union
import torch
import torch.nn as nn
import numpy as np
from transformers import PreTrainedModel, AutoConfig, TrainingArguments, Trainer, EvalPrediction
from transformers.file_utils import ModelOutput
from packaging import version
import os

@dataclass
class EnergyFeatureClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

class Wav2Vec2ClassificationHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features, **kwargs):
        x = features
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

class EnergyFeatureEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.proj = nn.Sequential(
            nn.Linear(1, 256),  # 输入维度为1
            nn.ReLU(),
            nn.Linear(256, config.hidden_size)
        )
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, energy_features):
        hidden_states = self.proj(energy_features.unsqueeze(-1))  # 将一维输入转为 (batch_size, 1)
        hidden_states = self.layer_norm(hidden_states)
        hidden_states = self.dropout(hidden_states)
        return hidden_states

class SimpleWav2Vec2Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)

    def forward(self, hidden_states, attention_mask=None):
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
            attention_mask = (1.0 - attention_mask) * -10000.0
        
        encoder_outputs = self.encoder(hidden_states.transpose(0, 1), src_key_padding_mask=attention_mask)
        return encoder_outputs.transpose(0, 1)

class Wav2Vec2ForEnergyClassification(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.num_labels = config.num_labels
        self.pooling_mode = config.pooling_mode
        self.energy_feature_encoder = EnergyFeatureEncoder(config)
        self.wav2vec2 = SimpleWav2Vec2Model(config)
        self.classifier = Wav2Vec2ClassificationHead(config)

    def freeze_feature_extractor(self):
        self.energy_feature_encoder.eval()
        for param in self.energy_feature_encoder.parameters():
            param.requires_grad = False

    def merged_strategy(self, hidden_states, mode="mean"):
        if mode == "mean":
            outputs = torch.mean(hidden_states, dim=1)
        elif mode == "sum":
            outputs = torch.sum(hidden_states, dim=1)
        elif mode == "max":
            outputs = torch.max(hidden_states, dim=1)[0]
        else:
            raise Exception(
                "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
        return outputs

    def forward(
            self,
            energy_features,
            attention_mask=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            labels=None,
    ):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        hidden_states = self.energy_feature_encoder(energy_features)
        hidden_states = hidden_states.unsqueeze(1)  # 添加序列维度
        hidden_states = self.wav2vec2(hidden_states, attention_mask)
        hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
        logits = self.classifier(hidden_states)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = nn.MSELoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output

        return EnergyFeatureClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=hidden_states,
            attentions=None,
        )

@dataclass
class DataCollatorForEnergyFeatures:
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor, np.ndarray]]]) -> Dict[str, torch.Tensor]:
        label_features = [feature["labels"] for feature in features]
        energy_features = [feature["energy_features"] for feature in features]

        d_type = torch.long if isinstance(label_features[0], int) else torch.float

        batch = {}
        batch["labels"] = torch.tensor(label_features, dtype=d_type)

        # 处理能量特征
        padded_energy_features = []
        for ef in energy_features:
            if isinstance(ef, (list, np.ndarray)):
                ef = ef[0] if len(ef) > 0 else 0.0  # 取第一个元素，如果是空列表则用0.0
            elif isinstance(ef, (int, float)):
                ef = float(ef)
            else:
                ef = 0.0  # 默认值
            padded_energy_features.append(ef)

        # 将列表转换为张量
        batch["energy_features"] = torch.tensor(padded_energy_features, dtype=torch.float)

        return batch

def compute_metrics(p: EvalPrediction):
    preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    preds = np.squeeze(preds) if is_regression else np.argmax(preds, axis=1)

    if is_regression:
        return {"mse": ((preds - p.label_ids) ** 2).mean().item()}
    else:
        return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}

# 加载或创建配置
model_name_or_path = "c3f9d884181a224a6ac87bf8885c84d1cff3384f"  # 使用预训练的 wav2vec2 模型配置
config = AutoConfig.from_pretrained(model_name_or_path)

# 设置能量特征维度和其他必要的配置
config.energy_feature_dim = 1  # 一维输入
config.num_labels = 4
config.problem_type = "single_label_classification"
config.pooling_mode = "mean"
config.hidden_dropout_prob = 0.1
config.final_dropout = 0.1
config.hidden_size = 768  # 确保这与预训练模型的隐藏大小一致
config.num_attention_heads = 12
config.num_hidden_layers = 12

# 创建数据整理器实例
data_collator = DataCollatorForEnergyFeatures()

is_regression = False

# 创建模型实例
model = Wav2Vec2ForEnergyClassification(config)

# 设置训练参数
training_args = TrainingArguments(
    output_dir="./results_energybased",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    evaluation_strategy="steps",
    num_train_epochs=1,
    fp16=True,
    save_steps=10,
    eval_steps=10,
    logging_steps=10,
    learning_rate=1e-4,
    save_total_limit=2,
)

# 创建训练器
trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,  # 请确保您已经定义了 train_dataset
    eval_dataset=eval_dataset,    # 请确保您已经定义了 eval_dataset
)

# 开始训练
trainer.train()

  attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
Could not estimate the number of tokens of the input, floating-point operations will not be computed
  2%|▏         | 10/410 [00:01<00:29, 13.71it/s]

{'loss': 1.4959, 'grad_norm': 5.934451580047607, 'learning_rate': 9.75609756097561e-05, 'epoch': 0.02}


                                                
  2%|▏         | 10/410 [00:02<00:29, 13.71it/s]  

{'eval_loss': 1.1893579959869385, 'eval_accuracy': 0.4604141414165497, 'eval_runtime': 1.3284, 'eval_samples_per_second': 618.036, 'eval_steps_per_second': 155.074, 'epoch': 0.02}


  5%|▍         | 20/410 [00:03<00:56,  6.95it/s]

{'loss': 1.2647, 'grad_norm': 5.565686225891113, 'learning_rate': 9.51219512195122e-05, 'epoch': 0.05}


                                                
  5%|▍         | 20/410 [00:04<00:56,  6.95it/s]  

{'eval_loss': 0.968656599521637, 'eval_accuracy': 0.4835566282272339, 'eval_runtime': 1.2354, 'eval_samples_per_second': 664.568, 'eval_steps_per_second': 166.749, 'epoch': 0.05}


  7%|▋         | 30/410 [00:06<01:08,  5.51it/s]

{'loss': 1.1053, 'grad_norm': 10.294601440429688, 'learning_rate': 9.26829268292683e-05, 'epoch': 0.07}


                                                
  7%|▋         | 30/410 [00:07<01:08,  5.51it/s]  

{'eval_loss': 1.0071831941604614, 'eval_accuracy': 0.47381243109703064, 'eval_runtime': 1.2717, 'eval_samples_per_second': 645.58, 'eval_steps_per_second': 161.985, 'epoch': 0.07}


 10%|▉         | 40/410 [00:08<01:01,  6.03it/s]

{'loss': 1.0794, 'grad_norm': 2.9233145713806152, 'learning_rate': 9.02439024390244e-05, 'epoch': 0.1}


                                                
 10%|▉         | 40/410 [00:10<01:01,  6.03it/s]  

{'eval_loss': 1.0682835578918457, 'eval_accuracy': 0.4652862250804901, 'eval_runtime': 1.1942, 'eval_samples_per_second': 687.48, 'eval_steps_per_second': 172.498, 'epoch': 0.1}


 12%|█▏        | 50/410 [00:11<00:51,  6.94it/s]

{'loss': 1.1559, 'grad_norm': 7.282463550567627, 'learning_rate': 8.78048780487805e-05, 'epoch': 0.12}


                                                
 12%|█▏        | 50/410 [00:12<00:51,  6.94it/s]  

{'eval_loss': 1.048591136932373, 'eval_accuracy': 0.47015833854675293, 'eval_runtime': 1.2186, 'eval_samples_per_second': 673.729, 'eval_steps_per_second': 169.048, 'epoch': 0.12}


 15%|█▍        | 60/410 [00:13<00:50,  6.86it/s]

{'loss': 1.1213, 'grad_norm': 18.375062942504883, 'learning_rate': 8.53658536585366e-05, 'epoch': 0.15}


                                                
 15%|█▍        | 60/410 [00:14<00:50,  6.86it/s]  

{'eval_loss': 1.0111593008041382, 'eval_accuracy': 0.47381243109703064, 'eval_runtime': 1.2256, 'eval_samples_per_second': 669.903, 'eval_steps_per_second': 168.088, 'epoch': 0.15}


 17%|█▋        | 70/410 [00:15<00:48,  6.99it/s]

{'loss': 0.9405, 'grad_norm': 5.121417999267578, 'learning_rate': 8.292682926829268e-05, 'epoch': 0.17}


                                                
 17%|█▋        | 70/410 [00:16<00:48,  6.99it/s]  

{'eval_loss': 1.0978834629058838, 'eval_accuracy': 0.4908648133277893, 'eval_runtime': 1.2347, 'eval_samples_per_second': 664.928, 'eval_steps_per_second': 166.84, 'epoch': 0.17}


 20%|█▉        | 80/410 [00:17<00:46,  7.15it/s]

{'loss': 1.3695, 'grad_norm': 10.376787185668945, 'learning_rate': 8.073170731707318e-05, 'epoch': 0.19}


                                                
 20%|█▉        | 80/410 [00:19<00:46,  7.15it/s]  

{'eval_loss': 1.180609107017517, 'eval_accuracy': 0.47868454456329346, 'eval_runtime': 1.2721, 'eval_samples_per_second': 645.372, 'eval_steps_per_second': 161.933, 'epoch': 0.19}


 22%|██▏       | 90/410 [00:20<00:45,  7.01it/s]

{'loss': 1.4868, 'grad_norm': 4.4701924324035645, 'learning_rate': 7.853658536585367e-05, 'epoch': 0.22}


                                                
 22%|██▏       | 90/410 [00:21<00:45,  7.01it/s]  

{'eval_loss': 1.1565855741500854, 'eval_accuracy': 0.4823386073112488, 'eval_runtime': 1.1762, 'eval_samples_per_second': 698.034, 'eval_steps_per_second': 175.146, 'epoch': 0.22}


 24%|██▍       | 100/410 [00:22<00:41,  7.53it/s]

{'loss': 1.0092, 'grad_norm': 5.3505473136901855, 'learning_rate': 7.609756097560976e-05, 'epoch': 0.24}


                                                 
 24%|██▍       | 100/410 [00:23<00:41,  7.53it/s] 

{'eval_loss': 1.019242763519287, 'eval_accuracy': 0.47259441018104553, 'eval_runtime': 1.2182, 'eval_samples_per_second': 673.958, 'eval_steps_per_second': 169.105, 'epoch': 0.24}


 27%|██▋       | 110/410 [00:24<00:44,  6.72it/s]

{'loss': 0.8289, 'grad_norm': 2.3280856609344482, 'learning_rate': 7.365853658536585e-05, 'epoch': 0.27}


                                                 
 27%|██▋       | 110/410 [00:25<00:44,  6.72it/s] 

{'eval_loss': 1.1530696153640747, 'eval_accuracy': 0.47381243109703064, 'eval_runtime': 1.1916, 'eval_samples_per_second': 689.0, 'eval_steps_per_second': 172.88, 'epoch': 0.27}


 29%|██▉       | 120/410 [00:26<00:38,  7.44it/s]

{'loss': 1.1783, 'grad_norm': 6.410701274871826, 'learning_rate': 7.121951219512195e-05, 'epoch': 0.29}


                                                 
 29%|██▉       | 120/410 [00:28<00:38,  7.44it/s] 

{'eval_loss': 1.1240954399108887, 'eval_accuracy': 0.47137635946273804, 'eval_runtime': 1.2515, 'eval_samples_per_second': 656.037, 'eval_steps_per_second': 164.608, 'epoch': 0.29}


 32%|███▏      | 130/410 [00:29<00:41,  6.77it/s]

{'loss': 1.0242, 'grad_norm': 3.56785249710083, 'learning_rate': 6.878048780487805e-05, 'epoch': 0.32}


                                                 
 32%|███▏      | 130/410 [00:30<00:41,  6.77it/s] 

{'eval_loss': 1.014062523841858, 'eval_accuracy': 0.47259441018104553, 'eval_runtime': 1.2798, 'eval_samples_per_second': 641.514, 'eval_steps_per_second': 160.964, 'epoch': 0.32}


 34%|███▍      | 140/410 [00:31<00:38,  6.96it/s]

{'loss': 1.0119, 'grad_norm': 4.759372234344482, 'learning_rate': 6.634146341463415e-05, 'epoch': 0.34}


                                                 
 34%|███▍      | 140/410 [00:32<00:38,  6.96it/s] 

{'eval_loss': 1.0231857299804688, 'eval_accuracy': 0.47381243109703064, 'eval_runtime': 1.2138, 'eval_samples_per_second': 676.393, 'eval_steps_per_second': 169.716, 'epoch': 0.34}


 37%|███▋      | 150/410 [00:33<00:34,  7.54it/s]

{'loss': 0.9755, 'grad_norm': 3.1859962940216064, 'learning_rate': 6.390243902439025e-05, 'epoch': 0.37}


                                                 
 37%|███▋      | 150/410 [00:35<00:34,  7.54it/s] 

{'eval_loss': 1.017181158065796, 'eval_accuracy': 0.47381243109703064, 'eval_runtime': 1.2531, 'eval_samples_per_second': 655.175, 'eval_steps_per_second': 164.392, 'epoch': 0.37}


 39%|███▉      | 160/410 [00:36<00:31,  7.82it/s]

{'loss': 0.999, 'grad_norm': 4.143805503845215, 'learning_rate': 6.146341463414634e-05, 'epoch': 0.39}


                                                 
 39%|███▉      | 160/410 [00:37<00:31,  7.82it/s] 

{'eval_loss': 1.0147358179092407, 'eval_accuracy': 0.47259441018104553, 'eval_runtime': 1.3288, 'eval_samples_per_second': 617.834, 'eval_steps_per_second': 155.023, 'epoch': 0.39}


 41%|████▏     | 170/410 [00:38<00:35,  6.78it/s]

{'loss': 0.9565, 'grad_norm': 9.80313491821289, 'learning_rate': 5.902439024390244e-05, 'epoch': 0.41}


                                                 
 41%|████▏     | 170/410 [00:39<00:35,  6.78it/s] 

{'eval_loss': 1.0360673666000366, 'eval_accuracy': 0.47259441018104553, 'eval_runtime': 1.2188, 'eval_samples_per_second': 673.596, 'eval_steps_per_second': 169.014, 'epoch': 0.41}


 44%|████▍     | 180/410 [00:40<00:33,  6.95it/s]

{'loss': 1.1007, 'grad_norm': 7.774517059326172, 'learning_rate': 5.6585365853658533e-05, 'epoch': 0.44}


                                                 
 44%|████▍     | 180/410 [00:42<00:33,  6.95it/s] 

{'eval_loss': 0.9991481900215149, 'eval_accuracy': 0.47381243109703064, 'eval_runtime': 1.27, 'eval_samples_per_second': 646.458, 'eval_steps_per_second': 162.205, 'epoch': 0.44}


 46%|████▋     | 190/410 [00:43<00:32,  6.74it/s]

{'loss': 1.0925, 'grad_norm': 5.032519817352295, 'learning_rate': 5.414634146341464e-05, 'epoch': 0.46}


                                                 
 46%|████▋     | 190/410 [00:44<00:32,  6.74it/s] 

{'eval_loss': 1.0065792798995972, 'eval_accuracy': 0.47381243109703064, 'eval_runtime': 1.2542, 'eval_samples_per_second': 654.612, 'eval_steps_per_second': 164.251, 'epoch': 0.46}


 49%|████▉     | 200/410 [00:45<00:31,  6.73it/s]

{'loss': 0.9326, 'grad_norm': 6.278726577758789, 'learning_rate': 5.1707317073170736e-05, 'epoch': 0.49}


                                                 
 49%|████▉     | 200/410 [00:46<00:31,  6.73it/s] 

{'eval_loss': 1.0147473812103271, 'eval_accuracy': 0.47259441018104553, 'eval_runtime': 1.2275, 'eval_samples_per_second': 668.851, 'eval_steps_per_second': 167.824, 'epoch': 0.49}


 51%|█████     | 210/410 [00:47<00:31,  6.31it/s]

{'loss': 1.1036, 'grad_norm': 7.69120979309082, 'learning_rate': 4.926829268292683e-05, 'epoch': 0.51}


                                                 
 51%|█████     | 210/410 [00:49<00:31,  6.31it/s] 

{'eval_loss': 1.0611748695373535, 'eval_accuracy': 0.47259441018104553, 'eval_runtime': 1.2883, 'eval_samples_per_second': 637.251, 'eval_steps_per_second': 159.895, 'epoch': 0.51}


 54%|█████▎    | 220/410 [00:50<00:29,  6.54it/s]

{'loss': 1.1026, 'grad_norm': 4.128659725189209, 'learning_rate': 4.682926829268293e-05, 'epoch': 0.54}


                                                 
 54%|█████▎    | 220/410 [00:51<00:29,  6.54it/s] 

{'eval_loss': 1.0064667463302612, 'eval_accuracy': 0.47259441018104553, 'eval_runtime': 1.2769, 'eval_samples_per_second': 642.978, 'eval_steps_per_second': 161.332, 'epoch': 0.54}


 56%|█████▌    | 230/410 [00:52<00:24,  7.37it/s]

{'loss': 1.1025, 'grad_norm': 5.319516181945801, 'learning_rate': 4.4390243902439024e-05, 'epoch': 0.56}


                                                 
 56%|█████▌    | 230/410 [00:53<00:24,  7.37it/s] 

{'eval_loss': 1.0025233030319214, 'eval_accuracy': 0.47381243109703064, 'eval_runtime': 1.3103, 'eval_samples_per_second': 626.566, 'eval_steps_per_second': 157.214, 'epoch': 0.56}


 59%|█████▊    | 240/410 [00:55<00:23,  7.29it/s]

{'loss': 1.058, 'grad_norm': 2.92557430267334, 'learning_rate': 4.195121951219512e-05, 'epoch': 0.58}


                                                 
 59%|█████▊    | 240/410 [00:56<00:23,  7.29it/s] 

{'eval_loss': 0.9859239459037781, 'eval_accuracy': 0.4835566282272339, 'eval_runtime': 1.3476, 'eval_samples_per_second': 609.249, 'eval_steps_per_second': 152.869, 'epoch': 0.58}


 61%|██████    | 250/410 [00:57<00:21,  7.30it/s]

{'loss': 1.1351, 'grad_norm': 8.762629508972168, 'learning_rate': 3.951219512195122e-05, 'epoch': 0.61}


                                                 
 61%|██████    | 250/410 [00:58<00:21,  7.30it/s] 

{'eval_loss': 0.9825782179832458, 'eval_accuracy': 0.4872107207775116, 'eval_runtime': 1.4443, 'eval_samples_per_second': 568.439, 'eval_steps_per_second': 142.629, 'epoch': 0.61}


 63%|██████▎   | 260/410 [00:59<00:20,  7.29it/s]

{'loss': 1.0278, 'grad_norm': 5.99693489074707, 'learning_rate': 3.707317073170732e-05, 'epoch': 0.63}


                                                 
 63%|██████▎   | 260/410 [01:01<00:20,  7.29it/s] 

{'eval_loss': 1.0397249460220337, 'eval_accuracy': 0.4872107207775116, 'eval_runtime': 1.2616, 'eval_samples_per_second': 650.739, 'eval_steps_per_second': 163.279, 'epoch': 0.63}


 66%|██████▌   | 270/410 [01:02<00:19,  7.11it/s]

{'loss': 1.1505, 'grad_norm': 3.9375317096710205, 'learning_rate': 3.4634146341463416e-05, 'epoch': 0.66}


                                                 
 66%|██████▌   | 270/410 [01:03<00:19,  7.11it/s] 

{'eval_loss': 0.9953012466430664, 'eval_accuracy': 0.4847746789455414, 'eval_runtime': 1.3188, 'eval_samples_per_second': 622.52, 'eval_steps_per_second': 156.199, 'epoch': 0.66}


 68%|██████▊   | 280/410 [01:04<00:21,  6.00it/s]

{'loss': 0.9947, 'grad_norm': 3.86800217628479, 'learning_rate': 3.2195121951219514e-05, 'epoch': 0.68}


                                                 
 68%|██████▊   | 280/410 [01:05<00:21,  6.00it/s] 

{'eval_loss': 0.9717870354652405, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.2413, 'eval_samples_per_second': 661.382, 'eval_steps_per_second': 165.95, 'epoch': 0.68}


 71%|███████   | 290/410 [01:06<00:15,  7.68it/s]

{'loss': 0.86, 'grad_norm': 2.8648548126220703, 'learning_rate': 2.975609756097561e-05, 'epoch': 0.71}


                                                 
 71%|███████   | 290/410 [01:08<00:15,  7.68it/s] 

{'eval_loss': 1.0334666967391968, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.2491, 'eval_samples_per_second': 657.272, 'eval_steps_per_second': 164.918, 'epoch': 0.71}


 73%|███████▎  | 300/410 [01:09<00:14,  7.59it/s]

{'loss': 1.1464, 'grad_norm': 5.990636348724365, 'learning_rate': 2.731707317073171e-05, 'epoch': 0.73}


                                                 
 73%|███████▎  | 300/410 [01:10<00:14,  7.59it/s] 

{'eval_loss': 1.0052778720855713, 'eval_accuracy': 0.47990256547927856, 'eval_runtime': 1.4365, 'eval_samples_per_second': 571.511, 'eval_steps_per_second': 143.4, 'epoch': 0.73}


 76%|███████▌  | 310/410 [01:11<00:13,  7.24it/s]

{'loss': 1.1137, 'grad_norm': 7.039565563201904, 'learning_rate': 2.4878048780487805e-05, 'epoch': 0.76}


                                                 
 76%|███████▌  | 310/410 [01:13<00:13,  7.24it/s] 

{'eval_loss': 0.9726890325546265, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.304, 'eval_samples_per_second': 629.625, 'eval_steps_per_second': 157.981, 'epoch': 0.76}


 78%|███████▊  | 320/410 [01:14<00:13,  6.50it/s]

{'loss': 0.9822, 'grad_norm': 6.407066345214844, 'learning_rate': 2.2439024390243904e-05, 'epoch': 0.78}


                                                 
 78%|███████▊  | 320/410 [01:15<00:13,  6.50it/s] 

{'eval_loss': 0.9701511859893799, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.3287, 'eval_samples_per_second': 617.896, 'eval_steps_per_second': 155.038, 'epoch': 0.78}


 80%|████████  | 330/410 [01:16<00:11,  7.17it/s]

{'loss': 1.047, 'grad_norm': 6.650015830993652, 'learning_rate': 2e-05, 'epoch': 0.8}


                                                 
 80%|████████  | 330/410 [01:18<00:11,  7.17it/s] 

{'eval_loss': 0.9714782238006592, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.4252, 'eval_samples_per_second': 576.078, 'eval_steps_per_second': 144.546, 'epoch': 0.8}


 83%|████████▎ | 340/410 [01:19<00:10,  6.76it/s]

{'loss': 1.0672, 'grad_norm': 3.011765718460083, 'learning_rate': 1.7560975609756096e-05, 'epoch': 0.83}


                                                 
 83%|████████▎ | 340/410 [01:20<00:10,  6.76it/s] 

{'eval_loss': 0.9702971577644348, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.3916, 'eval_samples_per_second': 589.954, 'eval_steps_per_second': 148.027, 'epoch': 0.83}


 85%|████████▌ | 350/410 [01:21<00:07,  7.59it/s]

{'loss': 0.9919, 'grad_norm': 3.513758420944214, 'learning_rate': 1.5121951219512196e-05, 'epoch': 0.85}


                                                 
 85%|████████▌ | 350/410 [01:23<00:07,  7.59it/s] 

{'eval_loss': 0.9764745235443115, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.7068, 'eval_samples_per_second': 481.018, 'eval_steps_per_second': 120.694, 'epoch': 0.85}


 88%|████████▊ | 360/410 [01:24<00:08,  6.09it/s]

{'loss': 0.9431, 'grad_norm': 11.876274108886719, 'learning_rate': 1.2682926829268294e-05, 'epoch': 0.88}


                                                 
 88%|████████▊ | 360/410 [01:26<00:08,  6.09it/s] 

{'eval_loss': 0.9711700081825256, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.3813, 'eval_samples_per_second': 594.363, 'eval_steps_per_second': 149.134, 'epoch': 0.88}


 90%|█████████ | 370/410 [01:27<00:06,  6.11it/s]

{'loss': 0.9602, 'grad_norm': 5.580501079559326, 'learning_rate': 1.024390243902439e-05, 'epoch': 0.9}


                                                 
 90%|█████████ | 370/410 [01:28<00:06,  6.11it/s] 

{'eval_loss': 0.9689047932624817, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.3125, 'eval_samples_per_second': 625.537, 'eval_steps_per_second': 156.956, 'epoch': 0.9}


 93%|█████████▎| 380/410 [01:29<00:04,  6.73it/s]

{'loss': 0.9685, 'grad_norm': 7.638960361480713, 'learning_rate': 7.804878048780489e-06, 'epoch': 0.93}


                                                 
 93%|█████████▎| 380/410 [01:30<00:04,  6.73it/s] 

{'eval_loss': 0.9678650498390198, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.2593, 'eval_samples_per_second': 651.94, 'eval_steps_per_second': 163.58, 'epoch': 0.93}


 95%|█████████▌| 390/410 [01:32<00:02,  7.00it/s]

{'loss': 0.9696, 'grad_norm': 9.11452865600586, 'learning_rate': 5.365853658536585e-06, 'epoch': 0.95}


                                                 
 95%|█████████▌| 390/410 [01:33<00:02,  7.00it/s] 

{'eval_loss': 0.9682180285453796, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.5129, 'eval_samples_per_second': 542.671, 'eval_steps_per_second': 136.164, 'epoch': 0.95}


 98%|█████████▊| 400/410 [01:34<00:01,  6.81it/s]

{'loss': 0.9777, 'grad_norm': 2.898589611053467, 'learning_rate': 2.9268292682926833e-06, 'epoch': 0.97}


                                                 
 98%|█████████▊| 400/410 [01:36<00:01,  6.81it/s] 

{'eval_loss': 0.9683374762535095, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.6236, 'eval_samples_per_second': 505.665, 'eval_steps_per_second': 126.878, 'epoch': 0.97}


100%|██████████| 410/410 [01:37<00:00,  6.43it/s]

{'loss': 0.8512, 'grad_norm': 7.0485663414001465, 'learning_rate': 4.878048780487805e-07, 'epoch': 1.0}


                                                 
100%|██████████| 410/410 [01:39<00:00,  6.43it/s] 

{'eval_loss': 0.9682682752609253, 'eval_accuracy': 0.48112058639526367, 'eval_runtime': 1.4587, 'eval_samples_per_second': 562.836, 'eval_steps_per_second': 141.223, 'epoch': 1.0}


100%|██████████| 410/410 [01:39<00:00,  4.11it/s]

{'train_runtime': 99.6941, 'train_samples_per_second': 32.931, 'train_steps_per_second': 4.113, 'train_loss': 1.0654193878173828, 'epoch': 1.0}





TrainOutput(global_step=410, training_loss=1.0654193878173828, metrics={'train_runtime': 99.6941, 'train_samples_per_second': 32.931, 'train_steps_per_second': 4.113, 'total_flos': 0.0, 'train_loss': 1.0654193878173828, 'epoch': 0.9987819732034104})

In [12]:
trainer.evaluate(eval_dataset=eval_dataset)

100%|██████████| 206/206 [00:01<00:00, 143.15it/s]


{'eval_loss': 0.9682682752609253,
 'eval_accuracy': 0.48112058639526367,
 'eval_runtime': 1.4516,
 'eval_samples_per_second': 565.574,
 'eval_steps_per_second': 141.91,
 'epoch': 0.9987819732034104}