In [None]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

In [None]:
import torch_geometric.transforms as T
from torch_geometric.datasets import OGB_MAG
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import SAGEConv, GATConv, Linear, HGTConv, HeteroConv, GCNConv, to_hetero

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

dataset = OGB_MAG(root='./data', preprocess='metapath2vec', transform=T.ToUndirected())
data = dataset[0]
num_of_class = dataset.num_classes

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv((-1, -1), hidden_channels)
        self.conv2 = SAGEConv((-1, -1), out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x['paper']

In [None]:
class HeteroGNN(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
            ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
            ('paper', 'rev_writes', 'author'): SAGEConv((-1, -1), hidden_channels),
        }, aggr='sum')

        self.conv2 = HeteroConv({
            ('paper', 'cites', 'paper'): SAGEConv((-1, -1), hidden_channels),
            ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
            ('paper', 'rev_writes', 'author'): SAGEConv((-1, -1), hidden_channels),
        }, aggr='sum')

        self.conv3 = HeteroConv({
            ('paper', 'cites', 'paper'): GCNConv(-1, hidden_channels),
            ('author', 'writes', 'paper'): SAGEConv((-1, -1), hidden_channels),
            ('paper', 'rev_writes', 'author'): SAGEConv((-1, -1), hidden_channels),
        }, aggr='sum')

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = {key: x.relu() for key, x in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict)
        x_dict = {key: x.relu() for key, x in x_dict.items()}
        x_dict = self.conv3(x_dict, edge_index_dict)
        x_dict = {key: x.relu() for key, x in x_dict.items()}

        return self.lin(x_dict['paper'])

In [None]:
class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, hidden_channels2, hidden_channels3, out_channels, num_heads):
        super().__init__()

        self.lin_paper = Linear(128, hidden_channels)
        self.lin_author = Linear(128, hidden_channels)
        self.lin_institution = Linear(128, hidden_channels)
        self.lin_field_of_study = Linear(128, hidden_channels)

        self.conv1 = HGTConv(hidden_channels, hidden_channels2, data.metadata(),
                       num_heads, group='sum')
        self.conv2 = HGTConv(hidden_channels2, hidden_channels3, data.metadata(),
                       num_heads, group='sum')

        self.lin = Linear(hidden_channels3, out_channels)

    def forward(self, x_dict, edge_index_dict):
        for node_type, x in x_dict.items():
            if node_type == 'paper':
                x_dict[node_type] = self.lin_paper(x).relu_()
            if node_type == 'author':
                x_dict[node_type] = self.lin_author(x).relu_()
            if node_type == 'institution':
                x_dict[node_type] = self.lin_institution(x).relu_()
            if node_type == 'field_of_study':
                x_dict[node_type] = self.lin_field_of_study(x).relu_()

        x_dict = self.conv1(x_dict, edge_index_dict)
        x_dict = self.conv2(x_dict, edge_index_dict)

        return self.lin(x_dict['paper'])

In [None]:
def train():
    model.train()

    total_examples = total_loss = 0
    for batch in tqdm(train_loader):
        batch.to(device)
        optimizer.zero_grad()
        batch_size = batch['paper'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)
        loss = criterion(out[:batch_size], batch['paper'].y[:batch_size])
        loss.backward()
        optimizer.step()

        total_examples += batch_size
        total_loss += float(loss) * batch_size

    return total_loss / total_examples

def test():
    model.eval()

    total_examples = total_loss = 0
    for batch in tqdm(test_loader):
        batch.to(device)
        optimizer.zero_grad()
        batch_size = batch['paper'].batch_size
        out = model(batch.x_dict, batch.edge_index_dict)

        pred = out.argmax(dim=1)
        correct = pred[:batch_size] == batch['paper'].y[:batch_size]
        accs = int(correct.sum()) / int(batch_size)
        return accs


    return total_loss / total_examples

In [None]:
train_loader = NeighborLoader(
    data,
    # Sample 15 neighbors for each node and each edge type for 2 iterations:
    num_neighbors=[15] * 2,
    # Use a batch size of 128 for sampling training nodes of type "paper":
    batch_size=128,
    input_nodes=('paper', data['paper'].train_mask),
)

test_loader = NeighborLoader(
    data,
    # Sample 15 neighbors for each node and each edge type for 2 iterations:
    num_neighbors=[15] * 2,
    # Use a batch size of 128 for sampling training nodes of type "paper":
    batch_size=128,
    input_nodes=('paper', data['paper'].test_mask),
)

In [None]:
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
numofep = 4

In [None]:
loss_GNN = []
acc_GNN = []
model= GNN(hidden_channels=64, out_channels=num_of_class)
model = to_hetero(model, data.metadata(), aggr='sum')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)  # Define optimizer.
model.to(device)
for i in range(1,numofep):
    print(f"Epoch {i}")
    loss = train()
    print(f"Current losso is: {loss}")
    loss_GNN.append(loss)
    acc_GNN.append(test())

In [None]:
loss_HeteroGNN = []
acc_HeteroGNN = []
model= HeteroGNN(hidden_channels=64, out_channels=num_of_class)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)  # Define optimizer.
model.to(device)
for i in range(1,numofep):
    print(f"Epoch {i}")
    loss = train()
    print(f"Current losso is: {loss}")
    loss_HeteroGNN.append(loss)
    acc_HeteroGNN.append(test())

In [None]:
loss_HGT = []
acc_HGT= []
model = HGT(hidden_channels=64, hidden_channels2=32, hidden_channels3=16, out_channels=num_of_class, num_heads=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)  # Define optimizer.
model.to(device)
for i in range(1,numofep):
    print(f"Epoch {i}")
    loss = train()
    print(f"Current losso is: {loss}")
    loss_HGT.append(loss)
    acc_HGT.append(train())

In [None]:
import matplotlib.pyplot as plt
# plot lines
plt.plot(loss_HeteroGNN, label="loss_HeteroGNN")
plt.plot(loss_HGT, label="loss_HGT")
plt.legend()
plt.savefig("loss.png")

plt.close()

plt.plot(acc_HeteroGNN, label="acc_HeteroGNN")
plt.plot(acc_HGT, label="acc_HGT")
plt.legend()
plt.savefig("acc.png")