In [1]:
import torch
import numpy as np
import torch_geometric
from torch import Tensor
from torch.nn import Linear
from torch_geometric.data import Data, Dataset
from torch_geometric.nn.conv import GATConv
from torch_geometric.utils import negative_sampling
from torch_geometric.loader import DataLoader
from sklearn.metrics import roc_auc_score, RocCurveDisplay
import matplotlib.pyplot as plt
from tqdm import tqdm
#torch.set_default_dtype(torch.float32)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
import gc
torch.cuda.empty_cache()
gc.collect()
#print(torch.cuda.memory_summary())

Using device: cuda


0

In [3]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------

In [4]:
#data = torch.load("data/act-mooc/graph.pt")
data = torch.load("data/junyi/graph.pt")

In [5]:
data_hom = data.to_homogeneous()

In [6]:
data = data_hom
# no data features is implemented as just one random number drawn from normal
data.x = torch.randn(data.num_nodes, 1)
data.edge_attr = data.edge_attr.float()
data.time = data.time.float()
data = data.to(device)
data

Data(edge_index=[2, 32434622], node_id=[74460], edge_attr=[32434622, 6], time=[32434622], edge_y=[32434622], node_type=[74460], edge_type=[32434622], x=[74460, 1])

node_type==1 are resources
let's split randomly by users

In [7]:
user_nodes = data.node_id[data.node_type==0]
user_nodes = user_nodes[torch.randperm(user_nodes.shape[0])]
resource_nodes = data.node_id[data.node_type==1]

lim_train = int(user_nodes.size(0) * 0.7)
lim_val = int(user_nodes.size(0) * 0.85)
train_nodes = user_nodes[:lim_train]
val_nodes = user_nodes[lim_train:lim_val]
test_nodes = user_nodes[lim_val:]

In [8]:
def create_mask(edge_index, nodes):
    mask = torch.zeros(edge_index.size(1), dtype=torch.bool).to(device)
    for node in nodes:
        mask |= (edge_index[0] == node) | (edge_index[1] == node)
    return mask

def create_mask_in_batches(edge_index, nodes, batch_size):
    mask = torch.zeros(edge_index.size(1), dtype=torch.bool, device=device)
    for i in range(0, len(nodes), batch_size):
        batch_nodes = nodes[i:i + batch_size]
        batch_mask = ((edge_index[0].unsqueeze(1) == batch_nodes).any(dim=1)) | \
                     ((edge_index[1].unsqueeze(1) == batch_nodes).any(dim=1))
        mask |= batch_mask
    return mask

train_mask = create_mask_in_batches(data.edge_index, train_nodes, batch_size=256)
val_mask = create_mask_in_batches(data.edge_index, val_nodes, batch_size=256)
test_mask = create_mask_in_batches(data.edge_index, test_nodes, batch_size=256)

def create_subset(data, mask):
    subset = Data()
    for key, item in data:
        if key in ['node_id', 'node_type', 'x']:
            subset[key] = item
        elif key in ['edge_index']:
            subset[key] = item[:,mask]
        else:
            subset[key] = item[mask]
    return subset

train_data = create_subset(data, train_mask)
val_data = create_subset(data, val_mask)
test_data = create_subset(data, test_mask)

This class iterates over batches of users and returns all edges they are connected with.

In [9]:
class UserLoader:
    def __init__(self, data, user_nodes, batch_size=32):
        self.data = data
        self.index = 0
        self.user_nodes = user_nodes
        self.batch_size = batch_size

    def __iter__(self):
        return self

    def __next__(self):
        if self.index >= len(self.user_nodes):
            self.index = 0
            raise StopIteration
        
        selected_nodes = self.user_nodes[self.index:self.index+self.batch_size]
        mask = create_mask_in_batches(self.data.edge_index, selected_nodes, batch_size=256)
        batch = create_subset(self.data, mask)
        self.index += self.batch_size
        return batch

In [10]:
batch_size=2048
train_loader = UserLoader(train_data, train_nodes, batch_size=batch_size)
val_loader = UserLoader(val_data, val_nodes, batch_size=batch_size)
test_loader = UserLoader(test_data, test_nodes, batch_size=batch_size)

In [11]:
class TimeEncoder(torch.nn.Module):
    def __init__(self, time_dim):
        super().__init__()
        self.time_dim = time_dim
        self.time_lin = Linear(1, time_dim)

    def forward(self, t: Tensor):
        return self.time_lin(t.view(-1, 1)).cos()

