In [112]:
import torch
import numpy as np
import torch_geometric
from torch import Tensor
from torch.nn import Linear, Embedding
from torch_geometric.data import Data
from torch_geometric.nn.conv import GATConv
from sklearn.metrics import roc_auc_score, average_precision_score, RocCurveDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda:2


0

In [113]:
# experimental setting
relative_time = False
semi_transductive = True
heterogenous_msg_passing = True
dataset_path = "data/act-mooc/graph.pt"
#dataset_path = "data/junyi/graph.pt"

In [114]:
# hyperparameters
batch_size = 256
num_epochs = 5
emb_dim = 50
hidden_dim = 50
time_dim = 50
identity_dim = 25
learning_rate = 0.0001

In [115]:
data = torch.load(dataset_path)
data = data.to_homogeneous()
data.x = torch.zeros(data.num_nodes, 1)
data.edge_attr = data.edge_attr.float()
data.time = data.time.float()
data = data.to(device)
data

Data(edge_index=[2, 823498], node_id=[7144], edge_attr=[823498, 4], time=[823498], edge_y=[823498], node_type=[7144], edge_type=[823498], x=[7144, 1])

In [116]:
def create_mask_in_batches(edge_index, nodes, batch_size):
    mask = torch.zeros(edge_index.size(1), dtype=torch.bool, device=device)
    for i in range(0, len(nodes), batch_size):
        batch_nodes = nodes[i:i + batch_size]
        batch_mask = ((edge_index[0].unsqueeze(1) == batch_nodes).any(dim=1)) | \
                     ((edge_index[1].unsqueeze(1) == batch_nodes).any(dim=1))
        mask |= batch_mask
    return mask

In [117]:
# Function to create a new Data object from the original data and a mask
def create_subset(data, mask):
    # Filter data using the mask
    subset = Data()
    for key, item in data:
        if key in ['node_id', 'node_type', 'x']:
            subset[key] = item
        elif key in ['edge_index']:
            subset[key] = item[:,mask]
        else:
            subset[key] = item[mask]
    return subset

In [118]:
if relative_time:
    unique_node_ids = data.node_id[data.node_type==0].unique()
    min_times_per_node = {node_id: data.time[data.edge_index[0,:] == node_id].min() for node_id in unique_node_ids}
    
    # Subtract the minimum time from each node's times
    for node_id, min_time in min_times_per_node.items():
        data.time[data.edge_index[0,:] == node_id] -= min_time
        data.time[data.edge_index[1,:] == node_id] -= min_time

In [119]:
time_array = data.time.cpu().numpy()
quantile_70 = np.quantile(time_array, 0.7)
quantile_85 = np.quantile(time_array, 0.85)

# Create masks for splitting the data
train_mask = time_array < quantile_70
val_mask = time_array < quantile_85
test_mask = np.ones_like(time_array, dtype=bool)

# Create subsets
train_data = create_subset(data, train_mask)
val_data = create_subset(data, val_mask)
test_data = create_subset(data, test_mask)

This class iterates over batches of users and returns all edges they are connected with.

In [121]:
train_loader = TemporalLoader(train_data, 0, batch_size=batch_size)
val_loader = TemporalLoader(val_data, quantile_70, batch_size=batch_size)
test_loader = TemporalLoader(test_data, quantile_85, batch_size=batch_size)

In [122]:
class TimeEncoder(torch.nn.Module):
    def __init__(self, time_dim):
        super().__init__()
        self.time_dim = time_dim
        self.time_lin = Linear(1, time_dim)

    def forward(self, t: Tensor):
        return self.time_lin(t.view(-1, 1)).cos()

