In [1]:
!pip install pandas numpy torch scikit-learn xgboost matplotlib tqdm

Looking in indexes: https://mirrors.cernet.edu.cn/pypi/web/simple


In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import xgboost as xgb

In [3]:
data = pd.read_excel('data.xlsx')

In [4]:
features = ['T', 'Tn', 'Tx', 'Po']
target = 'WW'

In [5]:
X = data[features].values
y = data[target].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [6]:
# 数据标准化
scaler = MinMaxScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 转换为PyTorch张量
X_train = torch.tensor(X_train, dtype=torch.float32)
X_test = torch.tensor(X_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
y_test = torch.tensor(y_test, dtype=torch.long)

In [7]:
# 定义BP神经网络模型
class BPNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(BPNN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

In [8]:
# 初始化BP神经网络模型
input_size = X_train.shape[1]
hidden_size = 64
num_classes = len(np.unique(y))  # 类别数量
model = BPNN(input_size, hidden_size, num_classes)

In [9]:
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [10]:
# 训练BP神经网络
epochs = 100
batch_size = 32
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader)}")

Epoch 1/100: 100%|██████████| 365/365 [00:00<00:00, 406.03it/s]


Epoch 1/100, Loss: 1.2084512357842432


Epoch 2/100: 100%|██████████| 365/365 [00:00<00:00, 542.67it/s]


Epoch 2/100, Loss: 1.0316019019035443


Epoch 3/100: 100%|██████████| 365/365 [00:00<00:00, 566.97it/s]


Epoch 3/100, Loss: 1.0082599865247126


Epoch 4/100: 100%|██████████| 365/365 [00:00<00:00, 594.96it/s]


Epoch 4/100, Loss: 0.9855352307019168


Epoch 5/100: 100%|██████████| 365/365 [00:00<00:00, 559.74it/s]


Epoch 5/100, Loss: 0.9658485672245287


Epoch 6/100: 100%|██████████| 365/365 [00:00<00:00, 580.07it/s]


Epoch 6/100, Loss: 0.949379501195803


Epoch 7/100: 100%|██████████| 365/365 [00:00<00:00, 567.70it/s]


Epoch 7/100, Loss: 0.9375446803765755


Epoch 8/100: 100%|██████████| 365/365 [00:00<00:00, 578.18it/s]


Epoch 8/100, Loss: 0.9261897231618019


Epoch 9/100: 100%|██████████| 365/365 [00:00<00:00, 559.38it/s]


Epoch 9/100, Loss: 0.9204853027650755


Epoch 10/100: 100%|██████████| 365/365 [00:00<00:00, 580.41it/s]


Epoch 10/100, Loss: 0.9141580954806445


Epoch 11/100: 100%|██████████| 365/365 [00:00<00:00, 555.08it/s]


Epoch 11/100, Loss: 0.9109495199706457


Epoch 12/100: 100%|██████████| 365/365 [00:00<00:00, 530.82it/s]


Epoch 12/100, Loss: 0.9066013322301107


Epoch 13/100: 100%|██████████| 365/365 [00:00<00:00, 530.41it/s]


Epoch 13/100, Loss: 0.9044029323205556


Epoch 14/100: 100%|██████████| 365/365 [00:00<00:00, 540.89it/s]


Epoch 14/100, Loss: 0.9023259472357084


Epoch 15/100: 100%|██████████| 365/365 [00:00<00:00, 541.71it/s]


Epoch 15/100, Loss: 0.8995303087038536


Epoch 16/100: 100%|██████████| 365/365 [00:00<00:00, 575.12it/s]


Epoch 16/100, Loss: 0.8985288430566657


Epoch 17/100: 100%|██████████| 365/365 [00:00<00:00, 592.43it/s]


Epoch 17/100, Loss: 0.898613237191553


Epoch 18/100: 100%|██████████| 365/365 [00:00<00:00, 573.14it/s]


Epoch 18/100, Loss: 0.896026171641807


Epoch 19/100: 100%|██████████| 365/365 [00:00<00:00, 541.61it/s]


Epoch 19/100, Loss: 0.8957129901402617


Epoch 20/100: 100%|██████████| 365/365 [00:00<00:00, 592.70it/s]


Epoch 20/100, Loss: 0.8943711745412383


Epoch 21/100: 100%|██████████| 365/365 [00:00<00:00, 608.02it/s]


Epoch 21/100, Loss: 0.8934971072085917


Epoch 22/100: 100%|██████████| 365/365 [00:00<00:00, 552.03it/s]


Epoch 22/100, Loss: 0.8910057309555681


