In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import networkx as nx
import matplotlib.pyplot as plt
import torch.optim as optim
from scipy import sparse as sp
import random
from graphviz import Graph
import pickle
import numpy as np
import heapq

device = torch.device('cuda')

In [2]:
# Encoder
class GraphAttentionLayer(torch.nn.Module):
    def __init__(self, in_features, out_features, n_heads, is_concat = True, dropout = 0.6, leacky_relu_negative_slope = 0.2):
        super(GraphAttentionLayer, self).__init__()
        self.W = torch.nn.Parameter(torch.randn(in_features, out_features))
        self.is_concat = is_concat
        self.n_heads = n_heads

        if is_concat:
            assert out_features % n_heads == 0

            self.n_hidden = out_features // n_heads
        else:
            self.n_hidden = out_features

        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias = False)

        self.attn = nn.Linear(self.n_hidden * 2, 1, bias = False)
        self.activation = nn.LeakyReLU(negative_slope = leacky_relu_negative_slope)
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(dropout) 

    def forward(self, x, adj):
        n_nodes = x.shape[0]
        g=self.linear(x).view(n_nodes, self.n_heads, self.n_hidden)
        g_repeat = g.repeat(n_nodes, 1,1)
        g_repeat_interleave = g.repeat_interleave(n_nodes, dim=0)
        g_concat = torch.cat([g_repeat_interleave, g_repeat], dim = -1)
        g_concat = g_concat.view(n_nodes, n_nodes, self.n_heads, 2 * self.n_hidden)
        e = self.activation(self.attn(g_concat))
        e = e.squeeze(-1)
        adj = adj.repeat(1, 1, self.n_heads)
        assert adj.shape[0] == 1 or adj.shape[0] == n_nodes
        assert adj.shape[1] == 1 or adj.shape[1] == n_nodes
        assert adj.shape[2] == 1 or adj.shape[2] == self.n_heads
        e=e.masked_fill(adj == 0, 1)
        a = self.softmax(e)
        a = self.dropout(a)
        attn_res = torch.einsum('ijh,jhf->ihf', a, g)
        if self.is_concat:
            return attn_res.reshape(n_nodes, self.n_heads * self.n_hidden)
        else:
            attn_res = attn_res.mean(dim=1)
            return attn_res

