In [1]:
import Bio.PDB as PDB
import pathlib
import os
import urllib.request
import numpy as np
import biographs as bg
from Bio import SeqIO
import torch
import networkx as nx
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import scipy
from tqdm import tqdm
from torch_geometric.data import Dataset, download_url, Data
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
import freesasa

In [2]:
pro_res_table = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'W', 'Y']

# Dictionary for getting Residue symbols
ressymbl = {'ALA': 'A', 'CYS': 'C', 'ASP': 'D', 'GLU':'E', 'PHE': 'F', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LYS': 'K', 'LEU': 'L', 'MET': 'M', 'ASN': 'N', 'PRO': 'P', 'GLN':'Q', 'ARG':'R', 'SER': 'S','THR': 'T', 'VAL': 'V', 'TRP':'W', 'TYR': 'Y'}
sites_dict = {'INORGANIC_BINDING': 0,
 'METAL_BINDING': 1,
 'COMPOUND_BINDING': 2,
 'UNCLASSIFIED': 3,
 'MULTIPLE_LIGAND_BINDING': 4,
 'COMPOUND_DRUG_BINDING': 5,
 'ACTIVE_COMPLEXED': 6,
 'ACTIVE': 7,
 'PROTEIN_PROTEIN_INTERACTION': 8,
 'PROTEIN_DNA_INTERACTION': 9,
 'PROTEIN_RNA_INTERACTION': 10}
aa_properties = {  # полярность(1)/неполярность(0); вандервальсов радиус; pI; Гидрофобность; Заряд; pKa(R), Normalized frequency of turn (Crawford et al., 1973), Normalized frequency of alpha-helix (Burgess et al., 1974)
    'G': [0, 48, 6.06, -0.4, 0, 0, 1.38, 0.12],     # Глицин 
    'A': [0, 67, 6.01, 1.8, 0, 0, 0.6, 0.486],      # Аланин
    'V': [0, 105, 6.00, 4.2, 0, 0, 0.48, 0.379],     # Валин
    'I': [0, 124, 6.05, 4.5, 0, 0, 0.67, 0.37],     # Изолейцин
    'L': [0, 124, 6.01, 3.8, 0, 0, 0.7, 0.42],     # Лейцин
    'P': [0, 90, 6.30, -1.6, 0, 0, 1.47, 0.208],     # Пролин
    'S': [1, 73, 5.68, -0.8, 0, 0, 1.26, 0.2],     # Серин
    'T': [1, 93, 5.60, -0.7, 0, 0, 1.05, 0.272],     # Треонин
    'C': [1, 86, 5.05, 2.5, 0, 8.33, 1.29, 0.2],   # Цистеин
    'M': [0, 124, 5.74, 1.9, 0, 0, 0.67, 0.417],     # Метионин
    'D': [1, 91, 2.85, -3.5, -1, 3.65, 1.24, 0.288], # Аспарагиновая кислота
    'N': [1, 96, 5.41, -3.5, 0, 0, 1.42, 0.193],     # Аспарагин
    'E': [1, 109, 3.15, -3.5, -1, 4.25, 0.64, 0.538],# Глутаминовая кислота
    'Q': [1, 114, 5.65, -3.5, 0, 0, 0.92, 0.418],    # Глутамин
    'K': [1, 135, 9.60, -3.9, 1, 10.28, 1.1, 0.402],# Лизин
    'R': [1, 148, 10.76, -4.5, 1, 12.48, 0.79, 0.262],# Аргинин
    'H': [1, 118, 7.60, -3.2, 1, 6.0, 0.95, 0.4], # Гистидин
    'F': [0, 135, 5.49, 2.8, 0, 0, 1.05, 0.288],     # Фенилаланин
    'Y': [1, 141, 5.64, -1.3, 0, 10.1, 1.35, 0.161], # Тирозин
    'W': [0, 163, 5.89, -0.9, 0, 0, 1.23, 0.462]     # Триптофан
}


In [3]:
def get_sequence_and_ca_coordinates(structure):
    sequence =""
    sequence_dict = {}
    count = 0
    ca_coords = []
    try:
        result, _ = freesasa.calcBioPDB(structure)
        residue_areas = result.residueAreas()
    except:
        filtered_structure = filter_unknown_residues(structure)
        result, _ = freesasa.calcBioPDB(filtered_structure)
        residue_areas = result.residueAreas()
    sasa_dict = {}
# for model in structure:
    for chain in structure[0]:
        for residue in chain:
            if residue.id[0].startswith('H_'):
                continue
            if residue.get_resname() in ressymbl.keys():
                sequence = sequence + ressymbl[residue.get_resname()]
                print(residue.get_id())
                if residue.get_id()[2] != ' ':
                    sequence_dict[(str(residue.get_id()[1]) + residue.get_id()[2],chain.id)] = (ressymbl[residue.get_resname()],count)
                    sasa_dict[(str(residue.get_id()[1]) + residue.get_id()[2],chain.id)] = residue_areas[chain.id][str(residue.get_id()[1]) + residue.get_id()[2]].total
                else:
                    sequence_dict[(str(residue.get_id()[1]),chain.id)] = (ressymbl[residue.get_resname()],count)
                    sasa_dict[(str(residue.get_id()[1]),chain.id)] = residue_areas[chain.id][str(residue.get_id()[1])].total
            
                count += 1
                try:                               #Если нет альфа-атома, то добавим координаты бета-атома остова аминокислоты
                    ca_atom = residue["CA"]   
                    ca_coords.append(ca_atom.coord)
                except:
                    try:
                        c_atom = residue['C']
                        c_coords.append(cb_atom.coord)
                    except:
                        n_atom = residue['N']
                        ca_coords.append(n_atom.coord)
              
    return sequence, sequence_dict, np.array(ca_coords), sasa_dict

In [4]:
def get_phis_chem_properties(seq):
    aa_prop_is_seq = np.zeros([len(seq), len(aa_properties['A'])])
    for i, residue in enumerate(seq):
        if residue in aa_properties.keys():
            aa_prop_is_seq[i,:] = aa_properties[residue]
    return aa_prop_is_seq

In [5]:
def one_hot_symbftrs(sequence):
        one_hot_symb = np.zeros((len(sequence),len(pro_res_table)))
        row= 0
        for res in sequence:
          col = pro_res_table.index(res)
          one_hot_symb[row][col]=1
          row +=1
        return one_hot_symb

In [6]:
def get_node_labels(seq_dict, file):
        y = torch.zeros((len(seq_dict), len(sites_dict)), dtype=torch.long)
        for index, row in train_pdb_df[train_pdb_df['PDBID']== os.path.splitext(os.path.basename(file))[0]][['SITE_TYPE','AminoAcidsWithChains','POS_CHAINS']].iterrows():
            pos_chain = eval(row['POS_CHAINS'])[0]
            if pos_chain == "YES":
                for amino_acid, position, chain in eval(row['AminoAcidsWithChains']):
                    key = (position, chain)
                    # print(key)
                    if key in seq_dict and seq_dict[key][0] == amino_acid:
                        # print(key)
                        s_type = sites_dict[row['SITE_TYPE']]
                        # print(s_type)
                        y[seq_dict[key][1], s_type] = 1
            else:
                for amino_acid, position in eval(row['AminoAcidsWithChains']):
                    new_seq_dict = {key[0]:val for  key,val in zip(seq_dict.keys(),seq_dict.values())}
                    key = str(position)
                    print(amino_acid, type(amino_acid))
                    print(key, type(key))
                    print(new_seq_dict[key])
                    if key in new_seq_dict.keys() and new_seq_dict[key][0] == amino_acid:
                        print(key, new_seq_dict[key], amino_acid)
                        s_type = sites_dict[row['SITE_TYPE']]
                        print(s_type)
                        y[new_seq_dict[key][1], s_type] = 1
        return y

In [145]:
def filter_unknown_residues(structure):
    """Создает копию структуры без UNK, нуклеотидов и гетероатомов"""
    filtered_structure = PDB.Structure.Structure("filtered")
    
    for model in structure:
        new_model = PDB.Model.Model(model.id)
        for chain in model:
            new_chain = PDB.Chain.Chain(chain.id)
            for residue in chain:
                resname = residue.get_resname().strip()
                # Пропуск нуклеотидов, UNK и гетероатомов
                if resname in ['A', 'C', 'G', 'T', 'U', 'N'] or resname == 'UNK' or residue.id[0] != ' ':
                    continue
                new_chain.add(residue)
            new_model.add(new_chain)
        filtered_structure.add(new_model)
    
    return filtered_structure

In [3]:
train_pdb_df = pd.read_csv("train_df_full3_with_comb_classes_wo_val.csv")

In [5]:
test_pdb_df = pd.read_csv('test_df_full3_with_comb_classes.csv')

In [5]:
val_pdb_df = pd.read_csv('val_df.csv')

# Датасет

In [25]:
class Train_Graph_Dataset(Dataset):
    def __init__(self, root, specific_files=None, transform=None, pre_transform=None):
        self.root = root
        self.specific_files = specific_files
        self.custom_processed_dir = os.path.join(self.root, 'processed_type_4')
        os.makedirs(self.custom_processed_dir, exist_ok=True)
        super(Train_Graph_Dataset, self).__init__(root, transform=None, pre_transform=None)
        self.data = self.processed_paths
        self.data_prot = []
          # Список конкретных файлов для обработки
        
        self._initialize_data()

    def _initialize_data(self):
        processed_files = [os.path.join(self.custom_processed_dir, os.path.splitext(os.path.basename(file))[0] + '.pt')
                          for file in self.raw_paths]  # Используем self.raw_paths

        if all(os.path.exists(f) for f in processed_files):
            print("Все обработанные файлы уже существуют. Загружаем их.")
            self.data_prot = []
            for file_path in processed_files:
                try:
                    data = torch.load(file_path, map_location=torch.device('cpu'), weights_only=False)
                    self.data_prot.append(data)
                except Exception as e:
                    print(f"Ошибка при загрузке {file_path}: {str(e)}")
                    continue
        else:
            print("Некоторые обработанные файлы отсутствуют. Запускаем обработку.")
            self.process()

    @property
    def raw_file_names(self):
        all_files = [f.path for f in os.scandir(self.root + "/raw_files") if f.is_file()]
        if self.specific_files:
            # Фильтруем файлы, оставляя только те, что есть в specific_files
            return [f for f in all_files if f in self.specific_files]
        else:
            return all_files

    @property
    def processed_file_names(self):
        return [os.path.splitext(os.path.basename(file))[0]+'.pt' for file in self.raw_paths]

    def download(self):
        # Download to `self.raw_dir`.
        pass
    
    def process(self):
        if all(os.path.exists(os.path.join(self.custom_processed_dir, f)) for f in self.processed_file_names):
            print("Все обработанные файлы уже существуют. Пропускаем повторную обработку.")
            return
        self.data_prot = []  # Reset/clear data_prot
        
        data_list = []
        files_with_errors = []
        count1 = 0
        count2 = 0
        count3 = 0
        count4 = 0
        print(f"Найдено файлов для обработки: {len(self.raw_paths)}")
        
        for file in tqdm(self.raw_paths):
            
            if pathlib.Path(file).suffix == ".pdb":
                print(f"\nОбработка файла: {file}")
                try:
                    struct = self._get_structure(file)
                    print(f"Структура загружена успешно")
                    
                    seq, seq_dict, Ca_coords, sasa_dict = self._get_sequence_and_ca_coordinates(struct)
                    print(f"Последовательность получена")
                    
                    node_feats = self._get_node_ftrs(seq, struct, file, sasa_dict)
                    print(f"Размер признаков узлов: {node_feats.shape}")
                    
                    mat = self._get_adjacency(Ca_coords)
                    print("Матрица смежности получена")
                    edge_index = self._get_edgeindex(mat)
                    print("Список список ребер получен")
                    print(f"Размер матрицы смежности: {mat.shape}")
                    print(f"Размер матрицы узлов: {node_feats.shape}")
                    if mat.shape[0] == node_feats.shape[0]:
                        count3 += 1
                    else:
                        count4 += 1

                    labels = self._get_node_labels(seq_dict, file)
                    # print(labels.nonzero())
                    print('Метки получены')
                    
                    data = Data(x=node_feats, edge_index=edge_index, y = labels)
                    data_list.append(data)
                    count1 += 1
                    
                    torch.save(data, self.custom_processed_dir + "/" + 
                             os.path.splitext(os.path.basename(file))[0] + '.pt')
                    print(f"Данные сохранены успешно {file}")
                    print("---------------------")
                    # print(count1)
                
                except Exception as e:
                    files_with_errors.append(file)
                    print(f"Ошибка при обработке {file}: {str(e)}")
                    count2 += 1
                    # print(count2)
                    continue
                
        self.data_prot = data_list
        print(self.custom_processed_dir)
        print(f"\nИтоговая статистика:")
        print(f"Успешно обработано файлов: {count1}")
        print(f"Ошибок при обработке: {count2}")
        print(f"Количество файлов с которых матрицы смежности и узлов равны {count3}")
        print(f"Количество файлов с которых матрицы смежности и узлов не равны {count4}")
        print(f"Размер data_prot: {len(self.data_prot)}")
        print(files_with_errors)
    def __len__(self):
        return len(self.data_prot)
    
    # file stands for file path
    def __getitem__(self, idx):
        # print(idx)
        # print(len(self.data_prot))
     
        return self.data_prot[idx] 
     
    def _get_adjacency(self, ca_coords):
        ca_distance = np.linalg.norm(ca_coords[:, None, :] - ca_coords[None, :, :], axis=-1) + np.eye(ca_coords.shape[0])*6
        adjacency_matrix = ca_distance < 6
        # network = nx.from_numpy_array(mask)
        
        #прошлый кусок кода, который пока не убираю
        # network = molecule.network()  
        # mat = nx.adjacency_matrix(network)
        # m = mat.todense()
        return adjacency_matrix
   
    def _get_edgeindex(self, adjacency_matrix):
        nx_graph = nx.from_numpy_array(adjacency_matrix)
        edge_index = np.array(nx_graph.edges()).T
        return torch.tensor(edge_index, dtype=torch.long)
        
        #Прошлый кусок кода
        
        # a, b = np.nonzero(adjacency_mat > 0)
        # edge_index = np.stack((a, b), axis=0)  # Создаем массив (2, num_edges)
        # return torch.tensor(edge_index, dtype=torch.long)
        # edge_ind = []
        
    def _get_structure(self, file):
        parser = PDB.PDBParser()
        structure = parser.get_structure(id, file)
        return structure

    def _filter_unknown_residues(self, structure):
        """Создает копию структуры без UNK, нуклеотидов и гетероатомов"""
        filtered_structure = PDB.Structure.Structure("filtered")
        
        for model in structure:
            new_model = PDB.Model.Model(model.id)
            for chain in model:
                new_chain = PDB.Chain.Chain(chain.id)
                for residue in chain:
                    resname = residue.get_resname().strip()
                    # Пропуск нуклеотидов, UNK и гетероатомов
                    if resname in ['A', 'C', 'G', 'T', 'U', 'N'] or resname == 'UNK' or residue.id[0] != ' ':
                        continue
                    new_chain.add(residue)
                new_model.add(new_chain)
            filtered_structure.add(new_model)
        
        return filtered_structure


    # Function to get sequence from pdb structure 
    # Uses structure made using biopython
    # Those residues for which symbols are U / X are converted into A
    
    def _get_sequence_and_ca_coordinates(self, structure):
        sequence =""
        sequence_dict = {}
        count = 0
        ca_coords = []
        try:
            result, _ = freesasa.calcBioPDB(structure)
            residue_areas = result.residueAreas()
        except:
            filtered_structure = self.filter_unknown_residues(structure)
            result, _ = freesasa.calcBioPDB(filtered_structure)
            residue_areas = result.residueAreas()
        sasa_dict = {}
    # for model in structure:
        for chain in structure[0]:
            for residue in chain:
                if residue.get_resname() in ressymbl.keys():
                    if residue.id[0].startswith('H_'):
                        continue
                    sequence = sequence + ressymbl[residue.get_resname()]
                    if residue.get_id()[2] != ' ':
                        sequence_dict[(str(residue.get_id()[1]) + residue.get_id()[2],chain.id)] = (ressymbl[residue.get_resname()],count)
                        sasa_dict[(str(residue.get_id()[1]) + residue.get_id()[2],chain.id)] = residue_areas[chain.id][str(residue.get_id()[1]) + residue.get_id()[2]].total
                    else:
                        sequence_dict[(str(residue.get_id()[1]),chain.id)] = (ressymbl[residue.get_resname()],count)
                        sasa_dict[(str(residue.get_id()[1]),chain.id)] = residue_areas[chain.id][str(residue.get_id()[1])].total
                    count += 1
                    try:                               #Если нет альфа-атома, то добавим координаты бета-атома остова аминокислоты
                        ca_atom = residue["CA"]   
                        ca_coords.append(ca_atom.coord)
                    except:
                        try:
                            c_atom = residue['C']
                            c_coords.append(cb_atom.coord)
                        except:
                            n_atom = residue['N']
                            ca_coords.append(n_atom.coord)
                  
        return sequence, sequence_dict, np.array(ca_coords), sasa_dict
    

    # One hot encoding for symbols
    def _get_one_hot_symbftrs(self, sequence):
        one_hot_symb = np.zeros((len(sequence),len(pro_res_table)))
        row = 0
        for res in sequence:
          col = pro_res_table.index(res)
          one_hot_symb[row][col]=1
          row +=1
        return one_hot_symb
    
    def _get_phis_chem_properties(self, sequence):
        aa_prop_is_seq = np.zeros((len(sequence), len(aa_properties['A'])))
        for i, residue in enumerate(sequence):
            if residue in aa_properties.keys():
                aa_prop_is_seq[i,:] = aa_properties[residue]
        return aa_prop_is_seq

    def _get_sasa(self, sasa_dict, sequence):
        if len(sasa_dict) == len(sequence):
            sasa = np.zeros(len(sasa_dict))
            for i, key in enumerate(sasa_dict.keys()):
                sasa[i] = sasa_dict[key]
            return sasa

    def _handle_sequence_mismatch(self, sequence, ang_seq, angles):
        """Синхронизирует углы с основной последовательностью"""
        full_angles = np.zeros((len(sequence), angles.shape[1]))
        res_idx = 0
        
        for i, residue in enumerate(sequence):
            if res_idx < len(angles) and residue == ang_seq[res_idx]:
                full_angles[i,:] = angles[res_idx]
                res_idx += 1
            else:
                full_angles[i,:] = np.array([1,0,1,0])  # Заполнитель
        
        return full_angles

    def _get_node_ftrs(self, sequence, structure, file, sasa_dict):
        try:
            one_hot_symb = self._get_one_hot_symbftrs(sequence)
            # print('one-hot done')
            phis_chem_properties = self._get_phis_chem_properties(sequence)
            # print('phis done')
            sasa = self._get_sasa(sasa_dict, sequence)
            # print('sasa done')
            angles, ang_seq = self._get_angles(structure[0])
            
            # Синхронизация длин последовательностей
            if len(sequence) != len(ang_seq):
                print('angles rewrite')
                angles = self._handle_sequence_mismatch(sequence, ang_seq, angles)
            return torch.tensor(np.hstack((one_hot_symb, angles, phis_chem_properties, sasa.reshape(-1,1))), dtype = torch.float)
                    # else:
                    #     one_hot_symb = self._get_one_hot_symbftrs(ang_seq)
                    #     phis_chem_properties = self._get_phis_chem_properties(sequence)
                    #     return torch.tensor(np.hstack((one_hot_symb, angles, phis_chem_properties)), dtype = torch.float)
        except Exception as e:
    
                    print(f"Ошибка при обработке {file}: {str(e)}") 

    def _get_angles(self, structure):
        angles_trig = []
        seq = ""

        # Проходим по моделям и цепям
    # for model in structure:
    #     for chain in model:
        for chain in structure:    
            polypeptides = PDB.CaPPBuilder().build_peptides(chain)
            for poly in polypeptides:
                seq += poly.get_sequence()
                phi_psi = poly.get_phi_psi_list()
                phi_psi[0] = (0, phi_psi[0][1])
                phi_psi[-1] = (phi_psi[-1][0], 0)
                phi_psi = np.array(phi_psi)
                try:
                    sin_phi = np.sin(phi_psi[:,0])
                except TypeError:
                    phi_psi = np.array([[0 if value is None else value for value in item] for item in phi_psi])
                    sin_phi = np.sin(phi_psi[:,0])
                cos_phi = np.cos(phi_psi[:,0])
                sin_psi = np.sin(phi_psi[:,1])
                cos_psi = np.cos(phi_psi[:,1])
                sin_cos_poly = np.column_stack((sin_phi,cos_phi,sin_psi, cos_psi))
                angles_trig.append(sin_cos_poly)
        return np.vstack(angles_trig), seq

    def _get_node_labels(self, seq_dict, file):
        y = torch.zeros((len(seq_dict), len(sites_dict.keys())), dtype=torch.long)
        for index, row in train_pdb_df[train_pdb_df['PDBID']== os.path.splitext(os.path.basename(file))[0]][['SITE_TYPE','AminoAcidsWithChains','POS_CHAINS']].iterrows():
            pos_chain = eval(row['POS_CHAINS'])[0]
            if pos_chain == "YES":
                for amino_acid, position, chain in eval(row['AminoAcidsWithChains']):
                    key = (position, chain)
                    # print(key)
                    if key in seq_dict and seq_dict[key][0] == amino_acid:
                        # print(key)
                        s_type = sites_dict[row['SITE_TYPE']]
                        # print(s_type)
                        y[seq_dict[key][1], s_type] = 1
            else:
                for amino_acid, position in eval(row['AminoAcidsWithChains']):
                    new_seq_dict = {key[0]:val for  key,val in zip(seq_dict.keys(),seq_dict.values())}
                    key = str(position)
                    if key in new_seq_dict.keys() and new_seq_dict[key][0] == amino_acid:
                        print(key, new_seq_dict[key], amino_acid)
                        s_type = sites_dict[row['SITE_TYPE']]
                        y[new_seq_dict[key][1], s_type] = 1
        return y
    

In [26]:
train_dataset =Train_Graph_Dataset(root=r"D:/Proteins/train")


Все обработанные файлы уже существуют. Загружаем их.


In [7]:
class Val_Graph_Dataset(Dataset):
    def __init__(self, root, specific_files=None, transform=None, pre_transform=None):
        self.root = root
        self.specific_files = specific_files
        self.custom_processed_dir = os.path.join(self.root, 'processed_type_4')
          # Список конкретных файлов для обработки
        os.makedirs(self.custom_processed_dir, exist_ok=True)
        super(Val_Graph_Dataset, self).__init__(root, transform=None, pre_transform=None)
        self.data = self.processed_paths
        self.data_prot = []
        
        self._initialize_data()

    def _initialize_data(self):
        processed_files = [os.path.join(self.custom_processed_dir, os.path.splitext(os.path.basename(file))[0] + '.pt')
                          for file in self.raw_paths]  # Используем self.raw_paths

        if all(os.path.exists(f) for f in processed_files):
            print("Все обработанные файлы уже существуют. Загружаем их.")
            self.data_prot = []
            for file_path in processed_files:
                try:
                    data = torch.load(file_path, map_location=torch.device('cpu'), weights_only=False)
                    self.data_prot.append(data)
                except Exception as e:
                    print(f"Ошибка при загрузке {file_path}: {str(e)}")
                    continue
        else:
            print("Некоторые обработанные файлы отсутствуют. Запускаем обработку.")
            self.process()

    @property
    def raw_file_names(self):
        all_files = [f.path for f in os.scandir(self.root + "/raw_files") if f.is_file()]
        if self.specific_files:
            # Фильтруем файлы, оставляя только те, что есть в specific_files
            return [f for f in all_files if f in self.specific_files]
        else:
            return all_files

    @property
    def processed_file_names(self):
        return [os.path.splitext(os.path.basename(file))[0]+'.pt' for file in self.raw_paths]

    def download(self):
        # Download to `self.raw_dir`.
        pass
    
    def process(self):
        if all(os.path.exists(os.path.join(self.custom_processed_dir, f)) for f in self.processed_file_names):
            print("Все обработанные файлы уже существуют. Пропускаем повторную обработку.")
            return
        self.data_prot = []  # Reset/clear data_prot
        
        data_list = []
        files_with_errors = []
        count1 = 0
        count2 = 0
        count3 = 0
        count4 = 0
        print(f"Найдено файлов для обработки: {len(self.raw_paths)}")
        
        for file in tqdm(self.raw_paths):
            
            if pathlib.Path(file).suffix == ".pdb":
                print(f"\nОбработка файла: {file}")
                try:
                    struct = self._get_structure(file)
                    print(f"Структура загружена успешно")
                    
                    seq, seq_dict, Ca_coords, sasa_dict = self._get_sequence_and_ca_coordinates(struct)
                    print(f"Последовательность получена")
                    
                    node_feats = self._get_node_ftrs(seq, struct, file, sasa_dict)
                    print(f"Размер признаков узлов: {node_feats.shape}")
                    
                    mat = self._get_adjacency(Ca_coords)
                    print("Матрица смежности получена")
                    edge_index = self._get_edgeindex(mat)
                    print("Список список ребер получен")
                    print(f"Размер матрицы смежности: {mat.shape}")
                    print(f"Размер матрицы узлов: {node_feats.shape}")
                    if mat.shape[0] == node_feats.shape[0]:
                        count3 += 1
                    else:
                        count4 += 1

                    labels = self._get_node_labels(seq_dict, file)
                    # print(labels.nonzero())
                    print('Метки получены')
                    
                    data = Data(x=node_feats, edge_index=edge_index, y = labels)
                    data_list.append(data)
                    count1 += 1
                    
                    torch.save(data, self.custom_processed_dir + "/" + 
                             os.path.splitext(os.path.basename(file))[0] + '.pt')
                    print(f"Данные сохранены успешно {file}")
                    print("---------------------")
                    # print(count1)
                
                except Exception as e:
                    files_with_errors.append(file)
                    print(f"Ошибка при обработке {file}: {str(e)}")
                    count2 += 1
                    # print(count2)
                    continue
                
        self.data_prot = data_list
        print(self.custom_processed_dir)
        print(f"\nИтоговая статистика:")
        print(f"Успешно обработано файлов: {count1}")
        print(f"Ошибок при обработке: {count2}")
        print(f"Количество файлов с которых матрицы смежности и узлов равны {count3}")
        print(f"Количество файлов с которых матрицы смежности и узлов не равны {count4}")
        print(f"Размер data_prot: {len(self.data_prot)}")
        print(files_with_errors)
    def __len__(self):
        return len(self.data_prot)
    
    # file stands for file path
    def __getitem__(self, idx):
        # print(idx)
        # print(len(self.data_prot))
     
        return self.data_prot[idx] 
     
    def _get_adjacency(self, ca_coords):
        ca_distance = np.linalg.norm(ca_coords[:, None, :] - ca_coords[None, :, :], axis=-1) + np.eye(ca_coords.shape[0])*6
        adjacency_matrix = ca_distance < 6
        # network = nx.from_numpy_array(mask)
        
        #прошлый кусок кода, который пока не убираю
        # network = molecule.network()  
        # mat = nx.adjacency_matrix(network)
        # m = mat.todense()
        return adjacency_matrix
   
    def _get_edgeindex(self, adjacency_matrix):
        nx_graph = nx.from_numpy_array(adjacency_matrix)
        edge_index = np.array(nx_graph.edges()).T
        return torch.tensor(edge_index, dtype=torch.long)
        
        #Прошлый кусок кода
        
        # a, b = np.nonzero(adjacency_mat > 0)
        # edge_index = np.stack((a, b), axis=0)  # Создаем массив (2, num_edges)
        # return torch.tensor(edge_index, dtype=torch.long)
        # edge_ind = []
        
    def _get_structure(self, file):
        parser = PDB.PDBParser()
        structure = parser.get_structure(id, file)
        return structure

    def _filter_unknown_residues(self, structure):
        """Создает копию структуры без UNK, нуклеотидов и гетероатомов"""
        filtered_structure = PDB.Structure.Structure("filtered")
        
        for model in structure:
            new_model = PDB.Model.Model(model.id)
            for chain in model:
                new_chain = PDB.Chain.Chain(chain.id)
                for residue in chain:
                    resname = residue.get_resname().strip()
                    # Пропуск нуклеотидов, UNK и гетероатомов
                    if resname in ['A', 'C', 'G', 'T', 'U', 'N'] or resname == 'UNK' or residue.id[0] != ' ':
                        continue
                    new_chain.add(residue)
                new_model.add(new_chain)
            filtered_structure.add(new_model)
        
        return filtered_structure


    # Function to get sequence from pdb structure 
    # Uses structure made using biopython
    # Those residues for which symbols are U / X are converted into A
    
    def _get_sequence_and_ca_coordinates(self, structure):
        sequence =""
        sequence_dict = {}
        count = 0
        ca_coords = []
        try:
            result, _ = freesasa.calcBioPDB(structure)
            residue_areas = result.residueAreas()
        except:
            filtered_structure = self.filter_unknown_residues(structure)
            result, _ = freesasa.calcBioPDB(filtered_structure)
            residue_areas = result.residueAreas()
        sasa_dict = {}
    # for model in structure:
        for chain in structure[0]:
            for residue in chain:
                if residue.get_resname() in ressymbl.keys():
                    if residue.id[0].startswith('H_'):
                        continue
                    sequence = sequence + ressymbl[residue.get_resname()]
                    if residue.get_id()[2] != ' ':
                        sequence_dict[(str(residue.get_id()[1]) + residue.get_id()[2],chain.id)] = (ressymbl[residue.get_resname()],count)
                        sasa_dict[(str(residue.get_id()[1]) + residue.get_id()[2],chain.id)] = residue_areas[chain.id][str(residue.get_id()[1]) + residue.get_id()[2]].total
                    else:
                        sequence_dict[(str(residue.get_id()[1]),chain.id)] = (ressymbl[residue.get_resname()],count)
                        sasa_dict[(str(residue.get_id()[1]),chain.id)] = residue_areas[chain.id][str(residue.get_id()[1])].total
                    count += 1
                    try:                               #Если нет альфа-атома, то добавим координаты бета-атома остова аминокислоты
                        ca_atom = residue["CA"]   
                        ca_coords.append(ca_atom.coord)
                    except:
                        try:
                            c_atom = residue['C']
                            c_coords.append(cb_atom.coord)
                        except:
                            n_atom = residue['N']
                            ca_coords.append(n_atom.coord)
                  
        return sequence, sequence_dict, np.array(ca_coords), sasa_dict
    

    # One hot encoding for symbols
    def _get_one_hot_symbftrs(self, sequence):
        one_hot_symb = np.zeros((len(sequence),len(pro_res_table)))
        row = 0
        for res in sequence:
          col = pro_res_table.index(res)
          one_hot_symb[row][col]=1
          row +=1
        return one_hot_symb
    
    def _get_phis_chem_properties(self, sequence):
        aa_prop_is_seq = np.zeros((len(sequence), len(aa_properties['A'])))
        for i, residue in enumerate(sequence):
            if residue in aa_properties.keys():
                aa_prop_is_seq[i,:] = aa_properties[residue]
        return aa_prop_is_seq

    def _get_sasa(self, sasa_dict, sequence):
        if len(sasa_dict) == len(sequence):
            sasa = np.zeros(len(sasa_dict))
            for i, key in enumerate(sasa_dict.keys()):
                sasa[i] = sasa_dict[key]
            return sasa

    def _handle_sequence_mismatch(self, sequence, ang_seq, angles):
        """Синхронизирует углы с основной последовательностью"""
        full_angles = np.zeros((len(sequence), angles.shape[1]))
        res_idx = 0
        
        for i, residue in enumerate(sequence):
            if res_idx < len(angles) and residue == ang_seq[res_idx]:
                full_angles[i,:] = angles[res_idx]
                res_idx += 1
            else:
                full_angles[i,:] = np.array([1,0,1,0])  # Заполнитель
        
        return full_angles

    def _get_node_ftrs(self, sequence, structure, file, sasa_dict):
        try:
            one_hot_symb = self._get_one_hot_symbftrs(sequence)
            # print('one-hot done')
            phis_chem_properties = self._get_phis_chem_properties(sequence)
            # print('phis done')
            sasa = self._get_sasa(sasa_dict, sequence)
            # print('sasa done')
            angles, ang_seq = self._get_angles(structure[0])
            
            # Синхронизация длин последовательностей
            if len(sequence) != len(ang_seq):
                print('angles rewrite')
                angles = self._handle_sequence_mismatch(sequence, ang_seq, angles)
            return torch.tensor(np.hstack((one_hot_symb, angles, phis_chem_properties, sasa.reshape(-1,1))), dtype = torch.float)
                    # else:
                    #     one_hot_symb = self._get_one_hot_symbftrs(ang_seq)
                    #     phis_chem_properties = self._get_phis_chem_properties(sequence)
                    #     return torch.tensor(np.hstack((one_hot_symb, angles, phis_chem_properties)), dtype = torch.float)
        except Exception as e:
    
                    print(f"Ошибка при обработке {file}: {str(e)}") 

    def _get_angles(self, structure):
        angles_trig = []
        seq = ""

        # Проходим по моделям и цепям
    # for model in structure:
    #     for chain in model:
        for chain in structure:    
            polypeptides = PDB.CaPPBuilder().build_peptides(chain)
            for poly in polypeptides:
                seq += poly.get_sequence()
                phi_psi = poly.get_phi_psi_list()
                phi_psi[0] = (0, phi_psi[0][1])
                phi_psi[-1] = (phi_psi[-1][0], 0)
                phi_psi = np.array(phi_psi)
                try:
                    sin_phi = np.sin(phi_psi[:,0])
                except TypeError:
                    phi_psi = np.array([[0 if value is None else value for value in item] for item in phi_psi])
                    sin_phi = np.sin(phi_psi[:,0])
                cos_phi = np.cos(phi_psi[:,0])
                sin_psi = np.sin(phi_psi[:,1])
                cos_psi = np.cos(phi_psi[:,1])
                sin_cos_poly = np.column_stack((sin_phi,cos_phi,sin_psi, cos_psi))
                angles_trig.append(sin_cos_poly)
        return np.vstack(angles_trig), seq

    def _get_node_labels(self, seq_dict, file):
        y = torch.zeros((len(seq_dict), len(sites_dict.keys())), dtype=torch.long)
        for index, row in val_pdb_df[val_pdb_df['PDBID']== os.path.splitext(os.path.basename(file))[0]][['SITE_TYPE','AminoAcidsWithChains','POS_CHAINS']].iterrows():
            pos_chain = eval(row['POS_CHAINS'])[0]
            if pos_chain == "YES":
                for amino_acid, position, chain in eval(row['AminoAcidsWithChains']):
                    key = (position, chain)
                    # print(key)
                    if key in seq_dict and seq_dict[key][0] == amino_acid:
                        # print(key)
                        s_type = sites_dict[row['SITE_TYPE']]
                        # print(s_type)
                        y[seq_dict[key][1], s_type] = 1
            else:
                for amino_acid, position in eval(row['AminoAcidsWithChains']):
                    new_seq_dict = {key[0]:val for  key,val in zip(seq_dict.keys(),seq_dict.values())}
                    key = str(position)
                    if key in new_seq_dict.keys() and new_seq_dict[key][0] == amino_acid:
                        print(key, new_seq_dict[key], amino_acid)
                        s_type = sites_dict[row['SITE_TYPE']]
                        y[new_seq_dict[key][1], s_type] = 1
        return y
    

In [None]:
val_dataset = Val_Graph_Dataset(root=r"D:/Proteins/val")


In [None]:
class Test_Graph_Dataset(Dataset):
    def __init__(self, root, specific_files=None, transform=None, pre_transform=None):
        self.root = root
        self.specific_files = specific_files
        self.custom_processed_dir = os.path.join(self.root, 'processed_type_4')
          # Список конкретных файлов для обработки
        os.makedirs(self.custom_processed_dir, exist_ok=True)
        super(Test_Graph_Dataset, self).__init__(root, transform=None, pre_transform=None)
        self.data = self.processed_paths
        self.data_prot = []
        
        self._initialize_data()

    def _initialize_data(self):
        processed_files = [os.path.join(self.custom_processed_dir, os.path.splitext(os.path.basename(file))[0] + '.pt')
                          for file in self.raw_paths]  # Используем self.raw_paths

        if all(os.path.exists(f) for f in processed_files):
            print("Все обработанные файлы уже существуют. Загружаем их.")
            self.data_prot = []
            for file_path in processed_files:
                try:
                    data = torch.load(file_path, map_location=torch.device('cpu'), weights_only=False)
                    self.data_prot.append(data)
                except Exception as e:
                    print(f"Ошибка при загрузке {file_path}: {str(e)}")
                    continue
        else:
            print("Некоторые обработанные файлы отсутствуют. Запускаем обработку.")
            self.process()

    @property
    def raw_file_names(self):
        all_files = [f.path for f in os.scandir(self.root + "/raw_files") if f.is_file()]
        if self.specific_files:
            # Фильтруем файлы, оставляя только те, что есть в specific_files
            return [f for f in all_files if f in self.specific_files]
        else:
            return all_files

    @property
    def processed_file_names(self):
        return [os.path.splitext(os.path.basename(file))[0]+'.pt' for file in self.raw_paths]

    def download(self):
        # Download to `self.raw_dir`.
        pass
    
    def process(self):
        if all(os.path.exists(os.path.join(self.custom_processed_dir, f)) for f in self.processed_file_names):
            print("Все обработанные файлы уже существуют. Пропускаем повторную обработку.")
            return
        self.data_prot = []  # Reset/clear data_prot
        
        data_list = []
        files_with_errors = []
        count1 = 0
        count2 = 0
        count3 = 0
        count4 = 0
        print(f"Найдено файлов для обработки: {len(self.raw_paths)}")
        
        for file in tqdm(self.raw_paths):
            
            if pathlib.Path(file).suffix == ".pdb":
                print(f"\nОбработка файла: {file}")
                try:
                    struct = self._get_structure(file)
                    print(f"Структура загружена успешно")
                    
                    seq, seq_dict, Ca_coords, sasa_dict = self._get_sequence_and_ca_coordinates(struct)
                    print(f"Последовательность получена")
                    
                    node_feats = self._get_node_ftrs(seq, struct, file, sasa_dict)
                    print(f"Размер признаков узлов: {node_feats.shape}")
                    
                    mat = self._get_adjacency(Ca_coords)
                    print("Матрица смежности получена")
                    edge_index = self._get_edgeindex(mat)
                    print("Список список ребер получен")
                    print(f"Размер матрицы смежности: {mat.shape}")
                    print(f"Размер матрицы узлов: {node_feats.shape}")
                    if mat.shape[0] == node_feats.shape[0]:
                        count3 += 1
                    else:
                        count4 += 1

                    labels = self._get_node_labels(seq_dict, file)
                    # print(labels.nonzero())
                    print('Метки получены')
                    
                    data = Data(x=node_feats, edge_index=edge_index, y = labels)
                    data_list.append(data)
                    count1 += 1
                    
                    torch.save(data, self.custom_processed_dir + "/" + 
                             os.path.splitext(os.path.basename(file))[0] + '.pt')
                    print(f"Данные сохранены успешно {file}")
                    print("---------------------")
                    # print(count1)
                
                except Exception as e:
                    files_with_errors.append(file)
                    print(f"Ошибка при обработке {file}: {str(e)}")
                    count2 += 1
                    # print(count2)
                    continue
                
        self.data_prot = data_list
        print(self.custom_processed_dir)
        print(f"\nИтоговая статистика:")
        print(f"Успешно обработано файлов: {count1}")
        print(f"Ошибок при обработке: {count2}")
        print(f"Количество файлов с которых матрицы смежности и узлов равны {count3}")
        print(f"Количество файлов с которых матрицы смежности и узлов не равны {count4}")
        print(f"Размер data_prot: {len(self.data_prot)}")
        print(files_with_errors)
    def __len__(self):
        return len(self.data_prot)
    
    # file stands for file path
    def __getitem__(self, idx):
        # print(idx)
        # print(len(self.data_prot))
     
        return self.data_prot[idx] 
     
    def _get_adjacency(self, ca_coords):
        ca_distance = np.linalg.norm(ca_coords[:, None, :] - ca_coords[None, :, :], axis=-1) + np.eye(ca_coords.shape[0])*6
        adjacency_matrix = ca_distance < 6
        # network = nx.from_numpy_array(mask)
        
        #прошлый кусок кода, который пока не убираю
        # network = molecule.network()  
        # mat = nx.adjacency_matrix(network)
        # m = mat.todense()
        return adjacency_matrix
   
    def _get_edgeindex(self, adjacency_matrix):
        nx_graph = nx.from_numpy_array(adjacency_matrix)
        edge_index = np.array(nx_graph.edges()).T
        return torch.tensor(edge_index, dtype=torch.long)
        
        #Прошлый кусок кода
        
        # a, b = np.nonzero(adjacency_mat > 0)
        # edge_index = np.stack((a, b), axis=0)  # Создаем массив (2, num_edges)
        # return torch.tensor(edge_index, dtype=torch.long)
        # edge_ind = []
        
    def _get_structure(self, file):
        parser = PDB.PDBParser()
        structure = parser.get_structure(id, file)
        return structure

    def _filter_unknown_residues(self, structure):
        """Создает копию структуры без UNK, нуклеотидов и гетероатомов"""
        filtered_structure = PDB.Structure.Structure("filtered")
        
        for model in structure:
            new_model = PDB.Model.Model(model.id)
            for chain in model:
                new_chain = PDB.Chain.Chain(chain.id)
                for residue in chain:
                    resname = residue.get_resname().strip()
                    # Пропуск нуклеотидов, UNK и гетероатомов
                    if resname in ['A', 'C', 'G', 'T', 'U', 'N'] or resname == 'UNK' or residue.id[0] != ' ':
                        continue
                    new_chain.add(residue)
                new_model.add(new_chain)
            filtered_structure.add(new_model)
        
        return filtered_structure


    # Function to get sequence from pdb structure 
    # Uses structure made using biopython
    # Those residues for which symbols are U / X are converted into A
    
    def _get_sequence_and_ca_coordinates(self, structure):
        sequence =""
        sequence_dict = {}
        count = 0
        ca_coords = []
        try:
            result, _ = freesasa.calcBioPDB(structure)
            residue_areas = result.residueAreas()
        except:
            filtered_structure = self._filter_unknown_residues(structure)
            result, _ = freesasa.calcBioPDB(filtered_structure)
            residue_areas = result.residueAreas()
        sasa_dict = {}
    # for model in structure:
        for chain in structure[0]:
            for residue in chain:
                if residue.get_resname() in ressymbl.keys():
                    if residue.id[0].startswith('H_'):
                        continue
                    sequence = sequence + ressymbl[residue.get_resname()]
                    if residue.get_id()[2] != ' ':
                        sequence_dict[(str(residue.get_id()[1]) + residue.get_id()[2],chain.id)] = (ressymbl[residue.get_resname()],count)
                        sasa_dict[(str(residue.get_id()[1]) + residue.get_id()[2],chain.id)] = residue_areas[chain.id][str(residue.get_id()[1]) + residue.get_id()[2]].total
                    else:
                        sequence_dict[(str(residue.get_id()[1]),chain.id)] = (ressymbl[residue.get_resname()],count)
                        sasa_dict[(str(residue.get_id()[1]),chain.id)] = residue_areas[chain.id][str(residue.get_id()[1])].total
                    count += 1
                    try:                               #Если нет альфа-атома, то добавим координаты бета-атома остова аминокислоты
                        ca_atom = residue["CA"]   
                        ca_coords.append(ca_atom.coord)
                    except:
                        try:
                            c_atom = residue['C']
                            c_coords.append(cb_atom.coord)
                        except:
                            n_atom = residue['N']
                            ca_coords.append(n_atom.coord)
                  
        return sequence, sequence_dict, np.array(ca_coords), sasa_dict
    

    # One hot encoding for symbols
    def _get_one_hot_symbftrs(self, sequence):
        one_hot_symb = np.zeros((len(sequence),len(pro_res_table)))
        row = 0
        for res in sequence:
          col = pro_res_table.index(res)
          one_hot_symb[row][col]=1
          row +=1
        return one_hot_symb
    
    def _get_phis_chem_properties(self, sequence):
        aa_prop_is_seq = np.zeros((len(sequence), len(aa_properties['A'])))
        for i, residue in enumerate(sequence):
            if residue in aa_properties.keys():
                aa_prop_is_seq[i,:] = aa_properties[residue]
        return aa_prop_is_seq

    def _get_sasa(self, sasa_dict, sequence):
        if len(sasa_dict) == len(sequence):
            sasa = np.zeros(len(sasa_dict))
            for i, key in enumerate(sasa_dict.keys()):
                sasa[i] = sasa_dict[key]
            return sasa

    def _handle_sequence_mismatch(self, sequence, ang_seq, angles):
        """Синхронизирует углы с основной последовательностью"""
        full_angles = np.zeros((len(sequence), angles.shape[1]))
        res_idx = 0
        
        for i, residue in enumerate(sequence):
            if res_idx < len(angles) and residue == ang_seq[res_idx]:
                full_angles[i,:] = angles[res_idx]
                res_idx += 1
            else:
                full_angles[i,:] = np.array([1,0,1,0])  # Заполнитель
        
        return full_angles

    def _get_node_ftrs(self, sequence, structure, file, sasa_dict):
        try:
            one_hot_symb = self._get_one_hot_symbftrs(sequence)
            # print('one-hot done')
            phis_chem_properties = self._get_phis_chem_properties(sequence)
            # print('phis done')
            sasa = self._get_sasa(sasa_dict, sequence)
            # print('sasa done')
            angles, ang_seq = self._get_angles(structure[0])
            
            # Синхронизация длин последовательностей
            if len(sequence) != len(ang_seq):
                print('angles rewrite')
                angles = self._handle_sequence_mismatch(sequence, ang_seq, angles)
            return torch.tensor(np.hstack((one_hot_symb, angles, phis_chem_properties, sasa.reshape(-1,1))), dtype = torch.float)
                    # else:
                    #     one_hot_symb = self._get_one_hot_symbftrs(ang_seq)
                    #     phis_chem_properties = self._get_phis_chem_properties(sequence)
                    #     return torch.tensor(np.hstack((one_hot_symb, angles, phis_chem_properties)), dtype = torch.float)
        except Exception as e:
    
                    print(f"Ошибка при обработке {file}: {str(e)}") 

    def _get_angles(self, structure):
        angles_trig = []
        seq = ""

        # Проходим по моделям и цепям
    # for model in structure:
    #     for chain in model:
        for chain in structure:    
            polypeptides = PDB.CaPPBuilder().build_peptides(chain)
            for poly in polypeptides:
                seq += poly.get_sequence()
                phi_psi = poly.get_phi_psi_list()
                phi_psi[0] = (0, phi_psi[0][1])
                phi_psi[-1] = (phi_psi[-1][0], 0)
                phi_psi = np.array(phi_psi)
                try:
                    sin_phi = np.sin(phi_psi[:,0])
                except TypeError:
                    phi_psi = np.array([[0 if value is None else value for value in item] for item in phi_psi])
                    sin_phi = np.sin(phi_psi[:,0])
                cos_phi = np.cos(phi_psi[:,0])
                sin_psi = np.sin(phi_psi[:,1])
                cos_psi = np.cos(phi_psi[:,1])
                sin_cos_poly = np.column_stack((sin_phi,cos_phi,sin_psi, cos_psi))
                angles_trig.append(sin_cos_poly)
        return np.vstack(angles_trig), seq

    def _get_node_labels(self, seq_dict, file):
        y = torch.zeros((len(seq_dict), len(sites_dict.keys())), dtype=torch.long)
        for index, row in test_pdb_df[test_pdb_df['PDBID']== os.path.splitext(os.path.basename(file))[0]][['SITE_TYPE','AminoAcidsWithChains','POS_CHAINS']].iterrows():
            pos_chain = eval(row['POS_CHAINS'])[0]
            if pos_chain == "YES":
                for amino_acid, position, chain in eval(row['AminoAcidsWithChains']):
                    key = (position, chain)
                    # print(key)
                    if key in seq_dict and seq_dict[key][0] == amino_acid:
                        # print(key)
                        s_type = sites_dict[row['SITE_TYPE']]
                        # print(s_type)
                        y[seq_dict[key][1], s_type] = 1
            else:
                for amino_acid, position in eval(row['AminoAcidsWithChains']):
                    new_seq_dict = {key[0]:val for  key,val in zip(seq_dict.keys(),seq_dict.values())}
                    key = str(position)
                    if key in new_seq_dict.keys() and new_seq_dict[key][0] == amino_acid:
                        print(key, new_seq_dict[key], amino_acid)
                        s_type = sites_dict[row['SITE_TYPE']]
                        y[new_seq_dict[key][1], s_type] = 1
        return y
    

In [None]:
test_dataset = Test_Graph_Dataset(root=r"D:/Proteins/test")

In [30]:
from torch_geometric.loader import DataLoader
import torch_geometric

In [14]:
def get_target_index(y):
    for i, val in enumerate(y):
        # print(i, val)
        if val == 1:
            return i
    return len(y)

In [15]:
for data in train_dataset:
    data.y = torch.tensor([get_target_index(node) for node in data.y], dtype=torch.long)
for data in val_dataset:
    data.y = torch.tensor([get_target_index(node) for node in data.y], dtype=torch.long)
for data in test_dataset:
    data.y = torch.tensor([get_target_index(node) for node in data.y], dtype=torch.long)

# Обучение графовой нейронной сети

In [31]:
train_loader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=32)
test_loader = DataLoader(dataset=test_dataset, batch_size=32)
val_loader = DataLoader(dataset=val_dataset, batch_size=32)

