In [2]:
import os
import torch
from tqdm.auto import tqdm
import warnings
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
import pandas as pd

# 忽略特定的警告
warnings.filterwarnings("ignore", message="Be aware, overflowing tokens are not returned*")

# 加载数据
new_file_path = './dataset/Mozilla_bug_raw_processed.csv'
df = pd.read_csv(new_file_path, encoding='latin-1')
df = df[['bug_id', 'summary', 'who','description']]
label_dict = {label: idx for idx, label in enumerate(df['who'].unique())}
df['label'] = df['who'].replace(label_dict)

  df['label'] = df['who'].replace(label_dict)


In [3]:


# 合并bug_id和summary作为模型的输入
df['text_input'] = df['description']  # 使用空格作为分隔符
X_train, X_val, y_train, y_val = train_test_split(df.index.values, df.label.values, test_size=0.15, random_state=42, stratify=df.label.values)
df['data_type'] = ['not_set']*df.shape[0]
df.loc[X_train, 'data_type'] = 'train'
df.loc[X_val, 'data_type'] = 'val'
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
# 对训练和验证数据的合并文本进行编码
encoded_data_train = tokenizer.batch_encode_plus(
    df[df.data_type=='train'].text_input.values,  # 使用合并后的文本
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=512, 
    return_tensors='pt'
)
encoded_data_val = tokenizer.batch_encode_plus(
    df[df.data_type=='val'].text_input.values,  # 使用合并后的文本
    add_special_tokens=True, 
    return_attention_mask=True, 
    pad_to_max_length=True, 
    max_length=512, 
    return_tensors='pt'
)
# 准备Tensor数据
input_ids_train = encoded_data_train['input_ids']
attention_masks_train = encoded_data_train['attention_mask']
labels_train = torch.tensor(df[df.data_type=='train'].label.values)
input_ids_val = encoded_data_val['input_ids']
attention_masks_val = encoded_data_val['attention_mask']
labels_val = torch.tensor(df[df.data_type=='val'].label.values)
dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)
# 定义DataLoader
batch_size = 4
train_loader = DataLoader(dataset_train, sampler=RandomSampler(dataset_train), batch_size=batch_size)
val_loader = DataLoader(dataset_val, sampler=SequentialSampler(dataset_val), batch_size=32)


Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [4]:
# 初始化模型
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=len(label_dict), output_attentions=False, output_hidden_states=False)
optimizer = AdamW(model.parameters(), lr=1e-5, eps=1e-8)
# 加载模型
# 计算层数
num_transformer_layers = len(model.bert.encoder.layer)
print(f'The BERT model has {num_transformer_layers} transformer layers.')
print(model)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


The BERT model has 12 transformer layers.
BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
             

In [5]:
# import pandas as pd
# from transformers import BertTokenizer

# # 加载数据集
# df = pd.read_csv('filtered_bug_raw_10_to_13.csv', encoding='latin-1')
# df['text_input'] = df['bug_id'].astype(str) + " " + df['summary']+df['description']

# tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

# # 对所有合并后的文本进行编码
# encoded_inputs = tokenizer.batch_encode_plus(df['text_input'].tolist(), add_special_tokens=True, truncation=False, padding=False)

# # 计算所有编码后的长度
# lengths = [len(input_ids) for input_ids in encoded_inputs['input_ids']]

# # 计算平均长度、中位数、最大和最小长度
# average_length = sum(lengths) / len(lengths)
# median_length = sorted(lengths)[len(lengths) // 2]
# max_length = max(lengths)
# min_length = min(lengths)

# print(f"Average length: {average_length}")
# print(f"Median length: {median_length}")
# print(f"Max length: {max_length}")
# print(f"Min length: {min_length}")


In [6]:
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

checkpoint_path = 'model_checkpoint_bert_morethan10_jump.pth'

# 检查是否有可用的检查点
if os.path.isfile(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f'Resuming training from epoch {start_epoch}')
else:
    start_epoch = 0
    print('Starting training from scratch')



Starting training from scratch


In [7]:
import pymysql
from datetime import datetime

