In [1]:
import torch.nn.functional as F
import torch.nn as nn
import torch
from torch.nn import Parameter
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing, GATConv
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros

class GAT(nn.Module):
    def __init__(self, in_feats=1,
                 h_feats=[8, 8, 1],
                 heads=[8, 8, 4],
                 dropout=0.6,
                 negative_slope=0.2,
                 linear_layer=None,
                 **kwargs):
        super(GAT, self).__init__()
        self.dropout = dropout
        self.layers = nn.ModuleList()

        self.linear_layer = linear_layer
        if self.linear_layer is not None:
            print('Applying linear')
            self.linear = nn.Linear(in_feats, linear_layer)

        in_feats = in_feats if linear_layer is None else linear_layer
        for i, h_feat in enumerate(h_feats):
            last = i + 1 == len(h_feats)
            self.layers.append(GATConv(in_feats, h_feat,
                                       heads=heads[i],
                                       dropout=dropout,
                                       concat=False if last else True))
            in_feats = h_feat * heads[i]

    def forward(self, X, A, edge_attr=None, return_alphas=False):
        if self.linear_layer is not None:
            X = self.linear(X)
            #X = F.relu(X)

        alphas = []
        for layer in self.layers[:-1]:
            if return_alphas:
                X, alpha, _ = layer(
                    X, A, edge_attr=edge_attr, return_alpha=True)
                alphas.append(alpha)
            else:
                X = layer(X, A, edge_attr=edge_attr)
            X = F.relu(X)
            X = F.dropout(X, self.dropout)

        if return_alphas:
            X, alpha, edge_index = self.layers[-1](
                X, A, edge_attr=edge_attr, return_alpha=True)
            alphas.append(alpha)
            return X, alphas, edge_index

        X = self.layers[-1](X, A, edge_attr=edge_attr)
        return X

ModuleNotFoundError: No module named 'torch_geometric'