In [24]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torchvision.models as models

import os
import numpy as np
from sklearn import metrics
from tqdm import trange, tqdm
import nibabel as nib
import matplotlib.pyplot as plt
import torch.nn.functional as F
import utilities as UT
from ranksvm import get_dynamic_image

In [25]:
import torch
from torch import nn
from torch.utils.data import Dataset
import torchvision.models as models

import os
import numpy as np
from sklearn import metrics
from tqdm import trange, tqdm

import matplotlib.pyplot as plt

import utilities as UT
from ranksvm import get_dynamic_image

LABEL_PATH = '/home/raytrack/.jupyter/Dynamic/Preprocessed'

def prep_data(LABEL_PATH, TEST_NUM):
    # This function is used to prepare train/test labels for 5-fold cross-validation
    TEST_LABEL = f'{LABEL_PATH}/fold_CNvsAD_{TEST_NUM}.csv'

    # combine train labels
    filenames = [f'{LABEL_PATH}/fold_CNvsAD_{i}.csv' for i in range(5)]
    filenames.remove(TEST_LABEL)

    combined_train_list_path = f'{LABEL_PATH}/combined_train_list_{TEST_NUM}.csv'
    with open(combined_train_list_path, 'w') as combined_train_list:
        for fold in filenames:
            for line in open(fold, 'r'):
                combined_train_list.write(line)
    TRAIN_LABEL = combined_train_list_path
    
    return TRAIN_LABEL, TEST_LABEL


class Dataset_Early_Fusion(Dataset):
    def __init__(self, label_file):
        self.files = UT.read_csv(label_file)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        temp = self.files[idx]
        full_path = temp[0]

        label_str = full_path.split('/')[-3]  # Get the label string from the file path
        if label_str == 'CN':
            label = 0
        elif label_str == 'AD':
            label = 1
        else:
            raise ValueError(f'Unexpected label: {label_str}')

        im = np.load(full_path)
       #  im = np.load(full_path, allow_pickle=True)
        im = get_dynamic_image(im)
        im = np.expand_dims(im, 0)
        im = np.concatenate([im, im, im], 0)

        return im, label, full_path  # label is now an int

    
        print(full_path)




In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DGM(nn.Module):
    def __init__(self, in_channels, reduction=4, groups=4):  # reduction 默认值从 16 改为 4
        super(DGM, self).__init__()
        self.groups = groups
        mid_channels = max(1, in_channels // reduction)  # 保证 mid_channels > 0 且与 groups 兼容

        # 动态权重调整层
        self.weight_layer = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, 1, groups=self.groups, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, in_channels, 1, groups=self.groups, bias=False),
            nn.Sigmoid()
        )
        # 分组点卷积
        self.pointwise_groups = nn.Conv2d(in_channels, in_channels, 1, groups=self.groups, bias=False)

    def forward(self, x):
        # 动态权重
        weights = self.weight_layer(x)
        # 应用动态权重
        x = x * weights
        # 分组点卷积进行特征融合
        x = self.pointwise_groups(x)
        return x

    
