In [1]:
import os
import sys
import cv2
import torch
import random
import warnings
import torchvision
import numpy as np
import pandas as pd
import torch.nn as nn
import seaborn as sns
# import pydicom as pdcm
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from sklearn.model_selection import KFold
import random
import copy
import nibabel as nib 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import torch
from torch import nn
from models.cnn import cnn3d
from models import (cnn, C3DNet, resnet, ResNetV2, ResNeXt, ResNeXtV2, WideResNet, PreActResNet,
        EfficientNet, DenseNet, ShuffleNet, ShuffleNetV2, SqueezeNet, MobileNet, MobileNetV2)

# from opts import parse_opts



def main(cnn_name, model_depth, n_classes, in_channels, sample_size):
 
    # simple CNN 
    if cnn_name == 'cnn':
        """
        3D simple cnn model
        """
        print(cnn_name)
        model = cnn3d()
    
    # C3D
    elif cnn_name == 'C3D':
        """
        "Learning spatiotemporal features with 3d convolutional networks." 
        """
        model = C3DNet.get_model(
            sample_size=sample_size,
            sample_duration=16,
            num_classes=n_classes,
            in_channels=1)

    # ResNet
    elif cnn_name == 'resnet':
        """
        3D resnet
        model_depth = [10, 18, 34, 50, 101, 152, 200]
        """
        model = resnet.generate_model(
            model_depth=model_depth,
            n_classes=n_classes,
            n_input_channels=in_channels,
            shortcut_type='B',
            conv1_t_size=7,
            conv1_t_stride=1,
            no_max_pool=False,
            widen_factor=1.0)
    
    # ResNetV2
    elif cnn_name == 'ResNetV2':
        """
        3D resnet
        model_depth = [10, 18, 34, 50, 101, 152, 200]
        """
        model = ResNetV2.generate_model(
            model_depth=model_depth,
            n_classes=n_classes,
            n_input_channels=in_channels,
            shortcut_type='B',
            conv1_t_size=7,
            conv1_t_stride=1,
            no_max_pool=False,
            widen_factor=1.0)

    # ResNeXtV2
    elif cnn_name == 'ResNeXt':
        """
        WideResNet
        model_depth = [50, 101, 152, 200]
        """
        model = ResNeXt.generate_model(
            model_depth=model_depth,
            n_classes=n_classes,
            in_channels=in_channels,
            sample_size=sample_size,
            sample_duration=16)
    
    # ResNeXtV2
    elif cnn_name == 'ResNeXtV2':
        """
        WideResNet
        model_depth = [50, 101, 152, 200]
        """
        model = ResNeXtV2.generate_model(
            model_depth=model_depth,
            n_classes=n_classes,
            n_input_channels=in_channels)

    # PreActResNet
    elif cnn_name == 'PreActResNet':
        """
        WideResNet
        model_depth = [50, 101, 152, 200]
        """
        model = PreActResNet.generate_model(
            model_depth=model_depth,
            n_classes=n_classes,
            n_input_channels=in_channels)

    # WideResNet
    elif cnn_name == 'WideResNet':
        """
        WideResNet
        model_depth = [50, 101, 152, 200]
        """
        model = WideResNet.generate_model(
            model_depth=model_depth,
            n_classes=n_classes,
            n_input_channels=in_channels)

    # DenseNet
    elif cnn_name == 'DenseNet':
        """
        3D resnet
        model_depth = [121, 169, 201]
        """
        model = DenseNet.generate_model(
            model_depth=model_depth,
            num_classes=n_classes,
            n_input_channels=in_channels)

    # SqueezeNet
    elif cnn_name == 'SqueezeNet':
        """
        SqueezeNet
        "SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and 
        <0.5MB model size"
        """
        model = SqueezeNet.get_model(
            version=1.0,
            sample_size=sample_size,
            sample_duration=16,
            num_classes=n_classes,
            in_channels=in_channels)
   
    # ShuffleNetV2
    elif cnn_name == 'ShuffleNetV2':
        """
        ShuffleNetV2
        "ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
        """
        model = ShuffleNetV2.get_model(
            sample_size=sample_size,
            num_classes=n_classes,
            width_mult=1.,
            in_channels=in_channels)

    # ShuffleNet
    elif cnn_name == 'ShuffleNet':
        """
        ShuffleNet
        """
        model = ShuffleNet.get_model(
            groups=3,
            num_classes=n_classes,
            in_channels=in_channels)

    # MobileNet
    elif cnn_name == 'MobileNet':
        """
        MobileNet
        "MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications" 
        """
        model = MobileNet.get_model(
            sample_size=sample_size,
            num_classes=n_classes,
            in_channels=in_channels)

    # MobileNetV2
    elif cnn_name == 'MobileNetV2':
        """
        MobileNet
        "MobileNetV2: Inverted Residuals and Linear Bottlenecks"
        """
        model = MobileNetV2.get_model(
            sample_size=sample_size,
            num_classes=n_classes,
            in_channels=in_channels)
    
    # EfficientNet
    elif cnn_name == 'EfficientNet':
        """
        EfficientNet
        """
        model = EfficientNet3D.from_name(
            'efficientnet-b4', 
            override_params={'num_classes': n_classes}, 
            in_channels=in_channels)
    

    return model



