In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv, BatchNorm
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

data = torch.load('../data/processed/graph_data_full.pt', weights_only=False)
data = data.to(device)

print(f"Data Loaded: {data}")
print(f"Label 1 ratio (Delayed): {data.y.sum() / len(data.y):.2%}")

In [None]:
class FlightDelayGNN(torch.nn.Module):
    def __init__(self, num_nodes, input_edge_feats, hidden_dim, heads=2):
        super(FlightDelayGNN, self).__init__()
        self.node_emb = torch.nn.Linear(num_nodes, hidden_dim)
        self.edge_proj = torch.nn.Linear(input_edge_feats, hidden_dim)
        self.conv1 = GATv2Conv(hidden_dim, hidden_dim, heads=heads, 
                               concat=True, edge_dim=hidden_dim)
        self.bn1 = BatchNorm(hidden_dim * heads)
        self.conv2 = GATv2Conv(hidden_dim * heads, hidden_dim, heads=1, 
                               concat=False, edge_dim=hidden_dim)
        self.bn2 = BatchNorm(hidden_dim)
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(2 * hidden_dim + input_edge_feats, 64),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(64, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1)
        )

    def forward(self, x, edge_index, edge_attr):
        x = self.node_emb(x)
        edge_attr_emb = self.edge_proj(edge_attr)
        x = self.conv1(x, edge_index, edge_attr=edge_attr_emb)
        x = self.bn1(x)
        x = F.elu(x) 
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv2(x, edge_index, edge_attr=edge_attr_emb)
        x = self.bn2(x)
        x = F.elu(x)
        
        src_idx, dst_idx = edge_index
        x_src = x[src_idx]
        x_dst = x[dst_idx]
        edge_cat = torch.cat([x_src, x_dst, edge_attr], dim=1)
        
        return self.decoder(edge_cat)

num_nodes = data.x.shape[0]  
input_edge_feats = data.edge_attr.shape[1] 
hidden_dim = 64
heads = 2

model = FlightDelayGNN(num_nodes, input_edge_feats, hidden_dim, heads=heads).to(device)
print(model)

In [None]:
num_edges = data.num_edges
indices = torch.randperm(num_edges)

train_size = int(0.8 * num_edges)
train_idx = indices[:train_size]
test_idx = indices[train_size:]

print(f"Train samples: {len(train_idx)}")
print(f"Test samples: {len(test_idx)}")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

pos_weight = torch.tensor([4.0]).to(device) 
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)

def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index, data.edge_attr)
    loss = criterion(out[train_idx], data.y[train_idx])
    loss.backward()
    optimizer.step()
    return loss.item()

@torch.no_grad()
def test(mask_idx):
    model.eval()
    out = model(data.x, data.edge_index, data.edge_attr)
    pred = (torch.sigmoid(out[mask_idx]) > 0.5).float()
    y_true = data.y[mask_idx]
    acc = accuracy_score(y_true.cpu(), pred.cpu())
    f1 = f1_score(y_true.cpu(), pred.cpu())
    return acc, f1

print("Starting training...")
losses = []
for epoch in range(1, 1000):
    loss = train()
    losses.append(loss)
    
    if epoch % 20 == 0:
        train_acc, train_f1 = test(train_idx)
        test_acc, test_f1 = test(test_idx)
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train F1: {train_f1:.4f}, Test F1: {test_f1:.4f}')

plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.show()

In [None]:
torch.save(model.state_dict(), '../state/models/gnn_flight_delay.pth')
print("Model saved")