In [1]:
import torch
import numpy as np
import torch_geometric
from torch import Tensor
from torch.nn import Linear
from torch_geometric.data import Data
from torch_geometric.nn.conv import GATConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import LinkNeighborLoader
from sklearn.metrics import roc_auc_score, RocCurveDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm
torch.set_default_dtype(torch.float32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
import gc
gc.collect()
torch.cuda.empty_cache()

Using device: cuda


In [30]:
#data = torch.load("data/act-mooc/graph.pt")
data = torch.load("data/junyi/graph.pt")

In [31]:
data_hom = data.to_homogeneous()

In [32]:
data = data_hom
# no data features is implemented as just one random number drawn from normal
data.x = torch.randn(data.num_nodes, 1)
data.edge_attr = data.edge_attr.float()
data.time = data.time.float()
data = data.to(device)
data

Data(edge_index=[2, 32434622], node_id=[74460], edge_attr=[32434622, 6], time=[32434622], edge_y=[32434622], node_type=[74460], edge_type=[32434622], x=[74460, 1])

In [33]:
time_array = data.time.cpu().numpy()
quantile_70 = np.quantile(time_array, 0.7)
quantile_85 = np.quantile(time_array, 0.85)

# Create masks for splitting the data
train_mask = time_array < quantile_70
val_mask = (time_array >= quantile_70) & (time_array < quantile_85)
test_mask = time_array >= quantile_85

# Function to create a new Data object from the original data and a mask
def create_subset(data, mask):
    # Filter data using the mask
    subset = Data()
    for key, item in data:
        if key in ['node_id', 'node_type', 'x']:
            subset[key] = item
        elif key in ['edge_index']:
            subset[key] = item[:,mask]
        else:
            subset[key] = item[mask]
    return subset

# Create subsets
train_data = create_subset(data, train_mask)
val_data = create_subset(data, val_mask)
test_data = create_subset(data, test_mask)

In [34]:
train_data

Data(edge_index=[2, 22704066], node_id=[74460], edge_attr=[22704066, 6], time=[22704066], edge_y=[22704066], node_type=[74460], edge_type=[22704066], x=[74460, 1])

In [35]:
batch_size = 100000
train_loader = LinkNeighborLoader(train_data, num_neighbors=[30] * 2, batch_size=batch_size, edge_label_index=train_data.edge_index)
val_loader = LinkNeighborLoader(val_data, num_neighbors=[30] * 2, batch_size=batch_size, edge_label_index=val_data.edge_index)
test_loader = LinkNeighborLoader(test_data, num_neighbors=[30] * 2, batch_size=batch_size, edge_label_index=test_data.edge_index)



In [36]:
class TimeEncoder(torch.nn.Module):
    def __init__(self, time_dim):
        super().__init__()
        self.time_dim = time_dim
        self.time_lin = Linear(1, time_dim)

    def forward(self, t: Tensor):
        return self.time_lin(t.view(-1, 1)).cos()

class GATModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, edge_dim, time_enc):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, edge_dim=edge_dim)
        self.conv2 = GATConv(hidden_channels, out_channels, edge_dim=edge_dim)
        self.time_enc = time_enc

    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, t: Tensor) -> Tensor:
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        # edge_attr: Edge features
        # t: Timestamp of edges
        time_enc = self.time_enc(t)
        edge_attr = torch.cat([time_enc, edge_attr], dim=-1)
        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.conv2(x, edge_index, edge_attr)
        return x

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, time_enc):
        super().__init__()
        self.time_enc = time_enc
        dim = 2 * in_channels + time_enc.time_dim
        self.lin = Linear(dim, dim)
        self.lin_final = Linear(dim, 1)

    def forward(self, z_src, z_dst, t):
        time_enc = self.time_enc(t)
        input = torch.cat([z_src, z_dst, time_enc], dim=-1)
        h = self.lin(input)
        h = h.relu()
        return self.lin_final(h)

emd_dim = hidden_dim = time_dim = 100