# model = main(cnn_name = 'ResNet',
#              model_depth = 101, 
#              n_classes = 2,
#              in_channels=160,
#              sample_size = 192)

model = resnet.generate_model(
            model_depth = 101,
            n_classes = 2,
            n_input_channels = 160,
            shortcut_type = 'B',
            conv1_t_size = 7,
            conv1_t_stride = 1,
            no_max_pool = False,
            widen_factor = 1.0)



In [3]:
all_list = []
for root, dirs, files in os.walk('./AD/'):
    for file in files:
        file_path = os.path.join(root, file)
        all_list.append(file_path)

for root, dirs, files in os.walk('./CN/'):
    for file in files:
        file_path = os.path.join(root, file)
        all_list.append(file_path)
all_list = list(filter(lambda x: '.DS_Store' not in x, all_list))
# all_list.remove('./CN/011_S_0016/ADNI_011_S_0016_MR_MPR-R__GradWarp__B1_Correction__N3__Scaled_Br_20061206170814835_S13160_I31928.nii')
# all_list.remove('./CN/013_S_0502/ADNI_013_S_0502_MR_MPR__GradWarp__B1_Correction__N3__Scaled_Br_20070926112008188_S27531_I75291.nii')

In [4]:
means, stdevs = [], []
all_data = np.zeros((100, 192, 192, 160))
for idx, val in enumerate(all_list):
    img = nib.load(val) #读取nii
    img_fdata = img.get_fdata()
    if img_fdata.shape[2] > 160:
        img_fdata = img_fdata[:, :, 4:164]
    img_fdata = cv2.resize(img_fdata, (192, 192))
    all_data[idx] = img_fdata

means = np.mean(all_data, axis = (0, 1, 2))
stds = np.std(all_data, axis = (0, 1, 2))

In [5]:
test = torch.randn((2, 2, 2))

In [7]:
test.unsqueeze(2).size() # 160

torch.Size([2, 2, 1, 2])

In [8]:
class MRIdata(Dataset):
    def __init__(self, path_list, transform = None):
        self.path_list = path_list
        self.transform = transform


    def __len__(self):
        return len(self.path_list)
    
    def __getitem__(self, index):
        
        img = nib.load(self.path_list[index]) #读取nii
        img_fdata = img.get_fdata()
        if img_fdata.shape[2] > 160:
            img_fdata = img_fdata[:, :, 4:164]
        img_fdata = cv2.resize(img_fdata, (192, 192))
        
#         print(img_fdata.shape, type(img_fdata))

#         inp_data = read_dicom_img(self.train_dir, str(self.data['BraTS21ID'][index]))
        inp_data = self.transform(img_fdata[:])
        if self.path_list[index].split('/')[1] == 'AD':
            label = torch.tensor([0, 1], dtype = torch.float)
        else:
            label = torch.tensor([1, 0], dtype = torch.float)
        
#         print(self.path_list[index].split('/')[1])
        return inp_data.float().unsqueeze(1), label

In [9]:
transforms = T.Compose([T.ToTensor(), T.Normalize(means, stds)])


random.shuffle(all_list)

train_dataset = MRIdata(all_list[:80], transform = transforms)
train_loader = DataLoader(train_dataset, shuffle = True, batch_size = 16)

test_dataset = MRIdata(all_list[80:], transform = transforms)
test_loader = DataLoader(test_dataset, shuffle = False, batch_size = 32)

In [10]:
for indx, (data, label) in enumerate(train_loader, 0):
    inputs, labels = data.to(device), label.to(device)
    break

