In [6]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch import optim
import pandas as pd                                                             
from sklearn.metrics import accuracy_score,precision_score,recall_score 

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [8]:
class MyDateset(Dataset):
    
    def __init__(self, root, train,transform=None, target_transform=None):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.data = pd.read_csv(root)
        self.train = train
        if train:
            # 标签列默认最后一列
            self.X = self.data.iloc[:, :-1]
            self.y = self.data.iloc[:, -1]
        else:
            self.X = self.data
            self.y = None
    def __getitem__(self, index):
        if self.train:
            return (self.X.iloc[index], self.y.iloc[index])
        else:
            return self.X.loc[index] 
    
    def __len__(self):
        return len(self.data)

In [9]:
data = MyDateset('D:/temp_files/datasets/spaceship_titanic/train.csv', train=True)
X, y = data.__getitem__(2)

In [10]:
def train_test_dataloader(train:Dataset, test:Dataset, batch_size, shuffle=False):
    former = torch.utils.data.DataLoader(dataset=train, batch_size=batch_size, shuffle=shuffle)
    latter = torch.utils.data.DataLoader(dataset=test, batch_size=batch_size, shuffle=shuffle)
    return former, latter

In [11]:
class MyNeuralNetwork(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MyNeuralNetwork, self).__init__()
        self.linear_relu_sequential = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, num_classes),
        )
    def forward(self, x):
        return self.linear_relu_sequential(x)

In [12]:
input_size = 10
hidden_size = 20
num_classes = 2
num_epochs = 10
batch_size = 64
learning_rate = 1e-3
model = MyNeuralNetwork(input_size, hidden_size, num_classes)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [13]:
def train_validate_loops(train:DataLoader, 
                         validate:DataLoader, 
                         model,
                         num_epochs,
                         loss_function, 
                         optimizer):
    for num_epoch in range(num_epochs):
        print(f"轮次 {num_epoch + 1}\n------------------------------\n训练集：\n")
        n = len(train.dataset)
        for batch, (train_X, train_y) in enumerate(train):
            # 计算预测和损失
            train_y_pred = model(train_X)
            loss = loss_function(train_y_pred, train_y)
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            # 每X批 显示一次loss，batch和100的最小公倍数
            if batch % 100 == 0:
                loss, current = loss.item(), batch * len(train_X)
                print(f"当前损失:{loss:7f}\t[{current:5d}/{n:>5d}]")
        # 小批量
        num_batches = len(validate)
        validate_loss, accuracy = 0, 0
        y_full = np.array([])
        y_pred_full = np.array([])
        with torch.no_grad():
            for (validata_X, validate_y) in validate:
                validate_y_pred = model(validata_X)
                validate_loss += loss_function(validate_y_pred, validate_y).item()
                validate_y_pred = validate_y_pred.argmax(1)
                y_full = np.append(y_full, validate_y.numpy())
                y_pred_full = np.append(y_pred_full, validate_y_pred.numpy())
        validate_loss /= num_batches 
        accuracy = accuracy_score(y_full, y_pred_full)
        # 二分类就是average='binaray'
        precision = precision_score(y_full, y_pred_full, average='macro')
        recall = recall_score(y_full, y_pred_full, average='macro')
        print(f"验证集：\n 准确度：{accuracy*100}%，精确度：{precision*100}%，召回率：{recall*100}%\n平均损失：{validate_loss:>8f} ")
        print("=======================================")