time_enc = TimeEncoder(time_dim).to(device)
gnn = GATModel(data.x.size(1), hidden_dim, emd_dim, time_dim+data.edge_attr.size(1), time_enc).to(device)
link_pred = LinkPredictor(emd_dim, time_enc).to(device)

# Test forward run of the models
batch = next(iter(train_loader)).to(device)
embs = gnn(batch.x, batch.edge_index, batch.edge_attr, batch.time)
print(embs.shape)
link_preds = link_pred(embs[batch.edge_index[0]], embs[batch.edge_index[1]], batch.time)
print(link_preds.shape)

torch.Size([4102, 100])
torch.Size([109329, 1])


In [37]:
optimizer = torch.optim.Adam(set(time_enc.parameters()) | set(gnn.parameters()) | set(link_pred.parameters()), lr=0.0001)
criterion = torch.nn.BCEWithLogitsLoss()

In [38]:
@torch.no_grad()
def test(gnn, link_pred):
    gnn.eval()
    link_pred.eval()
    scores = []
    for batch in tqdm(val_loader, desc='Validation Batches'):
        node_embs = gnn(batch.x, batch.edge_index, batch.edge_attr, batch.time)
        link_preds = link_pred(node_embs[batch.edge_index[0]], node_embs[batch.edge_index[1]], batch.time).sigmoid()
        #RocCurveDisplay.from_predictions(data.edge_y.cpu(), link_preds.cpu())
        #plt.show()
        scores += [roc_auc_score(batch.edge_y.cpu().numpy(), link_preds.cpu().numpy())]
    return np.mean(scores)

In [39]:
for epoch in range(0, 10):
    gnn.train()
    link_pred.train()
    
    for batch in tqdm(train_loader, desc='Training Batches'):
        optimizer.zero_grad()
        node_embs = gnn(batch.x, batch.edge_index, batch.edge_attr, batch.time)
    
        # take all 1 edges and sample the same number of 0 edges from edge_index
        positive_edges = batch.edge_index[:, batch.edge_y == 1]
        negative_indices = torch.nonzero(batch.edge_y == 0).squeeze()
        
        negative_indices = negative_indices[torch.randperm(negative_indices.size(0))][:positive_edges.size(1)]
        negative_edges = batch.edge_index[:, negative_indices]
        
        all_edges = torch.cat([positive_edges, negative_edges], dim=1)
        all_labels = torch.cat([torch.ones(positive_edges.shape[1]), 
                                torch.zeros(negative_edges.shape[1])], dim=0).double()
        all_times = torch.cat([batch.time[batch.edge_y == 1], batch.time[negative_indices]])
        all_labels = all_labels.to(device)
        
        link_preds = link_pred(node_embs[all_edges[0]], node_embs[all_edges[1]], all_times).flatten()
        
        loss = criterion(link_preds, all_labels)

        loss.backward()
        optimizer.step()

    if epoch % 1 == 0:
        test_auc = test(gnn, link_pred)
        print(f'Epoch: {epoch}, Train Loss: {loss}, Test AUC: {test_auc}')
    

Training Batches: 100%|██████████| 228/228 [01:06<00:00,  3.43it/s]
Validation Batches: 100%|██████████| 49/49 [00:13<00:00,  3.59it/s]


Epoch: 0, Train Loss: 0.6939550127618092, Test AUC: 0.506664245570543


Training Batches: 100%|██████████| 228/228 [01:06<00:00,  3.40it/s]
Validation Batches: 100%|██████████| 49/49 [00:13<00:00,  3.58it/s]


Epoch: 1, Train Loss: 0.6945394015802459, Test AUC: 0.4911348697783684


Training Batches: 100%|██████████| 228/228 [01:06<00:00,  3.41it/s]
Validation Batches: 100%|██████████| 49/49 [00:13<00:00,  3.59it/s]


Epoch: 2, Train Loss: 0.6944559231766214, Test AUC: 0.5068690503978903


Training Batches:  96%|█████████▋| 220/228 [01:05<00:02,  3.35it/s]


KeyboardInterrupt: 