In [122]:
import torch
from torch import Tensor
from torch.nn import Linear
from torch_geometric.nn.conv import GATConv
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score, RocCurveDisplay
import matplotlib.pyplot as plt
torch.set_default_dtype(torch.float32)
cuda_available = torch.cuda.is_available()
print("CUDA Available:", cuda_available)

CUDA Available: True


In [123]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [124]:
data = torch.load("data/act-mooc/graph.pt")

In [125]:
#data = torch.load("data/junyi/graph.pt")
#del data[('resource', 'rev_accesses', 'user')]
data_hom = data.to_homogeneous()

In [130]:
data = data_hom
data.x = torch.randn(data.num_nodes, 1)
data = data.to(device)
data

Data(edge_index=[2, 823498], node_id=[7144], edge_attr=[823498, 4], time=[823498], edge_y=[823498], node_type=[7144], edge_type=[823498], x=[7144, 1])

In [136]:
data.edge_attr =data.edge_attr.float()
data.time =data.time.float()

data.time.dtype

torch.float32

In [137]:
class TimeEncoder(torch.nn.Module):
    def __init__(self, time_dim):
        super().__init__()
        self.time_dim = time_dim
        self.time_lin = Linear(1, time_dim)

    def forward(self, t: Tensor):
        # embed time and add to edge_attr
        return self.time_lin(t.view(-1, 1)).cos()

class GATModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, edge_dim, time_enc):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, edge_dim=edge_dim)
        self.conv2 = GATConv(hidden_channels, out_channels, edge_dim=edge_dim)
        self.time_enc = time_enc

    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, t: Tensor) -> Tensor:
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        # edge_attr: Edge features
        # t: Timestamp of edges
        time_enc = self.time_enc(t)
        edge_attr = torch.cat([time_enc, edge_attr], dim=-1)
        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.conv2(x, edge_index, edge_attr)
        return x

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, time_enc):
        super().__init__()
        self.time_enc = time_enc
        dim = 2 * in_channels + time_enc.time_dim
        self.lin = Linear(dim, dim)
        self.lin_final = Linear(dim, 1)

    def forward(self, z_src, z_dst, t):
        time_enc = self.time_enc(t)
        input = torch.cat([z_src, z_dst, time_enc], dim=-1)
        h = self.lin(input)
        h = h.relu()
        return self.lin_final(h)


emd_dim = hidden_dim = time_dim = 16

time_enc = TimeEncoder(time_dim).to(device)
gnn = GATModel(data.x.size(1), hidden_dim, emd_dim, time_dim+data.edge_attr.size(1), time_enc).to(device)
link_pred = LinkPredictor(emd_dim, time_enc).to(device)

# Test forward run of the models
embs = gnn(data.x, data.edge_index, data.edge_attr, data.time)
print(embs.shape)
link_preds = link_pred(embs[data.edge_index[0]], embs[data.edge_index[1]], data.time)
print(link_preds.shape)

torch.Size([7144, 16])
torch.Size([823498, 1])


In [138]:
optimizer = torch.optim.Adam(set(time_enc.parameters()) | set(gnn.parameters()) | set(link_pred.parameters()), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()

In [139]:
@torch.no_grad()
def test(gnn, link_pred, data):
    gnn.eval()
    link_pred.eval()
    node_embs = gnn(data.x, data.edge_index, data.edge_attr, data.time)
    link_preds = link_pred(node_embs[data.edge_index[0]], node_embs[data.edge_index[1]], data.time).sigmoid()
    #RocCurveDisplay.from_predictions(data.edge_y, link_preds)
    #plt.show()
    return roc_auc_score(data.edge_y.cpu().numpy(), link_preds.cpu().numpy())

In [140]:
for epoch in range(0, 100):
    gnn.train()
    link_pred.train()
    optimizer.zero_grad()
    
    node_embs = gnn(data.x, data.edge_index, data.edge_attr, data.time)

    # take all 1 edges and sample the same number of 0 edges from edge_index
    positive_edges = data.edge_index[:, data.edge_y == 1]
    negative_indices = torch.nonzero(data.edge_y == 0).squeeze()
    
    negative_indices = negative_indices[torch.randperm(negative_indices.size(0))][:positive_edges.size(1)]
    negative_edges = data.edge_index[:, negative_indices]
    
    all_edges = torch.cat([positive_edges, negative_edges], dim=1)
    all_labels = torch.cat([torch.ones(positive_edges.shape[1]), 
                            torch.zeros(negative_edges.shape[1])], dim=0).double()
    all_times = torch.cat([data.time[data.edge_y == 1], data.time[negative_indices]])
    all_labels = all_labels.to(device)
    
    link_preds = link_pred(node_embs[all_edges[0]], node_embs[all_edges[1]], all_times).flatten()
    
    loss = criterion(link_preds, all_labels)

    loss.backward()
    optimizer.step()
    
    test_auc = test(gnn, link_pred, data)
    print(f'Epoch: {epoch}, Train Loss: {loss}, Test AUC: {test_auc}')
    

Epoch: 0, Train Loss: 0.6937603642696201, Test AUC: 0.498769654065793
Epoch: 1, Train Loss: 0.6945780535233474, Test AUC: 0.5068827445771643
Epoch: 2, Train Loss: 0.6932498224727709, Test AUC: 0.5034052086940485
Epoch: 3, Train Loss: 0.693683015662305, Test AUC: 0.5113500072179161
Epoch: 4, Train Loss: 0.6926424382397065, Test AUC: 0.5071598980185239
Epoch: 5, Train Loss: 0.6932940874504707, Test AUC: 0.5124394207090478
Epoch: 6, Train Loss: 0.6927942042985717, Test AUC: 0.5093633135861677
Epoch: 7, Train Loss: 0.6930127761544214, Test AUC: 0.512417262688953
Epoch: 8, Train Loss: 0.6932954716262995, Test AUC: 0.5192316627745428
Epoch: 9, Train Loss: 0.6924372624923244, Test AUC: 0.5172259100331127
Epoch: 10, Train Loss: 0.692597188023448, Test AUC: 0.5162985950823489
Epoch: 11, Train Loss: 0.6929247273750987, Test AUC: 0.52617324939778
Epoch: 12, Train Loss: 0.6922663260362717, Test AUC: 0.5280391530139831
Epoch: 13, Train Loss: 0.6919736347255832, Test AUC: 0.5232051721997351
Epoch: 1