class GATModel(torch.nn.Module):
    def __init__(self, emb_dim, hidden_channels, out_channels, edge_dim, time_enc, resource_ids):
        super().__init__()
        self.conv1_user_to_resource = GATConv(emb_dim, hidden_channels, edge_dim=edge_dim)
        self.conv2_user_to_resource = GATConv(hidden_channels, out_channels, edge_dim=edge_dim)
        
        if heterogenous_msg_passing:
            self.conv1_resource_to_user = GATConv(emb_dim, hidden_channels, edge_dim=edge_dim)
            self.conv2_resource_to_user = GATConv(hidden_channels, out_channels, edge_dim=edge_dim)
            
        self.time_enc = time_enc
        if semi_transductive:
            self.mapping = {nid: i+1 for i, nid in enumerate(resource_ids)}
            self.embedding = Embedding(len(resource_ids)+1, emb_dim)
        else:
            self.mapping = {nid: 1 for i, nid in enumerate(resource_ids)}
            self.embedding = Embedding(2, emb_dim)

    def forward(self, node_id: Tensor, node_type: Tensor, edge_index: Tensor, edge_attr: Tensor, t: Tensor, edge_type: Tensor) -> Tensor:
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        # edge_attr: Edge features
        # t: Timestamp of edges
        time_enc = self.time_enc(t)
        edge_attr = torch.cat([time_enc, edge_attr], dim=-1)
        

        node_ids = torch.tensor([self.mapping.get(nid, 0) for nid in node_id]).to(device)
        x = self.embedding(node_ids)

        mask = edge_type == 0
        
        if heterogenous_msg_passing:
            x_resources = self.conv1_user_to_resource(x, edge_index[:, mask], edge_attr[mask]).relu()
            x_users = self.conv1_resource_to_user(x, edge_index[:, ~mask], edge_attr[~mask]).relu()
            users = node_type == 0
            h = torch.zeros_like(x_users)
            h[users] = x_users[users]
            h[~users] = x_resources[~users]
        else:
            h = self.conv1_user_to_resource(x, edge_index, edge_attr).relu()

        if heterogenous_msg_passing:
            h_resources = self.conv2_user_to_resource(h, edge_index[:, mask], edge_attr[mask]).relu()
            h_users = self.conv2_resource_to_user(h, edge_index[:, ~mask], edge_attr[~mask]).relu()
            o = torch.zeros_like(h_users)
            o[users,:] = h_users[users,:]
            o[~users,:] = h_resources[~users,:]
        else: 
            o = self.conv2_user_to_resource(h, edge_index, edge_attr).relu()
        
        return o

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, time_enc):
        super().__init__()
        self.time_enc = time_enc
        dim = 2 * in_channels + time_enc.time_dim # 2*
        self.lin = Linear(dim, dim)
        self.lin_final = Linear(dim, 1)

    def forward(self, z_src, z_dst, t):
        time_enc = self.time_enc(t)
        input = torch.cat([z_src, z_dst, time_enc], dim=-1)
        h = self.lin(input)
        h = h.relu()
        
        return self.lin_final(h)

criterion = torch.nn.BCEWithLogitsLoss()

time_enc = TimeEncoder(time_dim).to(device)
gnn = GATModel(identity_dim, hidden_dim, emb_dim, time_dim+data.edge_attr.size(1), time_enc, data.node_id[data.node_type == 1]).to(device)
link_pred = LinkPredictor(emb_dim, time_enc).to(device)

# Test forward run of the models
batch = next(iter(train_loader)).to(device)
train_loader.reset() # reset
embs = gnn(batch.node_id, batch.node_type, batch.edge_index, batch.edge_attr, batch.time, batch.edge_type)
print(embs.shape)
link_preds = link_pred(embs[batch.edge_index[0]], embs[batch.edge_index[1]], batch.time)
print(link_preds.shape)

torch.Size([7144, 50])
torch.Size([256, 1])


In [123]:
sum(p.numel() for p in gnn.parameters())

21650