In [12]:
inputs.size()

torch.Size([16, 160, 1, 192, 192])

In [7]:
class MRINet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(self.conv_layer(in_chan = 160, out_chan = 128),
                                  self.conv_layer(in_chan = 128, out_chan = 128),
                                  self.conv_layer(in_chan = 128, out_chan = 256))
        
        self.fc = nn.Sequential(nn.Linear(123904, 512),
                                nn.Dropout(p = 0.15),
                                nn.Linear(512, 1))
        self.sigmoid = nn.Sigmoid()
    
    def conv_layer(self, in_chan, out_chan):
        conv_layer = nn.Sequential(
            nn.Conv2d(in_chan, out_chan, kernel_size=(3, 3), padding = 0),
            nn.LeakyReLU(),
            nn.MaxPool2d((2,2)),
            nn.BatchNorm2d(out_chan))
        
        return conv_layer    
           
    def forward(self, x):
        out = self.conv(x)
        out = out.view(out.size(0), -1)
        out = self.fc(out)    
        return self.sigmoid(out)

In [13]:
# net = MRINet().to(device)
net = model.to(device)
LRate = 0.001
# criterion = nn.BCEWithLogitsLoss()
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

optimizer = optim.Adam(net.parameters(), lr = LRate)
EPOCHS = 50
best_score = np.inf


best_model = copy.deepcopy(net)  # Will work
    
for epoch in range(EPOCHS):
    total_loss = 0.0
    count = 0
    net.train()
    for indx, (data, label) in enumerate(train_loader, 0):
        inputs, labels = data.to(device), label.to(device)
        
        optimizer.zero_grad()
        outputs = net(inputs)#.squeeze(1)
        
        loss = criterion(outputs, labels)
        
        loss.backward()
        
        total_loss += loss.detach().item()
        optimizer.step()
        
        count += 1
    
    print(f"Epoch:{epoch}/{EPOCHS} - train Loss:{total_loss/count}")
    
    net.eval()
    total_loss = 0.0
    count = 0
    for indx, (data, label) in enumerate(test_loader, 0):
        with torch.no_grad():
            inputs, labels = data.to(device), label.to(device)

            outputs = net(inputs)

            loss = criterion(outputs, labels)
            total_loss += loss.detach().item()

            count += 1
    if best_score > total_loss / count:
        print('loss {} -> {} reduce saving ....'.format(best_score, total_loss / count))
        best_model = copy.deepcopy(net)  
        
        best_score = total_loss / count
    print(f"Epoch:{epoch}/{EPOCHS} - test Loss:{total_loss/count}")
    print('-----------------------')

print("Training Complete")    
torch.save(best_model, './best_model.pt')

Epoch:0/50 - train Loss:2.0702829480171205
loss inf -> 4.3319292068481445 reduce saving ....
Epoch:0/50 - test Loss:4.3319292068481445
-----------------------
Epoch:1/50 - train Loss:1.4037230730056762
loss 4.3319292068481445 -> 1.7819896936416626 reduce saving ....
Epoch:1/50 - test Loss:1.7819896936416626
-----------------------
Epoch:2/50 - train Loss:0.880765688419342
loss 1.7819896936416626 -> 0.5765385627746582 reduce saving ....
Epoch:2/50 - test Loss:0.5765385627746582
-----------------------
Epoch:3/50 - train Loss:0.7046285748481751
Epoch:3/50 - test Loss:0.6742491126060486
-----------------------
Epoch:4/50 - train Loss:0.7502418041229248
Epoch:4/50 - test Loss:0.5933481454849243
-----------------------
Epoch:5/50 - train Loss:0.692220401763916
Epoch:5/50 - test Loss:0.6457471251487732
-----------------------
Epoch:6/50 - train Loss:0.7099916577339173
Epoch:6/50 - test Loss:0.6960732936859131
-----------------------


KeyboardInterrupt: 

In [None]:
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()

In [15]:
torch.save(best_model, './best_model.pt')

In [20]:
net.eval()
correct = 0
total = 0
for indx, (data, label) in enumerate(test_loader, 0):
    with torch.no_grad():
        inputs, labels = data.to(device), label.to(device)
        
        outputs = net(inputs)
        _, predicted = torch.max(outputs, 1)
        _, labelss = torch.max(labels, 1)
        total += labels.size(0)
        correct += (predicted == labelss).sum().item()

