In [47]:
import torch
import numpy as np
import torch_geometric
from torch import Tensor
from torch.nn import Linear, Embedding
from torch_geometric.data import Data, Dataset
from torch_geometric.nn.conv import GATConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import DataLoader
from sklearn.metrics import roc_auc_score, RocCurveDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm
#torch.set_default_dtype(torch.float32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
import gc
torch.cuda.empty_cache()
gc.collect()
#print(torch.cuda.memory_summary())

Using device: cuda


2427

In [68]:
data = torch.load("data/act-mooc/graph.pt")
#data = torch.load("data/junyi/graph.pt")

In [69]:
data_hom = data.to_homogeneous()

In [70]:
data = data_hom
# no data features is implemented as just one random number drawn from normal
data.x = torch.zeros(data.num_nodes, 1)
data.edge_attr = data.edge_attr.float()
data.time = data.time.float()
data = data.to(device)
data

Data(edge_index=[2, 823498], node_id=[7144], edge_attr=[823498, 4], time=[823498], edge_y=[823498], node_type=[7144], edge_type=[823498], x=[7144, 1])

In [73]:
unique_node_ids = data.node_id[data.node_type==0].unique()
min_times_per_node = {node_id: data.time[data.edge_index[0,:] == node_id].min() for node_id in unique_node_ids}

# Subtract the minimum time from each node's times
for node_id, min_time in min_times_per_node.items():
    data.time[data.edge_index[0,:] == node_id] -= min_time
    data.time[data.edge_index[1,:] == node_id] -= min_time

In [82]:
data.node_type

tensor([0, 0, 0,  ..., 1, 1, 1], device='cuda:0')

node_type==1 are resources
let's split randomly by users

In [83]:
time_array = data.time.cpu().numpy()
quantile_70 = np.quantile(time_array, 0.7)
quantile_85 = np.quantile(time_array, 0.85)

# Create masks for splitting the data
train_mask = time_array < quantile_70
val_mask = time_array < quantile_85
test_mask = np.ones_like(time_array, dtype=bool)

# Function to create a new Data object from the original data and a mask
def create_subset(data, mask):
    # Filter data using the mask
    subset = Data()
    for key, item in data:
        if key in ['node_id', 'node_type', 'x']:
            subset[key] = item
        elif key in ['edge_index']:
            subset[key] = item[:,mask]
        else:
            subset[key] = item[mask]
    return subset

# Create subsets
train_data = create_subset(data, train_mask)
val_data = create_subset(data, val_mask)
test_data = create_subset(data, test_mask)

In [84]:
def create_mask_in_batches(edge_index, nodes, batch_size):
    mask = torch.zeros(edge_index.size(1), dtype=torch.bool, device=device)
    for i in range(0, len(nodes), batch_size):
        batch_nodes = nodes[i:i + batch_size]
        batch_mask = ((edge_index[0].unsqueeze(1) == batch_nodes).any(dim=1)) | \
                     ((edge_index[1].unsqueeze(1) == batch_nodes).any(dim=1))
        mask |= batch_mask
    return mask

This class iterates over batches of users and returns all edges they are connected with.

In [85]:
class TemporalLoader:
    def __init__(self, data, start_time, batch_size=32):
        self.data = data
        self.time_order = torch.argsort(self.data.time)
        self.start_time = start_time
        self.start_index = torch.nonzero(self.data.time[self.time_order]>=start_time)[0]
        self.index = self.start_index
        self.batch_size = batch_size
        self.length = len(data.edge_y) - self.start_index

    def __iter__(self):
        return self

    def __len__(self):
        return self.length // self.batch_size

    def reset(self):
        self.index = 0

    def __next__(self):
        if self.index + self.batch_size >= self.length:
            self.index = 0
            raise StopIteration

        # these edges are predicted
        mask = torch.zeros(self.data.edge_index.size(1), dtype=torch.bool).to(device)
        mask[self.time_order[self.index+self.start_index:self.index+self.start_index+self.batch_size]] = 1

        # but add neighbors that may be useful for the prediction = edges containing the same IDs in the time before
        before_mask = torch.zeros(self.data.edge_index.size(1), dtype=torch.bool).to(device)
        before_mask[self.time_order[:self.index+self.start_index]] = 1
        first_time = self.data.time[self.time_order[self.index+self.start_index]]

        edge_nodes = self.data.edge_index[:,mask].unique()
        neighbor_mask = create_mask_in_batches(self.data.edge_index, edge_nodes, 256)

        mask |= (neighbor_mask & before_mask)
        
        batch = create_subset(self.data, mask)
        batch.edge_y[batch.time < first_time] = -1
        self.index += self.batch_size
        return batch

