In [None]:
import torch
from torch_geometric.data import Data, InMemoryDataset, download_url
import torch_geometric.transforms as T

from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import GraphConv
from torch_geometric.nn import global_mean_pool

from torch_geometric.loader import DataLoader

In [None]:
# Define a dataset class
class Dataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        print("INFO: self.processed_paths = ",self.processed_paths)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def process(self):
        # Read data into huge `Data` list.
        data_list = None

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

# Load dataset
root = "/path/to/pyg_datasets" #NOTE: DATA SHOULD BE SAVED IN <root>/processed/data.pt, create this with create_dataset.ipynb
transform = T.Compose([T.ToUndirected(),T.NormalizeFeatures()])
dataset = Dataset(root, transform=transform, pre_transform=None, pre_filter=None)

# Sanity check
data = dataset[0]
print(data.x)
print(dataset)
if transform is not None:
    print(transform(dataset[0]).edge_index)
    print(transform(dataset[0]).x)

In [None]:
# Define your model
class GNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(in_channels, hidden_channels).jittable() #NOTE: NEEDED FOR DEPLOYMENT IN CMAKE
        self.conv2 = GraphConv(hidden_channels, hidden_channels).jittable()
        self.conv3 = GraphConv(hidden_channels, hidden_channels).jittable()
        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        x = torch.sigmoid(x) #NOTE: DON'T SOFTMAX IF USING BCELOSS, USE SIGMOID INSTEAD
        
        return x

# Create model for binary classification
model = GNN(in_channels=dataset.num_node_features,hidden_channels=64,out_channels=1)
print(model)

In [None]:
# Put model on device if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [None]:
# Split dataset
torch.manual_seed(12345)
dataset = dataset.shuffle()

print(len(dataset))

fracs = [0.8, 0.1, 0.1] #NOTE: SHOULD CHECK np.sum(fracs) == 1 and len(fracs)==3
fracs = [torch.sum(torch.tensor(fracs[:idx])) for idx in range(1,len(fracs)+1)]
print(fracs)
split1, split2 = [int(len(dataset)*frac) for frac in fracs[:-1]]
train_dataset = dataset[:split1]
val_dataset = dataset[split1:split2]
test_dataset = dataset[split2:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of validation graphs: {len(val_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

In [None]:
# Create dataloaders
from torch_geometric.loader import DataLoader

batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)#, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)
    print()
    break

In [None]:
# Define training and validation routines
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(loader):
    model.train()
    for data in loader:  # Iterate in batches over the training dataset.
        counts = torch.unique(data.y,return_counts=True)[1]
        weights = counts / len(data.y)
        weights = np.power(weights,-1)
        weight = torch.tensor([weights[idx] for idx in torch.squeeze(data.y)]).to(device)
        criterion = torch.nn.BCELoss(weight=weight)
        
        data = data.to(device)#NOTE: ADDED
        out = torch.squeeze(model(data.x, data.edge_index, data.batch))  # Perform a single forward pass.
        loss = criterion(out, data.y.float())  # Compute the loss.
        
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

@torch.no_grad()
def val(loader):
    model.eval()

    correct = 0
    loss_tot = 0.0
    for data in loader:  # Iterate in batches over the training/test dataset.
        counts = torch.unique(data.y,return_counts=True)[1]
        weights = counts / len(data.y)
        weights = np.power(weights,-1)
        weight = torch.tensor([weights[idx] for idx in torch.squeeze(data.y)]).to(device)
        criterion = torch.nn.BCELoss(weight=weight)
        
        data = data.to(device)
        out = torch.squeeze(model(data.x, data.edge_index, data.batch))
        loss = criterion(out, data.y.float())
#         pred = out.argmax(dim=1)  # Use the class with highest probability.
        pred = out.round() #NOTE: JUST FOR USING BCELOSS -> ARGMAX COLLAPSES TO A ONE ELEMENT TENSOR
        correct += int((pred == data.y).sum())  # Check against ground-truth labels.
        loss_tot += loss.item()
    return correct / len(loader.dataset), loss_tot / len(loader.dataset)  # Derive ratio of correct predictions.

# Train and test the model
nepochs = 5
for epoch in range(1, nepochs+1):
    train(train_loader)
    train_acc, train_loss = val(train_loader)
    val_acc, val_loss = val(val_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f} Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f} Loss: {val_loss:.4f}')
    
test_acc, test_loss = val(test_loader)
print(f'Test Acc: {train_acc:.4f} Loss: {train_loss:.4f}')