In [None]:
import numpy as np
import os
from Bio import SeqIO
def load_embeddings(npy_folder_path, max_length, embedding_type='proteinbert'):
    features_dict = {}
    for filename in os.listdir(npy_folder_path):
        if filename.endswith('.npy'):
            protein_id = filename[:-4]
            feature = np.load(os.path.join(npy_folder_path, filename))
            squeezed_feature = np.squeeze(feature)
            if squeezed_feature.shape[0] > max_length:
                padded_feature = squeezed_feature[:max_length, :]
            else:
                padding = np.zeros((max_length - squeezed_feature.shape[0], squeezed_feature.shape[1]))
                padded_feature = np.vstack([squeezed_feature, padding])
            features_dict[protein_id] = padded_feature
    return features_dict

def create_one_hot_features(fasta_file, max_length, amino_acids='ACDEFGHIKLMNPQRSTVWY'):
    aa_to_onehot = {aa: np.eye(len(amino_acids))[i] for i, aa in enumerate(amino_acids)}
    one_hot_features_dict = {}
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequence_id = record.id
        encoded_seq = np.array([aa_to_onehot.get(aa, np.zeros(len(amino_acids))) for aa in str(record.seq)])
        if len(encoded_seq) > max_length:
            encoded_seq = encoded_seq[:max_length]  
        padding_length = max_length - len(encoded_seq)
        if padding_length > 0:  
            padded_seq = np.pad(encoded_seq, ((0, padding_length), (0, 0)), 'constant')
        else:
            padded_seq = encoded_seq
        one_hot_features_dict[sequence_id] = padded_seq
    return one_hot_features_dict
max_length = 160

def combine_features(one_hot_features, proteinbert_features, esm_features):
    combined_features_dict = {}
    for seq_id in one_hot_features:
        if seq_id in proteinbert_features and seq_id in esm_features:
            combined_feature = np.concatenate([
                proteinbert_features[seq_id], 
                esm_features[seq_id], 
                one_hot_features[seq_id]
            ], axis=1)
            combined_features_dict[seq_id] = combined_feature
    return combined_features_dict

In [None]:
fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/DECOY_eval.fasta'
proteinbert_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_ProteinBERT/eval_non_AMP_proteinbert'
esm_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_esm/example_eval_non_amp_esm'
one_hot_features = create_one_hot_features(fasta_file, max_length)
proteinbert_features = load_embeddings(proteinbert_path, max_length, 'proteinbert')
esm_features = load_embeddings(esm_path, max_length, 'esm')
decoy_amp_eval_combined_features = combine_features(one_hot_features, proteinbert_features, esm_features)

fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/AMP_eval.fasta'
proteinbert_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_ProteinBERT/eval_AMP_proteinbert'
esm_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_esm/example_eval_AMP_esm'
one_hot_features = create_one_hot_features(fasta_file, max_length)
proteinbert_features = load_embeddings(proteinbert_path, max_length, 'proteinbert')
esm_features = load_embeddings(esm_path, max_length, 'esm')
amp_eval_combined_features = combine_features(one_hot_features, proteinbert_features, esm_features)

fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/AMP_test.fasta'
proteinbert_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_ProteinBERT/test_AMP_proteinbert'
esm_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_esm/example_test_AMP_esm'
one_hot_features = create_one_hot_features(fasta_file, max_length)
proteinbert_features = load_embeddings(proteinbert_path, max_length, 'proteinbert')
esm_features = load_embeddings(esm_path, max_length, 'esm')
amp_test_combined_features = combine_features(one_hot_features, proteinbert_features, esm_features)

fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/DECOY_test.fasta'
proteinbert_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_ProteinBERT/test_non_amp_proteinbert'
esm_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_esm/example_test_non_amp_esm'
one_hot_features = create_one_hot_features(fasta_file, max_length)
proteinbert_features = load_embeddings(proteinbert_path, max_length, 'proteinbert')
esm_features = load_embeddings(esm_path, max_length, 'esm')
decoy_amp_test_combined_features = combine_features(one_hot_features, proteinbert_features, esm_features)

fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/AMP_train.fasta'
proteinbert_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_ProteinBERT/train_AMP_proteinbert'
esm_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_esm/example_train_AMP_esm'
one_hot_features = create_one_hot_features(fasta_file, max_length)
proteinbert_features = load_embeddings(proteinbert_path, max_length, 'proteinbert')
esm_features = load_embeddings(esm_path, max_length, 'esm')
amp_train_combined_features = combine_features(one_hot_features, proteinbert_features, esm_features)

fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/DECOY_train.fasta'
proteinbert_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_ProteinBERT/train_non_amp_proteinbert'
esm_path = '/Workspace10/yumengzhang/xingxingpeng/example/example_esm/example_train_non_amp_esm'
one_hot_features = create_one_hot_features(fasta_file, max_length)
proteinbert_features = load_embeddings(proteinbert_path, max_length, 'proteinbert')
esm_features = load_embeddings(esm_path, max_length, 'esm')
decoy_amp_train_combined_features = combine_features(one_hot_features, proteinbert_features, esm_features)

In [None]:
from Bio.PDB import PDBParser
import torch
from torch_geometric.data import Data
import numpy as np
import os
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

def get_residue_positions(pdb_file, max_residues=160):
    parser = PDBParser()
    structure = parser.get_structure('PDB', pdb_file)
    model = structure[0]  # Typically, only the first model is used

    residue_positions = []
    for chain in model:
        for residue in chain:
            if residue.id[0] == ' ' and 'CA' in residue:  # Filter out non-standard residues and ensure CA exists
                residue_positions.append(residue['CA'].coord)
                if len(residue_positions) >= max_residues:  # If reached max_residues, stop adding more residues
                    break
        if len(residue_positions) >= max_residues:
            break
    return residue_positions

def build_edges_with_attr(residue_positions, cutoff):
    edges = []
    edge_attrs = []
    num_residues = len(residue_positions)  # Get the number of residues (nodes)
    for i in range(num_residues):
        for j in range(i + 1, num_residues):
            dist = np.linalg.norm(residue_positions[i] - residue_positions[j])
            if dist < cutoff:
                edges.append([i, j])
                edges.append([j, i])
                edge_attrs.append([dist])
                edge_attrs.append([dist])
    # Convert to tensors and add boundary check
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    edge_attr = torch.tensor(edge_attrs, dtype=torch.float)
    
    # Boundary check
    if edge_index.max().item() >= num_residues:
        raise ValueError(f"Edge index out of bounds! Max index: {edge_index.max().item()}, Num residues: {num_residues}")
    
    return edge_index, edge_attr

def create_graph(feature_array, pdb_file, cutoff=10.0, is_amp=True, max_residues=160):
    residue_positions = get_residue_positions(pdb_file, max_residues)
    edge_index, edge_attr = build_edges_with_attr(residue_positions, cutoff)
    x = torch.tensor(feature_array, dtype=torch.float)
    y = torch.tensor([1 if is_amp else 0], dtype=torch.long)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

def create_graphs_for_sequences(features_dict, pdb_folder, is_amp=True, max_residues=160):
    graphs = {}
    for seq_id, features in features_dict.items():
        pdb_file = os.path.join(pdb_folder, f"{seq_id}.pdb")
        if os.path.exists(pdb_file):
            try:
                graph = create_graph(features, pdb_file, is_amp=is_amp, max_residues=max_residues)
                graphs[seq_id] = graph
            except Exception as e:
                logging.error(f"Error processing {seq_id} from {pdb_file}: {e}")
        else:
            logging.warning(f"No PDB file found for {seq_id}")
    return graphs

In [None]:
#for evaluation pdb files
amp_graphs = create_graphs_for_sequences(amp_eval_combined_features, '/Workspace10/yumengzhang/xingxingpeng/example/example_PDB/expample_amp_eval_pdb', is_amp=True)
decoy_graphs = create_graphs_for_sequences(decoy_amp_eval_combined_features, '/Workspace10/yumengzhang/xingxingpeng/example/example_PDB/example_non_amp_eval_pdb', is_amp=False)

#for test files
amp_test_graphs = create_graphs_for_sequences(amp_test_combined_features, '/Workspace10/yumengzhang/xingxingpeng/example/example_PDB/example_amp_test_pdb', is_amp=True)
decoy_test_graphs = create_graphs_for_sequences(decoy_amp_test_combined_features, '/Workspace10/yumengzhang/xingxingpeng/example/example_PDB/example_non_amp_test_pdb', is_amp=False)

