In [1]:
import random
import numpy as np
import torch
import os

def set_seed(seed):
    """
    设置所有随机种子以确保结果可复现
    
    Args:
        seed (int): 随机种子数值
    """
    random.seed(seed)  # Python的随机种子
    np.random.seed(seed)  # Numpy的随机种子
    torch.manual_seed(seed)  # PyTorch的CPU随机种子
    torch.cuda.manual_seed(seed)  # PyTorch的GPU随机种子
    torch.cuda.manual_seed_all(seed)  # 如果使用多GPU，为所有GPU设置种子
    
    # 设置cudnn的随机种子
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    # 设置Python的hash种子
    os.environ['PYTHONHASHSEED'] = str(seed)
seed = 2003  # 可以设置任何整数
set_seed(seed)

In [2]:
import os
import shutil
from tqdm import tqdm

def organize_files(src_root_dir, dst_root_dir):
    cases = ['case1', 'case2', 'case3', 'case4']

    # 为每个 case 创建目录
    for case in cases:
        case_dir = os.path.join(dst_root_dir, case)
        if not os.path.exists(case_dir):
            os.makedirs(case_dir)

    # 遍历源目录
    for main_dir in tqdm(os.listdir(src_root_dir), desc="处理目录"):
        main_dir_path = os.path.join(src_root_dir, main_dir)
        if os.path.isdir(main_dir_path):
            for sub_dir in os.listdir(main_dir_path):
                sub_dir_path = os.path.join(main_dir_path, sub_dir)
                if os.path.isdir(sub_dir_path):
                    case_name = sub_dir.split('_')[0]
                    case_dir = os.path.join(dst_root_dir, case_name)

                    for filename in os.listdir(sub_dir_path):
                        if filename.endswith('.npy'):
                            src_npy_path = os.path.join(sub_dir_path, filename)
                            dst_npy_path = os.path.join(case_dir, f"{main_dir}_{sub_dir}_{filename}")
                            shutil.copy(src_npy_path, dst_npy_path)
                            # print(f"已复制 {src_npy_path} 到 {dst_npy_path}")

src_directory = 'dataset_1k2k3k_withbandpass_extrafeatures'
dst_directory = 'data_1k2k3k_nobandpass_organized_dataset_1'

print(f"源目录: {src_directory}")
print(f"目标目录: {dst_directory}")

if not os.path.exists(src_directory):
    print(f"源目录不存在: {src_directory}")
else:
    organize_files(src_directory, dst_directory)
    print("文件整理完成。")

源目录: dataset_1k2k3k_nobandpass_extrafeatures
目标目录: data_1k2k3k_nobandpass_organized_dataset_1
源目录不存在: dataset_1k2k3k_nobandpass_extrafeatures


In [3]:
import numpy as np
import pandas as pd

from pathlib import Path
from tqdm import tqdm

import torchaudio
from sklearn.model_selection import train_test_split
import os
import sys

In [4]:
import os
from pathlib import Path
import numpy as np
from tqdm import tqdm

data = []

for case in ['case1', 'case2', 'case3', 'case4']:
    case_path = Path(f'data_1k2k3k_nobandpass_organized_dataset_1/{case}')
    for path in tqdm(case_path.glob("*.npy"), desc=f"处理 {case}"):
        name = path.stem
        # 获取文件名的各个部分
        parts = path.name.split('_')
        prefix = parts[0]  # 前缀
        case_id = '_'.join(parts[1:-2])  # case_id（可能包含多个部分）
        sample_set = parts[-2]  # sample_set

        try:
            # 加载 .npy 文件
            npy_path = path.with_suffix('.npy')
            if npy_path.exists():
                energy_features = np.load(npy_path)

            data.append({
                "name": name,
                "path": str(path),
                "case": case,
                "prefix": prefix,
                "case_id": case_id,
                "sample_set": sample_set,
                "energy_features": energy_features
            })
        except Exception as e:
            print(f"处理文件 {path} 时出错: {str(e)}")
            # 跳过损坏的文件
            pass

# 显示收集到的数据条目数
print(f"收集了 {len(data)} 个项目。")

# 数据统计
case_counts = {case: sum(1 for item in data if item['case'] == case) for case in ['case1', 'case2', 'case3', 'case4']}
print("\n数据分布:")
for case, count in case_counts.items():
    print(f"{case}: {count} 个项目")

# 检查是否所有项目都有能量特征
items_with_features = sum(1 for item in data if item['energy_features'] is not None)
print(f"\n具有能量特征的项目: {items_with_features} / {len(data)}")

