In [1]:
import os
os.getcwd()

'/mnt/AD_for_fusion'

In [2]:
config = {
  "data": {
    "PET_path": "data/PET.csv",
    "MRI_path": "data/MRI_0328.csv",
    "low_dim_path": "data/ALL.csv",
    "labels_path": "data/LABEL.csv",
    "batch_size": 128,
    "shuffle": False,
    "test_size":0.2,
    "val_size":0.1,
    "random_state":42
  },
  "model": {
    "type": "GCN",
    "pet_input_size": 166,  
    "mri_input_size": 498,  
    "low_dim_input_size":17,
    "embedding_dim":64,
    "output_dim":2,
    "hidden_channels":128,
    "num_heads":8 
  },
  "train": {
    "repeat_times": 10,
    "epochs": 100,
    "learning_rate": 0.001,
    "device": "cuda:1"
  },
  "earlystopping":{
    "patience":5,
    "delta":0.001
  }
}

In [None]:
# 集成版本

import torch
import torch.nn as nn
import torch.optim as optim
# import json
from data.data_loader import load_and_align_data, create_data_loader
from models.model import *
from utils import * 
from torch_geometric.data import Data          
import torch.optim.lr_scheduler as lr_scheduler
from torch_geometric.utils import dense_to_sparse
#from config import config
import random
import numpy as np
import pandas as pd

# 设置训练设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

# 结果的接收器
random.seed(99)
result = []
repeat_times = config['train']['repeat_times']
random_state = config['data']['random_state']

if repeat_times != 1:
    random_state = [random.randint(0, 10000) for _ in range(config['train']['repeat_times'])]
    print(f'共分割样本{repeat_times}次，随机数种子为：{random_state}')
elif repeat_times == 1:
    random_state = [config['data']['random_state']]
    print(f'仅进行{repeat_times}次分割样本，随机数种子为：{random_state}')

for seed in random_state:
    # 加载数据并创建数据集
    train_dataset, val_dataset, _ = load_and_align_data(PET_path = config['data']['PET_path'], 
                                                        MRI_path = config['data']['MRI_path'], 
                                                        low_dim_path = config['data']['low_dim_path'],
                                                        labels_path = config['data']['labels_path'],
                                                        test_size = config['data']['test_size'],
                                                        val_size = config['data']['val_size'],
                                                        random_state = seed)

    # 创建数据加载器
    train_loader = create_data_loader(train_dataset, batch_size=config['data']['batch_size'], shuffle=config['data']['shuffle'])
    val_loader = create_data_loader(val_dataset, batch_size=config['data']['batch_size'], shuffle=config['data']['shuffle'])
    
    
    model = Fusion_model(pet_input_size = config["model"]["pet_input_size"], 
                         mri_input_size = config["model"]["mri_input_size"], 
                         low_dim_input_size = config["model"]["low_dim_input_size"],
                         embedding_dim = config["model"]["embedding_dim"],
                         output_dim = config["model"]["output_dim"], 
                         hidden_channels = config["model"]["hidden_channels"]).to(device)

    criterion = nn.NLLLoss()
    optimizer = optim.Adam(model.parameters(), lr=config['train']['learning_rate'])
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.3, patience=5)
    
    # 初始化早停对象
    early_stopping = EarlyStopping(patience=config["earlystopping"]["patience"], delta=config["earlystopping"]["delta"])
    
    best_val_loss = float('inf')
    
    for epoch in range(config['train']['epochs']):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        for pet_features, mri_features, low_dim_features, labels ,types in train_loader: 
            batch_size = pet_features.size(0)  # 获取当前批次的大小
            # 准备数据
            pet_features = pet_features.to(device)
            mri_features = mri_features.to(device)
            low_dim_features = low_dim_features.float().to(device)
            labels = labels.to(device)
            types = types.to(device)
            
#            if types == 1:
#                for param in model.pet_branch.parameters():
#                    param.requires_grad = False
#            elif types == 2:
#                for param in model.mri_branch.parameters():
#                    param.requires_grad = False
#            else: 
#                for param in model.parameters():
#                    param.requires_grad = True
        
            optimizer.zero_grad()  # 清除梯度
            outputs = model(pet_features, mri_features, low_dim_features,types)  # 前向传播
            loss = criterion(outputs, labels)  # 计算损失
            loss.backward()  # 反向传播
        
            # 仅更新需要更新的参数
            optimizer.step()

            # 将所有参数的梯度重新打开，为下一个batch准备
            for param in model.parameters():
                param.requires_grad = True
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
        train_loss = total_loss / len(train_loader)
        train_acc = 100 * correct / total
    
        # 计算验证集上的损失
        model.eval()
        val_total_loss = 0
        val_correct = 0
        val_total = 0
        val_acc_list = []
    
        with torch.no_grad():
            for pet_features, mri_features, low_dim_features, labels ,types in val_loader: 
                batch_size = pet_features.size(0)  # 获取当前批次的大小
                    
                # 准备数据
                pet_features = pet_features.to(device)
                mri_features = mri_features.to(device)
                low_dim_features = low_dim_features.float().to(device)
                labels = labels.to(device)
                types = types.to(device)
                
                # high_dim_cov_matrix = cov_builder(high_dim_features , labels).to(device)
                
                outputs = model(pet_features, mri_features, low_dim_features,types)

                loss = criterion(outputs, labels)
                val_total_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
                
        val_loss =  val_total_loss/len(val_loader)
        val_acc = 100 * val_correct / val_total
        scheduler.step(val_loss)
        val_acc_list.append(val_acc)
        
        print(f'Epoch {epoch+1}, Train Loss: {train_loss:.6f}, Train Acc: {train_acc:.6f}, Val Loss: {val_loss:.6f}, Val Acc: {val_acc:.6f}')
                
    result.append(max(val_acc_list))
print(f'Finished! \n Acc:{np.mean(result),max(result)}, \n list:{result}')