# 数据库连接信息
host = '38.147.173.234'
user = 'root'
password = '123456'
db = 'training_statistics_db'

# 模型名称，根据实际情况手动设置
model_name = 'bert-base-cased'
# 学习率和可选特性，根据实际情况手动设置
learning_rate = 1e-5  # 示例学习率
optional_feature = 'descrition'  # 示例可选特性
dataset = new_file_path
num_epochs = 15
for epoch in range(start_epoch, num_epochs):
    model.train()
    start_time = datetime.now()

        
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        labels = batch[2].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        progress_bar.set_postfix(loss=loss.item())
    torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path)
    model.eval()
    correct_topk = {k: 0 for k in range(1, 11)}
    total = 0
    val_progress_bar = tqdm(val_loader, desc="Validating")
    
    for batch in val_progress_bar:
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        labels = batch[2].to(device)

        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs[0]
        total += labels.size(0)
        
        # 计算top1到top10的正确率
        _, predicted_topk = torch.topk(logits, k=10, dim=1)
        labels_expanded = labels.unsqueeze(1)
        for k in range(1, 11):
            correct_topk[k] += (predicted_topk[:, :k] == labels_expanded).any(dim=1).sum().item()
                
    # 打印每个topK的准确率
    top10accuracy = []  # 初始化存储Top1到Top10准确率的数组

    for k in range(1, 11):
        accuracy = 100 * correct_topk[k] / total
        top10accuracy.append(accuracy)  # 将计算出的准确率添加到数组中
        print(f'Accuracy after epoch {epoch + 1}: Top{k}: {accuracy:.2f}%')
        print(top10accuracy)
    import pandas as pd
    import os
        # ...训练过程...
    end_time = datetime.now()
    duration = (end_time - start_time).total_seconds()/60.0
    # 定义数据字典，用于创建DataFrame
    data = {
            'epoch': [epoch],
            'start_time': [start_time],
            'end_time': [end_time],
            'duration': [duration],
            'user_id': [1],
            'model': [model_name],
            'top1_accuracy': [top10accuracy[0]],
            'top2_accuracy': [top10accuracy[1]],
            'top3_accuracy': [top10accuracy[2]],
            'top4_accuracy': [top10accuracy[3]],
            'top5_accuracy': [top10accuracy[4]],
            'top6_accuracy': [top10accuracy[5]],
            'top7_accuracy': [top10accuracy[6]],
            'top8_accuracy': [top10accuracy[7]],
            'top9_accuracy': [top10accuracy[8]],
            'top10_accuracy': [top10accuracy[9]],
            'optional_feature': [optional_feature],
            'learning_rate': [learning_rate],
            'dataset': [dataset]
    }

        # 创建DataFrame
    df = pd.DataFrame(data)

        # 检查train.csv文件是否存在来决定是否添加表头
    file_exists = os.path.isfile('train.csv')

        # 如果文件存在，不写入表头，模式为追加；如果文件不存在，写入表头，模式为写入
    df.to_csv('train.csv', mode='a', header=not file_exists, index=False)

    print(f'Epoch {epoch + 1} training data inserted into train.csv.')

Epoch 1:   0%|          | 0/18187 [00:00<?, ?it/s]

Validating:   0%|          | 0/402 [00:00<?, ?it/s]

Accuracy after epoch 1: Top1: 63.09%
[63.086150490730645]
Accuracy after epoch 1: Top2: 69.68%
[63.086150490730645, 69.67596198784858]
Accuracy after epoch 1: Top3: 72.71%
[63.086150490730645, 69.67596198784858, 72.71381835176818]
Accuracy after epoch 1: Top4: 74.65%
[63.086150490730645, 69.67596198784858, 72.71381835176818, 74.65337279950148]
Accuracy after epoch 1: Top5: 76.02%
[63.086150490730645, 69.67596198784858, 72.71381835176818, 74.65337279950148, 76.01651347561925]
Accuracy after epoch 1: Top6: 77.04%
[63.086150490730645, 69.67596198784858, 72.71381835176818, 74.65337279950148, 76.01651347561925, 77.03692163888456]
Accuracy after epoch 1: Top7: 77.97%
[63.086150490730645, 69.67596198784858, 72.71381835176818, 74.65337279950148, 76.01651347561925, 77.03692163888456, 77.97164667393675]
Accuracy after epoch 1: Top8: 78.74%
[63.086150490730645, 69.67596198784858, 72.71381835176818, 74.65337279950148, 76.01651347561925, 77.03692163888456, 77.97164667393675, 78.74279482785481]
Accu