In [17]:
from sklearn.metrics import accuracy_score, classification_report
def calculate_metrics(true_labels, predicted_labels):
    # Преобразуйте метки в плоские массивы
    true_labels_flat = torch.cat(true_labels).numpy()
    predicted_labels_flat = torch.cat(predicted_labels).numpy()
    
    # Расчет метрик
    accuracy = accuracy_score(true_labels_flat, predicted_labels_flat)
    
    # Расчет метрик для каждого класса
    report = classification_report(true_labels_flat, predicted_labels_flat, output_dict=True)
    
    return accuracy, report

In [18]:
import matplotlib.pyplot as plt
def plot_loss(train_dict, val_dict):
    epochs_train = list(train_dict.keys())
    losses_train = list(train_dict.values())
    losses_val = list(val_dict.values())
    
    # Создание графика
    plt.figure(figsize=(15, 8))
    plt.plot(epochs_train, losses_train, marker='o')
    plt.plot(epochs_train, losses_val, marker='^')
    plt.title('Потери по эпохам')
    plt.xlabel('Эпоха')
    plt.ylabel('Потеря')
    plt.grid(True)
    plt.show()

In [19]:
from torch_geometric.data import Batch
from torch_geometric.nn import BatchNorm, SAGEConv
import torch
from torch_geometric.nn import SAGEConv
from torch_geometric.data import Data
import torch.nn as nn
import torch.nn.functional as F
class Net(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.batch_norms = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.batch_norms.append(BatchNorm(hidden_channels))
        self.linear1 = nn.Linear(hidden_channels, out_channels)
        for _ in range(num_layers - 1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            self.batch_norms.append(BatchNorm(hidden_channels))

    def forward(self, x, edge_index, get_embedds=False):
        for conv, batch_norm in zip(self.convs, self.batch_norms):
            x = conv(x, edge_index)
            x = batch_norm(x)
            x = F.leaky_relu(x)
            x = F.dropout(x, p=0.2, training=self.training)
        if get_embedds:
            return x
        else:
            return self.linear1(x)


In [27]:
from sklearn.utils.class_weight import compute_class_weight
classes = torch.cat([data.y for data in train_loader]).numpy()
classes_unique = np.unique(classes)
class_weights = compute_class_weight(class_weight = 'balanced', y = classes, classes = classes_unique)

In [36]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_layers = 3
model = Net(in_channels=train_dataset.num_features, hidden_channels=20,
            out_channels=train_dataset.num_classes, num_layers=num_layers).to(device)
loss_op = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights).float())
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

