In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics.pairwise import cosine_similarity
import dgl
from dgl.nn import GATConv

In [None]:
import os
import pandas as pd
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# if device == "cpu":
#     print("device CPU")
#     exit(0)

In [None]:
DATA_HOME = "/lyceum/jhk1c21/msc_project/data"
V14_PATH = os.path.join(DATA_HOME, "graph", "v14")
FILTERED_PATH = os.path.join(V14_PATH, "filtered")

In [None]:
# Load the data
nodes = pd.read_csv(os.path.join(V14_PATH, "nodes_v14.csv"), index_col='id')
similarity = pd.read_csv(os.path.join(FILTERED_PATH, "similarity_edges.csv"))

titles = np.load(os.path.join(FILTERED_PATH, 'title_embedding.npy'))
abstracts = np.load(os.path.join(FILTERED_PATH, 'abstract_embedding.npy'))
keywords = np.load(os.path.join(FILTERED_PATH, 'keywords_embedding.npy'))
domains = np.load(os.path.join(FILTERED_PATH, 'domains_embedding.npy'))

ids = np.load(os.path.join(FILTERED_PATH, "filtered_id.npy"))
edges = np.load(os.path.join(FILTERED_PATH, 'filtered_edge.npy'))

In [None]:
df = pd.DataFrame()
df['src'] = edges[:, 0]
df['dst'] = edges[:, 1]

# convert id from str to numbers
id_to_int = {original_id: i for i, original_id in enumerate(ids)}
int_to_id = {i: original_id for original_id, i in id_to_int.items()}

df['src'] = df['src'].apply(lambda x: id_to_int[x])
df['dst'] = df['dst'].apply(lambda x: id_to_int[x])

In [None]:
class MultiFeatureGAT(nn.Module):
    def __init__(self, in_dims, hidden_dim, out_dim, num_heads):
        super().__init__()
        self.gat_layers = nn.ModuleList([
            GATConv(in_dim, hidden_dim, num_heads) for in_dim in in_dims
        ])
        self.fc = nn.Linear(hidden_dim * num_heads * len(in_dims), out_dim)

    def forward(self, graph, inputs):
        hs = [gat_layer(graph, feature) for gat_layer, feature in zip(self.gat_layers, inputs)]
        h = torch.cat(hs, dim=-1)
        h = h.view(h.shape[0], -1)
        h = self.fc(h)
        return h
