导入库和数据：

从Kaggle下载Titanic数据集，并导入到我们的Python环境中。

In [50]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import CSVLogger

# 读取数据HousePrices
train_data = pd.read_csv('train.csv')
test_data = pd.read_csv('test.csv')


数据探索和可视化：

查看数据集的基本信息，了解各个特征。

In [51]:
train_data.info()
train_data.describe()
train_data.head()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 891 entries, 0 to 890
Data columns (total 12 columns):
 #   Column       Non-Null Count  Dtype  
---  ------       --------------  -----  
 0   PassengerId  891 non-null    int64  
 1   Survived     891 non-null    int64  
 2   Pclass       891 non-null    int64  
 3   Name         891 non-null    object 
 4   Sex          891 non-null    object 
 5   Age          714 non-null    float64
 6   SibSp        891 non-null    int64  
 7   Parch        891 non-null    int64  
 8   Ticket       891 non-null    object 
 9   Fare         891 non-null    float64
 10  Cabin        204 non-null    object 
 11  Embarked     889 non-null    object 
dtypes: float64(2), int64(5), object(5)
memory usage: 83.7+ KB


Unnamed: 0,PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
0,1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
1,2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
2,3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
3,4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
4,5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


自定义Dataset类

In [52]:
class TitanicDataset(Dataset):
    def __init__(self, data, scaler=None, train=True):
        self.data = data
        self.train = train
        self.scaler = scaler
        self.features = self._preprocess(data)

    def _preprocess(self, data):
        data['Age'].fillna(data['Age'].median(), inplace=True)
        data['Embarked'].fillna(data['Embarked'].mode()[0], inplace=True)
        data['Fare'].fillna(data['Fare'].median(), inplace=True)
        data.drop(['Cabin', 'Ticket', 'Name'], axis=1, inplace=True)

        data = pd.get_dummies(data, columns=['Sex', 'Embarked'], drop_first=True)

        if self.train:
            features = data.drop(['Survived', 'PassengerId'], axis=1).values
        else:
            features = data.drop(['PassengerId'], axis=1).values

        if self.scaler:
            features = self.scaler.transform(features)

        return features

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        if self.train:
            label = self.data.iloc[idx]['Survived']
            return torch.tensor(self.features[idx], dtype=torch.float32), torch.tensor(label, dtype=torch.float32)
        else:
            return torch.tensor(self.features[idx], dtype=torch.float32)


数据预处理和划分：

In [53]:
# 合并训练和测试数据以确保相同的预处理
all_data = pd.concat([train_data.drop(['Survived'], axis=1), test_data], ignore_index=True)

# 标准化
scaler = StandardScaler()
scaler.fit(pd.get_dummies(all_data.drop(['PassengerId', 'Cabin', 'Ticket', 'Name'], axis=1), columns=['Sex', 'Embarked'], drop_first=True))

# 创建训练和验证数据集
train_dataset = TitanicDataset(train_data, scaler=scaler, train=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 划分训练集和验证集
train_indices, val_indices = train_test_split(list(range(len(train_dataset))), test_size=0.2, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=32, sampler=torch.utils.data.SubsetRandomSampler(train_indices))
val_loader = DataLoader(train_dataset, batch_size=32, sampler=torch.utils.data.SubsetRandomSampler(val_indices))




定义模型：

使用PyTorch Lightning定义模型。

In [54]:
class TitanicModel(pl.LightningModule):
    def __init__(self):
        super(TitanicModel, self).__init__()
        self.layer_1 = nn.Linear(8, 64)
        self.layer_2 = nn.Linear(64, 32)
        self.layer_3 = nn.Linear(32, 1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.criterion = nn.BCELoss()

    def forward(self, x):
        x = self.relu(self.layer_1(x))
        x = self.relu(self.layer_2(x))
        x = self.sigmoid(self.layer_3(x))
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y.unsqueeze(1))
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y.unsqueeze(1))
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        return optimizer


训练模型：

使用PyTorch Lightning的Trainer进行训练。

In [55]:
early_stopping_callback = EarlyStopping(monitor='val_loss', mode='min', patience=15)
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    mode='min',
    filename='TitanicModel-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    verbose=True
)

# 配置CSVLogger
csv_logger = CSVLogger(save_dir='logs/', name='titanic')

model = TitanicModel()

trainer = Trainer(
    max_epochs=50,
    check_val_every_n_epoch=1,
    logger=csv_logger,
    callbacks=[early_stopping_callback, checkpoint_callback]
)

