In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import networkx as nx
import matplotlib.pyplot as plt
import torch.optim as optim
from scipy import sparse as sp
import random
from graphviz import Graph
import pickle
import numpy as np
import heapq

device = torch.device('cuda')

In [2]:
# Encoder
class GraphAttentionLayer(torch.nn.Module):
    def __init__(self, in_features, out_features, n_heads, is_concat = True, dropout = 0.6, leacky_relu_negative_slope = 0.2):
        super(GraphAttentionLayer, self).__init__()
        self.W = torch.nn.Parameter(torch.randn(in_features, out_features))
        self.is_concat = is_concat
        self.n_heads = n_heads

        if is_concat:
            assert out_features % n_heads == 0

            self.n_hidden = out_features // n_heads
        else:
            self.n_hidden = out_features

        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias = False)

        self.attn = nn.Linear(self.n_hidden * 2, 1, bias = False)
        self.activation = nn.LeakyReLU(negative_slope = leacky_relu_negative_slope)
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(dropout) 

    def forward(self, x, adj):
        n_nodes = x.shape[0]
        g=self.linear(x).view(n_nodes, self.n_heads, self.n_hidden)
        g_repeat = g.repeat(n_nodes, 1,1)
        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim = -1)
        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
        e = self.activation(self.attn(g_concat))
        e = e.squeeze(-1)
        adj = adj.unsqueeze(2)
        adj = adj.repeat(1, 1, self.n_heads)
        assert adj.shape[0] == 1 or adj.shape[0] == n_nodes
        assert adj.shape[1] == 1 or adj.shape[1] == n_nodes
        assert adj.shape[2] == 1 or adj.shape[2] == self.n_heads
        e=e.masked_fill(adj == 0, 1)
        a = self.softmax(e)
        a = self.dropout(a)
        attn_res = torch.einsum('ijh,jhf->ihf', a, g)
        if self.is_concat:
            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
        else:
            attn_res = attn_res.mean(dim=1)
            return attn_res


In [3]:
def is_connected(curr_node, neighbor):
    return adj_matrix[curr_node][neighbor] > 0