Epoch 2:   0%|          | 0/18187 [00:00<?, ?it/s]

Validating:   0%|          | 0/402 [00:00<?, ?it/s]

Accuracy after epoch 2: Top1: 70.62%
[70.61847639819287]
Accuracy after epoch 2: Top2: 76.64%
[70.61847639819287, 76.63966349898737]
Accuracy after epoch 2: Top3: 79.38%
[70.61847639819287, 76.63966349898737, 79.38152360180713]
Accuracy after epoch 2: Top4: 81.03%
[70.61847639819287, 76.63966349898737, 79.38152360180713, 81.02508178844057]
Accuracy after epoch 2: Top5: 82.22%
[70.61847639819287, 76.63966349898737, 79.38152360180713, 81.02508178844057, 82.21685620813211]
Accuracy after epoch 2: Top6: 83.12%
[70.61847639819287, 76.63966349898737, 79.38152360180713, 81.02508178844057, 82.21685620813211, 83.12042374201589]
Accuracy after epoch 2: Top7: 83.94%
[70.61847639819287, 76.63966349898737, 79.38152360180713, 81.02508178844057, 82.21685620813211, 83.12042374201589, 83.93830814768656]
Accuracy after epoch 2: Top8: 84.56%
[70.61847639819287, 76.63966349898737, 79.38152360180713, 81.02508178844057, 82.21685620813211, 83.12042374201589, 83.93830814768656, 84.56145817105468]
Accuracy aft

Epoch 3:   0%|          | 0/18187 [00:00<?, ?it/s]

Validating:   0%|          | 0/402 [00:00<?, ?it/s]

Accuracy after epoch 3: Top1: 72.70%
[72.69823960118399]
Accuracy after epoch 3: Top2: 78.55%
[72.69823960118399, 78.54806044555227]
Accuracy after epoch 3: Top3: 81.37%
[72.69823960118399, 78.54806044555227, 81.36781430129304]
Accuracy after epoch 3: Top4: 82.93%
[72.69823960118399, 78.54806044555227, 81.36781430129304, 82.93347873500545]
Accuracy after epoch 3: Top5: 84.11%
[72.69823960118399, 78.54806044555227, 81.36781430129304, 82.93347873500545, 84.1096744041128]
Accuracy after epoch 3: Top6: 84.96%
[72.69823960118399, 78.54806044555227, 81.36781430129304, 82.93347873500545, 84.1096744041128, 84.95871631095186]
Accuracy after epoch 3: Top7: 85.80%
[72.69823960118399, 78.54806044555227, 81.36781430129304, 82.93347873500545, 84.1096744041128, 84.95871631095186, 85.79996884249883]
Accuracy after epoch 3: Top8: 86.44%
[72.69823960118399, 78.54806044555227, 81.36781430129304, 82.93347873500545, 84.1096744041128, 84.95871631095186, 85.79996884249883, 86.43869761645117]
Accuracy after e

Epoch 4:   0%|          | 0/18187 [00:00<?, ?it/s]

Validating:   0%|          | 0/402 [00:00<?, ?it/s]

