In [None]:
from ogb.nodeproppred import NodePropPredDataset
import pandas as pd
from sklearn.feature_extraction.text import TfidfTransformer, CountVectorizer
from sklearn.naive_bayes import MultinomialNB

In [None]:
dataset = NodePropPredDataset(name='ogbn-arxiv', root='./arxiv/')

In [None]:
graph, label = dataset[0] # graph: library-agnostic graph object

In [None]:
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"]

In [None]:
nodelabels = pd.Series(label.flatten(), name="label")

In [None]:
nodeid2paperid = pd.read_csv("arxiv/ogbn_arxiv/mapping/nodeidx2paperid.csv", dtype=str)
nodeid2paperid

In [None]:
nodeid2paperid2label = pd.merge(nodeid2paperid, nodelabels, left_index=True, right_index=True)
nodeid2paperid2label

In [None]:
paperid2text = pd.read_csv("arxiv/titleabs.tsv", sep="\t", dtype=str, names=["paper id", "title", "abstract"])
paperid2text

In [None]:
nodeid2text = pd.merge(nodeid2paperid2label, paperid2text, on="paper id")
nodeid2text = nodeid2text.assign(text=nodeid2text["title"].str.cat(nodeid2text["abstract"], sep=" "))
nodeid2text = nodeid2text[["node idx", "label", "text"]]
nodeid2text

In [None]:
nodeid2text_train = nodeid2text.loc[train_idx]
nodeid2text_valid = nodeid2text.loc[valid_idx]
nodeid2text_test = nodeid2text.loc[test_idx]

In [None]:
vectorizer = CountVectorizer()
vectorizer.fit(nodeid2text_train["text"])
X_train_counts = vectorizer.transform(nodeid2text_train["text"])

In [None]:
tf_transformer = TfidfTransformer(use_idf=False).fit(X_train_counts)
X_train = tf_transformer.transform(X_train_counts)
X_train

In [None]:
clf = MultinomialNB().fit(X_train, nodeid2text_train["label"])

In [None]:
valid_pred = clf.predict(tf_transformer.transform(vectorizer.transform(nodeid2text_valid["text"])))
test_pred = clf.predict(tf_transformer.transform(vectorizer.transform(nodeid2text_test["text"])))

In [None]:
(valid_pred == nodeid2text_valid["label"]).mean()

In [None]:
(test_pred == nodeid2text_test["label"]).mean()

In [None]:
len(nodeid2text["label"].unique())