In [29]:
import time

In [None]:
def train():
    model.train()

    total_loss = 0
    for i, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        loss = loss_op(model(data.x, data.edge_index).float(), data.y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)


@torch.no_grad()
def validate(loader):
    model.eval()
    val_loss = 0
    for i, data in enumerate(loader):
        data.to(device)
        outputs = model(data.x, data.edge_index)        
        loss = loss_op(outputs.float(), data.y)
        val_loss += loss.item()
        
        # _, predicted = torch.max(outputs, dim=1)
        # predicted_labels.append(predicted)
    
    val_loss_avg = val_loss / len(loader)
    return val_loss_avg
    # Рассчитайте метрики
    # accuracy, report = calculate_metrics(val_true, predicted_labels)

def predict_labels(data):
        model.eval()
        predicted_labels_list = []
        with torch.no_grad():
            for graph in data:
                outputs = model(graph.to(device).x, graph.to(device).edge_index)  # Передаем каждый граф отдельно
                
                predicted_labels = torch.argmax(outputs, dim=1)
                
                predicted_labels_list.append(predicted_labels)
        
        return predicted_labels_list
        
def get_embedds(data):
    model.eval()
    predicted_labels_list = []
    with torch.no_grad():
        for graph in data:
            outputs = model(graph.to(device).x, graph.to(device).edge_index, get_embedds=True)  # Передаем каждый граф отдельно
            predicted_labels_list.append(outputs)
    
    return predicted_labels_list