Epoch 23/100: 100%|██████████| 365/365 [00:00<00:00, 570.17it/s]


Epoch 23/100, Loss: 0.8916250610188262


Epoch 24/100: 100%|██████████| 365/365 [00:00<00:00, 608.63it/s]


Epoch 24/100, Loss: 0.8901368858063058


Epoch 25/100: 100%|██████████| 365/365 [00:00<00:00, 567.78it/s]


Epoch 25/100, Loss: 0.8898826161476031


Epoch 26/100: 100%|██████████| 365/365 [00:00<00:00, 590.27it/s]


Epoch 26/100, Loss: 0.8893930863844205


Epoch 27/100: 100%|██████████| 365/365 [00:00<00:00, 598.55it/s]


Epoch 27/100, Loss: 0.8893253412965226


Epoch 28/100: 100%|██████████| 365/365 [00:00<00:00, 451.20it/s]


Epoch 28/100, Loss: 0.8883709736882823


Epoch 29/100: 100%|██████████| 365/365 [00:00<00:00, 489.70it/s]


Epoch 29/100, Loss: 0.8886611833147806


Epoch 30/100: 100%|██████████| 365/365 [00:00<00:00, 507.45it/s]


Epoch 30/100, Loss: 0.8870517184473063


Epoch 31/100: 100%|██████████| 365/365 [00:00<00:00, 505.66it/s]


Epoch 31/100, Loss: 0.887042544720924


Epoch 32/100: 100%|██████████| 365/365 [00:00<00:00, 498.68it/s]


Epoch 32/100, Loss: 0.8866656950891835


Epoch 33/100: 100%|██████████| 365/365 [00:00<00:00, 504.21it/s]


Epoch 33/100, Loss: 0.8857105993244746


Epoch 34/100: 100%|██████████| 365/365 [00:00<00:00, 560.73it/s]


Epoch 34/100, Loss: 0.8862834060028808


Epoch 35/100: 100%|██████████| 365/365 [00:00<00:00, 576.34it/s]


Epoch 35/100, Loss: 0.8860746538802369


Epoch 36/100: 100%|██████████| 365/365 [00:00<00:00, 558.31it/s]


Epoch 36/100, Loss: 0.8865515685244783


Epoch 37/100: 100%|██████████| 365/365 [00:00<00:00, 584.57it/s]


Epoch 37/100, Loss: 0.8857907970474191


Epoch 38/100: 100%|██████████| 365/365 [00:00<00:00, 580.37it/s]


Epoch 38/100, Loss: 0.8859374132058392


Epoch 39/100: 100%|██████████| 365/365 [00:00<00:00, 620.95it/s]


Epoch 39/100, Loss: 0.8842936951003663


Epoch 40/100: 100%|██████████| 365/365 [00:00<00:00, 594.02it/s]


Epoch 40/100, Loss: 0.8842808715284687


Epoch 41/100: 100%|██████████| 365/365 [00:00<00:00, 588.27it/s]


Epoch 41/100, Loss: 0.8846998435177215


Epoch 42/100: 100%|██████████| 365/365 [00:00<00:00, 607.99it/s]


Epoch 42/100, Loss: 0.8836341167149478


Epoch 43/100: 100%|██████████| 365/365 [00:00<00:00, 573.19it/s]


Epoch 43/100, Loss: 0.8832188048591353


Epoch 44/100: 100%|██████████| 365/365 [00:00<00:00, 605.62it/s]


Epoch 44/100, Loss: 0.8842625814757935


Epoch 45/100: 100%|██████████| 365/365 [00:00<00:00, 583.36it/s]


Epoch 45/100, Loss: 0.8827857046911162


Epoch 46/100: 100%|██████████| 365/365 [00:00<00:00, 601.43it/s]


Epoch 46/100, Loss: 0.8836572591572591


Epoch 47/100: 100%|██████████| 365/365 [00:00<00:00, 613.67it/s]


Epoch 47/100, Loss: 0.8830771954908763


Epoch 48/100: 100%|██████████| 365/365 [00:00<00:00, 605.36it/s]


Epoch 48/100, Loss: 0.8837855748117787


Epoch 49/100: 100%|██████████| 365/365 [00:00<00:00, 550.05it/s]


Epoch 49/100, Loss: 0.8828155154234743


Epoch 50/100: 100%|██████████| 365/365 [00:00<00:00, 601.45it/s]


Epoch 50/100, Loss: 0.8834146777244464


Epoch 51/100: 100%|██████████| 365/365 [00:00<00:00, 605.01it/s]


Epoch 51/100, Loss: 0.8822105046820967


