In [11]:
import pytorch_lightning as pl
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset, random_split
import numpy as np

# 定义一个自定义的 Dataset
class MNISTDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label = self.data.iloc[idx, 0]
        image = self.data.iloc[idx, 1:].values.astype(np.float32)
        return torch.tensor(image), torch.tensor(label)

# 定义一个 LightningDataModule
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir='', batch_size=64, num_workers=2):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):
        # 数据集的定义
        dataset = MNISTDataset(csv_file=self.data_dir + 'train.csv')
        
        # 将数据集划分为训练集和验证集
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        self.train_dataset, self.val_dataset = random_split(dataset, [train_size, val_size])

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    def test_dataloader(self):
        # 从 test.csv 文件加载测试数据
        test_dataset = MNISTDataset(csv_file=self.data_dir + 'test.csv')
        return DataLoader(test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)


In [15]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Accuracy
import numpy as np

class CombinedMNISTModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # 处理28x28图像的分支
        # 处理28x28图像的分支
        self.cnn_branch = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=3, padding=1),  # 通道数减少到 8
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(8, 16, kernel_size=3, padding=1),  # 通道数减少到 16
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.cnn_fc = nn.Linear(16 * 7 * 7, 32)  # 全连接层神经元减少到 32

        # 处理一维序列的分支
        self.seq_branch = nn.Sequential(
            nn.Linear(28 * 28, 32),  # 减少全连接层的输出神经元数量
            nn.ReLU(),
            nn.Linear(32, 32)
        )

        # 结合后的全连接层
        self.fc_combined = nn.Sequential(
            nn.ReLU(),
            nn.Linear(64, 16),  # 减少结合后全连接层的神经元数量
            nn.ReLU(),
            nn.Linear(16, 10)
        )

        # 定义准确率计算
        self.train_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.val_accuracy = Accuracy(task='multiclass', num_classes=10)
        self.test_accuracy = Accuracy(task='multiclass', num_classes=10)

    def forward(self, x):
        # 将输入拆分为图像和序列两个分支
        x_img = x.view(-1, 1, 28, 28)  # 转换为图像形式 [batch_size, 1, 28, 28]
        x_seq = x  # 保持原始一维序列 [batch_size, 784]

        # 图像分支前向传播
        x_img = self.cnn_branch(x_img)
        x_img = x_img.view(x_img.size(0), -1)
        x_img = self.cnn_fc(x_img)

        # 序列分支前向传播
        x_seq = self.seq_branch(x_seq)

        # 结合两个分支的输出
        x_combined = torch.cat((x_img, x_seq), dim=1)  # 在特征维度上连接 [batch_size, 256]

        # 最后的全连接层
        x_combined = self.fc_combined(x_combined)
        return F.log_softmax(x_combined, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # 计算准确率
        preds = torch.argmax(logits, dim=1)
        acc = self.train_accuracy(preds, y)
        
        # 记录 loss 和 acc 到进度条
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # 计算准确率
        preds = torch.argmax(logits, dim=1)
        acc = self.val_accuracy(preds, y)
        
        # 记录 loss 和 acc 到进度条
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # 计算准确率
        preds = torch.argmax(logits, dim=1)
        acc = self.test_accuracy(preds, y)
        
        # 记录 loss 和 acc 到进度条
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)  # 使用 Adam 优化器


In [16]:
# 初始化数据模块
mnist_dm = MNISTDataModule(data_dir='', batch_size=64)

# 使用 Trainer 进行训练、验证和测试
trainer = pl.Trainer(max_epochs=10)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [17]:

# 训练模型
trainer.fit(CombinedMNISTModel(), mnist_dm)




  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | cnn_branch     | Sequential         | 1.2 K  | train
1 | cnn_fc         | Linear             | 25.1 K | train
2 | seq_branch     | Sequential         | 26.2 K | train
3 | fc_combined    | Sequential         | 1.2 K  | train
4 | train_accuracy | MulticlassAccuracy | 0      | train
5 | val_accuracy   | MulticlassAccuracy | 0      | train
6 | test_accuracy  | MulticlassAccuracy | 0      | train
--------------------------------------------------------------
53.8 K    Trainable params
0         Non-trainable params
53.8 K    Total params
0.215     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode


Epoch 9: 100%|██████████| 525/525 [00:18<00:00, 29.10it/s, v_num=14, train_loss_step=0.000793, train_acc_step=1.000, val_loss=0.0785, val_acc=0.982, train_loss_epoch=0.0247, train_acc_epoch=0.992]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|██████████| 525/525 [00:18<00:00, 29.08it/s, v_num=14, train_loss_step=0.000793, train_acc_step=1.000, val_loss=0.0785, val_acc=0.982, train_loss_epoch=0.0247, train_acc_epoch=0.992]


In [10]:
trainer.validate(datamodule=mnist_dm)

Restoring states from the checkpoint path at /lightning_logs/version_10/checkpoints/epoch=9-step=5250.ckpt
Loaded model weights from the checkpoint at /lightning_logs/version_10/checkpoints/epoch=9-step=5250.ckpt


Validation DataLoader 0: 100%|██████████| 132/132 [00:04<00:00, 30.04it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_acc            0.9923809766769409
        val_loss            0.02614075317978859
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.02614075317978859, 'val_acc': 0.9923809766769409}]