Accuracy after epoch 4: Top1: 74.33%
[74.33400841252532]
Accuracy after epoch 4: Top2: 80.27%
[74.33400841252532, 80.26951238510671]
Accuracy after epoch 4: Top3: 82.74%
[74.33400841252532, 80.26951238510671, 82.73874435270291]
Accuracy after epoch 4: Top4: 84.32%
[74.33400841252532, 80.26951238510671, 82.73874435270291, 84.31998753699953]
Accuracy after epoch 4: Top5: 85.39%
[74.33400841252532, 80.26951238510671, 82.73874435270291, 84.31998753699953, 85.39492132730955]
Accuracy after epoch 4: Top6: 86.32%
[74.33400841252532, 80.26951238510671, 82.73874435270291, 84.31998753699953, 85.39492132730955, 86.32185698706964]
Accuracy after epoch 4: Top7: 86.95%
[74.33400841252532, 80.26951238510671, 82.73874435270291, 84.31998753699953, 85.39492132730955, 86.32185698706964, 86.94500701043776]
Accuracy after epoch 4: Top8: 87.45%
[74.33400841252532, 80.26951238510671, 82.73874435270291, 84.31998753699953, 85.39492132730955, 86.32185698706964, 86.94500701043776, 87.45131640442436]
Accuracy aft

Epoch 5:   0%|          | 0/18187 [00:00<?, ?it/s]

Validating:   0%|          | 0/402 [00:00<?, ?it/s]

Accuracy after epoch 5: Top1: 75.04%
[75.03505218881446]
Accuracy after epoch 5: Top2: 80.90%
[75.03505218881446, 80.90045178376694]
Accuracy after epoch 5: Top3: 83.51%
[75.03505218881446, 80.90045178376694, 83.50989250662097]
Accuracy after epoch 5: Top4: 85.04%
[75.03505218881446, 80.90045178376694, 83.50989250662097, 85.04439943916498]
Accuracy after epoch 5: Top5: 86.18%
[75.03505218881446, 80.90045178376694, 83.50989250662097, 85.04439943916498, 86.1816482318118]
Accuracy after epoch 5: Top6: 87.13%
[75.03505218881446, 80.90045178376694, 83.50989250662097, 85.04439943916498, 86.1816482318118, 87.1319520174482]
Accuracy after epoch 5: Top7: 87.86%
[75.03505218881446, 80.90045178376694, 83.50989250662097, 85.04439943916498, 86.1816482318118, 87.1319520174482, 87.85636391961364]
Accuracy after epoch 5: Top8: 88.35%
[75.03505218881446, 80.90045178376694, 83.50989250662097, 85.04439943916498, 86.1816482318118, 87.1319520174482, 87.85636391961364, 88.35488393830815]
Accuracy after epoc

Epoch 6:   0%|          | 0/18187 [00:00<?, ?it/s]

Validating:   0%|          | 0/402 [00:00<?, ?it/s]

Accuracy after epoch 6: Top1: 74.76%
[74.7624240535909]
Accuracy after epoch 6: Top2: 80.48%
[74.7624240535909, 80.47982551799346]
Accuracy after epoch 6: Top3: 82.99%
[74.7624240535909, 80.47982551799346, 82.98800436205016]
Accuracy after epoch 6: Top4: 84.58%
[74.7624240535909, 80.47982551799346, 82.98800436205016, 84.58482629693098]
Accuracy after epoch 6: Top5: 85.72%
[74.7624240535909, 80.47982551799346, 82.98800436205016, 84.58482629693098, 85.72207508957781]
Accuracy after epoch 6: Top6: 86.62%
[74.7624240535909, 80.47982551799346, 82.98800436205016, 84.58482629693098, 85.72207508957781, 86.6178532481695]
Accuracy after epoch 6: Top7: 87.25%
[74.7624240535909, 80.47982551799346, 82.98800436205016, 84.58482629693098, 85.72207508957781, 86.6178532481695, 87.24879264682973]
Accuracy after epoch 6: Top8: 87.88%
[74.7624240535909, 80.47982551799346, 82.98800436205016, 84.58482629693098, 85.72207508957781, 86.6178532481695, 87.24879264682973, 87.87973204548994]
Accuracy after epoch 6:

Epoch 7:   0%|          | 0/18187 [00:00<?, ?it/s]

Validating:   0%|          | 0/402 [00:00<?, ?it/s]