# 定义全局注意力模块（GAM）
class GAM_Attention(nn.Module):
    def __init__(self, in_channels):
        super(GAM_Attention, self).__init__()

        self.global_avgpool = nn.AdaptiveAvgPool2d(1)  # 全局平均池化层
        self.channel_attention = nn.Sequential(
            nn.Linear(in_channels, in_channels // 16),  # 线性变换层，减小通道数
            nn.ReLU(inplace=True),  # ReLU 激活函数
            nn.Linear(in_channels // 16, in_channels),  # 线性变换层，恢复通道数
            nn.Sigmoid()  # Sigmoid 激活函数，产生通道注意力权重
        )

    def forward(self, x):
        b, c, h, w = x.size()

        # 全局平均池化，将特征图变成全局平均值
        x_global = self.global_avgpool(x).view(b, c)

        # 通道注意力：通过线性变换和 Sigmoid 操作产生通道权重
        x_channel_att = self.channel_attention(x_global).view(b, c, 1, 1)

        # 将输入特征图按通道加权
        x = x * x_channel_att

        return x


# 定义 LinearBottleNeck_1 模块
class LinearBottleNeck_1(nn.Module):
    def __init__(self, in_c, out_c, s, t):
        super().__init__()

        self.residual = nn.Sequential(
            nn.Conv2d(in_c, in_c * t, 1),  # 1x1 卷积层，升维操作
            nn.BatchNorm2d(in_c * t),  # 批归一化
            nn.ReLU6(inplace=True),  # ReLU6 激活函数

            nn.Conv2d(in_c * t, in_c * t, 3, stride=s, padding=1, groups=in_c * t),  # 3x3 深度可分离卷积
            nn.BatchNorm2d(in_c * t),  # 批归一化
            nn.ReLU6(inplace=True),  # ReLU6 激活函数

            nn.Conv2d(in_c * t, in_c * t, 1, stride=1, padding=0, groups=1),  # 1x1 卷积层
            nn.BatchNorm2d(in_c * t),  # 批归一化

            nn.Conv2d(in_c * t, out_c, 1),  # 1x1 卷积层，降维操作
            nn.BatchNorm2d(out_c)  # 批归一化
        )

        self.stride = s  # 步长
        self.in_channels = in_c  # 输入通道数
        self.out_channels = out_c  # 输出通道数

        # 添加全局注意力模块
        self.attention = GAM_Attention(out_c)  # 使用定义的全局注意力模块

    def forward(self, x):
        residual = self.residual(x)

        if self.stride == 1 and self.in_channels == self.out_channels:
            residual += x  # 恒等映射，如果步长为1且通道数不变，则加上原始输入

        # 应用全局注意力
        residual = self.attention(residual)

        return residual


# 定义 LinearBottleNeck_2 模块
class LinearBottleNeck_2(nn.Module):
    def __init__(self, in_c, out_c, s, t):
        super().__init__()

        self.residual = nn.Sequential(
            nn.Conv2d(in_c, in_c * t, 1),  # 1x1 卷积层，升维操作
            nn.BatchNorm2d(in_c * t),  # 批归一化
            nn.ReLU6(inplace=True),  # ReLU6 激活函数

            nn.Conv2d(in_c * t, in_c * t, 3, stride=s, padding=1, groups=in_c * t),  # 3x3 深度可分离卷积
            nn.BatchNorm2d(in_c * t),  # 批归一化
            nn.ReLU6(inplace=True),  # ReLU6 激活函数

            nn.Conv2d(in_c * t, in_c * t, 1, stride=1, padding=0, groups=1),  # 1x1 卷积层
            nn.BatchNorm2d(in_c * t),  # 批归一化

            nn.Conv2d(in_c * t, out_c, 1),  # 1x1 卷积层，降维操作
            nn.BatchNorm2d(out_c)  # 批归一化
        )

        self.residual_1 = nn.Sequential(
            nn.Conv2d(in_c, in_c * t, 1),  # 1x1 卷积层，升维操作
            nn.BatchNorm2d(in_c * t),  # 批归一化
            nn.ReLU6(inplace=True),  # ReLU6 激活函数

            nn.Conv2d(in_c * t, in_c * t, 5, stride=s, padding=2, groups=in_c * t),  # 5x5 深度可分离卷积
            nn.BatchNorm2d(in_c * t),  # 批归一化
            nn.ReLU6(inplace=True),  # ReLU6 激活函数

            nn.Conv2d(in_c * t, in_c * t, 1, stride=1, padding=0, groups=1),  # 1x1 卷积层
            nn.BatchNorm2d(in_c * t),  # 批归一化

            nn.Conv2d(in_c * t, out_c, 1),  # 1x1 卷积层，降维操作
            nn.BatchNorm2d(out_c)  # 批归一化
        )

        self.residual_2 = nn.Sequential(
            nn.Conv2d(in_c, out_c, 1, stride=2),  # 1x1 卷积层，步长为2，降维操作
            nn.BatchNorm2d(out_c)  # 批归一化
        )

        self.stride = s  # 步长
        self.in_channels = in_c  # 输入通道数
        self.out_channels = out_c  # 输出通道数

    def forward(self, x):
        residual = self.residual(x)
        residual_1 = self.residual_1(x)
        residual_2 = self.residual_2(x)

        # 多尺度特征融合
        out_feature = residual_1 + residual + residual_2

        return out_feature

# 定义 SuperDAM
class SuperDAM(nn.Module):
    def __init__(self, class_num=2):
        super().__init__()
        self.pre = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU6(inplace=True),
        )

        self.stage1 = nn.Sequential(
            LinearBottleNeck_1(32, 16, 1, 1),
            DGM(16)  # 在 stage1 后加入 DGM
        )
        self.stage2 = self.make_stage(2, 16, 24, 2, 6)
        self.stage3 = self.make_stage(3, 24, 32, 2, 6)
        self.stage4 = nn.Sequential(
            self.make_stage(4, 32, 64, 2, 6),
            DGM(64)  # 在 stage4 后加入  DGM
        )
        self.stage5 = self.make_stage(3, 64, 96, 1, 6)
        self.stage6 = nn.Sequential(
            self.make_stage(3, 96, 160, 2, 6),
            DGM(160)  # 在 stage6 后加入 DGM
        )
        self.stage7 = LinearBottleNeck_1(160, 320, 1, 6)

        self.conv1 = nn.Sequential(
            nn.Conv2d(320, 1280, 1),
            nn.BatchNorm2d(1280),
            nn.ReLU6(inplace=True),
            DGM(1280)  # 在 conv1 后加入  DGM
        )
        self.conv2 = nn.Conv2d(1280, class_num, 1)

    def forward(self, x):
        x = self.pre(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.stage6(x)
        x = self.stage7(x)
        x = self.conv1(x)
        x = F.adaptive_avg_pool2d(x, 1)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        return x

    def make_stage(self, repeat, in_c, out_c, s, t):
        layers = []
        if s == 1:
            layers.append(LinearBottleNeck_1(in_c, out_c, s, t))
        else:
            layers.append(LinearBottleNeck_2(in_c, out_c, s, t))

        while repeat - 1:
            layers.append(LinearBottleNeck_1(out_c, out_c, 1, t))
            repeat -= 1

        return nn.Sequential(*layers)


if __name__ == "__main__":
    # 定义测试的输入数据
    batch_size = 4
    sample_data = torch.rand(batch_size, 3, 110, 110)  # 输入数据形状为 (batch_size, 3, 110, 110)

    # 初始化模型
    model =SuperDAM (class_num=2)  # 类别数为2

    # 测试模型输出
    output = model(sample_data)

    # 打印模型输出形状
    print(f"Input shape: {sample_data.shape}")  # 输入形状
    print(f"Model output shape: {output.shape}")  # 输出形状，期望是 (batch_size, num_classes)



Input shape: torch.Size([4, 3, 110, 110])
Model output shape: torch.Size([4, 2])


In [None]:
def train(train_dataloader, val_dataloader):
# Assuming 'net' is your model instance
    # 检查是否有可用的 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    net = SuperDAM( class_num=2)
    net.to(device)
    # 加载预训练权重
    pretrained_weights_path = '/home/raytrack/.jupyter/Dynamic/newmodel_weights.pth'
    pretrained_dict = torch.load(pretrained_weights_path)

    # 获取模型的现有权重字典
    model_dict = net.state_dict()

    # 过滤掉不匹配的权重
    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size()}

    # 更新现有的模型权重字典
    model_dict.update(pretrained_dict)

    # 加载过滤后的权重字典
    net.load_state_dict(model_dict, strict=False)
    
    #opt = torch.optim.Adam(net.parameters(), lr=LR, weight_decay=0.001)
    #opt = torch.optim.SGD(net.parameters(), lr=LR, momentum=0.9)
    #scheduler = torch.optim.lr_scheduler.ExponentialLR(opt, gamma= 0.985)
    # scheduler = torch.optim.lr_scheduler.CyclicLR(opt, 
    #                                               base_lr=LR, 
    #                                               max_lr=0.001, 
    #                                               step_size_up=100,
    #                                               cycle_momentum=False)
    opt  = torch.optim.AdamW(net.parameters(), lr=0.001, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=50)
    LOSS_WEIGHTS = torch.tensor([1., 1.]) 
    loss_fcn = torch.nn.CrossEntropyLoss(weight=LOSS_WEIGHTS.to(device))
    #loss_fcn = torch.nn.CrossEntropyLoss(weight=torch.tensor([1.0, 1.0]).to(device))    
    t = trange(EPOCHS, desc=' ', leave=True)

    train_hist = []
    val_hist = []
    pred_result = []
    old_acc = 0
    old_auc = 0
    test_acc = 0
    best_epoch = 0
    test_performance = []
    for e in t:    
        y_true = []
        y_pred = []
        
        val_y_true = []
        val_y_pred = []                
        
        train_loss = 0
        val_loss = 0

        # training
        net.train()
        for step, (img, label, _) in enumerate(train_dataloader):
            img = img.float().to(device)
            label = label.long().to(device)
            opt.zero_grad()
            out = net(img)
            loss = loss_fcn(out, label)

            loss.backward()
            opt.step()
            
            label = label.cpu().detach()
            out = out.cpu().detach()
            y_true, y_pred = UT.assemble_labels(step, y_true, y_pred, label, out)        

            train_loss += loss.item()

        train_loss = train_loss/(step+1)
        acc = float(torch.sum(torch.max(y_pred, 1)[1]==y_true))/ float(len(y_pred))
        auc = metrics.roc_auc_score(y_true, y_pred[:,1])
        f1 = metrics.f1_score(y_true, torch.max(y_pred, 1)[1])
        precision = metrics.precision_score(y_true, torch.max(y_pred, 1)[1])
        recall = metrics.recall_score(y_true, torch.max(y_pred, 1)[1])
        ap = metrics.average_precision_score(y_true, torch.max(y_pred, 1)[1]) #average_precision

        scheduler.step()

        # val
        net.eval()
        full_path = []
        with torch.no_grad():
            for step, (img, label, _) in enumerate(val_dataloader):
                img = img.float().to(device)
                label = label.long().to(device)
                out = net(img)
                loss = loss_fcn(out, label)
                val_loss += loss.item()

                label = label.cpu().detach()
                out = out.cpu().detach()
                val_y_true, val_y_pred = UT.assemble_labels(step, val_y_true, val_y_pred, label, out)
                
                for item in _:
                    full_path.append(item)
                
        val_loss = val_loss/(step+1)
        val_acc = float(torch.sum(torch.max(val_y_pred, 1)[1]==val_y_true))/ float(len(val_y_pred))
        val_auc = metrics.roc_auc_score(val_y_true, val_y_pred[:,1])
        val_f1 = metrics.f1_score(val_y_true, torch.max(val_y_pred, 1)[1])
        val_precision = metrics.precision_score(val_y_true, torch.max(val_y_pred, 1)[1])
        val_recall = metrics.recall_score(val_y_true, torch.max(val_y_pred, 1)[1])
        val_ap = metrics.average_precision_score(val_y_true, torch.max(val_y_pred, 1)[1]) #average_precision


        train_hist.append([train_loss, acc, auc, f1, precision, recall, ap])
        val_hist.append([val_loss, val_acc, val_auc, val_f1, val_precision, val_recall, val_ap])             

        t.set_description("Epoch: %i, train loss: %.4f, train acc: %.4f, val loss: %.4f, val acc: %.4f, test acc: %.4f" 
                          %(e, train_loss, acc, val_loss, val_acc, test_acc))


        if(old_acc<val_acc):
            old_acc = val_acc
            old_auc = val_auc
            best_epoch = e
            test_loss = 0
            test_y_true = val_y_true
            test_y_pred = val_y_pred            

            test_loss = val_loss
            test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true))/ float(len(test_y_pred))
            test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
            test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1]) #average_precision

            test_performance = [best_epoch, test_loss, test_acc, test_auc, test_f1, test_precision, test_recall, test_ap]
        
        if(old_acc==val_acc) and (old_auc<val_auc):
            old_acc = val_acc
            old_auc = val_auc
            best_epoch = e
            test_loss = 0
            test_y_true = val_y_true
            test_y_pred = val_y_pred            

            test_loss = val_loss
            test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true))/ float(len(test_y_pred))
            test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
            test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
            test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1]) #average_precision

            test_performance = [best_epoch, test_loss, test_acc, test_auc, test_f1, test_precision, test_recall, test_ap]
    return train_hist, val_hist, test_performance, test_y_true, test_y_pred, full_path