times = []
train_losses = {}
val_losses = {}
for epoch in range(1, 51):
    start = time.time()
    loss = train()
    checkpoint = {
        'epoch': epoch,  # количество выполненных эпох
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,  # значение функции потерь
    }
    os.makedirs(f'D:/Proteins/model_num_layers{num_layers}', exist_ok=True)
    torch.save(checkpoint, f'D:/Proteins/model_num_layers{num_layers}/pitstop{epoch}.pth')
    train_losses[epoch] = loss
    val_l = validate(val_loader)
    val_losses[epoch] = val_l
    # test_f1 = test(test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_l:.4f}')
    if epoch % 10 == 0:
        plot_loss(train_losses, val_losses)
    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
plot_loss(train_losses, val_losses)

In [51]:
train_embeds = get_embedds(train_loader)

In [55]:
train_embeds = torch.cat(train_embeds,0).numpy()

In [57]:
train_embeds_df = pd.DataFrame(train_embeds)

In [62]:
train_embeds_df['target'] = torch.cat(train_true).numpy()

In [63]:
train_embeds_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11,12,13,14,15,16,17,18,19,target
0,0.126747,0.0884,-0.000347,0.135243,0.228443,0.035914,0.05114,0.113129,0.541479,0.163079,...,-0.001117,0.018419,0.070321,0.038027,0.152599,-2.2e-05,0.126592,0.080979,-0.000542,11
1,-0.003797,-0.001753,0.048125,-0.001851,-0.000344,-0.000793,0.036572,-0.003471,-0.006198,0.333027,...,-0.000665,-0.007981,0.14804,-0.001369,0.246156,-0.003534,0.161483,0.164086,0.213851,11
2,0.383381,0.180651,-0.00015,0.042861,0.036018,0.086786,-0.001451,0.3257,-3.4e-05,0.394632,...,-0.005355,-0.01193,0.192516,0.278623,-0.000201,0.281599,0.145916,0.145331,-0.005085,11
3,0.322193,-0.005495,0.033965,-0.00324,-6.8e-05,-0.004374,-0.001248,0.296459,-0.006078,0.001239,...,0.20046,0.013263,-0.000394,-0.004697,0.237077,0.223938,0.112776,0.131771,-0.000244,11
4,0.277039,0.256184,-0.001456,0.406127,0.42355,-0.002068,0.157232,0.28783,-0.005377,-0.004623,...,0.325236,0.041335,-0.006421,-0.004293,0.029904,0.173493,-0.001518,-0.001641,-0.00046,11