In [124]:
@torch.no_grad()
def test(gnn, link_pred, time_enc, loader):
    gnn.eval()
    link_pred.eval()
    time_enc.eval()

    true_labels = torch.Tensor().to(device)
    predictions = torch.Tensor().to(device)
    batches = 0
    for batch in tqdm(loader, desc='Batches'):
        batches += 1
        mask = (batch.edge_y < 0)
        node_embs = gnn(batch.node_id, batch.node_type, batch.edge_index[:,mask], batch.edge_attr[mask], batch.time[mask], batch.edge_type[mask])
        mask = (batch.edge_y >= 0) & (batch.edge_type == 1)
        link_preds = link_pred(node_embs[batch.edge_index[0,mask]], node_embs[batch.edge_index[1,mask]], batch.time[mask]).sigmoid()

        predictions = torch.cat((predictions, link_preds[:,0].detach()), dim=0)
        true_labels = torch.cat((true_labels, batch.edge_y[mask]), dim=0)
    if batches == 0:
        return test(gnn, link_pred, time_enc, loader)
    loss = criterion(predictions, true_labels)
    auc = roc_auc_score(true_labels.cpu().numpy(), predictions.cpu().numpy())
    ap = average_precision_score(true_labels.cpu().numpy(), predictions.cpu().numpy())
    return loss, auc, ap

In [126]:
for i in range(0, 3):
    time_enc = TimeEncoder(time_dim).to(device)
    gnn = GATModel(identity_dim, hidden_dim, emb_dim, time_dim+data.edge_attr.size(1), time_enc, data.node_id[data.node_type == 1]).to(device)
    link_pred = LinkPredictor(emb_dim, time_enc).to(device)
    optimizer = torch.optim.Adam(set(time_enc.parameters()) | set(gnn.parameters()) | set(link_pred.parameters()), lr=learning_rate)
    
    for epoch in range(0, num_epochs):
        gnn.train()
        link_pred.train()
        time_enc.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc='Training Batches'):
            batch.to(device)
            # forward pass with edges that are older than this batch
            mask = batch.edge_y < 0
            node_embs = gnn(batch.node_id, batch.node_type, batch.edge_index[:,mask], batch.edge_attr[mask], batch.time[mask], batch.edge_type[mask])
            
            # first predict for all dropout edges
            mask = (batch.edge_y == 1) & (batch.edge_type == 1)
            positive_edges = batch.edge_index[:, mask]
            if positive_edges.size(1) == 0: continue
            pos_out = link_pred(node_embs[positive_edges[0]], node_embs[positive_edges[1]], batch.time[mask])
            loss = criterion(pos_out, torch.ones_like(pos_out))
    
            # sample the same number of 0 edges from edge_index
            negative_indices = torch.nonzero((batch.edge_y == 0) & (batch.edge_type == 1)).squeeze()
            negative_indices = negative_indices[torch.randperm(negative_indices.size(0))][:positive_edges.size(1)]
            negative_edges = batch.edge_index[:, negative_indices]
            neg_out = link_pred(node_embs[negative_edges[0]], node_embs[negative_edges[1]], batch.time[negative_indices])
            loss += criterion(neg_out, torch.zeros_like(neg_out))
    
            # backward pass and gradient update
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss
    
        val_loss, val_auc, val_ap = test(gnn, link_pred, time_enc, val_loader)
        train_loss = total_loss / (2 * len(train_loader))
        print(f'Epoch: {epoch}, Train Loss: {train_loss}, Val Loss: {val_loss}, Val AUC: {val_auc}')
        
    train_loss, train_auc, train_ap = test(gnn, link_pred, time_enc, train_loader)
    test_loss, test_auc, test_ap = test(gnn, link_pred, time_enc, test_loader)
    
    append_to_csv("results.csv", train_loss.item(), val_loss.item(), test_loss.item(), train_auc, val_auc, test_auc, train_ap, val_ap, test_ap, i)

Training Batches: 100%|██████████| 2251/2251 [02:19<00:00, 16.12it/s]
Batches:   0%|          | 0/482 [00:00<?, ?it/s]
Batches: 100%|██████████| 482/482 [00:24<00:00, 19.32it/s]


Epoch: 0, Train Loss: 0.47399836778640747, Val Loss: 0.9258419871330261, Val AUC: 0.7419130624533772


Training Batches: 100%|██████████| 2251/2251 [02:20<00:00, 16.08it/s]
Batches: 100%|██████████| 482/482 [00:24<00:00, 19.62it/s]


Epoch: 1, Train Loss: 0.462100625038147, Val Loss: 0.918373703956604, Val AUC: 0.744299491923845