Epoch 52/100: 100%|██████████| 365/365 [00:00<00:00, 594.52it/s]


Epoch 52/100, Loss: 0.8816255443716703


Epoch 53/100: 100%|██████████| 365/365 [00:00<00:00, 506.33it/s]


Epoch 53/100, Loss: 0.8816096521403691


Epoch 54/100: 100%|██████████| 365/365 [00:00<00:00, 425.31it/s]


Epoch 54/100, Loss: 0.8814419381422539


Epoch 55/100: 100%|██████████| 365/365 [00:00<00:00, 523.93it/s]


Epoch 55/100, Loss: 0.8829176294477019


Epoch 56/100: 100%|██████████| 365/365 [00:00<00:00, 541.54it/s]


Epoch 56/100, Loss: 0.8815663056830837


Epoch 57/100: 100%|██████████| 365/365 [00:00<00:00, 507.23it/s]


Epoch 57/100, Loss: 0.8812090221741428


Epoch 58/100: 100%|██████████| 365/365 [00:00<00:00, 572.74it/s]


Epoch 58/100, Loss: 0.8811959174397873


Epoch 59/100: 100%|██████████| 365/365 [00:00<00:00, 547.57it/s]


Epoch 59/100, Loss: 0.8813593115708599


Epoch 60/100: 100%|██████████| 365/365 [00:00<00:00, 442.49it/s]


Epoch 60/100, Loss: 0.8808051615545195


Epoch 61/100: 100%|██████████| 365/365 [00:00<00:00, 596.45it/s]


Epoch 61/100, Loss: 0.8805587081060018


Epoch 62/100: 100%|██████████| 365/365 [00:00<00:00, 515.25it/s]


Epoch 62/100, Loss: 0.8806949797558458


Epoch 63/100: 100%|██████████| 365/365 [00:00<00:00, 439.27it/s]


Epoch 63/100, Loss: 0.879486834921249


Epoch 64/100: 100%|██████████| 365/365 [00:00<00:00, 515.81it/s]


Epoch 64/100, Loss: 0.8811757757239146


Epoch 65/100: 100%|██████████| 365/365 [00:00<00:00, 498.66it/s]


Epoch 65/100, Loss: 0.8810796220825143


Epoch 66/100: 100%|██████████| 365/365 [00:00<00:00, 479.24it/s]


Epoch 66/100, Loss: 0.8797766968812029


Epoch 67/100: 100%|██████████| 365/365 [00:00<00:00, 485.69it/s]


Epoch 67/100, Loss: 0.8801073342969973


Epoch 68/100: 100%|██████████| 365/365 [00:00<00:00, 514.24it/s]


Epoch 68/100, Loss: 0.8805855097019509


Epoch 69/100: 100%|██████████| 365/365 [00:01<00:00, 218.74it/s]


Epoch 69/100, Loss: 0.8801763443097677


Epoch 70/100: 100%|██████████| 365/365 [00:02<00:00, 168.41it/s]


Epoch 70/100, Loss: 0.8788144617048028


Epoch 71/100: 100%|██████████| 365/365 [00:01<00:00, 287.97it/s]


Epoch 71/100, Loss: 0.8791736829770754


Epoch 72/100: 100%|██████████| 365/365 [00:02<00:00, 146.86it/s]


Epoch 72/100, Loss: 0.8796009944726343


Epoch 73/100: 100%|██████████| 365/365 [00:01<00:00, 271.87it/s]


Epoch 73/100, Loss: 0.8801121917489457


Epoch 74/100: 100%|██████████| 365/365 [00:00<00:00, 405.06it/s]


Epoch 74/100, Loss: 0.8789877431033409


Epoch 75/100: 100%|██████████| 365/365 [00:01<00:00, 311.19it/s]


Epoch 75/100, Loss: 0.8789909165199489


Epoch 76/100: 100%|██████████| 365/365 [00:01<00:00, 292.59it/s]


Epoch 76/100, Loss: 0.8781783562816986


Epoch 77/100: 100%|██████████| 365/365 [00:00<00:00, 377.61it/s]


Epoch 77/100, Loss: 0.8792455598099591


Epoch 78/100: 100%|██████████| 365/365 [00:00<00:00, 448.28it/s]


Epoch 78/100, Loss: 0.8786418677192844


Epoch 79/100: 100%|██████████| 365/365 [00:01<00:00, 328.58it/s]


Epoch 79/100, Loss: 0.8781034863158448


Epoch 80/100: 100%|██████████| 365/365 [00:00<00:00, 414.46it/s]


Epoch 80/100, Loss: 0.8786624438142123