处理 case1: 0it [00:00, ?it/s]
处理 case2: 0it [00:00, ?it/s]
处理 case3: 0it [00:00, ?it/s]
处理 case4: 0it [00:00, ?it/s]

收集了 0 个项目。

数据分布:
case1: 0 个项目
case2: 0 个项目
case3: 0 个项目
case4: 0 个项目

具有能量特征的项目: 0 / 0





In [5]:
import pandas as pd
df = pd.DataFrame(data)
df.head()

In [6]:
import pandas as pd
df = pd.DataFrame(data)
df.head()
df = df.sample(frac=1).reset_index(drop=True)

Let's display some random sample of the dataset and run it a couple of times to get a feeling for the audio and the emotional label.

In [7]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from datasets import Dataset
from transformers import TrainingArguments, Trainer
import torch
from torch import nn

# 处理 energy_features 列的函数
def process_energy_features(df):
    new_values = []
    for v in df["energy_features"]:
        if isinstance(v, np.ndarray):
            new_values.append(float(v.item()))  # 使用 item() 来获取数组中的单个值
        else:
            new_values.append(float(v))
    df["energy_float"] = new_values
    return df

# 处理数据
df = process_energy_features(df)
df = df.drop(columns=["energy_features"])
df = df.rename(columns={"energy_float": "energy_features"})

# 定义辅助函数
def label_to_id(label, label_list):
    return label_list.index(label)

# 数据集分割
train_df, test_df = train_test_split(df, test_size=0.2, random_state=101, stratify=df["case"])

# 重置索引
train_df = train_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

# 打印数据集的形状
print("训练数据集形状:", train_df.shape)
print("测试数据集形状:", test_df.shape)

# 创建 Dataset 对象
train_dataset = Dataset.from_pandas(train_df)
eval_dataset = Dataset.from_pandas(test_df)

# 指定输出列
output_column = "case"

# 打印每个 case 的样本数
print("\n训练数据集中每个 case 的样本数:")
print(train_df[output_column].value_counts())
print("\n验证数据集中每个 case 的样本数:")
print(test_df[output_column].value_counts())

# 识别和排序标签列表
label_list = sorted(df["case"].unique())
num_labels = len(label_list)
print(f"\n这是一个有 {num_labels} 个类别的分类问题: {label_list}")

# 预处理函数
def preprocess_function(examples):
    labels = [label_to_id(label, label_list) for label in examples["case"]]
    energy_features = examples["energy_features"]
    return {
        "labels": labels, 
        "energy_features": energy_features
    }

# 应用预处理
train_dataset = train_dataset.map(preprocess_function, batched=True, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(preprocess_function, batched=True, remove_columns=eval_dataset.column_names)

# 检查能量特征加载情况
print("\n带有能量特征的训练数据集:")
print(train_dataset[:5])
print("\n带有能量特征的验证数据集:")
print(eval_dataset[:5])

# 统计包含能量特征的样本数
train_with_features = sum(1 for item in train_dataset if item['energy_features'] is not None)
eval_with_features = sum(1 for item in eval_dataset if item['energy_features'] is not None)
print(f"\n训练样本中包含能量特征的数量: {train_with_features} / {len(train_dataset)}")
print(f"验证样本中包含能量特征的数量: {eval_with_features} / {len(eval_dataset)}")

# 定义模型
class SimpleClassifier(nn.Module):
    def __init__(self, input_dim, num_labels):
        super().__init__()
        self.linear = nn.Linear(input_dim, num_labels)
    
    def forward(self, x):
        return self.linear(x)

# 数据整理函数
def collate_fn(batch):
    energy_features = torch.tensor([item['energy_features'] for item in batch], dtype=torch.float)
    labels = torch.tensor([item['labels'] for item in batch], dtype=torch.long)
    return {'energy_features': energy_features, 'labels': labels}

# 计算指标
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return {'accuracy': (predictions == labels).mean()}

# 初始化模型
model = SimpleClassifier(input_dim=1, num_labels=num_labels)

# 打印一些数据检查信息
print("\n数据类型检查:")
print("energy_features 的数据类型:", df["energy_features"].dtype)
print("\n前5个样本的 energy_features:")
print(df["energy_features"].head())

  from .autonotebook import tqdm as notebook_tqdm


KeyError: 'energy_features'

In [45]:
df.head(5)

Unnamed: 0,name,path,case,prefix,case_id,sample_set,energy_features
0,E7_case2_6_sample_36_3,data_1k2k3k_nobandpass_organized_dataset_1\cas...,case2,E7,case2_6_sample,36,0.428939
1,E8_case4_7_sample_41_2,data_1k2k3k_nobandpass_organized_dataset_1\cas...,case4,E8,case4_7_sample,41,184.313711
2,E6_case3_11_sample_66_2,data_1k2k3k_nobandpass_organized_dataset_1\cas...,case3,E6,case3_11_sample,66,25.09655
3,A1_case4_8_sample_47_2,data_1k2k3k_nobandpass_organized_dataset_1\cas...,case4,A1,case4_8_sample,47,128.370998
4,A4_case4_11_sample_66_2,data_1k2k3k_nobandpass_organized_dataset_1\cas...,case4,A4,case4_11_sample,66,92.202606


## Model

Before diving into the training part, we need to build our classification model based on the merge strategy.

In [49]:
from dataclasses import dataclass
from typing import Optional, Tuple, Dict, List, Union
import torch
import torch.nn as nn
import numpy as np
from transformers import PreTrainedModel, AutoConfig, TrainingArguments, Trainer, EvalPrediction
from transformers.file_utils import ModelOutput

@dataclass
class EnergyFeatureClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

class Wav2Vec2ClassificationHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.final_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, features):
        x = self.dropout(features)
        x = torch.tanh(self.dense(x))
        x = self.dropout(x)
        return self.out_proj(x)

class EnergyFeatureEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Linear(1, 256),  # 1D input for energy features
            nn.ReLU(),
            nn.Linear(256, config.hidden_size)
        )
        self.layer_norm = nn.LayerNorm(config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, energy_features):
        hidden_states = self.proj(energy_features.unsqueeze(-1))
        hidden_states = self.layer_norm(hidden_states)
        return self.dropout(hidden_states)

class SimpleWav2Vec2Model(nn.Module):
    def __init__(self, config):
        super().__init__()
        encoder_layer = nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_attention_heads)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_hidden_layers)

    def forward(self, hidden_states, attention_mask=None):
        if attention_mask is not None:
            attention_mask = (1.0 - attention_mask.unsqueeze(1).unsqueeze(2)) * -10000.0
        return self.encoder(hidden_states.transpose(0, 1), src_key_padding_mask=attention_mask).transpose(0, 1)

class Wav2Vec2ForEnergyClassification(PreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.num_labels = config.num_labels  # Add this line
        self.energy_feature_encoder = EnergyFeatureEncoder(config)
        self.wav2vec2 = SimpleWav2Vec2Model(config)
        self.classifier = Wav2Vec2ClassificationHead(config)

    def freeze_feature_extractor(self):
        for param in self.energy_feature_encoder.parameters():
            param.requires_grad = False

    def merged_strategy(self, hidden_states, mode="mean"):
        if mode == "mean":
            return torch.mean(hidden_states, dim=1)
        elif mode == "sum":
            return torch.sum(hidden_states, dim=1)
        elif mode == "max":
            return torch.max(hidden_states, dim=1)[0]
        else:
            raise ValueError("Invalid pooling mode: choose from ['mean', 'sum', 'max']")

    def forward(self, energy_features, attention_mask=None, labels=None, return_dict=True):
        hidden_states = self.energy_feature_encoder(energy_features).unsqueeze(1)  # Add sequence dimension
        hidden_states = self.wav2vec2(hidden_states, attention_mask)
        pooled_output = self.merged_strategy(hidden_states, mode=self.config.pooling_mode)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss() if self.config.problem_type == "single_label_classification" else nn.MSELoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        if not return_dict:
            return (loss, logits) if loss is not None else logits

        return EnergyFeatureClassifierOutput(loss=loss, logits=logits)

@dataclass
class DataCollatorForEnergyFeatures:
    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor, np.ndarray]]]) -> Dict[str, torch.Tensor]:
        labels = [feature["labels"] for feature in features]
        energy_features = [feature["energy_features"] for feature in features]

        batch = {
            "labels": torch.tensor(labels, dtype=torch.long if isinstance(labels[0], int) else torch.float),
            "energy_features": torch.tensor(energy_features, dtype=torch.float),
        }
        return batch

def compute_metrics(p: EvalPrediction):
    preds = np.argmax(p.predictions, axis=1) if not is_regression else np.squeeze(p.predictions)
    accuracy = (preds == p.label_ids).mean()
    return {"accuracy": accuracy} if not is_regression else {"mse": ((preds - p.label_ids) ** 2).mean()}