In [3]:
# Decoder
class Decoder(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(Decoder, self).__init__()
        self.n_heads = n_heads
        self.hidden_features = hidden_features
        self.d_h = d_h
        self.in_features = in_features

        self.phi1 = torch.nn.Linear(d_h, 1)
        self.phi2 = torch.nn.Linear(d_h, 1)
        self.softmax = nn.Softmax(dim=1)
        self.C = torch.nn.Parameter(torch.randn(1)) # constant C
        self.activation = nn.Tanh()

    def forward(self, output, next_node, adj_matrix_original, next_nodes):
        v_i = output[next_node, 0]
        v_j = adj_matrix_original[next_node, :]
        v_j = v_j.unsqueeze(1).cuda() * output
        
        v_i = v_i.unsqueeze(0)
        phi1_v_i = torch.matmul(v_i, self.phi1.state_dict()['weight'])

        phi2_v_j = torch.matmul(v_j, self.phi2.state_dict()['weight'])

        attn_input = torch.matmul(phi1_v_i, phi2_v_j.transpose(0, 1)) / (self.d_h ** 0.5)

        attn_output = self.C * self.activation(attn_input)
        #???????????????
        # adj_matrix_original[row_indices, :] = 1
        # adj_matrix_original[np.logical_not(np.isin(np.arange(adj_matrix_original.shape[0]), row_indices)), :] = 0
        # print(adj_matrix_original[next_node,:])
        adj_matrix_original = torch.where(torch.tensor(np.isin(torch.arange(adj_matrix_original.shape[1]), next_nodes)), torch.tensor(0), adj_matrix_original)
        # print("adj_matrix_origianl[next_node,:] :", adj_matrix_original.to(device)[next_node,:])

        attn_output = attn_output * adj_matrix_original.to(device)[next_node, :]
        attn_output = attn_output.unsqueeze(0)

        attn_weights = self.softmax(attn_output)
        attn_weights *= (attn_output != 0).float()
        # print(attn_weights)

        p= torch.max(attn_weights)
        next_node = torch.argmax(attn_weights)

        return next_node, p


In [4]:
class GAT(torch.nn.Module):
    def __init__(self, in_features, hidden_features, out_features, n_heads, d_h):
        super(GAT, self).__init__()
        self.n_heads = n_heads
        self.attention1 = GraphAttentionLayer(in_features, hidden_features, n_heads, is_concat = True, dropout = dropout)
        self.attention2 = GraphAttentionLayer(hidden_features, out_features, 1, is_concat = False, dropout = dropout)
        self.norm= nn.LayerNorm(out_features)
        self.decoder = Decoder(in_features, hidden_features, out_features, n_heads, d_h)
    
    def forward(self, x, adj):
        x = self.attention1(x, adj)
        x = self.attention2(x, adj)
        x = self.norm(x)
        output = F.softmax(x, dim=0)
        output = torch.mean(output, dim=1)
        output = output.unsqueeze(0)
        output = output.transpose(0,1)
        return output
    
    def decode(self, output, i, adj_matrix_original, next_nodes):
        return self.decoder(output, i, adj_matrix_original, next_nodes)

In [10]:
in_features =  1
n_heads = 4

def generate_random_weighted_graph(num_nodes, num_edges, max_weight=10):
    # 방향 그래프 생성
    graph = nx.Graph()
    
    # 노드 추가
    nodes = range(num_nodes)
    graph.add_nodes_from(nodes)
    
    # 노드에 가중치 할당 및 노드 특징 벡터 생성
    x = torch.zeros(num_nodes, in_features)
    for node in graph.nodes:
        weight = random.randint(1, max_weight)
        graph.nodes[node]['weight'] = weight
        x[node] = weight

    # # 최소 하나의 에지를 가지도록 보장하기 위해 모든 노드에 대해 연결되지 않은 노드를 선택하여 에지를 추가합니다.
    # connected_nodes = set()
    # for node in nodes:
    #     if node not in connected_nodes:
    #         # 현재 노드와 연결되지 않은 노드 선택
    #         unconnected_nodes = list(set(nodes) - connected_nodes - {node})
    #         if len(unconnected_nodes) == 0 :
    #             break
    #         target = random.choice(unconnected_nodes)

    #         # 에지 추가
    #         graph.add_edge(node, target)
    #         connected_nodes.add(node)
    #         connected_nodes.add(target)

    # 모든 노드를 하나의 연결된 컴포넌트로 만들기 위해 하나의 노드를 선택하여 그 노드와 나머지 노드들을 연결합니다.
    connected_nodes = set()
    start_node = random.choice(nodes)
    connected_nodes.add(start_node)
    for node in nodes:
        if node != start_node:
            graph.add_edge(start_node, node)
            connected_nodes.add(node)

    # 추가적인 랜덤 에지를 생성하여 num_edges를 충족시킵니다.
    additional_edges = num_edges - len(graph.edges)
    for _ in range(additional_edges):
        # 임의의 출발 노드와 도착 노드 선택
        source = random.choice(nodes)
        target = random.choice(nodes)

        # 출발 노드와 도착 노드가 같은 경우나 이미 연결된 경우 건너뜁니다.
        if source == target or graph.has_edge(source, target):
            continue

        # 에지 추가
        graph.add_edge(source, target)

    graph_original = graph

        # Generate v_prev tensor
    j = random.randint(0, num_nodes-1)

    adj_matrix = nx.adjacency_matrix(graph)
    adj_matrix_original = torch.Tensor(adj_matrix.todense())
    
    adj_matrix = adj_matrix + sp.eye(adj_matrix.shape[0]) # Add self-loop
    adj_tensor = torch.Tensor(adj_matrix.todense())

    adj_tensor = adj_tensor.unsqueeze(2) # adj_tensor (num_nodes, num_nodes, n_heads)
    # adj_tensor = adj_tensor.repeat(1, 1, n_heads) #

    # # Visualize the graph
    # nx.draw_networkx(graph, with_labels=True)
    # plt.show()
    
    return graph, x, adj_tensor, adj_matrix_original, graph_original, j

In [11]:
num_graphs = 100
output_file = 'random_undirected_graphs.pkl'

graphs = []

for _ in range(num_graphs):
    m = np.random.randint(2,7)
    num_nodes, num_edges, max_weight = np.random.randint(m,8), np.random.randint(m+1,9), np.random.randint(1,30)
    graph, x, adj_tensor, adj_matrix_original, graph_original, j= generate_random_weighted_graph(num_nodes, num_edges, max_weight)
    next_nodes = []
    graphs.append((x, adj_tensor, j, adj_matrix_original, next_nodes))


# 그래프를 pickle 파일로 저장
with open(output_file, 'wb') as f:
    pickle.dump(graphs, f)

  adj_matrix = nx.adjacency_matrix(graph)


In [12]:
# pickle 파일에서 그래프 데이터 로드
with open('random_undirected_graphs.pkl', 'rb') as f:
    graphs = pickle.load(f)

In [13]:
hidden_features = 4 * n_heads
out_features = n_heads
d_h = 4 * n_heads
dropout = 0.6
gat_model = GAT(in_features, hidden_features, out_features, n_heads, d_h).cuda()
gat_models = []
for graph_idx, (x, adj_tensor, j, adj_matrix_original, next_nodes) in enumerate(graphs):
    gat_models.append(gat_model)
    x = x.cuda()
    adj_tensor = adj_tensor.cuda()
    output = gat_model(x, adj_tensor)
    print(f"Graph {graph_idx+1} - Output:")
    print(output.shape)
    # print(output)
    n = output.size(0)
    next_nodes = [1] # next_node 값을 저장할 리스트
    branch_point = []
    row_indices = []
    possibilities=[]
    
    for i in range(n-1):
        next_node = next_nodes[i]
        row_indices = []
        row_indices = np.where(adj_matrix_original[:,next_node] == 1)[0]
        row_indices = np.setdiff1d(row_indices, next_nodes)
        # print("next_node", next_node)
        # print("next_nodes", next_nodes)
        # print("row_indices", row_indices)
        # print("!!", len(row_indices), type(row_indices))
        if len(row_indices) >= 2:
            branch_point.append(next_node)
            # print("branch_point has been updated")
        elif len(row_indices) == 0:
            if len(branch_point) > 0:
                next_node = branch_point.pop()
                # print("00000000000000000000",next_node)
                row_indices = np.where(np.isin(adj_matrix_original[:,next_node], next_nodes))[0]
                row_indices = np.setdiff1d(row_indices, next_nodes)
                # print("modified_row_indices:", row_indices)
                if len(row_indices) >= 2:
                    branch_point.append(next_node)
                    # print("22")
        
        next_node, p = gat_model.decode(output, next_node, adj_matrix_original, next_nodes)
        next_nodes.append(next_node.item())
        possibilities.append(p.item())
        # print("branch_point*******************************")
        # print(branch_point)

        print(f"Current Node: {next_nodes[i]}, Next Node: {next_nodes[i+1]}")
        


    print("Next Nodes:", next_nodes)
    print("")
    print("possibilities", possibilities)
    print("")
   


Graph 1 - Output:
torch.Size([5, 1])
Current Node: 1, Next Node: 4
Current Node: 4, Next Node: 0
Current Node: 0, Next Node: 2
Current Node: 2, Next Node: 3
Next Nodes: [1, 4, 0, 2, 3]

possibilities [0.19950979948043823, 0.19983011484146118, 0.19968289136886597, 0.19939309358596802]

Graph 2 - Output:
torch.Size([6, 1])
Current Node: 1, Next Node: 4
Current Node: 4, Next Node: 2
Current Node: 2, Next Node: 5
Current Node: 5, Next Node: 3
Current Node: 3, Next Node: 0
Next Nodes: [1, 4, 2, 5, 3, 0]

possibilities [0.1666344851255417, 0.16657082736492157, 0.16650520265102386, 0.16643138229846954, 0.16631805896759033]

Graph 3 - Output:
torch.Size([7, 1])
Current Node: 1, Next Node: 2
Current Node: 2, Next Node: 6
Current Node: 6, Next Node: 4
Current Node: 4, Next Node: 0
Current Node: 0, Next Node: 5
Current Node: 5, Next Node: 3
Next Nodes: [1, 2, 6, 4, 0, 5, 3]

possibilities [0.14269335567951202, 0.14283311367034912, 0.1427982896566391, 0.14274416863918304, 0.14270778000354767, 0.14

참고
https://chioni.github.io/posts/gat/