In [86]:
batch_size=512
train_loader = TemporalLoader(train_data, 0, batch_size=batch_size)
val_loader = TemporalLoader(val_data, quantile_70, batch_size=batch_size)
test_loader = TemporalLoader(test_data, quantile_85, batch_size=batch_size)

In [87]:
for batch in val_loader:
    assert batch.time[batch.edge_y>=0].min() >= quantile_70
    assert batch.time[batch.edge_y>=0].max() < quantile_85

In [88]:
for batch in test_loader:
    assert batch.time[batch.edge_y>=0].min() >= quantile_85

In [89]:
class TimeEncoder(torch.nn.Module):
    def __init__(self, time_dim):
        super().__init__()
        self.time_dim = time_dim
        self.time_lin = Linear(1, time_dim)

    def forward(self, t: Tensor):
        return self.time_lin(t.view(-1, 1)).cos()

class GATModel(torch.nn.Module):
    def __init__(self, emb_dim, hidden_channels, out_channels, edge_dim, time_enc, resource_ids):
        super().__init__()
        self.conv1_user_to_resource = GATConv(emb_dim, hidden_channels, edge_dim=edge_dim)
        self.conv1_resource_to_user = GATConv(emb_dim, hidden_channels, edge_dim=edge_dim)

        self.conv2_user_to_resource = GATConv(hidden_channels, out_channels, edge_dim=edge_dim)
        self.conv2_resource_to_user = GATConv(hidden_channels, out_channels, edge_dim=edge_dim)
        self.time_enc = time_enc
        self.mapping = {nid: i+1 for i, nid in enumerate(resource_ids)}
        self.embedding = Embedding(len(resource_ids)+1, emb_dim)

    def forward(self, node_id: Tensor, node_type: Tensor, edge_index: Tensor, edge_attr: Tensor, t: Tensor, edge_type: Tensor) -> Tensor:
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        # edge_attr: Edge features
        # t: Timestamp of edges
        time_enc = self.time_enc(t)
        edge_attr = torch.cat([time_enc, edge_attr], dim=-1)
        mask = edge_type == 0

        node_ids = torch.tensor([self.mapping.get(nid, 0) for nid in node_id]).to(device)
        x = self.embedding(node_ids)
        x_resources = self.conv1_user_to_resource(x, edge_index[:, mask], edge_attr[mask]).relu()
        x_users = self.conv1_resource_to_user(x, edge_index[:, ~mask], edge_attr[~mask]).relu()
        users = node_type == 0
        h = torch.zeros_like(x_users)
        h[users] = x_users[users]
        h[~users] = x_resources[~users]

        h_resources = self.conv2_user_to_resource(h, edge_index[:, mask], edge_attr[mask]).relu()
        h_users = self.conv2_resource_to_user(h, edge_index[:, ~mask], edge_attr[~mask]).relu()
        o = torch.zeros_like(h_users)
        o[users,:] = h_users[users,:]
        o[~users,:] = h_resources[~users,:]
        
        return o

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, time_enc):
        super().__init__()
        self.time_enc = time_enc
        dim = 2 * in_channels + time_enc.time_dim # 2*
        self.lin = Linear(dim, dim)
        self.lin_final = Linear(dim, 1)

    def forward(self, z_src, z_dst, t):
        time_enc = self.time_enc(t)
        input = torch.cat([z_src, z_dst, time_enc], dim=-1) # [z_src] # now only user and time embedding
        h = self.lin(input)
        h = h.relu()
        
        return self.lin_final(h)

emb_dim = hidden_dim = time_dim = 50
identity_dim = 50

time_enc = TimeEncoder(time_dim).to(device)

gnn = GATModel(identity_dim, hidden_dim, emb_dim, time_dim+data.edge_attr.size(1), time_enc, data.node_id[data.node_type == 1]).to(device)
link_pred = LinkPredictor(emb_dim, time_enc).to(device)