# Load configuration
model_name_or_path = "facebook/wav2vec2-large-xlsr-53"  # Example pretrained model
config = AutoConfig.from_pretrained(model_name_or_path)

# Update configuration for energy features
config.energy_feature_dim = 1
config.num_labels = 4
config.problem_type = "single_label_classification"
config.pooling_mode = "mean"
config.hidden_dropout_prob = 0.1
config.final_dropout = 0.1
config.hidden_size = 768
config.num_attention_heads = 12
config.num_hidden_layers = 12

# Initialize model, data collator, and trainer
model = Wav2Vec2ForEnergyClassification(config)
data_collator = DataCollatorForEnergyFeatures()
is_regression = config.problem_type == "regression"

training_args = TrainingArguments(
    output_dir="./results_energybased",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    evaluation_strategy="steps",
    num_train_epochs=1,
    fp16=True,
    save_steps=10,
    eval_steps=10,
    logging_steps=10,
    learning_rate=1e-4,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,  # Ensure these datasets are defined
    eval_dataset=eval_dataset,
)

# Start training
trainer.train()


  0%|          | 0/410 [00:00<?, ?it/s]Could not estimate the number of tokens of the input, floating-point operations will not be computed
  2%|▏         | 10/410 [00:01<00:41,  9.55it/s]

{'loss': 1.6713, 'grad_norm': 7.762096405029297, 'learning_rate': 9.75609756097561e-05, 'epoch': 0.02}



  2%|▏         | 10/410 [00:02<00:41,  9.55it/s]  

{'eval_loss': 1.4148029088974, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 1.8498, 'eval_samples_per_second': 443.838, 'eval_steps_per_second': 111.365, 'epoch': 0.02}


  5%|▍         | 20/410 [00:04<01:03,  6.09it/s]

{'loss': 1.4571, 'grad_norm': 4.453792095184326, 'learning_rate': 9.51219512195122e-05, 'epoch': 0.05}



  5%|▍         | 20/410 [00:06<01:03,  6.09it/s] 

{'eval_loss': 1.3842964172363281, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.1663, 'eval_samples_per_second': 378.986, 'eval_steps_per_second': 95.093, 'epoch': 0.05}


  7%|▋         | 30/410 [00:08<01:23,  4.53it/s]

{'loss': 1.425, 'grad_norm': 9.702316284179688, 'learning_rate': 9.26829268292683e-05, 'epoch': 0.07}



  7%|▋         | 30/410 [00:11<01:23,  4.53it/s] 

{'eval_loss': 1.424659013748169, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.9421, 'eval_samples_per_second': 279.055, 'eval_steps_per_second': 70.019, 'epoch': 0.07}


 10%|▉         | 40/410 [00:13<01:15,  4.89it/s]

{'loss': 1.4267, 'grad_norm': 4.601459980010986, 'learning_rate': 9.02439024390244e-05, 'epoch': 0.1}



 10%|▉         | 40/410 [00:16<01:15,  4.89it/s] 

{'eval_loss': 1.4792412519454956, 'eval_accuracy': 0.1668696711327649, 'eval_runtime': 3.1141, 'eval_samples_per_second': 263.636, 'eval_steps_per_second': 66.15, 'epoch': 0.1}


 12%|█▏        | 50/410 [00:18<01:07,  5.34it/s]

{'loss': 1.3881, 'grad_norm': 9.50358772277832, 'learning_rate': 8.78048780487805e-05, 'epoch': 0.12}



 12%|█▏        | 50/410 [00:21<01:07,  5.34it/s] 

{'eval_loss': 1.3414673805236816, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 3.0396, 'eval_samples_per_second': 270.098, 'eval_steps_per_second': 67.771, 'epoch': 0.12}


 15%|█▍        | 60/410 [00:23<01:08,  5.08it/s]

{'loss': 1.3666, 'grad_norm': 6.82242488861084, 'learning_rate': 8.53658536585366e-05, 'epoch': 0.15}



 15%|█▍        | 60/410 [00:26<01:08,  5.08it/s] 

{'eval_loss': 1.37977933883667, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.8939, 'eval_samples_per_second': 283.705, 'eval_steps_per_second': 71.185, 'epoch': 0.15}


 17%|█▋        | 70/410 [00:28<01:04,  5.28it/s]

{'loss': 1.3262, 'grad_norm': 6.691375255584717, 'learning_rate': 8.292682926829268e-05, 'epoch': 0.17}



 17%|█▋        | 70/410 [00:31<01:04,  5.28it/s] 