amp_train_graphs = create_graphs_for_sequences(amp_train_combined_features, '/Workspace10/yumengzhang/xingxingpeng/example/example_PDB/example_amp_train_pdb', is_amp=True)
decoy_train_graphs = create_graphs_for_sequences(decoy_amp_train_combined_features, '/Workspace10/yumengzhang/xingxingpeng/example/example_PDB/example_non_amp_train_pdb', is_amp=False)

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os

class PeptideDataset(Dataset):
    def __init__(self, root_dir, fasta_file, max_vertices=None):
        self.root_dir = root_dir
        self.peptides = self.parse_fasta(fasta_file)
        self.max_vertices = max_vertices if max_vertices is not None else self.determine_max_vertices()

    def parse_fasta(self, fasta_file):
        peptides = []
        with open(fasta_file, 'r') as file:
            for line in file:
                if line.startswith('>'):
                    peptide_id = line.strip().split()[0][1:]
                    peptides.append(peptide_id)
        return peptides

    def determine_max_vertices(self):
        max_vertices = 0
        for peptide_id in self.peptides:
            path = os.path.join(self.root_dir, peptide_id, 'p1_input_feat.npy')
            if os.path.exists(path):
                current_vertices = np.load(path).shape[0]
                max_vertices = max(max_vertices, current_vertices)
        return max_vertices

    def __len__(self):
        return len(self.peptides)

    def __getitem__(self, idx):
        peptide_id = self.peptides[idx]
        try:
            features = {
                'input_feat': np.load(os.path.join(self.root_dir, peptide_id, 'p1_input_feat.npy')),
                'rho_coords': np.load(os.path.join(self.root_dir, peptide_id, 'p1_rho_wrt_center.npy')),
                'theta_coords': np.load(os.path.join(self.root_dir, peptide_id, 'p1_theta_wrt_center.npy')),
                'mask': np.load(os.path.join(self.root_dir, peptide_id, 'p1_mask.npy'))
            }
            max_vertices = 5109  # 
            for key in features:
                current_length = features[key].shape[0]
                if current_length < max_vertices:
                    padding_shape = (max_vertices - current_length,) + features[key].shape[1:]
                    padding = np.zeros(padding_shape, dtype=features[key].dtype)
                    features[key] = np.concatenate((features[key], padding), axis=0)
                elif current_length > max_vertices:
                    features[key] = features[key][:max_vertices]
                
                features[key] = np.nan_to_num(features[key])

            label = 0 
            features_tensor = {key: torch.tensor(val, dtype=torch.float32) for key, val in features.items()}
            return features_tensor, torch.tensor(label, dtype=torch.long)
        except Exception as e:
            print(f"Error loading data for {peptide_id}: {e}")
            return None  

In [None]:
root_dir = '/Workspace10/yumengzhang/xingxingpeng/example/example_surface_feature/amp_eval'
fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/AMP_eval.fasta'
model2_eval_amp_dataset = PeptideDataset(root_dir=root_dir, fasta_file=fasta_file)

root_dir = '/Workspace10/yumengzhang/xingxingpeng/example/example_surface_feature/amp_train'
fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/AMP_train.fasta'
model2_trian_amp_dataset = PeptideDataset(root_dir=root_dir, fasta_file=fasta_file)

root_dir = '/Workspace10/yumengzhang/xingxingpeng/example/example_surface_feature/amp_test'
fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/AMP_test.fasta'
model2_test_amp_dataset = PeptideDataset(root_dir=root_dir, fasta_file=fasta_file)

In [None]:
root_dir = '/Workspace10/yumengzhang/xingxingpeng/example/example_surface_feature/non_amp_eval'
fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/DECOY_eval.fasta'
model2_decoy_eval_amp_dataset = PeptideDataset(root_dir=root_dir, fasta_file=fasta_file)

root_dir = '/Workspace10/yumengzhang/xingxingpeng/example/example_surface_feature/non_amp_test'
fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/DECOY_test.fasta'
model2_decoy_test_amp_dataset = PeptideDataset(root_dir=root_dir, fasta_file=fasta_file)

root_dir = '/Workspace10/yumengzhang/xingxingpeng/example/example_surface_feature/non_amp_train'
fasta_file = '/Workspace10/yumengzhang/xingxingpeng/example/example_fasta/DECOY_train.fasta'
model2_decoy_trian_amp_dataset = PeptideDataset(root_dir=root_dir, fasta_file=fasta_file)