Epoch 81/100: 100%|██████████| 365/365 [00:00<00:00, 455.43it/s]


Epoch 81/100, Loss: 0.8786755011506276


Epoch 82/100: 100%|██████████| 365/365 [00:00<00:00, 375.75it/s]


Epoch 82/100, Loss: 0.8788589061939553


Epoch 83/100: 100%|██████████| 365/365 [00:00<00:00, 440.27it/s]


Epoch 83/100, Loss: 0.8789445541492881


Epoch 84/100: 100%|██████████| 365/365 [00:00<00:00, 442.22it/s]


Epoch 84/100, Loss: 0.8779810013836378


Epoch 85/100: 100%|██████████| 365/365 [00:00<00:00, 419.66it/s]


Epoch 85/100, Loss: 0.8777179967050683


Epoch 86/100: 100%|██████████| 365/365 [00:00<00:00, 472.29it/s]


Epoch 86/100, Loss: 0.8778859620224939


Epoch 87/100: 100%|██████████| 365/365 [00:00<00:00, 454.85it/s]


Epoch 87/100, Loss: 0.8774934114658669


Epoch 88/100: 100%|██████████| 365/365 [00:00<00:00, 405.38it/s]


Epoch 88/100, Loss: 0.8777025295446996


Epoch 89/100: 100%|██████████| 365/365 [00:01<00:00, 321.46it/s]


Epoch 89/100, Loss: 0.8778802603891451


Epoch 90/100: 100%|██████████| 365/365 [00:00<00:00, 403.66it/s]


Epoch 90/100, Loss: 0.8765781244186506


Epoch 91/100: 100%|██████████| 365/365 [00:04<00:00, 79.47it/s] 


Epoch 91/100, Loss: 0.877702346893206


Epoch 92/100: 100%|██████████| 365/365 [00:01<00:00, 283.60it/s]


Epoch 92/100, Loss: 0.8774860575591048


Epoch 93/100: 100%|██████████| 365/365 [00:00<00:00, 388.39it/s]


Epoch 93/100, Loss: 0.8763759135383449


Epoch 94/100: 100%|██████████| 365/365 [00:01<00:00, 326.22it/s]


Epoch 94/100, Loss: 0.8760893400401285


Epoch 95/100: 100%|██████████| 365/365 [00:01<00:00, 361.88it/s]


Epoch 95/100, Loss: 0.876043275447741


Epoch 96/100: 100%|██████████| 365/365 [00:00<00:00, 425.34it/s]


Epoch 96/100, Loss: 0.8770491906633116


Epoch 97/100: 100%|██████████| 365/365 [00:00<00:00, 407.22it/s]


Epoch 97/100, Loss: 0.8762902660729134


Epoch 98/100: 100%|██████████| 365/365 [00:00<00:00, 418.18it/s]


Epoch 98/100, Loss: 0.8774234614960135


Epoch 99/100: 100%|██████████| 365/365 [00:00<00:00, 434.12it/s]


Epoch 99/100, Loss: 0.8763706453042488


Epoch 100/100: 100%|██████████| 365/365 [00:00<00:00, 379.03it/s]

Epoch 100/100, Loss: 0.8765541475929626





In [11]:
# BP神经网络模型评估
model.eval()
with torch.no_grad():
    outputs = model(X_test)
    _, predicted = torch.max(outputs.data, 1)
    accuracy = accuracy_score(y_test.numpy(), predicted.numpy())
    print(f'BPNN Accuracy: {accuracy}')

BPNN Accuracy: 0.6880356530682208


In [12]:
# XGBoost模型训练
xgb_model = xgb.XGBClassifier(objective='multi:softmax', num_class=num_classes, random_state=42)
xgb_model.fit(X_train.numpy(), y_train.numpy())

# XGBoost模型评估
y_pred_xgb = xgb_model.predict(X_test.numpy())
accuracy_xgb = accuracy_score(y_test.numpy(), y_pred_xgb)
print(f'XGBoost Accuracy: {accuracy_xgb}')

XGBoost Accuracy: 0.7123757284881728


In [13]:
from sklearn.ensemble import ExtraTreesClassifier

# 训练极端随机树模型
et_model = ExtraTreesClassifier(n_estimators=100, random_state=42)
et_model.fit(X_train.numpy(), y_train.numpy())

# 评估极端随机树模型
y_pred_et = et_model.predict(X_test.numpy())
accuracy_et = accuracy_score(y_test.numpy(), y_pred_et)
print(f'ExtraTrees Accuracy: {accuracy_et}')

ExtraTrees Accuracy: 0.7271169009256085