{'eval_loss': 1.4056100845336914, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.7353, 'eval_samples_per_second': 300.154, 'eval_steps_per_second': 75.313, 'epoch': 0.17}


 20%|█▉        | 80/410 [00:32<01:15,  4.35it/s]

{'loss': 1.5459, 'grad_norm': 10.356614112854004, 'learning_rate': 8.048780487804879e-05, 'epoch': 0.19}



 20%|█▉        | 80/410 [00:34<01:15,  4.35it/s]  

{'eval_loss': 1.3554842472076416, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 1.5706, 'eval_samples_per_second': 522.725, 'eval_steps_per_second': 131.159, 'epoch': 0.19}


 22%|██▏       | 90/410 [00:36<00:58,  5.51it/s]

{'loss': 1.3697, 'grad_norm': 5.958805084228516, 'learning_rate': 7.804878048780489e-05, 'epoch': 0.22}



 22%|██▏       | 90/410 [00:39<00:58,  5.51it/s] 

{'eval_loss': 1.4826003313064575, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 2.7252, 'eval_samples_per_second': 301.265, 'eval_steps_per_second': 75.591, 'epoch': 0.22}


 24%|██▍       | 100/410 [00:41<00:51,  6.00it/s]

{'loss': 1.4209, 'grad_norm': 5.731048107147217, 'learning_rate': 7.560975609756099e-05, 'epoch': 0.24}



 24%|██▍       | 100/410 [00:44<00:51,  6.00it/s]

{'eval_loss': 1.3583461046218872, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 3.0416, 'eval_samples_per_second': 269.923, 'eval_steps_per_second': 67.727, 'epoch': 0.24}


 27%|██▋       | 110/410 [00:46<01:10,  4.28it/s]

{'loss': 1.2343, 'grad_norm': 6.526340484619141, 'learning_rate': 7.317073170731707e-05, 'epoch': 0.27}



 27%|██▋       | 110/410 [00:49<01:10,  4.28it/s]

{'eval_loss': 1.4393787384033203, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 3.0007, 'eval_samples_per_second': 273.604, 'eval_steps_per_second': 68.651, 'epoch': 0.27}


 29%|██▉       | 120/410 [00:51<00:47,  6.13it/s]

{'loss': 1.4419, 'grad_norm': 6.005385398864746, 'learning_rate': 7.073170731707317e-05, 'epoch': 0.29}



 29%|██▉       | 120/410 [00:54<00:47,  6.13it/s]

{'eval_loss': 1.4177879095077515, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.9003, 'eval_samples_per_second': 283.073, 'eval_steps_per_second': 71.027, 'epoch': 0.29}


 32%|███▏      | 130/410 [00:56<00:57,  4.88it/s]

{'loss': 1.3328, 'grad_norm': 4.08013916015625, 'learning_rate': 6.829268292682928e-05, 'epoch': 0.32}



 32%|███▏      | 130/410 [00:59<00:57,  4.88it/s]

{'eval_loss': 1.356595754623413, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 3.0389, 'eval_samples_per_second': 270.163, 'eval_steps_per_second': 67.787, 'epoch': 0.32}


 34%|███▍      | 140/410 [01:01<00:49,  5.47it/s]

{'loss': 1.351, 'grad_norm': 5.449190139770508, 'learning_rate': 6.585365853658538e-05, 'epoch': 0.34}



 34%|███▍      | 140/410 [01:03<00:49,  5.47it/s] 

{'eval_loss': 1.3594731092453003, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 1.8516, 'eval_samples_per_second': 443.405, 'eval_steps_per_second': 111.256, 'epoch': 0.34}


 37%|███▋      | 150/410 [01:04<00:44,  5.81it/s]

{'loss': 1.3687, 'grad_norm': 6.454527378082275, 'learning_rate': 6.341463414634146e-05, 'epoch': 0.37}



 37%|███▋      | 150/410 [01:06<00:44,  5.81it/s]

{'eval_loss': 1.3350822925567627, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.1339, 'eval_samples_per_second': 384.739, 'eval_steps_per_second': 96.536, 'epoch': 0.37}


 39%|███▉      | 160/410 [01:08<00:51,  4.81it/s]

{'loss': 1.4177, 'grad_norm': 4.109144687652588, 'learning_rate': 6.097560975609756e-05, 'epoch': 0.39}



 39%|███▉      | 160/410 [01:11<00:51,  4.81it/s]