Accuracy after epoch 7: Top1: 74.85%
[74.84810718180402]
Accuracy after epoch 7: Top2: 80.61%
[74.84810718180402, 80.61224489795919]
Accuracy after epoch 7: Top3: 83.16%
[74.84810718180402, 80.61224489795919, 83.1593706184764]
Accuracy after epoch 7: Top4: 84.68%
[74.84810718180402, 80.61224489795919, 83.1593706184764, 84.67829880043621]
Accuracy after epoch 7: Top5: 85.68%
[74.84810718180402, 80.61224489795919, 83.1593706184764, 84.67829880043621, 85.68312821311731]
Accuracy after epoch 7: Top6: 86.61%
[74.84810718180402, 80.61224489795919, 83.1593706184764, 84.67829880043621, 85.68312821311731, 86.61006387287739]
Accuracy after epoch 7: Top7: 87.34%
[74.84810718180402, 80.61224489795919, 83.1593706184764, 84.67829880043621, 85.68312821311731, 86.61006387287739, 87.34226515033494]
Accuracy after epoch 7: Top8: 87.97%
[74.84810718180402, 80.61224489795919, 83.1593706184764, 84.67829880043621, 85.68312821311731, 86.61006387287739, 87.34226515033494, 87.97320454899517]
Accuracy after epo

Epoch 8:   0%|          | 0/18187 [00:00<?, ?it/s]

Validating:   0%|          | 0/402 [00:00<?, ?it/s]

Accuracy after epoch 8: Top1: 74.59%
[74.59105779716467]
Accuracy after epoch 8: Top2: 80.46%
[74.59105779716467, 80.45645739211716]
Accuracy after epoch 8: Top3: 82.64%
[74.59105779716467, 80.45645739211716, 82.63748247390559]
Accuracy after epoch 8: Top4: 84.31%
[74.59105779716467, 80.45645739211716, 82.63748247390559, 84.31219816170743]
Accuracy after epoch 8: Top5: 85.36%
[74.59105779716467, 80.45645739211716, 82.63748247390559, 84.31219816170743, 85.35597445084905]
Accuracy after epoch 8: Top6: 86.13%
[74.59105779716467, 80.45645739211716, 82.63748247390559, 84.31219816170743, 85.35597445084905, 86.1271226047671]
Accuracy after epoch 8: Top7: 86.95%
[74.59105779716467, 80.45645739211716, 82.63748247390559, 84.31219816170743, 85.35597445084905, 86.1271226047671, 86.94500701043776]
Accuracy after epoch 8: Top8: 87.49%
[74.59105779716467, 80.45645739211716, 82.63748247390559, 84.31219816170743, 85.35597445084905, 86.1271226047671, 86.94500701043776, 87.49026328088488]
Accuracy after 

Epoch 9:   0%|          | 0/18187 [00:00<?, ?it/s]

Validating:   0%|          | 0/402 [00:00<?, ?it/s]

Accuracy after epoch 9: Top1: 74.82%
[74.82473905592771]
Accuracy after epoch 9: Top2: 80.49%
[74.82473905592771, 80.48761489328555]
Accuracy after epoch 9: Top3: 82.73%
[74.82473905592771, 80.48761489328555, 82.73095497741082]
Accuracy after epoch 9: Top4: 84.34%
[74.82473905592771, 80.48761489328555, 82.73095497741082, 84.33556628758373]
Accuracy after epoch 9: Top5: 85.39%
[74.82473905592771, 80.48761489328555, 82.73095497741082, 84.33556628758373, 85.39492132730955]
Accuracy after epoch 9: Top6: 86.22%
[74.82473905592771, 80.48761489328555, 82.73095497741082, 84.33556628758373, 85.39492132730955, 86.22059510827232]
Accuracy after epoch 9: Top7: 86.86%
[74.82473905592771, 80.48761489328555, 82.73095497741082, 84.33556628758373, 85.39492132730955, 86.22059510827232, 86.85932388222464]
Accuracy after epoch 9: Top8: 87.45%
[74.82473905592771, 80.48761489328555, 82.73095497741082, 84.33556628758373, 85.39492132730955, 86.22059510827232, 86.85932388222464, 87.45131640442436]
Accuracy aft

Epoch 10:   0%|          | 0/18187 [00:00<?, ?it/s]

KeyboardInterrupt: 