In [None]:
import os
import torch
from torchmetrics.classification import MulticlassAUROC
from torch_geometric.loader import DataLoader
from IPython.display import Javascript

In [None]:
print(torch.cuda.is_available())
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
batch_size=64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)

for step, data in enumerate(train_loader):
    print(f'Step {step + 1}:')
    print('=======')
    print(f'Number of graphs in the current batch: {data.num_graphs}')
    print(data)

In [None]:
def graph_test(model, loss_fn, loader, device):
     """
     model: pytorch GNN model
     loss_fn: loss function
     loader: DataLoader
     device: device eused to bind the model and tensor
     """
     model.eval()
     auroc_score = 0.0
     ml_auroc = MulticlassAUROC(2, average='macro')
     preds = None
     labels = None

     with torch.no_grad():
          for data in loader:
               out = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))

               if preds is None:
                    preds = out
               else:
                    preds = torch.cat((preds, out))

               if labels is None:
                    labels = data.y.to(device).long()
               else:
                    labels = torch.cat((labels, data.y.to(device).long()))
     
     auroc_score = ml_auroc(preds, labels) 
     
        
     return auroc_score

In [None]:
from model.GraphConvNet import GNN

model = GNN()

learning_rate = 1e-3
epoch_num = 200
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss()

In [None]:
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

def train():
    model.train()
    total_loss = 0.0
    for data in train_loader:  # Iterate in batches over the training dataset.
        out = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))  # Perform a single forward pass.
        loss = loss_fn(out, data.y.to(device))  # Compute the loss.
        total_loss += loss
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.
    return total_loss / len(train_loader)

for epoch in range(epoch_num):
    train_loss = train()
    train_auroc = graph_test(model, loss_fn, train_loader, device)
    valid_auroc = graph_test(model, loss_fn, valid_loader, device)
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.5f}, Train Auc: {train_auroc:.4f}, Valid Auc: {valid_auroc:.4f}')