In [22]:
from torch_geometric import datasets

dataset = datasets.Planetoid(root='/tmp/Cora', name='Cora')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [23]:
graph = dataset[0]

In [32]:
from torch_geometric.transforms import RandomNodeSplit

split = RandomNodeSplit(num_test=1000)
graph = split(graph)

In [34]:
graph.test_mask.sum()

tensor(1000)

In [35]:
import torch
device = torch.device("cuda:0")
device
graph = graph.to(device)

In [189]:
from torch_geometric.nn import GCNConv, GATConv
import torch.nn.functional as F

class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
#         self.conv1 = GCNConv(dataset.num_node_features, 16, normalize=True)
#         self.dropout = nn.Dropout(p=0.5)
#         self.conv2 = GCNConv(16, dataset.num_classes, normalize=True)
        self.conv1 = GATConv(dataset.num_node_features, 16, heads=5)
        self.dropout = nn.Dropout(p=0.5)
        self.conv2 = GATConv((-1,-1), dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index).relu()
        x = self.dropout(x)
        output = self.conv2(x, edge_index)

        return output

In [190]:
model = GCN().to(device)
optim = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [43]:
from torchmetrics import Accuracy
accuracy = Accuracy(task="multiclass", num_classes=graph.y.unique().shape[0]).to(device)

with torch.inference_mode():
    model.eval()
    logits = model(graph)
    preds = torch.softmax(logits, dim = 1).argmax(dim=1)
    correct = (preds[graph.test_mask] == graph.y[graph.test_mask]).sum()
    print(accuracy(preds[graph.test_mask], graph.y[graph.test_mask]))

tensor(0.8700, device='cuda:0')


In [174]:
from torch import nn

from torch_geometric.typing import (
    Adj,
    OptPairTensor,
    OptTensor,
    SparseTensor,
    torch_sparse,
)
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils import add_self_loops as add_self_loops_fn
from torch_geometric.utils import (
    is_torch_sparse_tensor,
    scatter,
    spmm,
    to_edge_index,
)
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.utils.sparse import set_sparse_value
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False,
             add_self_loops=True, flow="source_to_target", dtype=None):

    fill_value = 2. if improved else 1.

    if isinstance(edge_index, SparseTensor):
        assert edge_index.size(0) == edge_index.size(1)

        adj_t = edge_index

        if not adj_t.has_value():
            adj_t = adj_t.fill_value(1., dtype=dtype)
        if add_self_loops:
            adj_t = torch_sparse.fill_diag(adj_t, fill_value)

        deg = torch_sparse.sum(adj_t, dim=1)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
        adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(-1, 1))
        adj_t = torch_sparse.mul(adj_t, deg_inv_sqrt.view(1, -1))

        return adj_t

    if is_torch_sparse_tensor(edge_index):
        assert edge_index.size(0) == edge_index.size(1)

        if edge_index.layout == torch.sparse_csc:
            raise NotImplementedError("Sparse CSC matrices are not yet "
                                      "supported in 'gcn_norm'")

        adj_t = edge_index
        if add_self_loops:
            adj_t, _ = add_self_loops_fn(adj_t, None, fill_value, num_nodes)

        edge_index, value = to_edge_index(adj_t)
        col, row = edge_index[0], edge_index[1]

        deg = scatter(value, col, 0, dim_size=num_nodes, reduce='sum')
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        value = deg_inv_sqrt[row] * value * deg_inv_sqrt[col]

        return set_sparse_value(adj_t, value), None

    assert flow in ['source_to_target', 'target_to_source']
    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    if add_self_loops:
        edge_index, edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, fill_value, num_nodes)

    if edge_weight is None:
        edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                 device=edge_index.device)

    row, col = edge_index[0], edge_index[1]
    idx = col if flow == 'source_to_target' else row
    deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')
    deg_inv_sqrt = deg.pow_(-0.5)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
    edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    return edge_index, edge_weight

class GNNLayer(nn.Module):
    def __init__(self, in_features, out_features, agg = "mean"):
        super(GNNLayer, self).__init__()
        self.lin = nn.Linear(in_features, out_features)
        self.init_weights()
        self.agg:str = agg

    def lift(self, x, edge_index):
        source_lift = torch.index_select(x, 0, edge_index[0])
        target_lift = torch.index_select(x, 0, edge_index[1])
        return source_lift, target_lift
    
    def init_weights(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        self.lin.bias.data.fill_(0.01)

    def forward(self, x, edge_index):
        edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index, None, x.shape[0],False, False)
        x = self.lin(x)
        source_lift, target_lift = self.lift(x, edge_index)
        x = torch.scatter_reduce(x, 0,  edge_index[1].repeat(x.shape[1],1).t(), source_lift, reduce=self.agg)
        return x 