Training Batches: 100%|██████████| 2251/2251 [02:18<00:00, 16.27it/s]
Batches: 100%|██████████| 482/482 [00:24<00:00, 19.38it/s]


Epoch: 2, Train Loss: 0.4628545045852661, Val Loss: 0.9034318923950195, Val AUC: 0.7417200345296964


Training Batches: 100%|██████████| 2251/2251 [02:00<00:00, 18.63it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.37it/s]


Epoch: 3, Train Loss: 0.463765949010849, Val Loss: 0.9231296181678772, Val AUC: 0.7374693507289175


Training Batches: 100%|██████████| 2251/2251 [02:21<00:00, 15.86it/s]
Batches: 100%|██████████| 482/482 [00:25<00:00, 19.16it/s]


Epoch: 4, Train Loss: 0.4584500193595886, Val Loss: 0.9090256690979004, Val AUC: 0.741870813562711


Batches: 100%|██████████| 2251/2251 [01:53<00:00, 19.87it/s]
Batches:   0%|          | 0/482 [00:00<?, ?it/s]
Batches: 100%|██████████| 482/482 [00:26<00:00, 18.05it/s]
Training Batches: 100%|██████████| 2251/2251 [02:16<00:00, 16.50it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.39it/s]


Epoch: 0, Train Loss: 0.47761037945747375, Val Loss: 0.9328076839447021, Val AUC: 0.7423107339418967


Training Batches: 100%|██████████| 2251/2251 [01:26<00:00, 26.16it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.74it/s]


Epoch: 1, Train Loss: 0.46269509196281433, Val Loss: 0.9034401178359985, Val AUC: 0.738573397393995


Training Batches: 100%|██████████| 2251/2251 [01:27<00:00, 25.79it/s]
Batches: 100%|██████████| 482/482 [00:15<00:00, 31.86it/s]


Epoch: 2, Train Loss: 0.4641542434692383, Val Loss: 0.9101966023445129, Val AUC: 0.7432027991612384


Training Batches: 100%|██████████| 2251/2251 [01:26<00:00, 26.02it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.71it/s]


Epoch: 3, Train Loss: 0.4641079604625702, Val Loss: 0.914806067943573, Val AUC: 0.740816706143825


Training Batches: 100%|██████████| 2251/2251 [01:26<00:00, 25.96it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.82it/s]


Epoch: 4, Train Loss: 0.46169334650039673, Val Loss: 0.9138443470001221, Val AUC: 0.7395095219418661


Batches: 100%|██████████| 2251/2251 [01:05<00:00, 34.56it/s]
Batches: 100%|██████████| 482/482 [00:15<00:00, 31.89it/s]
Training Batches: 100%|██████████| 2251/2251 [01:26<00:00, 26.14it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.36it/s]


Epoch: 0, Train Loss: 0.4961237609386444, Val Loss: 0.9304590821266174, Val AUC: 0.7292831435097245


Training Batches: 100%|██████████| 2251/2251 [01:26<00:00, 26.13it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.76it/s]


Epoch: 1, Train Loss: 0.4638499915599823, Val Loss: 0.8991677165031433, Val AUC: 0.7424007271231149


Training Batches: 100%|██████████| 2251/2251 [01:27<00:00, 25.85it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.75it/s]


Epoch: 2, Train Loss: 0.4524284303188324, Val Loss: 0.9032407402992249, Val AUC: 0.7574341064704088


Training Batches: 100%|██████████| 2251/2251 [01:26<00:00, 26.01it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.50it/s]


Epoch: 3, Train Loss: 0.4366084933280945, Val Loss: 0.8871848583221436, Val AUC: 0.759292641098798


Training Batches: 100%|██████████| 2251/2251 [01:26<00:00, 25.93it/s]
Batches: 100%|██████████| 482/482 [00:14<00:00, 32.82it/s]


Epoch: 4, Train Loss: 0.44076114892959595, Val Loss: 0.8817405700683594, Val AUC: 0.7631146997300685


Batches: 100%|██████████| 2251/2251 [01:05<00:00, 34.53it/s]
Batches: 100%|██████████| 482/482 [00:15<00:00, 31.81it/s]
