<a href="https://colab.research.google.com/github/TheoBacqueyrisse/graph-neural-networks/blob/main/Graph_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Graph Transformer Architecture**

In [1]:
# Let us first clone the GitHub repository
%%capture
!git clone https://github.com/TheoBacqueyrisse/Graph-Neural-Networks.git

In [2]:
# Install dependencies
%%capture
%cd /content/Graph-Neural-Networks
!pip install -r requirements.txt

In [3]:
from utils import *

In [4]:
from torch_geometric.nn import TransformerConv

In [78]:
class GraphTransformer(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GraphTransformer, self).__init__()
        self.conv = TransformerConv(input_dim, hidden_dim, heads=4, edge_dim=2)
        self.out = torch.nn.Linear(hidden_dim, output_dim)

    def forward(self, data):
        x = data.x
        x = x.to(torch.float32)
        x = x.squeeze().t()

        edge_index = data.edge_index
        edge_index = edge_index.t()

        edge_attr = data.edge_attr
        edge_attr = edge_attr.to(torch.float32)

        x = self.conv(x, edge_index, edge_attr)
        x = torch.sigmoid(x)
        x = self.out(x)
        return x

# Create an instance of the GNNTransformer model
model = GraphTransformer(input_dim = 1, hidden_dim = 32, output_dim = 1)

In [79]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

NUM_EPOCHS = 10

loss_function = L1Loss()

optimizer = Adam(params = model.parameters(), lr = 0.003)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=0.00001)

In [80]:
NB_GRAPHS_PER_BATCH = 64

train = ZINC('/content/Graph-Neural-Networks/data', split = 'train')
train = train[train.y > -10] # Drop Outliers

val = ZINC('/content/Graph-Neural-Networks/data', split = 'val')

test = ZINC('/content/Graph-Neural-Networks/data', split = 'test')

train_loader = DataLoader(train,
                          batch_size = NB_GRAPHS_PER_BATCH,
                          shuffle = True)

val_loader = DataLoader(val,
                        batch_size = NB_GRAPHS_PER_BATCH,
                        shuffle = False)

test_loader = DataLoader(test,
                         batch_size = NB_GRAPHS_PER_BATCH,
                         shuffle = False)

print("Number of Batches in Train Loader :", len(train_loader))
print("Number of Batches in Val Loader :", len(val_loader))
print("Number of Batches in Test Loader :", len(test_loader))

Number of Batches in Train Loader : 3433
Number of Batches in Val Loader : 382
Number of Batches in Test Loader : 79


In [74]:
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0.0

    for batch_data in train_loader:

        optimizer.zero_grad()
        predictions = model(batch_data)

        loss = loss_function(predictions, batch_data.y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    average_train_loss = total_loss / len(train_loader)

    model.eval()
    val_total_loss = 0.0
    with torch.no_grad():
        for val_batch_data in val_loader:
            val_predictions = model(val_batch_data)
            val_loss = loss_function(val_predictions, val_batch_data.y)
            val_total_loss += val_loss.item()

    average_val_loss = val_total_loss / len(val_loader)

    print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}] -> Train Loss: {average_train_loss:.4f} - Val Loss: {average_val_loss:.4f}")

RuntimeError: ignored

In [82]:
# for epoch in range(NUM_EPOCHS):
#     optimizer.zero_grad()
#     predictions = model(train)
#     loss = loss_function(predictions, train.y)
#     loss.backward()
#     optimizer.step()

#     print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] - Train Loss: {loss.item():.4f}")

#     model.eval()
#     val_total_loss = 0.0
#     with torch.no_grad():
#       val_predictions = model(val)
#       val_loss = loss_function(val_predictions, val.y)
#       loss.backward()
#       optimizer.step()

#     print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}] -> Val Loss: {val_loss.item():.4f}")