In [1]:
import os
import sys
import cv2
import torch
import random
import warnings
import torchvision
import numpy as np
import pandas as pd
import torch.nn as nn
import seaborn as sns
# import pydicom as pdcm
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import KFold
import random

import nibabel as nib 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ModuleNotFoundError: No module named 'nibabel'

In [2]:
all_list = []
for root, dirs, files in os.walk('./AD/'):
    for file in files:
        file_path = os.path.join(root, file)
        all_list.append(file_path)

for root, dirs, files in os.walk('./CN/'):
    for file in files:
        file_path = os.path.join(root, file)
        all_list.append(file_path)
all_list = list(filter(lambda x: '.DS_Store' not in x, all_list))
# all_list.remove('./CN/011_S_0016/ADNI_011_S_0016_MR_MPR-R__GradWarp__B1_Correction__N3__Scaled_Br_20061206170814835_S13160_I31928.nii')
# all_list.remove('./CN/013_S_0502/ADNI_013_S_0502_MR_MPR__GradWarp__B1_Correction__N3__Scaled_Br_20070926112008188_S27531_I75291.nii')

In [3]:
class MRIdata(Dataset):
    def __init__(self, path_list, transform = None):
        self.path_list = path_list
        self.transform = transform


    def __len__(self):
        return len(self.path_list)
    
    def __getitem__(self, index):
        
        img = nib.load(self.path_list[index]) #读取nii
        img_fdata = img.get_fdata()
        if img_fdata.shape[2] > 160:
            img_fdata = img_fdata[:, :, 4:164]
        img_fdata = cv2.resize(img_fdata, (192, 192))
        
#         print(img_fdata.shape, type(img_fdata))

#         inp_data = read_dicom_img(self.train_dir, str(self.data['BraTS21ID'][index]))
        inp_data = self.transform(img_fdata[:])
        if self.path_list[index].split('/')[1] == 'AD':
            label = torch.tensor(1, dtype = torch.float)
        else:
            label = torch.tensor(0, dtype = torch.float)
        
#         print(self.path_list[index].split('/')[1])
        return inp_data.float(), label

In [4]:
transforms = T.Compose([T.ToTensor()])


random.shuffle(all_list)

train_dataset = MRIdata(all_list[:80], transform = transforms)
train_loader = DataLoader(train_dataset, shuffle = True, batch_size = 4)

test_dataset = MRIdata(all_list[80:], transform = transforms)
test_loader = DataLoader(test_dataset, shuffle = False, batch_size = 4)

In [5]:
class MRINet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(self.conv_layer(in_chan = 160, out_chan = 128),
                                  self.conv_layer(in_chan = 128, out_chan = 128),
                                  self.conv_layer(in_chan = 128, out_chan = 256))
        
        self.fc = nn.Sequential(nn.Linear(123904, 512),
                                nn.Dropout(p = 0.15),
                                nn.Linear(512, 1))
        
    
    def conv_layer(self, in_chan, out_chan):
        conv_layer = nn.Sequential(
            nn.Conv2d(in_chan, out_chan, kernel_size=(3, 3), padding = 0),
            nn.LeakyReLU(),
            nn.MaxPool2d((2,2)),
            nn.BatchNorm2d(out_chan))
        
        return conv_layer    
           
    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)    
        return out

In [None]:
net = MRINet().to(device)
LRate = 0.0001
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(net.parameters(), lr = LRate)
EPOCHS = 100
best_score = np.inf
for epoch in range(EPOCHS):
    total_loss = 0.0
    count = 0
    net.train()
    for indx, (data, label) in enumerate(train_loader, 0):
        inputs, labels = data.to(device), label.to(device)
        optimizer.zero_grad()
        outputs = net(inputs).squeeze(1)
        
        loss = criterion(outputs, labels)
        
        loss.backward()
        
        total_loss += loss.detach().item()
        optimizer.step()
        
        count += 1
    
    print(f"Epoch:{epoch}/{EPOCHS} - train Loss:{total_loss/count}")
    
    net.eval()
    total_loss = 0.0
    count = 0
    for indx, (data, label) in enumerate(test_loader, 0):
        inputs, labels = data.to(device), label.to(device)
        optimizer.zero_grad()
        outputs = net(inputs).squeeze(1)
        
        loss = criterion(outputs, labels)
        
        loss.backward()
        
        total_loss += loss.detach().item()
        optimizer.step()
        count += 1
    if best_score > total_loss / count:
        print('loss {} -> {} reduce saving ....'.format(best_score, total_loss / count))
        torch.save(net, './best_model.pt')
        best_score = total_loss / count
    print(f"Epoch:{epoch}/{EPOCHS} - test Loss:{total_loss/count}")

print("Training Complete")    

In [22]:
 for indx, (data, label) in enumerate(train_loader, 0):
    inputs, labels = data.to(device), label.to(device)
    print(labels)

CN
AD
AD
CN
tensor([1., 1., 1., 1.])
CN
AD
CN
AD
tensor([1., 1., 1., 1.])
AD
AD
AD
AD
tensor([1., 1., 1., 1.])
CN
AD
AD
CN
tensor([1., 1., 1., 1.])
CN
CN
CN
CN
tensor([1., 1., 1., 1.])
CN
AD
AD
AD
tensor([1., 1., 1., 1.])
CN
CN
CN
CN
tensor([1., 1., 1., 1.])
AD
CN
AD
AD
tensor([1., 1., 1., 1.])
AD
AD
AD
CN
tensor([1., 1., 1., 1.])
CN
CN
AD
AD
tensor([1., 1., 1., 1.])
CN
AD
CN
CN
tensor([1., 1., 1., 1.])
AD
CN
AD
CN
tensor([1., 1., 1., 1.])
CN
CN
CN
AD
tensor([1., 1., 1., 1.])
CN
AD
CN
CN
tensor([1., 1., 1., 1.])
AD
CN
CN
CN
tensor([1., 1., 1., 1.])
CN
CN
AD
AD
tensor([1., 1., 1., 1.])
AD
CN
CN
AD
tensor([1., 1., 1., 1.])
CN
AD
AD
AD
tensor([1., 1., 1., 1.])
AD
CN
AD
AD
tensor([1., 1., 1., 1.])
CN
CN
CN
CN
tensor([1., 1., 1., 1.])
