From 514c521efee7e71efc64c67ede5db099cd364bbe Mon Sep 17 00:00:00 2001 From: SeasonMay <1447833641@qq.com> Date: Sun, 4 May 2025 15:48:11 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E5=BC=80=E6=BA=90=E5=AE=9E=E4=B9=A0?= =?UTF-8?q?=E3=80=91Barthez=E6=A8=A1=E5=9E=8B=E5=BE=AE=E8=B0=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llm/finetune/barthez/barthez_finetune.md | 25 +++ llm/finetune/barthez/barthez_mindspore.py | 254 +++++++++++++++++++++ llm/finetune/barthez/barthez_pytorch.py | 259 ++++++++++++++++++++++ 3 files changed, 538 insertions(+) create mode 100644 llm/finetune/barthez/barthez_finetune.md create mode 100644 llm/finetune/barthez/barthez_mindspore.py create mode 100644 llm/finetune/barthez/barthez_pytorch.py diff --git a/llm/finetune/barthez/barthez_finetune.md b/llm/finetune/barthez/barthez_finetune.md new file mode 100644 index 000000000..ed4be3e01 --- /dev/null +++ b/llm/finetune/barthez/barthez_finetune.md @@ -0,0 +1,25 @@ +# finetune barthez + +## dataset + +Allocine 法语情感分类数据集 + +## mindnlp+ascend + +| Epoch | train_loss | eval_loss | f1_score | +| :---: | ---------- | --------- | -------- | +| 1 | 0.2217 | 0.1630 | 0.9432 | +| 2 | 0.1782 | 0.1563 | 0.9453 | +| 3 | 0.1661 | 0.1582 | 0.9469 | +| 4 | 0.1582 | 0.1542 | 0.9503 | +| 5 | 0.1522 | 0.1499 | 0.9509 | + +## pytorch+cuda + +| Epoch | train_loss | eval_loss | f1_score | +| :---: | ---------- | --------- | -------- | +| 1 | 0.2468 | 0.1690 | 0.9427 | +| 2 | 0.1833 | 0.1586 | 0.9450 | +| 3 | 0.1685 | 0.1611 | 0.9443 | +| 4 | 0.1605 | 0.1559 | 0.9471 | +| 5 | 0.1543 | 0.1522 | 0.9499 | \ No newline at end of file diff --git a/llm/finetune/barthez/barthez_mindspore.py b/llm/finetune/barthez/barthez_mindspore.py new file mode 100644 index 000000000..b1d8bd98e --- /dev/null +++ b/llm/finetune/barthez/barthez_mindspore.py @@ -0,0 +1,254 @@ +import os + +os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' +import random +import numpy as np +import mindspore as ms +from mindspore import nn, ops, Tensor, context +from mindspore.dataset import GeneratorDataset +import pandas as pd +from mindnlp.transformers import AutoTokenizer, AutoModelForSequenceClassification +from mindnlp.engine import Trainer, TrainingArguments +from datasets import load_dataset +from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support + + +# Data preprocessing function +def preprocess_function(examples, tokenizer, max_length=128): + # Tokenize input text + # Important: For BARThez, we need to ensure inputs are properly formatted + texts = examples["review"] + if isinstance(texts, str): + texts = [texts] + + inputs = tokenizer( + texts, + max_length=max_length, + truncation=True, + padding="max_length", + return_tensors="np" # Return numpy arrays instead of lists + ) + + # Add labels to inputs dictionary + if isinstance(examples["label"], (int, float)): + inputs["labels"] = np.array([examples["label"]], dtype=np.int32) + else: + inputs["labels"] = np.array(examples["label"], dtype=np.int32) + + return inputs + + +# Function to compute evaluation metrics +def compute_metrics(eval_preds): + predictions, labels = eval_preds + + # Process predictions + if isinstance(predictions, tuple): + # Use the first element (usually logits) + predictions = predictions[0] + + try: + pred_array = np.array(predictions) + print(f"Prediction shape: {pred_array.shape}") + + if len(pred_array.shape) > 2: + # For sequence classification, we typically only care about the logits of the first token + predictions = pred_array[:, 0, :] + print(f"Reshaped shape: {predictions.shape}") + + # Get indices of maximum probability class + preds = np.argmax(predictions, axis=-1) + print(f"Labels shape: {np.array(labels).shape}") + print(f"Predicted classes: {preds[:10]}") # Print first 10 predictions + + except Exception as e: + print(f"Error processing predictions: {e}") + # If conversion fails, use a zero array as fallback + preds = np.zeros_like(labels) + + # Calculate metrics + precision, recall, f1, _ = precision_recall_fscore_support( + labels, preds, average='binary' + ) + accuracy = accuracy_score(labels, preds) + + return { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1 + } + + +# Load and process the Allocine dataset +def load_allocine_dataset(sample_ratio=0.1): + """ + Load the Allocine dataset using Hugging Face's datasets library. + The dataset contains "review" text and "label" labels (0 for negative, 1 for positive) + + Parameters: + sample_ratio: Proportion of data to use, range (0,1] + """ + # Load Allocine dataset + dataset = load_dataset("allocine") + + # Take a subset of the dataset (10%) + if sample_ratio < 1.0: + train_subset = dataset["train"].shuffle(seed=42).select(range(int(len(dataset["train"]) * sample_ratio))) + test_subset = dataset["test"].shuffle(seed=42).select(range(int(len(dataset["test"]) * sample_ratio))) + + dataset = { + "train": train_subset, + "test": test_subset + } + + return dataset + + +# Create MindSpore dataset +def create_mindspore_dataset(dataset, tokenizer, batch_size=8): + """ + Create a MindSpore dataset from a Hugging Face dataset + + Parameters: + dataset: Hugging Face dataset + tokenizer: Tokenizer for preprocessing + batch_size: Batch size for training + """ + # Process the entire dataset at once to get all features + features = [] + + for i in range(0, len(dataset), 100): # Process in chunks to avoid memory issues + batch = dataset[i:min(i + 100, len(dataset))] + texts = batch["review"] + labels = batch["label"] + + # Tokenize the texts + encodings = tokenizer( + texts, + max_length=128, + padding="max_length", + truncation=True, + return_tensors="np" + ) + + # Add each example to the features list + for j in range(len(texts)): + features.append({ + "input_ids": encodings["input_ids"][j], + "attention_mask": encodings["attention_mask"][j], + "labels": labels[j] + }) + + # Create a generator function + def generator(): + for item in features: + yield ( + Tensor(item["input_ids"], dtype=ms.int32), + Tensor(item["attention_mask"], dtype=ms.int32), + Tensor(item["labels"], dtype=ms.int32) + ) + + # Create and return the MindSpore dataset + return GeneratorDataset( + generator, + column_names=["input_ids", "attention_mask", "labels"] + ).batch(batch_size) + + +# Main program +def main(): + # Load model and tokenizer + print("Loading BARThez model and tokenizer...") + model_name = "moussaKam/barthez" + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForSequenceClassification.from_pretrained("moussaKam/barthez-sentiment-classification", + num_labels=2) + + # Load data, use only 10% for quick experimentation + print("Loading Allocine dataset (10%)...") + dataset = load_allocine_dataset(sample_ratio=0.5) + + # Allocine dataset is already split into train and test sets + train_dataset_raw = dataset["train"] + test_dataset_raw = dataset["test"] + + print(f"Number of training samples: {len(train_dataset_raw)}") + print(f"Number of test samples: {len(test_dataset_raw)}") + + # Data preprocessing and creating MindSpore datasets + print("Preprocessing data and creating MindSpore datasets...") + batch_size = 16 # Reduce batch size to decrease memory usage + + # Create MindSpore datasets directly from Hugging Face datasets + train_dataset = create_mindspore_dataset(train_dataset_raw, tokenizer, batch_size=batch_size) + val_dataset = create_mindspore_dataset(test_dataset_raw, tokenizer, batch_size=batch_size) + + # Define training parameters + training_args = TrainingArguments( + output_dir="./results_barthez_classification", + evaluation_strategy="epoch", + learning_rate=2e-6, + per_device_train_batch_size=batch_size, + per_device_eval_batch_size=batch_size, + num_train_epochs=10, + weight_decay=0.01, + save_strategy="epoch", + save_total_limit=2, + logging_dir="./logs", + logging_strategy="epoch", + logging_steps=10, + metric_for_best_model="accuracy", + greater_is_better=True, + load_best_model_at_end=True, + ) + + # Initialize trainer + print("Initializing trainer...") + trainer = Trainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + tokenizer=tokenizer, + compute_metrics=compute_metrics + ) + + # Start training + print("Starting training...") + try: + trainer.train() + + # Save model + output_dir = './barthez_allocine_mindspore_model/' + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + print(f"Saving model to {output_dir}") + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + # Model evaluation + print("Performing final model evaluation...") + eval_results = trainer.evaluate() + print(f"Final evaluation results: {eval_results}") + + # Save training statistics (using trainer's state) + if hasattr(trainer, 'state') and hasattr(trainer.state, 'log_history'): + stats_df = pd.DataFrame(trainer.state.log_history) + stats_df.to_csv(os.path.join(output_dir, 'training_stats.csv'), index=False) + print("Training statistics saved") + + print("\nTraining completed!") + + + except Exception as e: + print(f"Error during training: {e}") + # Try to get detailed error information + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + ms.set_context(device_target="Ascend", device_id=2) + main() \ No newline at end of file diff --git a/llm/finetune/barthez/barthez_pytorch.py b/llm/finetune/barthez/barthez_pytorch.py new file mode 100644 index 000000000..2d6c8e77a --- /dev/null +++ b/llm/finetune/barthez/barthez_pytorch.py @@ -0,0 +1,259 @@ +import os + +os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' +import random +import numpy as np +import torch +from torch.utils.data import Dataset, DataLoader +import pandas as pd +from torch.optim import AdamW # 从torch.optim导入AdamW +from transformers import get_linear_schedule_with_warmup +from transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification +) +from datasets import load_dataset +from sklearn.metrics import accuracy_score, f1_score +from tqdm.auto import tqdm + +# 检查CUDA是否可用 +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"使用设备: {device}") + +# 加载BARThez模型和分词器 +model_name = "moussaKam/barthez" +tokenizer = AutoTokenizer.from_pretrained(model_name) +model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) +model.to(device) + + +# 数据集加载和处理 +class AllocineDataset(Dataset): + def __init__(self, texts, labels, tokenizer, max_length=128): + self.texts = texts + self.labels = labels + self.tokenizer = tokenizer + self.max_length = max_length + + def __len__(self): + return len(self.texts) + + def __getitem__(self, idx): + text = str(self.texts[idx]) + label = self.labels[idx] + + encoding = self.tokenizer( + text, + add_special_tokens=True, + max_length=self.max_length, + padding='max_length', + truncation=True, + return_attention_mask=True, + return_tensors='pt' + ) + + return { + 'input_ids': encoding['input_ids'].flatten(), + 'attention_mask': encoding['attention_mask'].flatten(), + 'labels': torch.tensor(label, dtype=torch.long) + } + + +# 加载并处理Allocine数据集 +def load_allocine_dataset(sample_ratio=0.5): + """ + 使用Hugging Face的datasets库加载Allocine数据集。 + 数据集包含"review"文本和"label"标签(0表示负面,1表示正面) + + 参数: + sample_ratio: 要使用的数据比例,范围(0,1] + """ + # 加载Allocine数据集 + dataset = load_dataset("allocine") + + # 取数据集的子集(10%) + if sample_ratio < 1.0: + train_subset = dataset["train"].shuffle(seed=42).select(range(int(len(dataset["train"]) * sample_ratio))) + test_subset = dataset["test"].shuffle(seed=42).select(range(int(len(dataset["test"]) * sample_ratio))) + + dataset = { + "train": train_subset, + "test": test_subset + } + + return dataset + + +# 定义训练函数 +def train_model(model, train_dataloader, val_dataloader, epochs=3): + # 优化器设置 + optimizer = AdamW(model.parameters(), lr=2e-6, eps=1e-8) + + # 学习率调度器 + total_steps = len(train_dataloader) * epochs + scheduler = get_linear_schedule_with_warmup( + optimizer, + num_warmup_steps=0, + num_training_steps=total_steps + ) + + # 记录训练过程 + training_stats = [] + + # 开始训练循环 + for epoch in range(epochs): + print(f"\n======== Epoch {epoch + 1} / {epochs} ========") + + # 训练 + model.train() + total_train_loss = 0 + + train_progress_bar = tqdm(train_dataloader, desc="Training", leave=True) + for batch in train_progress_bar: + # 清除之前计算的梯度 + optimizer.zero_grad() + + # 将数据移动到GPU + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + labels = batch['labels'].to(device) + + # 前向传播 + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels + ) + + loss = outputs.loss + total_train_loss += loss.item() + + # 反向传播 + loss.backward() + + # 梯度裁剪,防止梯度爆炸 + torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) + + # 更新参数 + optimizer.step() + scheduler.step() + + # 计算平均训练损失 + avg_train_loss = total_train_loss / len(train_dataloader) + print(f"Average training loss: {avg_train_loss:.4f}") + + # 评估 + model.eval() + total_eval_loss = 0 + predictions = [] + true_labels = [] + + for batch in tqdm(val_dataloader, desc="Validation", leave=True): + with torch.no_grad(): + input_ids = batch['input_ids'].to(device) + attention_mask = batch['attention_mask'].to(device) + labels = batch['labels'].to(device) + + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + labels=labels + ) + + loss = outputs.loss + total_eval_loss += loss.item() + + # 获取预测结果 + logits = outputs.logits + preds = torch.argmax(logits, dim=1).cpu().numpy() + + predictions.extend(preds) + true_labels.extend(labels.cpu().numpy()) + + # 计算评估指标 + avg_val_loss = total_eval_loss / len(val_dataloader) + accuracy = accuracy_score(true_labels, predictions) + f1 = f1_score(true_labels, predictions, average='weighted') + + print(f"Validation loss: {avg_val_loss:.4f}") + print(f"Accuracy: {accuracy:.4f}") + print(f"F1 Score: {f1:.4f}") + + # 保存训练统计信息 + training_stats.append({ + 'epoch': epoch + 1, + 'Training Loss': avg_train_loss, + 'Validation Loss': avg_val_loss, + 'Accuracy': accuracy, + 'F1 Score': f1 + }) + + return training_stats + + +# 主程序 +def main(): + # 加载数据,只使用50%的数据 + print("正在加载Allocine数据集(50%)...") + dataset = load_allocine_dataset(sample_ratio=0.5) + + # Allocine数据集已经分割为训练集和测试集 + train_dataset_raw = dataset["train"] + test_dataset_raw = dataset["test"] + + print(f"训练样本数: {len(train_dataset_raw)}") + print(f"测试样本数: {len(test_dataset_raw)}") + + # 创建自定义数据集实例 + train_dataset = AllocineDataset( + train_dataset_raw["review"], + train_dataset_raw["label"], + tokenizer + ) + + val_dataset = AllocineDataset( + test_dataset_raw["review"], + test_dataset_raw["label"], + tokenizer + ) + + # 创建数据加载器 + batch_size = 16 # 减小批量大小以减少内存使用 + train_dataloader = DataLoader( + train_dataset, + batch_size=batch_size, + shuffle=True, + num_workers=4 + ) + + val_dataloader = DataLoader( + val_dataset, + batch_size=batch_size, + shuffle=False, + num_workers=4 + ) + + # 训练模型 + print("开始训练...") + training_stats = train_model( + model, + train_dataloader, + val_dataloader, + epochs=10 + ) + + # 保存模型 + output_dir = './barthez_allocine_model/' + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + print(f"保存模型到 {output_dir}") + model.save_pretrained(output_dir) + tokenizer.save_pretrained(output_dir) + + print("\n训练完成!") + + + +if __name__ == "__main__": + main() \ No newline at end of file