In [28]:
LABEL_PATH = '/home/raytrack/.jupyter/Dynamic/Preprocessed'


GPU = 0
BATCH_SIZE = 2

EPOCHS = 150

LR = 0.00001
LOSS_WEIGHTS = torch.tensor([1., 1.]) 

device = torch.device('cuda:'+str(GPU) if torch.cuda.is_available() else 'cpu')

In [29]:
#DATA_PATH = '/data/scratch/gliang/data/adni/ADNI2_MRI_Feature/Alex_Layer-9_DynamicImage'
#FEATURE_SHAPE=(256,5,5)
#print('DATA_PATH:',DATA_PATH)

train_hist = []
val_hist = []
test_performance = []
test_y_true = np.asarray([])
test_y_pred = np.asarray([])
full_path = np.asarray([])
for i in range(0, 3):
    print('Train Fold', i)
    
    TEST_NUM = i
    TRAIN_LABEL, TEST_LABEL = prep_data(LABEL_PATH, TEST_NUM)
    
    train_dataset = Dataset_Early_Fusion(label_file=TRAIN_LABEL)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, num_workers=0, batch_size=BATCH_SIZE , shuffle=True, drop_last=False)

    val_dataset = Dataset_Early_Fusion(label_file=TEST_LABEL)
    val_dataloader = torch.utils.data.DataLoader(val_dataset, num_workers=0, batch_size=BATCH_SIZE , shuffle=False, drop_last=False)
        
    cur_result = train(train_dataloader, val_dataloader)
    
    train_hist.append(cur_result[0])
    val_hist.append(cur_result[1]) 
    test_performance.append(cur_result[2]) 
    test_y_true = np.concatenate((test_y_true, cur_result[3].numpy()))
    if(len(test_y_pred) == 0):
        test_y_pred = cur_result[4].numpy()
    else:
        test_y_pred = np.vstack((test_y_pred, cur_result[4].numpy()))
    full_path = np.concatenate((full_path, np.asarray(cur_result[5])))

