In [1]:
!pip install torch_geometric



In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GATConv

In [3]:
import pandas as pd
node_info = pd.read_csv('node_info_sorted.csv')
edge_info = pd.read_csv('edge_info_sorted.csv')

In [4]:
title_to_idx = {title: idx for idx, title in enumerate(node_info['name'].unique())}

In [5]:
edges = edge_info[['source_name', 'target_name']].values
sources = [title_to_idx[source] for source, _ in edges]
targets = [title_to_idx[target] for _, target in edges]

In [6]:
num_nodes = len(node_info['name'].unique())
x = torch.eye(num_nodes)
edge_index = torch.tensor([sources, targets], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)

In [7]:
print(num_nodes)

5047


In [8]:
class GATModel(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, num_heads):
        super(GATModel, self).__init__()
        self.conv1 = GATConv(num_features, hidden_channels, num_heads)
        self.conv2 = GATConv(hidden_channels * num_heads, hidden_channels, num_heads)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = self.conv2(x, edge_index)
        return x

In [9]:
model = GATModel(num_features=num_nodes, hidden_channels=64, num_heads=4)

In [10]:
loader = DataLoader([data], batch_size=1)
embeddings = model(data.x, data.edge_index)



In [11]:
print(embeddings)

tensor([[-0.0043,  0.0002,  0.0002,  ...,  0.0023,  0.0079, -0.0090],
        [ 0.0106, -0.0009,  0.0322,  ..., -0.0097,  0.0151, -0.0143],
        [-0.0053, -0.0015,  0.0037,  ..., -0.0076,  0.0212,  0.0058],
        ...,
        [-0.0027, -0.0142, -0.0374,  ..., -0.0043, -0.0091, -0.0180],
        [-0.0178,  0.0121,  0.0063,  ...,  0.0026, -0.0062,  0.0180],
        [-0.0092,  0.0019, -0.0008,  ...,  0.0028, -0.0095,  0.0072]],
       grad_fn=<AddBackward0>)


In [12]:
import pickle

# Get the list of node IDs (assuming 'node_info' dataframe has a 'name' column)
node_ids = node_info['name'].tolist()

# Combine the embeddings and IDs into a dictionary
embeddings_data = {'embeddings': embeddings.detach().numpy(), 'ids': node_ids}

# Save the dictionary to a pickle file
with open('node_embeddings.pkl', 'wb') as f:
    pickle.dump(embeddings_data, f)