# outputs = torch.where(outputs > 0.00001, 1, 0)

In [21]:
predicted

tensor([1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       device='cuda:0')

In [22]:
labelss

tensor([1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1],
       device='cuda:0')

In [23]:
labels

tensor([[0., 1.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [0., 1.],
        [0., 1.],
        [1., 0.],
        [0., 1.],
        [0., 1.]], device='cuda:0')

In [16]:
net = torch.load('./best_model.pt').to(device)

In [12]:
net.eval()

for indx, (data, label) in enumerate(test_loader, 0):
    with torch.no_grad():
        inputs, labels = data.to(device), label.to(device)
        
        outputs = net(inputs).squeeze(1)
# outputs = torch.where(outputs > 0.00001, 1, 0)

In [14]:
outputs

tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0.], device='cuda:0')

In [16]:
labels

tensor([0., 1., 0., 1., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 1., 1., 1.,
        0., 0.], device='cuda:0')

In [15]:
from sklearn.metrics import confusion_matrix
tn, fp, fn, tp = confusion_matrix(labels.int().cpu().numpy(), outputs.cpu().numpy()).ravel()
specificity = tn / (tn + fp)
sensitivity = tp / (tp + fn)
print('specificity: ', specificity)
print('sensitivity: ', sensitivity)
# from sklearn.metrics import recall_score
# recall_score(labels.int().cpu().numpy(), outputs.cpu().numpy(), pos_label = 0)

specificity:  1.0
sensitivity:  0.09090909090909091


In [8]:
# net = MRINet().to(device)
net = model.to(device)
LRate = 0.001
# criterion = nn.BCEWithLogitsLoss()
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(device)

optimizer = optim.Adam(net.parameters(), lr = LRate)
EPOCHS = 50
best_score = np.inf


best_model = copy.deepcopy(net)  # Will work
    
for epoch in range(EPOCHS):
    total_loss = 0.0
    count = 0
    net.train()
    for indx, (data, label) in enumerate(train_loader, 0):
        inputs, labels = data.to(device), label.to(device)
        
        optimizer.zero_grad()
        outputs = net(inputs).squeeze(1)
        
        loss = criterion(outputs, labels)
        
        loss.backward()
        
        total_loss += loss.detach().item()
        optimizer.step()
        
        count += 1
    
    print(f"Epoch:{epoch}/{EPOCHS} - train Loss:{total_loss/count}")
    
    net.eval()
    total_loss = 0.0
    count = 0
    for indx, (data, label) in enumerate(test_loader, 0):
        with torch.no_grad():
            inputs, labels = data.to(device), label.to(device)

            outputs = net(inputs).squeeze(1)

            loss = criterion(outputs, labels)
            total_loss += loss.detach().item()

            count += 1
    if best_score > total_loss / count:
        print('loss {} -> {} reduce saving ....'.format(best_score, total_loss / count))
        best_model = copy.deepcopy(net)  
        
        best_score = total_loss / count
    print(f"Epoch:{epoch}/{EPOCHS} - test Loss:{total_loss/count}")
    print('-----------------------')

print("Training Complete")    
torch.save(best_model, './best_model.pt')

Epoch:0/50 - train Loss:22.917480087280275
loss inf -> 32.95305633544922 reduce saving ....
Epoch:0/50 - test Loss:32.95305633544922
-----------------------
Epoch:1/50 - train Loss:22.394531631469725
Epoch:1/50 - test Loss:32.95305633544922
-----------------------
Epoch:2/50 - train Loss:21.879074478149413
Epoch:2/50 - test Loss:32.95305633544922
-----------------------
Epoch:3/50 - train Loss:21.754942321777342
Epoch:3/50 - test Loss:32.95305633544922
-----------------------
Epoch:4/50 - train Loss:21.839803886413574
Epoch:4/50 - test Loss:32.95305633544922
-----------------------
Epoch:5/50 - train Loss:21.852628898620605
Epoch:5/50 - test Loss:34.46629333496094
-----------------------
Epoch:6/50 - train Loss:21.869690322875975
Epoch:6/50 - test Loss:34.20244598388672
-----------------------
Epoch:7/50 - train Loss:22.43526382446289
Epoch:7/50 - test Loss:34.13270568847656
-----------------------
Epoch:8/50 - train Loss:21.58940544128418
loss 32.95305633544922 -> 32.85969924926758 re