{'eval_loss': 1.4151747226715088, 'eval_accuracy': 0.1668696711327649, 'eval_runtime': 2.8303, 'eval_samples_per_second': 290.079, 'eval_steps_per_second': 72.785, 'epoch': 0.39}


 41%|████▏     | 170/410 [01:13<00:41,  5.78it/s]

{'loss': 1.3804, 'grad_norm': 8.673711776733398, 'learning_rate': 5.853658536585366e-05, 'epoch': 0.41}



 41%|████▏     | 170/410 [01:16<00:41,  5.78it/s]

{'eval_loss': 1.357816219329834, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.8121, 'eval_samples_per_second': 291.95, 'eval_steps_per_second': 73.254, 'epoch': 0.41}


 44%|████▍     | 180/410 [01:18<00:42,  5.44it/s]

{'loss': 1.3372, 'grad_norm': 7.230207920074463, 'learning_rate': 5.6097560975609764e-05, 'epoch': 0.44}



 44%|████▍     | 180/410 [01:20<00:42,  5.44it/s]

{'eval_loss': 1.3549745082855225, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 2.8522, 'eval_samples_per_second': 287.851, 'eval_steps_per_second': 72.226, 'epoch': 0.44}


 46%|████▋     | 190/410 [01:22<00:44,  4.99it/s]

{'loss': 1.4133, 'grad_norm': 5.4263916015625, 'learning_rate': 5.365853658536586e-05, 'epoch': 0.46}



 46%|████▋     | 190/410 [01:25<00:44,  4.99it/s]

{'eval_loss': 1.345613956451416, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 2.8558, 'eval_samples_per_second': 287.487, 'eval_steps_per_second': 72.134, 'epoch': 0.46}


 49%|████▉     | 200/410 [01:27<00:41,  5.12it/s]

{'loss': 1.368, 'grad_norm': 7.628442764282227, 'learning_rate': 5.121951219512195e-05, 'epoch': 0.49}



 49%|████▉     | 200/410 [01:30<00:41,  5.12it/s]

{'eval_loss': 1.3600053787231445, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.9903, 'eval_samples_per_second': 274.558, 'eval_steps_per_second': 68.89, 'epoch': 0.49}


 51%|█████     | 210/410 [01:32<00:45,  4.40it/s]

{'loss': 1.3804, 'grad_norm': 7.013303279876709, 'learning_rate': 4.878048780487805e-05, 'epoch': 0.51}



 51%|█████     | 210/410 [01:34<00:45,  4.40it/s] 

{'eval_loss': 1.3740049600601196, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 1.7872, 'eval_samples_per_second': 459.368, 'eval_steps_per_second': 115.262, 'epoch': 0.51}


 54%|█████▎    | 220/410 [01:35<00:37,  5.09it/s]

{'loss': 1.3286, 'grad_norm': 5.795467376708984, 'learning_rate': 4.634146341463415e-05, 'epoch': 0.54}



 54%|█████▎    | 220/410 [01:38<00:37,  5.09it/s]

{'eval_loss': 1.3339807987213135, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.6346, 'eval_samples_per_second': 311.627, 'eval_steps_per_second': 78.191, 'epoch': 0.54}


 56%|█████▌    | 230/410 [01:40<00:31,  5.80it/s]

{'loss': 1.3521, 'grad_norm': 5.837274551391602, 'learning_rate': 4.390243902439025e-05, 'epoch': 0.56}



 56%|█████▌    | 230/410 [01:43<00:31,  5.80it/s]

{'eval_loss': 1.3328840732574463, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 2.8732, 'eval_samples_per_second': 285.747, 'eval_steps_per_second': 71.698, 'epoch': 0.56}


 59%|█████▊    | 240/410 [01:45<00:37,  4.59it/s]

{'loss': 1.2589, 'grad_norm': 3.926546335220337, 'learning_rate': 4.146341463414634e-05, 'epoch': 0.58}



 59%|█████▊    | 240/410 [01:48<00:37,  4.59it/s]

{'eval_loss': 1.347190022468567, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.9562, 'eval_samples_per_second': 277.725, 'eval_steps_per_second': 69.685, 'epoch': 0.58}


 61%|██████    | 250/410 [01:50<00:35,  4.45it/s]

{'loss': 1.3679, 'grad_norm': 9.493637084960938, 'learning_rate': 3.9024390243902444e-05, 'epoch': 0.61}



 61%|██████    | 250/410 [01:53<00:35,  4.45it/s]

{'eval_loss': 1.3648983240127563, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.8793, 'eval_samples_per_second': 285.137, 'eval_steps_per_second': 71.545, 'epoch': 0.61}


 63%|██████▎   | 260/410 [01:55<00:24,  6.02it/s]