trainer.fit(model, train_loader, val_loader)

# 加载最优模型
best_model_path = checkpoint_callback.best_model_path
best_model = TitanicModel.load_from_checkpoint(best_model_path)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/titanic

  | Name      | Type    | Params
--------------------------------------
0 | layer_1   | Linear  | 576   
1 | layer_2   | Linear  | 2.1 K 
2 | layer_3   | Linear  | 33    
3 | relu      | ReLU    | 0     
4 | sigmoid   | Sigmoid | 0     
5 | criterion | BCELoss | 0     
--------------------------------------
2.7 K     Trainable params
0         Non-trainable params
2.7 K     Total params
0.011     Total estimated model params size (MB)


                                                                            

E:\Program Files\anaconda3\envs\mylearn\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
E:\Program Files\anaconda3\envs\mylearn\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.
E:\Program Files\anaconda3\envs\mylearn\Lib\site-packages\pytorch_lightning\loops\fit_loop.py:298: The number of training batches (23) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Epoch 0: 100%|██████████| 23/23 [00:00<00:00, 116.66it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 160.39it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 151.12it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 164.52it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 158.51it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 160.08it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 161.14it/s][A
Epoch 0: 100%|██████████| 23/23 [00:00<00:00, 91.91it/s, v_num=0]      [A

Epoch 0, global step 23: 'val_loss' reached 0.63097 (best 0.63097), saving model to 'logs/titanic\\version_0\\checkpoints\\TitanicModel-epoch=00-val_loss=0.63.ckpt' as top 1


Epoch 1: 100%|██████████| 23/23 [00:00<00:00, 116.04it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 150.85it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 145.84it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 143.16it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 138.17it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 134.18it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 141.97it/s][A
Epoch 1: 100%|██████████| 23/23 [00:00<00:00, 89.65it/s, v_num=0]      [A

Epoch 1, global step 46: 'val_loss' reached 0.55339 (best 0.55339), saving model to 'logs/titanic\\version_0\\checkpoints\\TitanicModel-epoch=01-val_loss=0.55.ckpt' as top 1


Epoch 2: 100%|██████████| 23/23 [00:00<00:00, 120.09it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 178.47it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 164.72it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 153.47it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 149.82it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 148.38it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 149.89it/s][A
Epoch 2: 100%|██████████| 23/23 [00:00<00:00, 93.87it/s, v_num=0]      [A

Epoch 2, global step 69: 'val_loss' reached 0.47432 (best 0.47432), saving model to 'logs/titanic\\version_0\\checkpoints\\TitanicModel-epoch=02-val_loss=0.47.ckpt' as top 1


Epoch 3: 100%|██████████| 23/23 [00:00<00:00, 111.72it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 166.66it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 166.65it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 157.88it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 148.11it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 151.52it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 153.83it/s][A
Epoch 3: 100%|██████████| 23/23 [00:00<00:00, 88.44it/s, v_num=0]      [A

Epoch 3, global step 92: 'val_loss' reached 0.43071 (best 0.43071), saving model to 'logs/titanic\\version_0\\checkpoints\\TitanicModel-epoch=03-val_loss=0.43.ckpt' as top 1


Epoch 4: 100%|██████████| 23/23 [00:00<00:00, 118.70it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 163.83it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 141.80it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 142.16it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 132.88it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 134.76it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 139.20it/s][A
Epoch 4: 100%|██████████| 23/23 [00:00<00:00, 88.61it/s, v_num=0]      [A

Epoch 4, global step 115: 'val_loss' reached 0.41586 (best 0.41586), saving model to 'logs/titanic\\version_0\\checkpoints\\TitanicModel-epoch=04-val_loss=0.42.ckpt' as top 1


Epoch 5: 100%|██████████| 23/23 [00:00<00:00, 117.64it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 198.11it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 181.10it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 166.27it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 161.43it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 155.75it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 161.69it/s][A
Epoch 5: 100%|██████████| 23/23 [00:00<00:00, 92.47it/s, v_num=0]      [A

Epoch 5, global step 138: 'val_loss' reached 0.40821 (best 0.40821), saving model to 'logs/titanic\\version_0\\checkpoints\\TitanicModel-epoch=05-val_loss=0.41.ckpt' as top 1


Epoch 6: 100%|██████████| 23/23 [00:00<00:00, 122.36it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 142.86it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 153.85it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 142.86it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 148.15it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 147.06it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 146.34it/s][A
Epoch 6: 100%|██████████| 23/23 [00:00<00:00, 94.10it/s, v_num=0]      [A

Epoch 6, global step 161: 'val_loss' reached 0.40422 (best 0.40422), saving model to 'logs/titanic\\version_0\\checkpoints\\TitanicModel-epoch=06-val_loss=0.40.ckpt' as top 1


Epoch 7: 100%|██████████| 23/23 [00:00<00:00, 116.44it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 166.64it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 163.25it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 155.15it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 151.89it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 145.62it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 148.75it/s][A
Epoch 7: 100%|██████████| 23/23 [00:00<00:00, 90.25it/s, v_num=0]      [A

Epoch 7, global step 184: 'val_loss' reached 0.39996 (best 0.39996), saving model to 'logs/titanic\\version_0\\checkpoints\\TitanicModel-epoch=07-val_loss=0.40.ckpt' as top 1


Epoch 8: 100%|██████████| 23/23 [00:00<00:00, 128.05it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 166.80it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 153.90it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 150.04it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 148.17it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 142.88it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 142.87it/s][A
Epoch 8: 100%|██████████| 23/23 [00:00<00:00, 97.62it/s, v_num=0]      [A

Epoch 8, global step 207: 'val_loss' was not in top 1


Epoch 9: 100%|██████████| 23/23 [00:00<00:00, 124.22it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 166.31it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 151.84it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 145.00it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 143.95it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 145.62it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 148.73it/s][A
Epoch 9: 100%|██████████| 23/23 [00:00<00:00, 95.24it/s, v_num=0]      [A

Epoch 9, global step 230: 'val_loss' reached 0.39896 (best 0.39896), saving model to 'logs/titanic\\version_0\\checkpoints\\TitanicModel-epoch=09-val_loss=0.40.ckpt' as top 1


Epoch 10: 100%|██████████| 23/23 [00:00<00:00, 120.17it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 142.86it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 153.85it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 150.00it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 148.15it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 147.04it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 153.84it/s][A
Epoch 10: 100%|██████████| 23/23 [00:00<00:00, 93.65it/s, v_num=0]     [A

Epoch 10, global step 253: 'val_loss' was not in top 1


Epoch 11: 100%|██████████| 23/23 [00:00<00:00, 111.39it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 130.18it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 136.21it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 136.13it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 133.18it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 131.44it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 129.56it/s][A
Epoch 11: 100%|██████████| 23/23 [00:00<00:00, 84.62it/s, v_num=0]     [A

Epoch 11, global step 276: 'val_loss' was not in top 1


Epoch 12: 100%|██████████| 23/23 [00:00<00:00, 99.19it/s, v_num=0] 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 137.71it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 128.92it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 121.77it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 122.57it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 120.52it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 123.75it/s][A
Epoch 12: 100%|██████████| 23/23 [00:00<00:00, 76.57it/s, v_num=0]     [A

Epoch 12, global step 299: 'val_loss' was not in top 1


Epoch 13: 100%|██████████| 23/23 [00:00<00:00, 108.42it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 142.71it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 142.78it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 142.79it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 136.62it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 137.82it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 138.56it/s][A
Epoch 13: 100%|██████████| 23/23 [00:00<00:00, 84.27it/s, v_num=0]     [A

Epoch 13, global step 322: 'val_loss' was not in top 1


Epoch 14: 100%|██████████| 23/23 [00:00<00:00, 110.75it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 137.45it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 140.05it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 140.97it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 140.77it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 140.10it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 143.93it/s][A
Epoch 14: 100%|██████████| 23/23 [00:00<00:00, 87.28it/s, v_num=0]     [A

Epoch 14, global step 345: 'val_loss' was not in top 1


Epoch 15: 100%|██████████| 23/23 [00:00<00:00, 110.08it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 163.80it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 152.61it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 142.15it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 137.43it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 134.75it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 139.19it/s][A
Epoch 15: 100%|██████████| 23/23 [00:00<00:00, 85.53it/s, v_num=0]     [A

Epoch 15, global step 368: 'val_loss' was not in top 1


Epoch 16: 100%|██████████| 23/23 [00:00<00:00, 111.67it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 158.91it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 143.81it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 136.94it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 133.74it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 135.47it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 142.38it/s][A
Epoch 16: 100%|██████████| 23/23 [00:00<00:00, 86.78it/s, v_num=0]     [A

Epoch 16, global step 391: 'val_loss' was not in top 1


Epoch 17: 100%|██████████| 23/23 [00:00<00:00, 113.12it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 142.86it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 142.86it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 150.00it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 153.84it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 152.76it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 154.89it/s][A
Epoch 17: 100%|██████████| 23/23 [00:00<00:00, 85.43it/s, v_num=0]     [A

Epoch 17, global step 414: 'val_loss' was not in top 1


Epoch 18: 100%|██████████| 23/23 [00:00<00:00, 114.85it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 164.89it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 142.21it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 135.99it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 129.69it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 132.13it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 137.24it/s][A
Epoch 18: 100%|██████████| 23/23 [00:00<00:00, 88.27it/s, v_num=0]     [A

Epoch 18, global step 437: 'val_loss' was not in top 1


Epoch 19: 100%|██████████| 23/23 [00:00<00:00, 113.60it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 166.67it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 153.83it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 150.00it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 146.20it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 137.50it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 141.62it/s][A
Epoch 19: 100%|██████████| 23/23 [00:00<00:00, 87.85it/s, v_num=0]     [A

Epoch 19, global step 460: 'val_loss' was not in top 1


Epoch 20: 100%|██████████| 23/23 [00:00<00:00, 112.66it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 166.67it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 162.79it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 151.77it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 151.93it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 154.66it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 156.54it/s][A
Epoch 20: 100%|██████████| 23/23 [00:00<00:00, 88.73it/s, v_num=0]     [A

Epoch 20, global step 483: 'val_loss' was not in top 1


Epoch 21: 100%|██████████| 23/23 [00:00<00:00, 114.22it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 166.58it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 149.71it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 147.34it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 146.20it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 144.79it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 148.02it/s][A
Epoch 21: 100%|██████████| 23/23 [00:00<00:00, 89.77it/s, v_num=0]     [A

Epoch 21, global step 506: 'val_loss' was not in top 1


Epoch 22: 100%|██████████| 23/23 [00:00<00:00, 114.60it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 142.90it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 142.87it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 150.01it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 142.87it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 135.14it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 139.53it/s][A
Epoch 22: 100%|██████████| 23/23 [00:00<00:00, 88.57it/s, v_num=0]     [A

Epoch 22, global step 529: 'val_loss' was not in top 1


Epoch 23: 100%|██████████| 23/23 [00:00<00:00, 102.78it/s, v_num=0]
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 166.67it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 133.33it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 125.00it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 124.99it/s][A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 123.28it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 123.54it/s][A
Epoch 23: 100%|██████████| 23/23 [00:00<00:00, 78.89it/s, v_num=0]     [A

Epoch 23, global step 552: 'val_loss' was not in top 1


Epoch 24: 100%|██████████| 23/23 [00:00<00:00, 99.03it/s, v_num=0] 
Validation: |          | 0/? [00:00<?, ?it/s][A
Validation:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/6 [00:00<?, ?it/s][A
Validation DataLoader 0:  17%|█▋        | 1/6 [00:00<00:00, 124.83it/s][A
Validation DataLoader 0:  33%|███▎      | 2/6 [00:00<00:00, 109.47it/s][A
Validation DataLoader 0:  50%|█████     | 3/6 [00:00<00:00, 106.10it/s][A
Validation DataLoader 0:  67%|██████▋   | 4/6 [00:00<00:00, 99.81it/s] [A
Validation DataLoader 0:  83%|████████▎ | 5/6 [00:00<00:00, 101.07it/s][A
Validation DataLoader 0: 100%|██████████| 6/6 [00:00<00:00, 102.59it/s][A
Epoch 24: 100%|██████████| 23/23 [00:00<00:00, 73.67it/s, v_num=0]     [A

Epoch 24, global step 575: 'val_loss' was not in top 1


Epoch 24: 100%|██████████| 23/23 [00:00<00:00, 72.29it/s, v_num=0]


对测试数据进行预测：

同样的预处理步骤，然后用训练好的模型进行预测。

In [56]:
# 创建测试数据集
test_dataset = TitanicDataset(test_data, scaler=scaler, train=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 预测
best_model.eval()
test_pred = []
with torch.no_grad():
    for batch in test_loader:
        preds = best_model(batch)
        preds = (preds.numpy() > 0.5).astype(int)
        test_pred.extend(preds)

# 保存预测结果
submission = pd.DataFrame({'PassengerId': test_data['PassengerId'], 'Survived': np.array(test_pred).ravel()})
submission.to_csv('submission.csv', index=False)


