In [None]:
import torch
from torch_geometric.data import Data
import networkx as nx
import numpy as np
import os
import pickle
from torch_geometric.loader import DataLoader

In [2]:
os.chdir('..')

Преобразование графа в нужный формат

In [3]:
def nx_to_pyg_data(G, node_characteristics, label):
    node_mapping = {node: i for i, node in enumerate(G.nodes())}
    
    # Преобразуем рёбра с учётом отображения узлов
    edge_index = torch.tensor(
        [(node_mapping[src], node_mapping[dst]) for src, dst in G.edges()],
        dtype=torch.long
    ).t().contiguous()
    
    # Преобразуем признаки узлов в тензор
    x = torch.tensor(node_characteristics, dtype=torch.float)
    
    # Преобразуем метку в тензор
    y = torch.tensor([label], dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index, y=y)

Графы для обучения

In [4]:
small_components_all = {}
small_components_families_all = {}
largest_component_all = {}
largest_component_families_all = {}
G_all = {}
node2name_dict_all = {}
name2seq_all = {}
node_freq_dict_all = {}
node_family_dict_all = {}

In [5]:
# Функция для загрузки данных из файла
def load_data_from_file(filename):
    try:
        with open(filename, 'rb') as f:
            return pickle.load(f)
    except Exception as e:
        print(f"Error loading {filename}: {e}")
        return None

def process_components(name):
    small_components_filename = os.path.join('data', f'data_{name}.pkl')
    if os.path.exists(small_components_filename):
        data = load_data_from_file(small_components_filename)
        if data:
            G_all[name] = data['graph']
            node_freq_dict_all[name] = data['node_freq_dict']
            node_family_dict_all[name] = data['node_family_dict']
        return True
    return False


# Процесс получения меток для маленьких компонент
def process_small_components(name):
    small_components_families_filename = os.path.join('data', f'small_components_families_{name}.pkl')
    if os.path.exists(small_components_families_filename):
        data = load_data_from_file(small_components_families_filename)
        if data:
            small_components_families_all[name] = data['families']
            small_components_all[name] = data['components']
        return True
    return False

# Процесс получения кластеров и меток для большой компоненты
def process_largest_component(name):
    largest_component_filename = os.path.join('data', f'largest_component_clusters_{name}.pkl')
    if os.path.exists(largest_component_filename):
        data = load_data_from_file(largest_component_filename)
        if data:
            largest_component = data['clusters']
            largest_component_families = data['families']
            
            largest_component_all[name] = largest_component
            largest_component_families_all[name] = largest_component_families
        return True
    return False

def process_dicts(name):
    dicts_filename = os.path.join('data', f'dicts_{name}.pkl')
    if os.path.exists(dicts_filename):
        data = load_data_from_file(dicts_filename)
        if data:
            node2name_dict_all[name] = data['node2name_dict']
            name2seq_all[name] = data['name2seq']

        return True
    return False
    

In [6]:
names = ['juncea', 'nigra', 'rapa']
for name in names:
    if not process_small_components(name):
        print(f"File for {name} (small components families) not found.")
    
    if not process_largest_component(name):
        print(f"File for {name} (largest component) not found.")
        
    if not process_components(name):
        print(f"File for {name} (all components) not found.")
    
    if not process_dicts(name):
        print(f"File for {name} (all dicts) not found.")

In [None]:
from funcs.embeddings import node_characteristics

In [8]:
filename = 'data/kmers.pkl'
with open(filename, 'rb') as f:
    random_kmers = pickle.load(f)
    
families_to_filter = ['LTR', 'Helitron', 'DNA/MuDR', 'LINE']
families_dict = {type_str: idx for idx, type_str in enumerate(families_to_filter)}

In [10]:
families_dict

{'LTR': 0, 'Helitron': 1, 'DNA/MuDR': 2, 'LINE': 3}

In [13]:
predictions = {}

for name in names:
    components = list(small_components_all[name]) + list(largest_component_all[name])
    families = list(small_components_families_all[name]) + list(largest_component_families_all[name])

    embeddings = []
    
    for component_nodes in components:
        G_sub = G_all[name].subgraph(component_nodes)
        
        node_embeddings = []
        for node in component_nodes:
            embedding = node_characteristics(node, G_sub, node_freq_dict_all[name], random_kmers, node2name_dict_all[name], name2seq_all[name])
            node_embeddings.append(embedding)
        
        embeddings.append(node_embeddings)
    G = G_all[name]
    
    
    train_data = []

    for i in range(len(components)):
        component_nodes = components[i]
        G_sub = G.subgraph(component_nodes)
    
        # G_sub_new = G_new.subgraph([vertex_map[x] for x in component_nodes])
    
        target = [0] * len(families_to_filter)
    
        graph_embedding = nx_to_pyg_data(G_sub, embeddings[i], target)
        train_data.append(graph_embedding)
    

    loader = DataLoader(train_data, batch_size=1, shuffle=True)
    
    # Загрузка модели
    model_filename = 'models_files/gnn.pkl'
    with open(model_filename, 'rb') as f:
        model = pickle.load(f)
    
    print("Модель успешно загружена.")
    
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(loader):
            new_data = batch
            output = model(new_data.x, new_data.edge_index, new_data.batch)
    
            prediction = output.argmax(dim=1)
            prediction_int = prediction.item()
            if name not in predictions:
                predictions[name] = []
            predictions[name].append(prediction_int)
            
            
            if prediction_int == 3:
                print(components[i])
                break

Модель успешно загружена.
{'R30342', 'R14592', 'N4246', 'R14587', 'R37769', 'R26845', 'N4058', 'R14589', 'R31237', 'R28350', 'N3552', 'R14586', 'R250', 'R29677', 'R14585', 'R14590', 'N3627', 'N5997', 'N2253', 'R14588', 'R16869', 'R37768', 'R14591', 'R4187'}
Модель успешно загружена.
{'R1059', 'N21', 'R293'}
Модель успешно загружена.
{'R1493', 'N277', 'R1769', 'R1767', 'N1141', 'N18', 'R546', 'R1768'}


In [None]:
for node in {'R30342', 'R14592', 'N4246', 'R14587', 'R37769', 'R26845', 'N4058', 'R14589', 'R31237', 'R28350', 'N3552', 'R14586', 'R250', 'R29677', 'R14585', 'R14590', 'N3627', 'N5997', 'N2253', 'R14588', 'R16869', 'R37768', 'R14591', 'R4187'}:
    node_name = node2name_dict_all['juncea'][node]
    seq = name2seq_all['juncea'][node_name]
    if len(seq) >= 5000 and len(seq) < 6000:
        print('>seq')
        print(seq)
    