{'loss': 1.4589, 'grad_norm': 7.927194118499756, 'learning_rate': 3.6585365853658535e-05, 'epoch': 0.63}



 63%|██████▎   | 260/410 [01:58<00:24,  6.02it/s]

{'eval_loss': 1.3503652811050415, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 2.7842, 'eval_samples_per_second': 294.88, 'eval_steps_per_second': 73.989, 'epoch': 0.63}


 66%|██████▌   | 270/410 [02:00<00:33,  4.19it/s]

{'loss': 1.4082, 'grad_norm': 4.393301486968994, 'learning_rate': 3.414634146341464e-05, 'epoch': 0.66}



 66%|██████▌   | 270/410 [02:02<00:33,  4.19it/s] 

{'eval_loss': 1.3515386581420898, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 1.9715, 'eval_samples_per_second': 416.429, 'eval_steps_per_second': 104.488, 'epoch': 0.66}


 68%|██████▊   | 280/410 [02:03<00:24,  5.34it/s]

{'loss': 1.3336, 'grad_norm': 5.0311455726623535, 'learning_rate': 3.170731707317073e-05, 'epoch': 0.68}



 68%|██████▊   | 280/410 [02:05<00:24,  5.34it/s]

{'eval_loss': 1.341119408607483, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 1.7595, 'eval_samples_per_second': 466.616, 'eval_steps_per_second': 117.08, 'epoch': 0.68}


 71%|███████   | 290/410 [02:07<00:22,  5.43it/s]

{'loss': 1.3198, 'grad_norm': 6.903152942657471, 'learning_rate': 2.926829268292683e-05, 'epoch': 0.71}



 71%|███████   | 290/410 [02:09<00:22,  5.43it/s]

{'eval_loss': 1.3707672357559204, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.7311, 'eval_samples_per_second': 300.612, 'eval_steps_per_second': 75.428, 'epoch': 0.71}


 73%|███████▎  | 300/410 [02:11<00:18,  5.86it/s]

{'loss': 1.4667, 'grad_norm': 5.198715686798096, 'learning_rate': 2.682926829268293e-05, 'epoch': 0.73}



 73%|███████▎  | 300/410 [02:14<00:18,  5.86it/s]

{'eval_loss': 1.339008092880249, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.9335, 'eval_samples_per_second': 279.866, 'eval_steps_per_second': 70.222, 'epoch': 0.73}


 76%|███████▌  | 310/410 [02:16<00:20,  4.78it/s]

{'loss': 1.3467, 'grad_norm': 5.782828330993652, 'learning_rate': 2.4390243902439026e-05, 'epoch': 0.76}



 76%|███████▌  | 310/410 [02:19<00:20,  4.78it/s]

{'eval_loss': 1.3324939012527466, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 2.8608, 'eval_samples_per_second': 286.979, 'eval_steps_per_second': 72.007, 'epoch': 0.76}


 78%|███████▊  | 320/410 [02:21<00:15,  5.65it/s]

{'loss': 1.3322, 'grad_norm': 6.878253936767578, 'learning_rate': 2.1951219512195124e-05, 'epoch': 0.78}



 78%|███████▊  | 320/410 [02:24<00:15,  5.65it/s]

{'eval_loss': 1.330183982849121, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.9143, 'eval_samples_per_second': 281.716, 'eval_steps_per_second': 70.686, 'epoch': 0.78}


 80%|████████  | 330/410 [02:26<00:18,  4.27it/s]

{'loss': 1.385, 'grad_norm': 5.335165500640869, 'learning_rate': 1.9512195121951222e-05, 'epoch': 0.8}



 80%|████████  | 330/410 [02:29<00:18,  4.27it/s]

{'eval_loss': 1.3323036432266235, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 3.0729, 'eval_samples_per_second': 267.172, 'eval_steps_per_second': 67.037, 'epoch': 0.8}


 83%|████████▎ | 340/410 [02:31<00:12,  5.70it/s]

{'loss': 1.3607, 'grad_norm': 4.17697286605835, 'learning_rate': 1.707317073170732e-05, 'epoch': 0.83}



 83%|████████▎ | 340/410 [02:33<00:12,  5.70it/s] 

{'eval_loss': 1.3308049440383911, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 1.5101, 'eval_samples_per_second': 543.657, 'eval_steps_per_second': 136.411, 'epoch': 0.83}


 85%|████████▌ | 350/410 [02:34<00:09,  6.08it/s]