In [None]:
GCN_train = list(amp_train_graphs.values()) + list(decoy_train_graphs.values())
fingerprint_trian = model2_trian_amp_dataset+model2_decoy_trian_amp_dataset
GCN_eval = list(amp_graphs.values()) + list(decoy_graphs.values())
fingerprint_eval = model2_eval_amp_dataset+model2_decoy_eval_amp_dataset
GCN_test = list(amp_test_graphs.values())+list(decoy_test_graphs.values())
fingerprint_test = model2_test_amp_dataset + model2_decoy_test_amp_dataset

In [None]:
from torch.utils.data import Dataset

class CombinedDataset(Dataset):
    def __init__(self, graph_dataset, feature_dataset):
        assert len(graph_dataset) == len(feature_dataset), "Datasets must be of the same size"
        self.graph_dataset = graph_dataset
        self.feature_dataset = feature_dataset

    def __len__(self):
        return len(self.graph_dataset)

    def __getitem__(self, idx):
        graph_data = self.graph_dataset[idx]
        feature_data, label = self.feature_dataset[idx]
        return graph_data, feature_data, label
    
train_combined_dataset = CombinedDataset(GCN_train, fingerprint_trian)
test_combined_dataset = CombinedDataset(GCN_test, fingerprint_test)
eval_combined_dataset = CombinedDataset(GCN_eval, fingerprint_eval)
from torch_geometric.data import Batch


def custom_collate_fn(batch):
    data_gcn_list = [item[0] for item in batch]  
    data_masif_list = [item[1] for item in batch]  

    data_gcn_batch = Batch.from_data_list(data_gcn_list)

   
    masif_keys = data_masif_list[0].keys()
    data_masif_batch = {key: torch.stack([d[key] for d in data_masif_list]) for key in masif_keys}
    
    return data_gcn_batch, data_masif_batch


final_train_dataset = train_combined_dataset + eval_combined_dataset
train_combined_loader = DataLoader(final_train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
test_combined_loader = DataLoader(test_combined_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)

In [None]:
from torch_geometric.data import Batch


def custom_collate_fn(batch):
    data_gcn_list = [item[0] for item in batch]  
    data_masif_list = [item[1] for item in batch] 

    data_gcn_batch = Batch.from_data_list(data_gcn_list)

    masif_keys = data_masif_list[0].keys()
    data_masif_batch = {key: torch.stack([d[key] for d in data_masif_list]) for key in masif_keys}
    
    return data_gcn_batch, data_masif_batch


final_train_dataset = train_combined_dataset + eval_combined_dataset
train_combined_loader = DataLoader(final_train_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)
test_combined_loader = DataLoader(test_combined_dataset, batch_size=32, shuffle=True, collate_fn=custom_collate_fn)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool


class ImprovedGCN(torch.nn.Module):
    def __init__(self, num_features, num_classes, heads=4, dropout=0.5):
        super(ImprovedGCN, self).__init__()
        self.conv1 = GCNConv(num_features, 1024)
        self.conv2 = GCNConv(1024, 512)
        self.conv3 = GCNConv(512, 256)
        self.conv4 = GCNConv(256, 128)
        self.conv5 = GCNConv(128, 64)
        self.conv6 = GCNConv(64, 32)
        
        self.attn1 = GATConv(32, 16 // heads, heads=heads, concat=True)
        self.fc = nn.Linear(16, num_classes)
        self.dropout = dropout

    def forward(self, data, return_features = False):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        x = F.relu(self.conv1(x, edge_index, edge_weight=edge_attr))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv2(x, edge_index, edge_weight=edge_attr))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv3(x, edge_index, edge_weight=edge_attr))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv4(x, edge_index, edge_weight=edge_attr))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv5(x, edge_index, edge_weight=edge_attr))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.relu(self.conv6(x, edge_index, edge_weight=edge_attr))
        x = F.elu(self.attn1(x, edge_index, edge_attr=edge_attr))
        x = global_mean_pool(x, batch)  #
        if return_features:
            return x
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

