In [7]:
from pathlib import Path

from torch.utils.data import Dataset  # not the one from PyG!
from torch_geometric.loader import DataLoader
import torch
import torch.nn.functional as F
import torch.nn as nn
from typing import List, Union
from torch import Tensor

In [8]:
class MyDataset(Dataset):
    def __init__(self, path: Path):
        super().__init__()
        self.graphs = list(path.glob("**/*.pt"))
    
    def __getitem__(self, idx):
        return torch.load(self.graphs[idx])
    
    def __len__(self) -> int:
        return len(self.graphs)


dataset = MyDataset(Path("/home/hrc/gsoc/"))

In [9]:
dataset[0]

Data(x=[272, 6], edge_index=[2, 2370], edge_attr=[2370, 4], y=[2370])

In [10]:
from torch_geometric.nn import GATConv

In [11]:
from torch_geometric.nn import MessagePassing

In [12]:
### Reproduced from somewhere on GitHub
def make_mlp(
    input_size: int,
    sizes: List,
    hidden_activation: str = "SiLU",
    output_activation: str = None,
) -> nn.Sequential:
    """Construct an MLP with specified fully-connected layers."""
    hidden_activation = getattr(nn, hidden_activation)
    if output_activation is not None:
        output_activation = getattr(nn, output_activation)
    layers = []
    n_layers = len(sizes)
    sizes = [input_size] + sizes
    
    # Hidden layers
    for i in range(n_layers - 1):
        layers.append(nn.Linear(sizes[i], sizes[i + 1]))
        layers.append(hidden_activation())
        
    # Final layer
    layers.append(nn.Linear(sizes[-2], sizes[-1]))
    if output_activation is not None:
        layers.append(output_activation())
    return nn.Sequential(*layers)

In [13]:
from torch_geometric.utils import add_self_loops

In [14]:
class MyMessagePassing(MessagePassing):
    def __init__(self, num_input_node_features,num_input_edge_features,num_hidden,num_output_node_features):
        super().__init__(aggr='add')
        in_channels = 2 * num_input_node_features + num_input_edge_features
        hidden_channels = [in_channels] * num_hidden
        self.phi = make_mlp(in_channels,hidden_channels,"ReLU","Sigmoid")
        self.gamma = make_mlp(num_input_node_features + hidden_channels[-1], hidden_channels +[num_output_node_features],"ReLU","Sigmoid")
    
    def forward(self,x,edge_index,edge_attr):
        out = self.propagate(edge_index,x=x,edge_attr=edge_attr)
        return out
    
    def message(self,x_i,x_j,edge_attr):
        inp = torch.concat([x_i,x_j,edge_attr],dim=-1)
        return self.phi(inp)
    
    def update(self,aggr_out,x):
        inp = torch.concat([aggr_out,x],dim=-1)
        return self.gamma(inp)

In [15]:
class MyModel(nn.Module):
    def __init__(self,hidden_channels):
        super().__init__()
#         self.conv1 = GATConv(-1, hidden_channels)
#         self.conv2 = GATConv(-1, 6)
        self.conv1 = MyMessagePassing(6,4,3,16) # This might not be the best design for Message Passing Layer, but it is just an optional demo
        self.conv2 = MyMessagePassing(16,4,3,6)
        self.fc = nn.Sequential(
            nn.Linear(16,6),
            nn.ReLU(),
            nn.Linear(6,1)
        )
        
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        x = self.conv1(x, edge_index,edge_attr)
        x = F.relu(x)
        x = self.conv2(x, edge_index,edge_attr)
        x = F.relu(x)
        
        fc_inp = torch.cat([x[edge_index[0]],x[edge_index[1]],edge_attr],dim=-1)
        out = self.fc(fc_inp)
        out = F.sigmoid(out)
        return out

In [16]:
model = MyModel(6)

In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Define the loss function
criterion = torch.nn.BCELoss()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loader = DataLoader(dataset, batch_size=256)

def train_model():
    model.train()
    
    loss_all = 0
    for graph_data in loader:

        output = model(graph_data)        
        loss = criterion(output.reshape(-1), graph_data.y)
        loss_all += loss.item()
        print(loss.item())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(loss_all)


In [18]:
# Train the model
for epoch in range(10):
    train_model()



0.6139814257621765
0.5935847163200378
0.5744291543960571
0.5559366345405579
0.5401993989944458
0.5223404765129089
0.5070285797119141
0.4892081916332245
0.4832633137702942
0.46790605783462524
0.4477982223033905
0.4390258193016052
0.427238404750824
0.42406028509140015
0.41388407349586487
0.4172343909740448
0.40494853258132935
0.4035087525844574
0.3976539075374603
0.396498441696167
0.38354969024658203
0.38267773389816284
0.37266334891319275
0.37485218048095703
0.37360572814941406
0.3542994558811188
0.342974454164505
0.3443525731563568
0.3403206169605255
0.3476988971233368
0.33552122116088867
0.34299296140670776
0.34122344851493835
0.3293152451515198
0.3382563889026642
0.34206247329711914
0.3217293620109558
0.3369494676589966
0.32852599024772644
0.3222833275794983
16.475583344697952
0.33540210127830505
0.3312946557998657
0.33396658301353455
0.33110126852989197
0.33217787742614746
0.32950201630592346
0.32790839672088623
0.32243672013282776
0.3336218595504761
0.3326578438282013
0.32200002670

In [19]:
# Test the model
model.eval()
with torch.no_grad():
    for graph_data in loader:
        output = model(graph_data)
        pred = output.argmax(dim=1)
        correct = pred.eq(graph_data.y).sum().item()
        accuracy = correct / graph_data.y.size(0)
        print('Accuracy: {:.4f}'.format(accuracy))

Accuracy: 0.8181
Accuracy: 0.8218
Accuracy: 0.8204
Accuracy: 0.8219
Accuracy: 0.8202
Accuracy: 0.8244
Accuracy: 0.8246
Accuracy: 0.8282
Accuracy: 0.8184
Accuracy: 0.8197
Accuracy: 0.8268
Accuracy: 0.8234
Accuracy: 0.8251
Accuracy: 0.8218
Accuracy: 0.8241
Accuracy: 0.8180
Accuracy: 0.8238
Accuracy: 0.8215
Accuracy: 0.8226
Accuracy: 0.8206
Accuracy: 0.8259
Accuracy: 0.8245
Accuracy: 0.8277
Accuracy: 0.8229
Accuracy: 0.8198
Accuracy: 0.8278
Accuracy: 0.8302
Accuracy: 0.8232
Accuracy: 0.8249
Accuracy: 0.8251
Accuracy: 0.8276
Accuracy: 0.8171
Accuracy: 0.8171
Accuracy: 0.8261
Accuracy: 0.8210
Accuracy: 0.8177
Accuracy: 0.8328
Accuracy: 0.8209
Accuracy: 0.8270
Accuracy: 0.8319