# Test forward run of the models
batch = next(iter(train_loader)).to(device)
train_loader.reset() # reset
embs = gnn(batch.node_id, batch.node_type, batch.edge_index, batch.edge_attr, batch.time, batch.edge_type)
print(embs.shape)
link_preds = link_pred(embs[batch.edge_index[0]], embs[batch.edge_index[1]], batch.time)
print(link_preds.shape)

torch.Size([7144, 50])
torch.Size([512, 1])


In [90]:
sum(p.numel() for p in gnn.parameters())

26600

In [91]:
@torch.no_grad()
def test(gnn, link_pred, time_enc):
    gnn.eval()
    link_pred.eval()
    time_enc.eval()
    scores = []
    for batch in tqdm(test_loader, desc='Validation Batches'):
        if (batch.edge_y>0).any():
            mask = (batch.edge_y < 0)
            node_embs = gnn(batch.node_id, batch.node_type, batch.edge_index[:,mask], batch.edge_attr[mask], batch.time[mask], batch.edge_type[mask])
            mask = (batch.edge_y >= 0) & (batch.edge_type == 1)
            link_preds = link_pred(node_embs[batch.edge_index[0,mask]], node_embs[batch.edge_index[1,mask]], batch.time[mask]).sigmoid()
            #RocCurveDisplay.from_predictions(batch.edge_y.cpu(), link_preds.cpu())
            #plt.show()
            scores += [roc_auc_score(batch.edge_y[mask].cpu().numpy(), link_preds.cpu().numpy())]
    return np.mean(scores)

In [None]:
optimizer = torch.optim.Adam(set(time_enc.parameters()) | set(gnn.parameters()) | set(link_pred.parameters()), lr=0.001)
criterion = torch.nn.BCEWithLogitsLoss()

for epoch in range(0, 20):
    gnn.train()
    link_pred.train()
    time_enc.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc='Training Batches'):
        batch.to(device)
        # forward pass with edges that are older than this batch
        mask = batch.edge_y < 0
        node_embs = gnn(batch.node_id, batch.node_type, batch.edge_index[:,mask], batch.edge_attr[mask], batch.time[mask], batch.edge_type[mask])
        
        # first predict for all dropout edges
        mask = (batch.edge_y == 1) & (batch.edge_type == 1)
        positive_edges = batch.edge_index[:, mask]
        if positive_edges.size(1) == 0: continue
        pos_out = link_pred(node_embs[positive_edges[0]], node_embs[positive_edges[1]], batch.time[mask])
        loss = criterion(pos_out, torch.ones_like(pos_out))

        # sample the same number of 0 edges from edge_index
        negative_indices = torch.nonzero((batch.edge_y == 0) & (batch.edge_type == 1)).squeeze()
        negative_indices = negative_indices[torch.randperm(negative_indices.size(0))][:positive_edges.size(1)]
        negative_edges = batch.edge_index[:, negative_indices]
        neg_out = link_pred(node_embs[negative_edges[0]], node_embs[negative_edges[1]], batch.time[negative_indices])
        loss += criterion(neg_out, torch.zeros_like(neg_out))

        # backward pass and gradient update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss

    test_auc = test(gnn, link_pred, time_enc)
    print(f'Epoch: {epoch}, Train Loss: {total_loss / train_data.edge_y.sum()}, Test AUC: {test_auc}')
    

Training Batches: 100%|██████████| 1125/1125 [00:49<00:00, 22.96it/s]
Validation Batches: 100%|██████████| 241/241 [00:07<00:00, 34.34it/s]


Epoch: 0, Train Loss: 0.20640122890472412, Test AUC: 0.4833717630479497


Training Batches: 100%|██████████| 1125/1125 [00:48<00:00, 23.33it/s]
Validation Batches: 100%|██████████| 241/241 [00:06<00:00, 34.80it/s]


Epoch: 1, Train Loss: 0.19788284599781036, Test AUC: 0.6908669020238983


Training Batches: 100%|██████████| 1125/1125 [00:48<00:00, 23.36it/s]
Validation Batches:  89%|████████▉ | 214/241 [00:06<00:00, 40.90it/s]

In [None]:
print(torch.cuda.memory_summary())