class MaSIF_site_PyTorch(nn.Module):
    def __init__(self, n_thetas, n_rhos, n_feat, n_rotations, dropout_rate=0.5):
        super(MaSIF_site_PyTorch, self).__init__()
        self.n_thetas = n_thetas
        self.n_rhos = n_rhos
        self.n_feat = n_feat
        self.n_rotations = n_rotations

        # Parameters
        self.mu_rho = nn.Parameter(torch.Tensor(self.n_rotations, 1))
        self.sigma_rho = nn.Parameter(torch.Tensor(self.n_rotations, 1))
        self.mu_theta = nn.Parameter(torch.Tensor(self.n_rotations, 1))
        self.sigma_theta = nn.Parameter(torch.Tensor(self.n_rotations, 1))

        # Initialize parameters
        nn.init.uniform_(self.mu_rho, 0, 1)
        nn.init.constant_(self.sigma_rho, 0.5)
        nn.init.uniform_(self.mu_theta, 0, 2 * np.pi)
        nn.init.constant_(self.sigma_theta, 0.5)

        # Layers
        self.avgpool1d = nn.AvgPool1d(kernel_size=6, stride=5)  # Adjust these values based on desired output size
        self.fc1 = nn.Linear(40840, 2)

    def forward(self, input_feat, rho_coords, theta_coords, mask, return_features = False):
        batch_size, n_vertices, num_points, n_feat = input_feat.size()
        input_feat = input_feat.mean(dim=2)

        output_feats = []
        for k in range(self.n_rotations):
            rotated_theta_coords = theta_coords + k * 2 * np.pi / self.n_rotations
            rotated_theta_coords %= 2 * np.pi

            rho_gauss = torch.exp(-torch.square(rho_coords - self.mu_rho[k]) / (2 * torch.square(self.sigma_rho[k]) + 1e-5))
            theta_gauss = torch.exp(-torch.square(rotated_theta_coords - self.mu_theta[k]) / (2 * torch.square(self.sigma_theta[k]) + 1e-5))

            gauss_activations = rho_gauss * theta_gauss * mask
            gauss_activations /= torch.sum(gauss_activations, dim=1, keepdim=True) + 1e-5

            gauss_activations = gauss_activations.unsqueeze(3)
            gauss_activations = gauss_activations.expand(-1, -1, -1, n_feat)

            gauss_desc = torch.sum(gauss_activations * input_feat.unsqueeze(2), dim=2)
            output_feats.append(gauss_desc)
        
        output_feats = torch.cat(output_feats, dim=2)
        #print(output_feats.shape)
        output_feats = output_feats.permute(0, 2, 1)  # [batch_size, 40, 5109]
        #print(output_feats.shape)
        # Apply AvgPool1d to reduce the middle dimension from 5109 to 1000
        output_feats = self.avgpool1d(output_feats)  # [batch_size, 40, 1000]
        #print(output_feats.shape)
        output_feats = output_feats.permute(0, 2, 1)  # [batch_size, 1000, 40]
        #print(output_feats.shape)
        output_feats = output_feats.reshape(batch_size, -1)  # Flatten to feed into the linear layer
        
        if return_features:
            return output_feats


device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
model_gcn = ImprovedGCN(num_features=3604, num_classes=2).to(device)
model_masif = MaSIF_site_PyTorch(n_thetas=16, n_rhos=5, n_feat=5, n_rotations=8).to(device)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool

class FusionModel(nn.Module):
    def __init__(self, model_gcn, model_masif, output_features, num_classes):
        super(FusionModel, self).__init__()
        self.model_gcn = model_gcn
        self.model_masif = model_masif
        self.reduce_masif = nn.Linear(40840, 16)
        self.fusion_layer = nn.Linear(output_features, 2)
        

    def forward(self, data_gcn, data_masif):
        gcn_features = self.model_gcn(data_gcn, return_features=True)
        input_feat = data_masif['input_feat']
        rho_coords = data_masif['rho_coords']
        theta_coords = data_masif['theta_coords']
        mask = data_masif['mask']
        masif_features = self.model_masif(input_feat, rho_coords, theta_coords, mask, return_features=True)
        #print("GCN Features Shape:", gcn_features.shape)
        #print("MaSIF Features Shape:", masif_features.shape)
        masif_features = F.relu(self.reduce_masif(masif_features))
        combined_features = torch.cat((gcn_features, masif_features), dim=1)
        combined_features = F.relu(combined_features)
        output = self.fusion_layer(combined_features)
        return F.log_softmax(output, dim=1)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model_gcn = ImprovedGCN(num_features=3604, num_classes=2).to(device)
