In [1]:
import os
import torch 
from tqdm.auto import tqdm
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
from transformers import XLNetTokenizer, XLNetForSequenceClassification, AdamW
# 接下来，您可以使用与BERT相同的训练和验证循环逻辑，但确保所有的模型、数据和优化器都已经切换到XLNet。

In [2]:
# 加载数据
# 加载数据
new_file_path = './dataset2/OpenOffice_total_10_10.csv'
# 指定需要提取的列
columns_to_extract = ['bug_id', 'product', 'abstracts', 'description', 'component', 'severity', 'priority', 'developer',  'status']
# columns_to_extract = [ 'description', 'developer']
df = pd.read_csv(new_file_path, usecols=columns_to_extract, encoding='latin-1')
# 将developer列作为标签
label_dict = {label: idx for idx, label in enumerate(df['developer'].unique())}
print(f' the number of label is {len(label_dict)}')
df['label'] = df['developer'].replace(label_dict).infer_objects()
# 合并bug_id和summary作为模型的输入
df['text_input'] = df['abstracts'].astype(str) + " " + df['description'].astype(str)  # 使用空格作为分隔符
X_train, X_val, y_train, y_val = train_test_split(df.index.values, df.label.values, test_size=0.15, random_state=42, stratify=df.label.values)
df['data_type'] = ['not_set']*df.shape[0]
df.loc[X_train, 'data_type'] = 'train'
df.loc[X_val, 'data_type'] = 'val'

 the number of label is 170


  df['label'] = df['developer'].replace(label_dict).infer_objects()


In [3]:
# new_file_path = 'dataprocessed_description_more_than10.csv'
# df = pd.read_csv(new_file_path, encoding='latin-1')
# df = df[['bug_id', 'summary', 'who', 'description']]
# label_dict = {label: idx for idx, label in enumerate(df['who'].unique())}
# df['label'] = df['who'].replace(label_dict)
# print(f' the number of label is {len(label_dict)}')
# # 合并bug_id和summary作为模型的输入
# df['text_input'] = df['description']  # 使用空格作为分隔符
# X_train, X_val, y_train, y_val = train_test_split(df.index.values, df.label.values, test_size=0.15, random_state=42, stratify=df.label.values)
# df['data_type'] = ['not_set']*df.shape[0]
# df.loc[X_train, 'data_type'] = 'train'
# df.loc[X_val, 'data_type'] = 'val'

In [4]:
# 使用XLNet的分词器
tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')

In [5]:
encoded_data_train = tokenizer.batch_encode_plus(
    df[df.data_type=='train'].text_input.values,
    add_special_tokens=True,
    return_attention_mask=True,
    padding='max_length',  # 更新此处
    truncation=True,  # 添加此行
    max_length=512,
    return_tensors='pt'
)
encoded_data_val = tokenizer.batch_encode_plus(
    df[df.data_type=='val'].text_input.values,
    add_special_tokens=True,
    return_attention_mask=True,
    padding='max_length',  # 更新此处
    truncation=True,  # 添加此行
    max_length=512,
    return_tensors='pt'
)

In [6]:
# 准备Tensor数据
input_ids_train = encoded_data_train['input_ids']
attention_masks_train = encoded_data_train['attention_mask']
labels_train = torch.tensor(df[df.data_type=='train'].label.values)
input_ids_val = encoded_data_val['input_ids']
attention_masks_val = encoded_data_val['attention_mask']
labels_val = torch.tensor(df[df.data_type=='val'].label.values)
dataset_train = TensorDataset(input_ids_train, attention_masks_train, labels_train)
dataset_val = TensorDataset(input_ids_val, attention_masks_val, labels_val)

In [7]:
# 定义DataLoader
batch_size = 2
train_loader = DataLoader(dataset_train, sampler=RandomSampler(dataset_train), batch_size=batch_size)
val_loader = DataLoader(dataset_val, sampler=SequentialSampler(dataset_val), batch_size=32)