In [64]:
test_embeds = get_embedds(test_loader)
test_embeds = torch.cat(test_embeds,0).numpy()
test_embeds_df = pd.DataFrame(test_embeds)
test_embeds_df['target'] = torch.cat(y_true).numpy()
test_embeds_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,11,12,13,14,15,16,17,18,19,target
0,0.403704,-0.001924,0.004701,0.126324,0.413402,-0.002392,0.146437,0.392951,-0.002192,0.231975,...,0.107337,-0.000375,-0.003272,-0.002262,0.009248,0.261667,-0.001019,-0.002223,-0.000816,11
1,0.419015,-0.002151,0.093981,-0.0052,0.453908,-0.00053,0.074189,0.394629,-0.019187,0.570207,...,0.30864,-0.008045,-0.000411,-0.001636,0.053601,0.392125,0.029525,0.01276,-0.002089,11
2,0.393625,0.00415,0.029989,-0.00275,0.327608,0.101878,0.112841,0.33722,-0.005666,0.474712,...,0.024433,-0.006729,0.244288,0.202892,0.099007,0.356002,0.150365,0.167605,0.036087,11
3,-0.011047,-0.010926,-0.02176,-0.019147,-0.007775,-0.013697,0.075258,-0.009342,-0.009849,0.459967,...,-0.007403,-0.009316,0.387904,0.121147,0.35563,-0.008478,0.181939,0.348604,0.050303,11
4,-0.010412,-0.008552,-0.013178,-0.010227,-0.008046,-0.00732,0.100035,-0.009387,-0.003819,0.25087,...,-0.005823,-0.01012,0.384825,0.244414,0.281254,-0.008412,0.199354,0.257666,0.135895,11


In [65]:
from sklearn.ensemble import GradientBoostingClassifier

In [None]:
gb_model = GradientBoostingClassifier(
    n_estimators=200,    # Количество деревьев
    learning_rate=0.1,   # Скорость обучения
    max_depth=3,         # Макс. глубина деревьев
    random_state=42
)

gb_model.fit(train_embeds_df.iloc[:,:-1],train_embeds_df.iloc[:,-1])

# 5. Прогнозирование и оценка
y_pred = gb_model.predict(test_embeds_df.iloc[:,:-1])
print("Classification Report:")
print(classification_report(test_embeds_df.iloc[:,-1], y_pred))

In [70]:
train_embeds_df.to_csv('train_embedds_df.csv', index=False)
test_embeds_df.to_csv('test_embedds_df.csv', index=False)