model_masif = MaSIF_site_PyTorch(n_thetas=16, n_rhos=5, n_feat=5, n_rotations=8).to(device)


gcn_output_features = 32  
masif_output_features = 5 * 8  
total_output_features = 32


fusion_model = FusionModel(model_gcn, model_masif, 32, num_classes=2).to(device)

In [None]:
from sklearn.metrics import roc_auc_score, confusion_matrix, matthews_corrcoef
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
optimizer = AdamW(fusion_model.parameters(), lr=0.0001, weight_decay=1e-2)
scheduler_cosine = CosineAnnealingLR(optimizer, T_max=10)
scheduler_plateau = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
best_accuracy = 0.0  
best_model_path = '1114_GCN_finger1.pth' 

criterion = nn.CrossEntropyLoss()
def evaluate_model(model, data_loader, criterion, device):
    model.eval() 
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    all_labels = []
    #all_preds = []
    all_probs = [] 
    with torch.no_grad():  
        for data_gcn, data_masif in data_loader:
            data_gcn.to(device)
            input_feat = data_masif['input_feat'].to(device)
            rho_coords = data_masif['rho_coords'].to(device)
            theta_coords = data_masif['theta_coords'].to(device)
            mask = data_masif['mask'].to(device)
            labels = data_gcn.y.to(device)
            outputs = model(data_gcn, {'input_feat': input_feat, 'rho_coords': rho_coords, 'theta_coords': theta_coords, 'mask': mask})
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            
            all_labels.extend(labels.cpu().numpy())
            probs = torch.nn.functional.softmax(outputs, dim=1)[:, 1] 
            all_probs.extend(probs.cpu().numpy())
            #all_preds.extend(predicted.cpu().numpy())
        avg_loss = total_loss / len(data_loader)
        accuracy = 100 * total_correct / total_samples
        auc_score = roc_auc_score(all_labels, all_probs) 
        tn, fp, fn, tp = confusion_matrix(all_labels, (np.array(all_probs) > 0.5).astype(int)).ravel()
        sensitivity = tp / (tp + fn)
        specificity = tn / (tn + fp)
        mcc = matthews_corrcoef(all_labels, (np.array(all_probs) > 0.5).astype(int))

    return avg_loss, accuracy, auc_score, sensitivity, specificity, mcc
epochs = 20
for epoch in range(epochs):
    fusion_model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0

    for data_gcn, data_masif in train_combined_loader:
        data_gcn.to(device)
        input_feat = data_masif['input_feat'].to(device)
        rho_coords = data_masif['rho_coords'].to(device)
        theta_coords = data_masif['theta_coords'].to(device)
        mask = data_masif['mask'].to(device)
        labels = data_gcn.y.to(device)
        optimizer.zero_grad()

        outputs = fusion_model(data_gcn, {'input_feat': input_feat, 'rho_coords': rho_coords, 'theta_coords': theta_coords, 'mask': mask})

        loss = criterion(outputs, labels)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()

        _, predicted = torch.max(outputs.data, 1)
        total_correct += (predicted == labels).sum().item()
        total_samples += labels.size(0)

    avg_loss = total_loss / len(train_combined_loader)
    accuracy = 100 * total_correct / total_samples
    print(f'Epoch {epoch+1}/{epochs}, Train Loss: {avg_loss:.4f}, Train Accuracy: {accuracy:.2f}%')

    test_loss, test_accuracy, test_auc, test_sens, test_spec, test_mcc = evaluate_model(fusion_model, test_combined_loader, criterion, device)
    print(f'Epoch {epoch+1}/{epochs}: Test Loss: {test_loss:.4f}, Accuracy: {test_accuracy:.2f}%, AUC: {test_auc:.4f}, SENS: {test_sens:.4f}, SPEC: {test_spec:.4f}, MCC: {test_mcc:.4f}')
    if test_accuracy > best_accuracy:
        best_accuracy = test_accuracy  
        torch.save(fusion_model.state_dict(), best_model_path)  
        print(f'Saved new best model with accuracy: {best_accuracy:.2f}%')
    scheduler_cosine.step()
    scheduler_plateau.step(test_loss)