In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch_geometric.graphgym import train
from torch_geometric.nn import GATConv
from torch_geometric.data import Data
import networkx as nx
import numpy as np
from torch_geometric.data import DataLoader
import os
import pickle



In [2]:
os.chdir('..')

Преобразование графа в нужный формат

In [3]:
def nx_to_pyg_data(G, node_characteristics, label):
   # Получаем числовые индексы для рёбер
    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    
    # Фичи для каждого узла
    x = torch.tensor(node_characteristics)
    y = torch.tensor([label])

    return Data(x=x, edge_index=edge_index, y=y)

Графы для обучения

In [4]:
small_components_all = {}
small_components_families_all = {}
largest_component_all = {}
largest_component_families_all = {}

In [5]:
# Функция для загрузки данных из файла
def load_data_from_file(filename):
    try:
        with open(filename, 'rb') as f:
            return pickle.load(f)
    except Exception as e:
        print(f"Error loading {filename}: {e}")
        return None

# Процесс получения меток для маленьких компонент
def process_small_components(name):
    small_components_families_filename = os.path.join('data', f'small_components_families_{name}.pkl')
    if os.path.exists(small_components_families_filename):
        data = load_data_from_file(small_components_families_filename)
        if data:
            small_components_families_all[name] = data['families']
            small_components_all[name] = data['components']
        return True
    return False

# Процесс получения кластеров и меток для большой компоненты
def process_largest_component(name):
    largest_component_filename = os.path.join('data', f'largest_component_clusters_{name}.pkl')
    if os.path.exists(largest_component_filename):
        data = load_data_from_file(largest_component_filename)
        if data:
            largest_component = data['clusters']
            largest_component_families = data['families']
            
            largest_component_all[name] = largest_component
            largest_component_families_all[name] = largest_component_families
        return True
    return False
    

In [6]:
name = 'arab'
if not process_small_components(name):
    print(f"File for {name} (small components families) not found.")

if not process_largest_component(name):
    print(f"File for {name} (largest component) not found.")

In [7]:
pickle_filename = os.path.join('data', f'X_y_characteristics.pkl')
X = None
y = None
if os.path.exists(pickle_filename):
    try:
        with open(pickle_filename, 'rb') as f:
            data = pickle.load(f)
        
        X = data['X']
        y = data['y']
                
    except Exception as e:
        print(f"Error loading {name}: {e}")
else:
    print(f"File for {name} not found.")

In [8]:
G = nx.Graph()
    
# Загрузка графа
with open('data_arab/graph_collapse.txt', 'r') as file:
    for line in file:
        node1, node2 = line.strip().split()
        G.add_edge(node1, node2)

In [9]:
G_new = G.copy()
vertex_map = {old_vertex: idx for idx, old_vertex in enumerate(G.nodes())}
G_new.add_nodes_from(vertex_map.values())
for u, v in G.edges():
    G_new.add_edge(vertex_map[u], vertex_map[v])

In [10]:
families_to_filter = ['LTR', 'Helitron', 'DNA/MuDR', 'LINE']
families_dict = {type_str: idx for idx, type_str in enumerate(families_to_filter)}

In [11]:
train_data = []

components = list(small_components_all[name]) + list(largest_component_all[name])
families = list(small_components_families_all[name]) + list(largest_component_families_all[name])

for i in range(len(components)):
    if families[i] not in families_to_filter:
        continue
    component_nodes = components[i]
    G_sub = G.subgraph(component_nodes)

    G_sub_new = G_new.subgraph([vertex_map[x] for x in component_nodes])

    target = [0] * len(families_to_filter)
    target[families_dict[families[i]]] = 1

    graph_embedding = nx_to_pyg_data(G_sub_new, X[i], target)
    train_data.append(graph_embedding)

In [12]:
from torch_geometric.loader import DataLoader

# Создаем DataLoader для батчей
loader = DataLoader(train_data, batch_size=16, shuffle=True)

# Обучение

In [13]:
from models.gnn import GNN

In [None]:
model = GNN(in_channels=103, hidden_channels=16, out_channels=4)
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Функция потерь
criterion = nn.CrossEntropyLoss()

# Тренировочный цикл
for epoch in range(100):
    model.train()
    total_loss = 0
    for x_batch, edge_index_batch, y_batch, _, _ in loader:
        optimizer.zero_grad()
        
        # Обработка каждого графа в батче
        batch_loss = 0
        for x, edge_index, y in zip(x_batch, edge_index_batch, y_batch):
            out = model(x, edge_index)
            loss = criterion(out, y)
            batch_loss += loss
        
        batch_loss.backward()
        optimizer.step()
        
        total_loss += batch_loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(loader)}")