{'loss': 1.3632, 'grad_norm': 3.8333632946014404, 'learning_rate': 1.4634146341463415e-05, 'epoch': 0.85}



 85%|████████▌ | 350/410 [02:37<00:09,  6.08it/s]

{'eval_loss': 1.333175539970398, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.7356, 'eval_samples_per_second': 300.119, 'eval_steps_per_second': 75.304, 'epoch': 0.85}


 88%|████████▊ | 360/410 [02:39<00:09,  5.07it/s]

{'loss': 1.3446, 'grad_norm': 11.858530044555664, 'learning_rate': 1.2195121951219513e-05, 'epoch': 0.88}



 88%|████████▊ | 360/410 [02:42<00:09,  5.07it/s]

{'eval_loss': 1.3317232131958008, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 2.9173, 'eval_samples_per_second': 281.427, 'eval_steps_per_second': 70.614, 'epoch': 0.88}


 90%|█████████ | 370/410 [02:44<00:07,  5.17it/s]

{'loss': 1.3604, 'grad_norm': 5.671450614929199, 'learning_rate': 9.756097560975611e-06, 'epoch': 0.9}



 90%|█████████ | 370/410 [02:47<00:07,  5.17it/s]

{'eval_loss': 1.3317124843597412, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 2.8319, 'eval_samples_per_second': 289.908, 'eval_steps_per_second': 72.742, 'epoch': 0.9}


 93%|█████████▎| 380/410 [02:49<00:05,  5.47it/s]

{'loss': 1.3624, 'grad_norm': 6.296205997467041, 'learning_rate': 7.317073170731707e-06, 'epoch': 0.93}



 93%|█████████▎| 380/410 [02:52<00:05,  5.47it/s]

{'eval_loss': 1.3306169509887695, 'eval_accuracy': 0.3337393422655298, 'eval_runtime': 3.1058, 'eval_samples_per_second': 264.348, 'eval_steps_per_second': 66.329, 'epoch': 0.93}


 95%|█████████▌| 390/410 [02:54<00:03,  5.23it/s]

{'loss': 1.3475, 'grad_norm': 8.907796859741211, 'learning_rate': 4.8780487804878055e-06, 'epoch': 0.95}



 95%|█████████▌| 390/410 [02:57<00:03,  5.23it/s]

{'eval_loss': 1.3306026458740234, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.9272, 'eval_samples_per_second': 280.476, 'eval_steps_per_second': 70.375, 'epoch': 0.95}


 98%|█████████▊| 400/410 [02:59<00:01,  5.82it/s]

{'loss': 1.3191, 'grad_norm': 3.5276119709014893, 'learning_rate': 2.4390243902439027e-06, 'epoch': 0.97}



 98%|█████████▊| 400/410 [03:02<00:01,  5.82it/s]

{'eval_loss': 1.3307002782821655, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 2.8692, 'eval_samples_per_second': 286.138, 'eval_steps_per_second': 71.796, 'epoch': 0.97}


100%|██████████| 410/410 [03:03<00:00,  4.05it/s]

{'loss': 1.3692, 'grad_norm': 8.218183517456055, 'learning_rate': 0.0, 'epoch': 1.0}



100%|██████████| 410/410 [03:05<00:00,  4.05it/s] 

{'eval_loss': 1.3304896354675293, 'eval_accuracy': 0.3325213154689403, 'eval_runtime': 1.6878, 'eval_samples_per_second': 486.437, 'eval_steps_per_second': 122.054, 'epoch': 1.0}


100%|██████████| 410/410 [03:06<00:00,  2.20it/s]

{'train_runtime': 186.1365, 'train_samples_per_second': 17.638, 'train_steps_per_second': 2.203, 'train_loss': 1.3807036981350038, 'epoch': 1.0}





TrainOutput(global_step=410, training_loss=1.3807036981350038, metrics={'train_runtime': 186.1365, 'train_samples_per_second': 17.638, 'train_steps_per_second': 2.203, 'total_flos': 0.0, 'train_loss': 1.3807036981350038, 'epoch': 0.9987819732034104})

In [48]:
trainer.evaluate(eval_dataset=eval_dataset)

100%|██████████| 206/206 [00:01<00:00, 126.30it/s]


{'eval_loss': 1.330702543258667,
 'eval_accuracy': 0.3325213154689403,
 'eval_runtime': 1.6414,
 'eval_samples_per_second': 500.194,
 'eval_steps_per_second': 125.505,
 'epoch': 0.9987819732034104}