class GATModel(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, edge_dim, time_enc):
        super().__init__()
        self.conv1 = GATConv(in_channels, hidden_channels, edge_dim=edge_dim)
        self.conv2 = GATConv(hidden_channels, out_channels, edge_dim=edge_dim)
        self.time_enc = time_enc

    def forward(self, x: Tensor, edge_index: Tensor, edge_attr: Tensor, t: Tensor) -> Tensor:
        # x: Node feature matrix of shape [num_nodes, in_channels]
        # edge_index: Graph connectivity matrix of shape [2, num_edges]
        # edge_attr: Edge features
        # t: Timestamp of edges
        time_enc = self.time_enc(t)
        edge_attr = torch.cat([time_enc, edge_attr], dim=-1)
        x = self.conv1(x, edge_index, edge_attr).relu()
        x = self.conv2(x, edge_index, edge_attr)
        return x

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, time_enc):
        super().__init__()
        self.time_enc = time_enc
        dim = 2 * in_channels + time_enc.time_dim
        self.lin = Linear(dim, dim)
        self.lin_final = Linear(dim, 1)

    def forward(self, z_src, z_dst, t):
        time_enc = self.time_enc(t)
        input = torch.cat([z_src, z_dst, time_enc], dim=-1)
        h = self.lin(input)
        h = h.relu()
        return self.lin_final(h)

emd_dim = hidden_dim = time_dim = 100

time_enc = TimeEncoder(time_dim).to(device)
gnn = GATModel(data.x.size(1), hidden_dim, emd_dim, time_dim+data.edge_attr.size(1), time_enc).to(device)
link_pred = LinkPredictor(emd_dim, time_enc).to(device)

# Test forward run of the models
batch = next(iter(train_loader)).to(device)
train_loader = UserLoader(train_data, train_nodes, batch_size=batch_size) # reset
embs = gnn(batch.x, batch.edge_index, batch.edge_attr, batch.time)
print(embs.shape)
link_preds = link_pred(embs[batch.edge_index[0]], embs[batch.edge_index[1]], batch.time)
print(link_preds.shape)

torch.Size([74460, 100])
torch.Size([857684, 1])


In [12]:
@torch.no_grad()
def test(gnn, link_pred):
    gnn.eval()
    link_pred.eval()
    scores = []
    for batch in tqdm(val_loader, desc='Validation Batches'):
        node_embs = gnn(batch.x, batch.edge_index, batch.edge_attr, batch.time)
        link_preds = link_pred(node_embs[batch.edge_index[0]], node_embs[batch.edge_index[1]], batch.time).sigmoid()
        #RocCurveDisplay.from_predictions(batch.edge_y.cpu(), link_preds.cpu())
        #plt.show()
        scores += [roc_auc_score(batch.edge_y.cpu().numpy(), link_preds.cpu().numpy())]
    return np.mean(scores)

In [16]:
optimizer = torch.optim.Adam(set(time_enc.parameters()) | set(gnn.parameters()) | set(link_pred.parameters()), lr=0.0001)
criterion = torch.nn.BCEWithLogitsLoss()

for epoch in range(0, 20):
    gnn.train()
    link_pred.train()
    
    for batch in tqdm(train_loader, desc='Training Batches'):
        batch.to(device)
        # forward pass
        node_embs = gnn(batch.x, batch.edge_index, batch.edge_attr, batch.time)
    
        # first predict for all dropout edges
        positive_edges = batch.edge_index[:, batch.edge_y == 1]
        pos_out = link_pred(node_embs[positive_edges[0]], node_embs[positive_edges[1]], batch.time[batch.edge_y == 1])
        loss = criterion(pos_out, torch.ones_like(pos_out))

        # sample the same number of 0 edges from edge_index
        negative_indices = torch.nonzero(batch.edge_y == 0).squeeze()
        negative_indices = negative_indices[torch.randperm(negative_indices.size(0))][:positive_edges.size(1)]
        negative_edges = batch.edge_index[:, negative_indices]
        neg_out = link_pred(node_embs[negative_edges[0]], node_embs[negative_edges[1]], batch.time[negative_indices])
        loss += criterion(neg_out, torch.zeros_like(neg_out))

        # backward pass and gradient update
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    test_auc = test(gnn, link_pred)
    print(f'Epoch: {epoch}, Train Loss: {loss}, Test AUC: {test_auc}')
    

Training Batches: 20it [00:15,  1.28it/s]
Validation Batches: 6it [00:02,  2.02it/s]


Epoch: 0, Train Loss: 1.3535363674163818, Test AUC: 0.6941514756658488


Training Batches: 25it [00:19,  1.28it/s]
Validation Batches: 6it [00:02,  2.03it/s]


Epoch: 1, Train Loss: 1.3568916320800781, Test AUC: 0.699999510903583


Training Batches: 25it [00:19,  1.28it/s]
Validation Batches: 6it [00:02,  2.02it/s]


Epoch: 2, Train Loss: 1.3589524030685425, Test AUC: 0.687752449460473


Training Batches: 1it [00:01,  1.45s/it]


KeyboardInterrupt: 

In [13]:
print(torch.cuda.memory_summary())

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |   8863 MiB |  10822 MiB |   4524 GiB |   4515 GiB |
|       from large pool |   8858 MiB |  10818 MiB |   4524 GiB |   4515 GiB |
|       from small pool |      4 MiB |      6 MiB |      0 GiB |      0 GiB |
|---------------------------------------------------------------------------|
| Active memory         |   8863 MiB |  10822 MiB |   4524 GiB |   4515 GiB |
|       from large pool |   8858 MiB |  10818 MiB |   4524 GiB |   4515 GiB |
|       from small pool |      4 MiB |      6 MiB |      0 GiB |      0 GiB |
|---------------------------------------------------------------