In [8]:
# 初始化XLNet模型
model = XLNetForSequenceClassification.from_pretrained("xlnet-base-cased", num_labels=len(label_dict))
optimizer = AdamW(model.parameters(), lr=1e-5, eps=1e-8)
# 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Some weights of the model checkpoint at xlnet-base-cased were not used when initializing XLNetForSequenceClassification: ['lm_loss.weight', 'lm_loss.bias']
- This IS expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLNetForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of XLNetForSequenceClassification were not initialized from the model checkpoint at xlnet-base-cased and are newly initialized: ['logits_proj.bias', 'sequence_summary.summary.weight', 'logits_proj.weight', 'sequence_summary.summary.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions a

XLNetForSequenceClassification(
  (transformer): XLNetModel(
    (word_embedding): Embedding(32000, 768)
    (layer): ModuleList(
      (0-11): 12 x XLNetLayer(
        (rel_attn): XLNetRelativeAttention(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (ff): XLNetFeedForward(
          (layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layer_1): Linear(in_features=768, out_features=3072, bias=True)
          (layer_2): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (activation_function): GELUActivation()
        )
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (sequence_summary): SequenceSummary(
    (summary): Linear(in_features=768, out_features=768, bias=True)
    (activation): Tanh()
    (first_dropout): Identity()
    (last

In [9]:
checkpoint_path = 'model_checkpoint_xlnet_top1-top10_eclipse_dataprocessed4444444OpenOffice_total_10_10.pth'
# 检查是否有可用的检查点
if os.path.isfile(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f'Resuming training from epoch {start_epoch}')
else:
    start_epoch = 0
    print('Starting training from scratch')

Starting training from scratch


In [10]:
import pymysql
from datetime import datetime

# 数据库连接信息
host = '38.147.173.234'
user = 'lijianye'
password = '660013'
db = 'training_statistics_db'
experiment_num = 13
# 模型名称，根据实际情况手动设置
model_name = 'xlnet'
# 学习率和可选特性，根据实际情况手动设置
learning_rate = 1e-5  # 示例学习率
optional_feature = 'abstract+descrition'  # 示例可选特性
dataset = new_file_path
num_epochs = 15
for epoch in range(start_epoch, num_epochs):
    model.train()
    start_time = datetime.now()
        
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        labels = batch[2].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        progress_bar.set_postfix(loss=loss.item())
    torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path)
    model.eval()
    correct_topk = {k: 0 for k in range(1, 11)}
    total = 0
    val_progress_bar = tqdm(val_loader, desc="Validating")
    
    for batch in val_progress_bar:
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        labels = batch[2].to(device)

        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs[0]
        total += labels.size(0)
        
        # 计算top1到top10的正确率
        _, predicted_topk = torch.topk(logits, k=10, dim=1)
        labels_expanded = labels.unsqueeze(1)
        for k in range(1, 11):
            correct_topk[k] += (predicted_topk[:, :k] == labels_expanded).any(dim=1).sum().item()
                
    # 打印每个topK的准确率
    top10accuracy = []  # 初始化存储Top1到Top10准确率的数组

    for k in range(1, 11):
        accuracy = 100 * correct_topk[k] / total
        top10accuracy.append(accuracy)  # 将计算出的准确率添加到数组中
        print(f'Accuracy after epoch {epoch + 1}: Top{k}: {accuracy:.2f}%')
        print(top10accuracy)
    import pandas as pd
    import os
        # ...训练结束时间
    end_time = datetime.now()
    duration = (end_time - start_time).total_seconds()/60.0
    # 定义数据字典，用于创建DataFrame
    data = {
            'epoch': [epoch],
            'start_time': [start_time],
            'end_time': [end_time],
            'duration': [duration],
            'user_id': [1],
            'model': [model_name],
            'top1_accuracy': [top10accuracy[0]],
            'top2_accuracy': [top10accuracy[1]],
            'top3_accuracy': [top10accuracy[2]],
            'top4_accuracy': [top10accuracy[3]],
            'top5_accuracy': [top10accuracy[4]],
            'top6_accuracy': [top10accuracy[5]],
            'top7_accuracy': [top10accuracy[6]],
            'top8_accuracy': [top10accuracy[7]],
            'top9_accuracy': [top10accuracy[8]],
            'top10_accuracy': [top10accuracy[9]],
            'optional_feature': [optional_feature],
            'learning_rate': [learning_rate],
            'dataset': [dataset],
            'experiment_num':[experiment_num],
    }

        # 创建DataFrame
    df = pd.DataFrame(data)

        # 检查train.csv文件是否存在来决定是否添加表头
    file_exists = os.path.isfile('modified_train_with_updated_duration.csv')

        # 如果文件存在，不写入表头，模式为追加；如果文件不存在，写入表头，模式为写入
    df.to_csv('xuezhangtrainoutcome.csv', mode='a', header=not file_exists, index=False)

    print(f'Epoch {epoch + 1} training data inserted into train.csv.')

Epoch 1:   0%|          | 0/8783 [00:00<?, ?it/s]

Validating:   0%|          | 0/97 [00:00<?, ?it/s]

Accuracy after epoch 1: Top1: 26.38%
[26.37858755240245]
Accuracy after epoch 1: Top2: 36.76%
[26.37858755240245, 36.76233473073202]
Accuracy after epoch 1: Top3: 42.89%
[26.37858755240245, 36.76233473073202, 42.88939051918736]
Accuracy after epoch 1: Top4: 47.37%
[26.37858755240245, 36.76233473073202, 42.88939051918736, 47.3718155433731]
Accuracy after epoch 1: Top5: 50.34%
[26.37858755240245, 36.76233473073202, 42.88939051918736, 47.3718155433731, 50.33860045146727]
Accuracy after epoch 1: Top6: 53.14%
[26.37858755240245, 36.76233473073202, 42.88939051918736, 47.3718155433731, 50.33860045146727, 53.144147049338926]
Accuracy after epoch 1: Top7: 55.85%
[26.37858755240245, 36.76233473073202, 42.88939051918736, 47.3718155433731, 50.33860045146727, 53.144147049338926, 55.852950661077074]
Accuracy after epoch 1: Top8: 57.76%
[26.37858755240245, 36.76233473073202, 42.88939051918736, 47.3718155433731, 50.33860045146727, 53.144147049338926, 55.852950661077074, 57.755562721702674]
Accuracy af

Epoch 2:   0%|          | 0/8783 [00:00<?, ?it/s]

Validating:   0%|          | 0/97 [00:00<?, ?it/s]

Accuracy after epoch 2: Top1: 32.12%
[32.11867139632377]
Accuracy after epoch 2: Top2: 44.05%
[32.11867139632377, 44.05030635278942]
Accuracy after epoch 2: Top3: 50.53%
[32.11867139632377, 44.05030635278942, 50.53208642373428]
Accuracy after epoch 2: Top4: 56.05%
[32.11867139632377, 44.05030635278942, 50.53208642373428, 56.04643663334408]
Accuracy after epoch 2: Top5: 59.72%
[32.11867139632377, 44.05030635278942, 50.53208642373428, 56.04643663334408, 59.72267010641728]
Accuracy after epoch 2: Top6: 62.82%
[32.11867139632377, 44.05030635278942, 50.53208642373428, 56.04643663334408, 59.72267010641728, 62.818445662689456]
Accuracy after epoch 2: Top7: 65.20%
[32.11867139632377, 44.05030635278942, 50.53208642373428, 56.04643663334408, 59.72267010641728, 62.818445662689456, 65.20477265398259]
Accuracy after epoch 2: Top8: 67.30%
[32.11867139632377, 44.05030635278942, 50.53208642373428, 56.04643663334408, 59.72267010641728, 62.818445662689456, 65.20477265398259, 67.3008706868752]
Accuracy a

Epoch 3:   0%|          | 0/8783 [00:00<?, ?it/s]

Validating:   0%|          | 0/97 [00:00<?, ?it/s]

Accuracy after epoch 3: Top1: 35.28%
[35.27894227668494]
Accuracy after epoch 3: Top2: 50.27%
[35.27894227668494, 50.27410512737826]
Accuracy after epoch 3: Top3: 57.53%
[35.27894227668494, 50.27410512737826, 57.529829087391164]
Accuracy after epoch 3: Top4: 62.62%
[35.27894227668494, 50.27410512737826, 57.529829087391164, 62.62495969042244]
Accuracy after epoch 3: Top5: 65.37%
[35.27894227668494, 50.27410512737826, 57.529829087391164, 62.62495969042244, 65.36601096420509]
Accuracy after epoch 3: Top6: 67.91%
[35.27894227668494, 50.27410512737826, 57.529829087391164, 62.62495969042244, 65.36601096420509, 67.91357626572074]
Accuracy after epoch 3: Top7: 70.27%
[35.27894227668494, 50.27410512737826, 57.529829087391164, 62.62495969042244, 65.36601096420509, 67.91357626572074, 70.26765559496937]
Accuracy after epoch 3: Top8: 72.11%
[35.27894227668494, 50.27410512737826, 57.529829087391164, 62.62495969042244, 65.36601096420509, 67.91357626572074, 70.26765559496937, 72.10577233150596]
Accura

Epoch 4:   0%|          | 0/8783 [00:00<?, ?it/s]

Validating:   0%|          | 0/97 [00:00<?, ?it/s]

Accuracy after epoch 4: Top1: 38.18%
[38.1812318606901]
Accuracy after epoch 4: Top2: 51.92%
[38.1812318606901, 51.918735891647856]
Accuracy after epoch 4: Top3: 59.17%
[38.1812318606901, 51.918735891647856, 59.17445985166076]
Accuracy after epoch 4: Top4: 64.62%
[38.1812318606901, 51.918735891647856, 59.17445985166076, 64.62431473718155]
Accuracy after epoch 4: Top5: 67.85%
[38.1812318606901, 51.918735891647856, 59.17445985166076, 64.62431473718155, 67.84908094163173]
Accuracy after epoch 4: Top6: 70.46%
[38.1812318606901, 51.918735891647856, 59.17445985166076, 64.62431473718155, 67.84908094163173, 70.46114156723638]
Accuracy after epoch 4: Top7: 72.59%
[38.1812318606901, 51.918735891647856, 59.17445985166076, 64.62431473718155, 67.84908094163173, 70.46114156723638, 72.58948726217349]
Accuracy after epoch 4: Top8: 74.23%
[38.1812318606901, 51.918735891647856, 59.17445985166076, 64.62431473718155, 67.84908094163173, 70.46114156723638, 72.58948726217349, 74.23411802644308]
Accuracy afte

Epoch 5:   0%|          | 0/8783 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# 训练和验证循环 增加top5 top10
num_epochs = 130  # 这里设置一个小一点的数，以便于测试，您可以根据需要调整
for epoch in range(start_epoch, num_epochs):
    model.train()
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}")
    for batch in progress_bar:
        optimizer.zero_grad()
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        labels = batch[2].to(device)
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        progress_bar.set_postfix(loss=loss.item())
    torch.save({'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict()}, checkpoint_path)
    model.eval()
    correct_topk = {k: 0 for k in range(1, 11)}
    total = 0
    val_progress_bar = tqdm(val_loader, desc="Validating")
    
    for batch in val_progress_bar:
        input_ids = batch[0].to(device)
        attention_mask = batch[1].to(device)
        labels = batch[2].to(device)

        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs[0]
        total += labels.size(0)
        
        # 计算top1到top10的正确率
        _, predicted_topk = torch.topk(logits, k=10, dim=1)
        labels_expanded = labels.unsqueeze(1)
        for k in range(1, 11):
            correct_topk[k] += (predicted_topk[:, :k] == labels_expanded).any(dim=1).sum().item()
    # 打印每个topK的准确率
    for k in range(1, 11):
        accuracy = 100 * correct_topk[k] / total
        print(f'Accuracy after epoch {epoch + 1}: Top{k}: {accuracy:.2f}%')

Epoch 6:   0%|          | 0/2547 [00:00<?, ?it/s]

Validating:   0%|          | 0/85 [00:00<?, ?it/s]

Accuracy after epoch 6: Top1: 54.02%
Accuracy after epoch 6: Top2: 66.67%
Accuracy after epoch 6: Top3: 73.82%
Accuracy after epoch 6: Top4: 76.75%
Accuracy after epoch 6: Top5: 78.49%
Accuracy after epoch 6: Top6: 80.20%
Accuracy after epoch 6: Top7: 81.72%
Accuracy after epoch 6: Top8: 82.91%
Accuracy after epoch 6: Top9: 84.09%
Accuracy after epoch 6: Top10: 84.72%


Epoch 7:   0%|          | 0/2547 [00:00<?, ?it/s]

Validating:   0%|          | 0/85 [00:00<?, ?it/s]

Accuracy after epoch 7: Top1: 54.21%
Accuracy after epoch 7: Top2: 66.22%
Accuracy after epoch 7: Top3: 71.56%
Accuracy after epoch 7: Top4: 74.86%
Accuracy after epoch 7: Top5: 77.23%
Accuracy after epoch 7: Top6: 79.24%
Accuracy after epoch 7: Top7: 80.83%
Accuracy after epoch 7: Top8: 82.13%
Accuracy after epoch 7: Top9: 83.09%
Accuracy after epoch 7: Top10: 83.95%


Epoch 8:   0%|          | 0/2547 [00:00<?, ?it/s]

Validating:   0%|          | 0/85 [00:00<?, ?it/s]

Accuracy after epoch 8: Top1: 54.76%
Accuracy after epoch 8: Top2: 66.15%
Accuracy after epoch 8: Top3: 71.41%
Accuracy after epoch 8: Top4: 74.60%
Accuracy after epoch 8: Top5: 77.01%
Accuracy after epoch 8: Top6: 78.72%
Accuracy after epoch 8: Top7: 80.35%
Accuracy after epoch 8: Top8: 81.50%
Accuracy after epoch 8: Top9: 82.76%
Accuracy after epoch 8: Top10: 83.57%


Epoch 9:   0%|          | 0/2547 [00:00<?, ?it/s]

Validating:   0%|          | 0/85 [00:00<?, ?it/s]

Accuracy after epoch 9: Top1: 55.77%
Accuracy after epoch 9: Top2: 66.89%
Accuracy after epoch 9: Top3: 72.56%
Accuracy after epoch 9: Top4: 75.57%
Accuracy after epoch 9: Top5: 78.05%
Accuracy after epoch 9: Top6: 79.64%
Accuracy after epoch 9: Top7: 81.05%
Accuracy after epoch 9: Top8: 82.39%
Accuracy after epoch 9: Top9: 83.54%
Accuracy after epoch 9: Top10: 84.43%


Epoch 10:   0%|          | 0/2547 [00:00<?, ?it/s]

Validating:   0%|          | 0/85 [00:00<?, ?it/s]

Accuracy after epoch 10: Top1: 55.28%
Accuracy after epoch 10: Top2: 65.81%
Accuracy after epoch 10: Top3: 72.41%
Accuracy after epoch 10: Top4: 75.64%
Accuracy after epoch 10: Top5: 77.57%
Accuracy after epoch 10: Top6: 79.76%
Accuracy after epoch 10: Top7: 81.39%
Accuracy after epoch 10: Top8: 82.57%
Accuracy after epoch 10: Top9: 83.39%
Accuracy after epoch 10: Top10: 84.28%


Epoch 11:   0%|          | 0/2547 [00:00<?, ?it/s]

Validating:   0%|          | 0/85 [00:00<?, ?it/s]

Accuracy after epoch 11: Top1: 55.02%
Accuracy after epoch 11: Top2: 66.07%
Accuracy after epoch 11: Top3: 72.30%
Accuracy after epoch 11: Top4: 75.64%
Accuracy after epoch 11: Top5: 77.72%
Accuracy after epoch 11: Top6: 79.64%
Accuracy after epoch 11: Top7: 80.94%
Accuracy after epoch 11: Top8: 82.17%
Accuracy after epoch 11: Top9: 83.24%
Accuracy after epoch 11: Top10: 83.98%


Epoch 12:   0%|          | 0/2547 [00:00<?, ?it/s]

KeyboardInterrupt: 