In [2]:
import torch
from transformers import AutoTokenizer, AutoModel
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch import nn
from transformers import AdamW
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm

In [3]:
# 检查设备
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
batch_size = 1
lr = 5e-5
max_length = 128
epochs = 5

Using device: mps


In [4]:
# 模型和存储路径
model_name = "facebook/esm2_t33_650M_UR50D"
model_path = "./models"

# 加载模型和 tokenizer
Tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=model_path)
model = AutoModel.from_pretrained(model_name, cache_dir=model_path)


Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t33_650M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# 读取数据
file_path = "./data/df_combined_filtered_grouped.pkl"
df = pd.read_pickle(file_path)
df_non_null_toxicity = df[df['toxicity'].notna()]
data = df_non_null_toxicity[['sequence', 'toxicity']]
data_dict = {
    "sequence": data['sequence'].tolist(),
    "toxicity": data['toxicity'].tolist(),
}

"""
# 统计toxicity的分布
toxicity_values = data['toxicity'].value_counts()
print(toxicity_values)

output:
toxicity
False    15912
True      3634
Name: count, dtype: int64
"""

"\n# 统计toxicity的分布\ntoxicity_values = data['toxicity'].value_counts()\nprint(toxicity_values)\n\noutput:\ntoxicity\nFalse    15912\nTrue      3634\nName: count, dtype: int64\n"

In [6]:
# 数据集类
class SequenceStabilityDataset(Dataset):
    def __init__(self, sequences, toxicity, tokenizer=Tokenizer, max_length=max_length):
        self.sequences = sequences
        self.toxicity = torch.tensor(toxicity, dtype=torch.float32)
        self.encoded_sequences = tokenizer(
            sequences, return_tensors="pt", padding=True, truncation=True, max_length=max_length
        )

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        return {
            'sequence': {key: val[idx] for key, val in self.encoded_sequences.items()},
            'toxicity': self.toxicity[idx]
        }

In [7]:
# 数据集分割
train_sequences, val_sequences, train_toxicity, val_toxicity = train_test_split(
    data_dict['sequence'],
    data_dict['toxicity'],
    test_size=0.2,
    random_state=42
)

train_dataset = SequenceStabilityDataset(train_sequences, train_toxicity, tokenizer=Tokenizer)
val_dataset = SequenceStabilityDataset(val_sequences, val_toxicity, tokenizer=Tokenizer)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


In [8]:
# 定义模型
class ToxicityPredictor(nn.Module):
    def __init__(self, base_model):
        super(ToxicityPredictor, self).__init__()
        self.base_model = base_model
        self.classifier = nn.Linear(base_model.config.hidden_size, 2)  # 二分类

    def forward(self, input_ids, attention_mask):
        # 获取 base_model 的输出
        outputs = self.base_model(input_ids=input_ids, attention_mask=attention_mask)
        # 取 [CLS] 的隐藏状态
        cls_hidden_state = outputs.last_hidden_state[:, 0, :]
        # 通过分类层
        logits = self.classifier(cls_hidden_state)
        return logits

In [9]:
# 初始化模型
toxicity_model = ToxicityPredictor(model)
toxicity_model.to(device)

# 损失函数和优化器
criterion = nn.MSELoss()
optimizer = AdamW(toxicity_model.parameters(), lr=lr)



In [10]:
# 训练和验证
# for epoch in range(epochs):
#     # 训练阶段
#     toxicity_model.train()
#     train_loss = 0.0
#     train_progress = tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{epochs} (Training)")
#     for batch in train_progress:
#         input_ids = batch['sequence']['input_ids'].to(device)
#         attention_mask = batch['sequence']['attention_mask'].to(device)
#         labels = batch['toxicity'].to(device)
#
#         optimizer.zero_grad()
#         outputs = toxicity_model(input_ids, attention_mask)
#         loss = criterion(outputs.squeeze(), labels)
#         loss.backward()
#         optimizer.step()
#
#         train_loss += loss.item()
#         avg_batch_loss = train_loss / (len(train_progress))
#         train_progress.set_postfix({"Batch Loss": loss.item(), "Avg Loss": avg_batch_loss})
#     avg_train_loss = train_loss / len(train_dataloader)
#
#     # 验证阶段
#     toxicity_model.eval()
#     val_loss = 0.0
#     val_progress = tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{epochs} (Validation)")
#     with torch.no_grad():
#         for batch in val_progress:
#             input_ids = batch['sequence']['input_ids'].to(device)
#             attention_mask = batch['sequence']['attention_mask'].to(device)
#             labels = batch['toxicity'].to(device)
#
#             outputs = toxicity_model(input_ids, attention_mask)
#             loss = criterion(outputs.squeeze(), labels)
#             val_loss += loss.item()
#             avg_val_batch_loss = val_loss / (len(val_progress))
#             val_progress.set_postfix({"Batch Loss": loss.item(), "Avg Loss": avg_val_batch_loss})
#     avg_val_loss = val_loss / len(val_dataloader)
#
#     print(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")
#
# # 保存模型
# save_path = "./toxicity_model.pth"
# torch.save(toxicity_model.state_dict(), save_path)
# print(f"Model saved to {save_path}")

In [11]:
# 评估模型性能
toxicity_model.eval()
predictions = []
true_values = []
count = 0 # 用于测试
with torch.no_grad():
    for batch in tqdm(val_dataloader):
        input_ids = batch['sequence']['input_ids'].to(device)
        attention_mask = batch['sequence']['attention_mask'].to(device)
        labels = batch['toxicity'].to(device)

        outputs = toxicity_model(input_ids, attention_mask)

        # 选取每个样本得分最大的类别
        predicted_labels = torch.argmax(outputs, dim=1)  # 选择最大概率的类别

        # 保存预测结果和真实标签
        predictions.extend(predicted_labels.cpu().numpy())
        true_values.extend(labels.cpu().numpy())

# 计算准确率、精度、召回率、F1分数、AUC-ROC
accuracy = accuracy_score(true_values, predictions)
precision = precision_score(true_values, predictions)
recall = recall_score(true_values, predictions)
f1 = f1_score(true_values, predictions)

# 打印所有评估指标
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")
print(f"F1 Score: {f1:.4f}")

100%|██████████| 3910/3910 [02:54<00:00, 22.41it/s]

Accuracy: 0.8210
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
