In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision.transforms import ToTensor
import numpy as np
import torch.nn.functional as F

from MRI_dataset import ADNIDataset

In [2]:
# 初始化数据集
data_dir = "D:/Data/MRI/ADNI/Image"
csv_path = "D:/Data/MRI/ADNI/pheno_ADNI_longitudinal_new.csv"
dataset = ADNIDataset(data_dir=data_dir, csv_path=csv_path)

# 数据集大小
dataset_size = len(dataset)
train_size = int(dataset_size * 0.7)
val_size = int(dataset_size * 0.15)
test_size = dataset_size - train_size - val_size

# 数据集切分
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# DataLoader
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [3]:
print(dataset[0][0].shape)
print(len(train_dataset))
print(len(test_dataset))
print(len(val_dataset))
for i, data in enumerate(train_loader):
    inputs, labels = data
    print(inputs.shape)
    print(labels)
    break

torch.Size([1, 91, 109, 91])
2772
594
594
torch.Size([4, 1, 91, 109, 91])
tensor([0, 2, 0, 0])


In [13]:
# 定义CNN网络结构
class SimpleCNN3D(nn.Module):
    def __init__(self):
        super(SimpleCNN3D, self).__init__()
        self.conv1 = nn.Conv3d(1, 8, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv3d(8, 16, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(16 * 22 * 27 * 22, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 3)  # Assuming 3 classes for CN, MCI, and AD

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 22 * 27 * 22)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [10]:
test_net = SimpleCNN3D()
for i, data in enumerate(train_loader):
    inputs, labels = data
    outputs = test_net(inputs)
    print(outputs.shape)
    break

torch.Size([4, 8, 45, 54, 45])
torch.Size([4, 16, 22, 27, 22])
torch.Size([4, 209088])
torch.Size([4, 3])


In [14]:
# 实例化网络
net = SimpleCNN3D()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# 训练循环（简化版）
epochs = 2  # 仅示例，实际训练可能需要更多epochs
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data

        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:  # 每100个batch打印一次
            print(f"[{epoch + 1}, {i + 1}] loss: {running_loss / 100}")
            running_loss = 0.0

print('Finished Training')

[1, 100] loss: 22.597203034162522
[1, 200] loss: 1.2069450399279595
[1, 300] loss: 0.969096185863018
[1, 400] loss: 0.9365454030036926
[1, 500] loss: 0.9310309332609177
[1, 600] loss: 0.9269363398849965
[2, 100] loss: 0.6482769296504557
[2, 200] loss: 0.5796486678160727
[2, 300] loss: 0.6155265440791845
[2, 400] loss: 0.5369631446828135
[2, 500] loss: 0.49649072224041446
[2, 600] loss: 0.46559560024994423
Finished Training


In [6]:
print("\033[33mInfomations:\033[0m")

[33mtest[0m