In [208]:
graph = graph.to(device)

In [216]:
graph.x = graph.x.unsqueeze(dim=0)

In [207]:
class CustomGATLayer(nn.Module):
    def __init__(self, in_features, out_features, heads = 1, agg = "mean"):
        super(CustomGATLayer, self).__init__()
        assert out_features % heads == 0
        
        self.lin = nn.Linear(in_features, out_features)
        
        self.lin_k = nn.Linear(out_features, out_features, bias=False)
        self.lin_q = nn.Linear(out_features, out_features, bias=False)
        self.lin_v = nn.Linear(out_features, out_features, bias=False)
        self.out = out_features
        self.heads = heads
        
        self.init_weights()
        self.agg:str = agg

    def lift(self, x, edge_index):
        source_lift = torch.index_select(x, 0, edge_index[0])
        target_lift = torch.index_select(x, 0, edge_index[1])
        return source_lift, target_lift
    
    def init_weights(self):
        torch.nn.init.xavier_uniform_(self.lin.weight)
        torch.nn.init.xavier_uniform_(self.lin_k.weight)
        torch.nn.init.xavier_uniform_(self.lin_q.weight)
        torch.nn.init.xavier_uniform_(self.lin_v.weight)
        self.lin.bias.data.fill_(0.01)

    def forward(self, x, edge_index):
        b, seq_len, feature_dim = x.size()
        
        x = self.lin(x)
        
        keys = self.lin_k(x)
        queries = self.lin_q(x)
        values = self.lin_v(x)
        
        head_dim = feature_dim // self.heads

        keys    = keys.view(b, seq_len, self.heads, head_dim)
        queries = queries.view(b, seq_len, self.heads, head_dim)
        values  = values.view(b, seq_len, self.heads, head_dim)
        
        qkv = torch.cat((keys, queries, values), 2)
        print(qkv)
        x = flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False)
        print(x)
        source_lift, target_lift = self.lift(x, edge_index)
        x = torch.scatter_reduce(x, 0,  edge_index[1].repeat(x.shape[1],1).t(), source_lift, reduce=self.agg)
        return x 

In [209]:
class CustomGCN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GNNLayer(dataset.num_node_features, 16).to(device)
        self.dropout = nn.Dropout(p=0.5)
        self.conv2 = GNNLayer(16, dataset.num_classes).to(device)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = self.dropout(x)
        x = torch.relu(x)        
        output = self.conv2(x, edge_index)

        return output

In [210]:
class CustomGAT(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = CustomGATLayer(dataset.num_node_features, 16).to(device)
        self.dropout = nn.Dropout(p=0.5)
        self.conv2 = CustomGATLayer(16, dataset.num_classes).to(device)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = self.dropout(x)
        x = torch.relu(x)        
        output = self.conv2(x, edge_index)

        return output

In [217]:
from torchmetrics import Accuracy
accuracy = Accuracy(task="multiclass", num_classes=graph.y.unique().shape[0]).to(device)

custom_model = CustomGAT().to(device)#CustomGCN().to(device)
custom_optim = torch.optim.Adam(custom_model.parameters(), lr=0.01, weight_decay=5e-4)
loss_fn = nn.CrossEntropyLoss()
loss_increase = 0
prev_loss = 0

def validate(out):
    global loss_increase, prev_loss
    loss = loss_fn(out[graph.val_mask], graph.y[graph.val_mask])
    if loss.item() >= prev_loss:
        loss_increase+=1
    else:
        loss_increase = 0
    prev_loss = loss.item()
    
def train(model, optim):
    for epoch in range(200):
        if loss_increase == 10:
            print(f"Breaked at {str(epoch)}")
            break
        model.train()
        out = model(graph)
        loss = loss_fn(out[graph.train_mask], graph.y[graph.train_mask])
        optim.zero_grad()
        loss.backward()
        optim.step()

        validate(out)
        break

def test(model):
    with torch.inference_mode():
        model.eval()
        logits = model(graph)
        preds = torch.softmax(logits, dim = 1).argmax(dim=1)
        correct = (preds[graph.test_mask] == graph.y[graph.test_mask]).sum()
        print(f"{accuracy(preds[graph.test_mask], graph.y[graph.test_mask]):.2%}")

train(custom_model, custom_optim)
test(custom_model)

RuntimeError: shape '[1, 2708, 1, 1433]' is invalid for input of size 43328

In [192]:
train(model, optim)
test(model)

85.90%