print(test_performance)

test_y_true = torch.tensor(test_y_true)
test_y_pred = torch.tensor(test_y_pred)
test_acc = float(torch.sum(torch.max(test_y_pred, 1)[1]==test_y_true.long()))/ float(len(test_y_pred))
test_auc = metrics.roc_auc_score(test_y_true, test_y_pred[:,1])
test_f1 = metrics.f1_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_precision = metrics.precision_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_recall = metrics.recall_score(test_y_true, torch.max(test_y_pred, 1)[1])
test_ap = metrics.average_precision_score(test_y_true, torch.max(test_y_pred, 1)[1])

print('ACC %.4f, AUC %.4f, F1 %.4f, Prec %.4f, Recall %.4f, AP %.4f' 
      %(test_acc, test_auc, test_f1, test_precision, test_recall, test_ap))

Train Fold 0


Epoch: 149, train loss: 0.0000, train acc: 1.0000, val loss: 0.3391, val acc: 0.8947, test acc: 1.0000: 100%|█| 150


Train Fold 1


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch: 149, train loss: 0.0000, train acc: 1.0000, val loss: 0.9524, val acc: 0.8824, test acc: 0.9412: 100%|█| 150


Train Fold 2


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
Epoch: 149, train loss: 0.0000, train acc: 1.0000, val loss: 0.3123, val acc: 0.8571, test acc: 0.9524: 100%|█| 150

[[32, 0.12782560250489042, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], [47, 0.4907944569650782, 0.9411764705882353, 0.9545454545454546, 0.9523809523809523, 1.0, 0.9090909090909091, 0.9679144385026738], [51, 0.2946976402140122, 0.9523809523809523, 0.9818181818181818, 0.9565217391304348, 0.9166666666666666, 1.0, 0.9166666666666666]]
ACC 0.9649, AUC 0.9716, F1 0.9667, Prec 0.9667, Recall 0.9667, AP 0.9520