In [32]:
# Decoder
class Decoder(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(Decoder, self).__init__()
        self.n_heads = n_heads
        self.hidden_features = hidden_features
        self.d_h = d_h
        self.in_features = in_features

        self.phi1 = torch.nn.Linear(d_h, 1)
        self.phi2 = torch.nn.Linear(d_h, 1)
        self.softmax = nn.Softmax(dim=0)
        self.C = torch.nn.Parameter(torch.randn(1)) # constant C
        self.activation = nn.Tanh()
    def forward(self, adj_matrix, node_weights, start_node):
        # 그래프의 노드 수
        num_nodes = node_weights.size(0)

        # 시작 노드로부터의 최단 거리와 이전 노드를 저장할 배열
        distances = torch.zeros((num_nodes,))
        distances[start_node] = 0

        # 이전 노드를 저장할 배열
        prev_nodes = [None] * num_nodes

        # adjacency matrix와 node_weights를 사용하여 가중치 행렬 계산
        attn = torch.diag(node_weights.squeeze(1))
        weight = torch.matmul(attn, adj_matrix)

        # parameter vectors phi1, phi2
        phi1 = self.phi1.state_dict()['weight']
        phi2 = self.phi2.state_dict()['weight']

        # 다익스트라 알고리즘을 위한 우선순위 큐
        pq = [(0, start_node)]

        while pq:
            curr_dist, curr_node = heapq.heappop(pq)

            # 현재 노드로부터의 거리가 더 긴 경우 건너뛰기
            if curr_dist > distances[curr_node]:
                continue

            # 현재 노드와 연결된 모든 노드들에 대해 최단 거리 갱신
            for neighbor in range(num_nodes):
                curr_weight = weight[curr_node][neighbor]
                if curr_weight > 0:
                    v_i = node_weights[curr_node]
                    v_j = node_weights[neighbor]
                    phi1_v_i = torch.matmul(v_i, phi1)
                    phi2_v_j = torch.matmul(v_j, phi2)
                    new_weight = torch.matmul(phi1_v_i, phi2_v_j) / (self.d_h ** 0.5)
                    new_weight = self.C * self.activation(new_weight)
                    masked_new_weight = torch.where(v_j == 0, float('-inf'), new_weight)
                    softmax_new_weight = self.softmax(masked_new_weight)

                    new_dist = curr_dist + softmax_new_weight

                    # 최단 거리 갱신
                    if new_dist < distances[neighbor]:
                        distances[neighbor] = new_dist
                        prev_nodes[neighbor] = curr_node  # 이전 노드 저장
                        heapq.heappush(pq, (new_dist, neighbor))

        # 시작 노드로부터의 최단 거리와 이전 노드를 반환
        return distances, prev_nodes, weight


In [33]:
# import heapq

# def find_max_weight_path(adj_matrix, node_weights, start_node):
#     num_nodes = len(node_weights)
#     distances = [float('-inf')] * (num_nodes + 1)
#     distances[start_node] = node_weights[start_node]  # 시작 노드의 거리를 0으로 설정
#     prev_nodes = [None] * (num_nodes + 1)

#     pq = [(-node_weights[start_node], start_node)]

#     while pq:
#         curr_weight, curr_node = heapq.heappop(pq)

#         if curr_weight < distances[curr_node]:
#             continue

#         for neighbor in range(num_nodes):
#             edge_weight = adj_matrix[curr_node][neighbor]

#             if edge_weight > 0:
#                 new_weight = curr_weight + node_weights[neighbor]

#                 if new_weight > distances[neighbor]:
#                     distances[neighbor] = new_weight
#                     prev_nodes[neighbor] = curr_node
#                     heapq.heappush(pq, (new_weight, neighbor))

#     max_weight = distances[start_node]  # 시작 노드의 거리를 최대 가중치로 설정
#     max_weight_index = start_node

#     for i in range(num_nodes + 1):
#         if distances[i] > max_weight:
#             max_weight = distances[i]
#             max_weight_index = i

#     if max_weight_index != start_node:
#         path = []
#         node = max_weight_index

#         while node is not None:
#             path.append(node)
#             node = prev_nodes[node]

#         path.reverse()

#         return max_weight, path

#     return None, None


In [34]:
# import heapq
# import torch

# def dijkstra(graph_data, start_node):
#     num_nodes = graph_data.size(0)  # 노드의 개수

#     # 초기화
#     distance = torch.zeros((num_nodes,))
#     visited = torch.zeros((num_nodes,), dtype=torch.bool)
#     previous = torch.zeros((num_nodes,), dtype=torch.long) - 1

#     # 시작 노드 설정
#     distance[start_node] = 0
#     queue = [(0, start_node)]

#     while queue:
#         # 가장 최단 거리를 가진 노드 선택
#         dist, current_node = heapq.heappop(queue)

#         # 이미 방문한 노드인 경우 건너뛰기
#         if visited[current_node]:
#             continue

#         visited[current_node] = True

#         # 현재 노드와 연결된 노드들을 탐색
#         for node in range(num_nodes):
#             weight = graph_data[current_node, node]
#             if weight > 0:
#                 new_distance = distance[current_node] + weight

#                 # 더 짧은 거리를 발견한 경우 업데이트
#                 if new_distance < distance[node]:
#                     distance[node] = new_distance
#                     previous[node] = current_node
#                     heapq.heappush(queue, (new_distance, node))

#     return distance, previous


# def get_shortest_path(start_node, end_node, previous):
#     path = []
#     current_node = end_node
#     while current_node != -1:
#         path.insert(0, current_node)
#         current_node = previous[current_node]

#     return path


# def find_shortest_path(graph_data, start_node):
#     distances, previous = dijkstra(graph_data, start_node)

#     # 최단 거리 중 가장 작은 값을 가지는 경로 탐색
#     min_distance = float('inf')
#     shortest_path = []
#     for node in range(len(graph_data)):
#         if node != start_node:
#             path = get_shortest_path(start_node, node, previous)
#             distance = distances[node]
#             if distance < min_distance:
#                 min_distance = distance
#                 shortest_path = path

#     return shortest_path


# # 사용 예시
# graph_data = torch.tensor([[0.0, 2.0, 0.0, 1.0],
#                            [2.0, 0.0, 0.0, 0.0],
#                            [0.0, 0.0, 0.0, 0.0],
#                            [1.0, 0.0, 0.0, 0.0]], dtype=torch.float32)

# start_node = 0

# shortest_path = find_shortest_path(graph_data, start_node)

# print("시작 노드로부터의 최단 거리:")
# for node, distance in enumerate(distances):
#     print(f"노드 {node}: {distance}")

# print("시작 노드로부터의 최단 경로:")
# for node in range(len(graph_data)):
#     if node != start_node:
#         path = get_shortest_path(start_node, node, previous)
#         print(f"노드 {node}: {path}")

# print("최단 거리가 가장 작은 경로:")
# print(shortest_path)


In [35]:
class GAT(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(GAT, self).__init__()
        self.n_heads = n_heads
        self.attention1 = GraphAttentionLayer(in_features, hidden_features, n_heads, is_concat = True, dropout = dropout)
        self.attention2 = GraphAttentionLayer(hidden_features, out_features, 1, is_concat = False, dropout = dropout)
        self.norm= nn.LayerNorm(out_features)
        self.decoder = Decoder(in_features, hidden_features, out_features, n_heads, d_h)
    
    def encode(self, x, adj):
        x = self.attention1(x, adj)
        x = self.attention2(x, adj)
        x = self.norm(x)
        output = F.softmax(x, dim=0)
        output = torch.mean(output, dim=1)
        output = output.unsqueeze(0)
        output = output.transpose(0,1)
        return output
    
    def decode(self, adj_matrix, node_weights, start_node):
        return self.decoder(adj_matrix, node_weights, start_node)

In [36]:
in_features =  1
n_heads = 4

def generate_random_weighted_graph(num_nodes, num_edges, max_weight=10):
    # 방향 그래프 생성
    graph = nx.Graph()
    
    # 노드 추가
    nodes = range(num_nodes)
    graph.add_nodes_from(nodes)
    
    # 노드에 가중치 할당 및 노드 특징 벡터 생성
    x = torch.zeros(num_nodes, in_features)
    for node in graph.nodes:
        weight = random.randint(1, max_weight)
        graph.nodes[node]['weight'] = weight
        x[node] = weight

    # 간선 추가
    edges = []
    for i in range(num_edges):
        # 임의의 출발 노드와 도착 노드 선택
        source = random.choice(nodes)
        target = random.choice(nodes)
        
        # 출발 노드와 도착 노드가 같은 경우 건너뜀
        if source == target:
            continue
        
        # 간선 추가
        edges.append((source, target))

    graph.add_edges_from(edges)

    graph_original = graph

        # Generate v_prev tensor
    j = random.randint(0, num_nodes-1)

    adj_matrix = nx.adjacency_matrix(graph)
    adj_matrix_original = adj_matrix
    adj_matrix_original = torch.Tensor(adj_matrix_original.todense())
    adj_matrix = adj_matrix + sp.eye(adj_matrix.shape[0]) # Add self-loop
    adj_tensor = torch.Tensor(adj_matrix.todense())

    # adj_tensor = adj_tensor.unsqueeze(2) # adj_tensor (num_nodes, num_nodes, n_heads)
    # adj_tensor = adj_tensor.repeat(1, 1, n_heads) #
    
    return graph, x, adj_tensor, adj_matrix_original, graph_original, j

In [37]:
num_graphs = 100
output_file = 'random_undirected_graphs.pkl'

graphs = []

for _ in range(num_graphs):
    num_nodes, num_edges, max_weight = np.random.randint(2,20), np.random.randint(10,30), np.random.randint(1,30)
    graph, x, adj_tensor, adj_matrix_original, graph_original, j= generate_random_weighted_graph(num_nodes, num_edges, max_weight)
    graphs.append((x, adj_tensor, j, adj_matrix_original))


# 그래프를 pickle 파일로 저장
with open(output_file, 'wb') as f:
    pickle.dump(graphs, f)

  adj_matrix = nx.adjacency_matrix(graph)


In [38]:
# pickle 파일에서 그래프 데이터 로드
with open('random_undirected_graphs.pkl', 'rb') as f:
    graphs = pickle.load(f)

In [42]:
hidden_features = 4 * n_heads
out_features = n_heads
d_h = 4 * n_heads
dropout = 0.6
gat_model = GAT(in_features, hidden_features, out_features, n_heads, d_h).cuda()
gat_models = []
results = {}
for graph_idx, (x, adj_tensor, j, adj_matrix_original) in enumerate(graphs):
    gat_models.append(gat_model)
    x = x.cuda()
    adj_tensor = adj_tensor.cuda()
    output = gat_model.encode(x, adj_tensor)
    print(f"Graph {graph_idx+1} - Output:")
    print(output.shape)
    start_node = 0
    distance, permutation,weight = gat_model.decode(adj_matrix_original.cuda(), output, start_node)
    print(f"Graph {graph_idx+1} - Distance:")
    print(distance)
    print(permutation)

Graph 1 - Output:
torch.Size([18, 1])
Graph 1 - Distance:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
Graph 2 - Output:
torch.Size([18, 1])
Graph 2 - Distance:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
Graph 3 - Output:
torch.Size([3, 1])
Graph 3 - Distance:
tensor([0., 0., 0.])
[None, None, None]
Graph 4 - Output:
torch.Size([17, 1])
Graph 4 - Distance:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
[None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None]
Graph 5 - Output:
torch.Size([10, 1])
Graph 5 - Distance:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
[None, None, None, None, None, None, None, None, None, None]
G

In [None]:
gat_model = GAT(in_features, hidden_features, out_features, n_heads, d_h).cuda()
gat_models = []
results = {}
for graph_idx, (x, adj_tensor, j, adj_matrix_original) in enumerate(graphs):
    gat_models.append(gat_model)
    x = x.cuda()
    adj_tensor = adj_tensor.cuda()
    output = gat_model.encode(x, adj_tensor)
    print(f"Graph {graph_idx+1} - Output:")
    print(output.shape)
    print(output)
    start_node = 0
    print(adj_tensor)
    max_weight, path = find_max_weight_path(adj_matrix_original, output, start_node)
    print("Max Weight:", max_weight)
    print("Path:", path)
    


참고
https://chioni.github.io/posts/gat/

In [None]:
# Decoder
class Decoder(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(Decoder, self).__init__()
        self.n_heads = n_heads
        self.hidden_features = hidden_features
        self.d_h = d_h
        self.in_features = in_features

        self.phi1 = torch.nn.Linear(d_h, 1)
        self.phi2 = torch.nn.Linear(d_h, 1)
        self.softmax = nn.Softmax(dim=1)
        self.C = torch.nn.Parameter(torch.randn(1)) # constant C
        self.activation = nn.Tanh()
    def forward(self, x, adj_matrix_original):
        selected_nodes = []  # SSP를 저장할 리스트
        
        # 첫 번째 노드 선택
        start_node = 0
        selected_nodes.append(start_node)
        current_node = start_node

        # 두 번째 이후의 노드 선택
        while True:
            max_value = float('-inf')
            next_node = None

            for neighbor in v_j:
                v_i = output[current_node,0]
                v_j = adj_matrix_original[-1,current_node].cuda() * x
                
            v_i = v_i + next_node  # v_i + v_2의 특성 값을 계산
            v_i.unsqueeze(0)
            phi1_v_i = torch.matmul(v_i, self.phi1.state_dict()['weight'])  # 갱신된 phi1_v_i 계산
            phi2_v_j = torch.matmul(v_j, self.phi2.state_dict()['weight'])

            attn_input = torch.matmul(phi1_v_i, phi2_v_j.transpose(0, 1)) / (self.d_h ** 0.5)
            attn_output = self.C * self.activation(attn_input)
            masked_attn_output = torch.where(v_j == 0, float('-inf'), attn_output)
            masked_attn_output = masked_attn_output[0]
            masked_attn_output = masked_attn_output.unsqueeze(0)

            attn_weights = self.softmax(masked_attn_output)

            # 다음 노드 선택
            next_node_index = torch.argmax(attn_weights)  # attn_weights에서 최대값의 인덱스를 찾음
            next_node = v_j[next_node_index]  # 해당 인덱스에 해당하는 노드를 선택

            # SSP 종료 조건: v_i와 next_node를 이어주는 엣지가 없을 경우
            if next_node not in v_j:
                break

            selected_nodes.append(next_node)  # 선택한 노드를 selected_nodes에 추가

        return selected_nodes