In [13]:
import math
import logging
import time
import sys
import argparse
import torch
import numpy as np
import pickle
from pathlib import Path

from evaluation.evaluation import eval_edge_prediction
from model.tgn import TGN
from utils.utils import EarlyStopMonitor, RandEdgeSampler, get_neighbor_finder
from utils.data_processing import get_data, compute_time_statistics

In [14]:
# Random Seeds for Reproducability
torch.manual_seed(0)
np.random.seed(0)
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x31561c8b0>

In [15]:

args = {
    "data": "wikipedia",
    "bs": 200,
    "prefix": "",
    "n_degree": 10,
    "n_head": 2,
    "n_epoch": 50,
    "n_layer": 1,
    "lr": 0.0001,
    "patience": 5,
    "n_runs": 1,
    "drop_out": 0.1,
    "gpu": 0,
    "node_dim": 100,
    "time_dim": 100,
    "backprop_every": 1,
    "use_memory": True,
    "embedding_module": "graph_attention",
    "message_function": "identity",
    "memory_updater": "gru",
    "aggregator": "last",
    "memory_update_at_end": False,
    "message_dim": 100,
    "memory_dim": 172,
    "different_new_nodes": False,
    "uniform": False,
    "randomize_features": False,
    "use_destination_embedding_in_message": False,
    "use_source_embedding_in_message": False,
    "dyrep": False,
}

# Extract key parameters for easier use
BATCH_SIZE = args["bs"]
NUM_NEIGHBORS = args["n_degree"]
NUM_EPOCH = args["n_epoch"]
NUM_HEADS = args["n_head"]
DROP_OUT = args["drop_out"]
GPU = args["gpu"]
DATA = args["data"]
NUM_LAYER = args["n_layer"]
LEARNING_RATE = args["lr"]
NODE_DIM = args["node_dim"]
TIME_DIM = args["time_dim"]
USE_MEMORY = args["use_memory"]
MESSAGE_DIM = args["message_dim"]
MEMORY_DIM = args["memory_dim"]

# Configure device
device = torch.device(f"cuda:{GPU}" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [21]:
train_rand_sampler = RandEdgeSampler(train_data.sources, train_data.destinations)
val_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=0)
nn_val_rand_sampler = RandEdgeSampler(new_node_val_data.sources, new_node_val_data.destinations,
                                      seed=1)
test_rand_sampler = RandEdgeSampler(full_data.sources, full_data.destinations, seed=2)
nn_test_rand_sampler = RandEdgeSampler(new_node_test_data.sources,
                                       new_node_test_data.destinations,
                                       seed=3)

In [16]:
# Load data
node_features, edge_features, full_data, train_data, val_data, test_data, new_node_val_data, new_node_test_data = get_data(
    DATA,
    different_new_nodes_between_val_and_test=args["different_new_nodes"],
    randomize_features=args["randomize_features"],
)

# Initialize neighbor finders
train_ngh_finder = get_neighbor_finder(train_data, args["uniform"])
full_ngh_finder = get_neighbor_finder(full_data, args["uniform"])

The dataset has 157474 interactions, involving 9227 different nodes
The training dataset has 81029 interactions, involving 6141 different nodes
The validation dataset has 23621 interactions, involving 3256 different nodes
The test dataset has 23621 interactions, involving 3564 different nodes
The new node validation dataset has 12016 interactions, involving 2120 different nodes
The new node test dataset has 11715 interactions, involving 2437 different nodes
922 nodes were used for the inductive testing, i.e. are never seen during training


In [37]:
# Initialize the TGN model
tgn = TGN(
    neighbor_finder=train_ngh_finder,
    node_features=node_features,
    edge_features=edge_features,
    device=device,
    n_layers=NUM_LAYER,
    n_heads=NUM_HEADS,
    dropout=DROP_OUT,
    use_memory=USE_MEMORY,
    message_dimension=MESSAGE_DIM,
    memory_dimension=MEMORY_DIM,
    memory_update_at_start=not args["memory_update_at_end"],
    embedding_module_type=args["embedding_module"],
    message_function=args["message_function"],
    aggregator_type=args["aggregator"],
    memory_updater_type=args["memory_updater"],
    n_neighbors=NUM_NEIGHBORS,
)

# Set up loss and optimizer
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(tgn.parameters(), lr=LEARNING_RATE)

# Move model to device
tgn = tgn.to(device)

num_batch = math.ceil(len(train_data.sources) / BATCH_SIZE)

In [36]:
# Training loop
new_nodes_val_aps = []
val_aps = []
epoch_times = []
total_epoch_times = []
train_losses = []

early_stopper = EarlyStopMonitor(max_round=args["patience"])
for epoch in range(NUM_EPOCH):
    start_epoch = time.time()
    
    # Initialize memory
    if USE_MEMORY:
        tgn.memory.__init_memory__()
    
    tgn.set_neighbor_finder(train_ngh_finder)
    m_loss = []
    print(f"Starting epoch {epoch}")
    for k in range(0, num_batch, args["backprop_every"]):
        loss = 0
        optimizer.zero_grad()
        for j in range(args["backprop_every"]):
            batch_idx = k + j

            if batch_idx >= num_batch:
                continue

            start_idx = batch_idx * BATCH_SIZE
            end_idx = min(len(train_data.sources), start_idx + BATCH_SIZE)
            
            sources_batch = train_data.sources[start_idx:end_idx]
            destinations_batch = train_data.destinations[start_idx:end_idx]
            edge_idxs_batch = train_data.edge_idxs[start_idx: end_idx]
            timestamps_batch = train_data.timestamps[start_idx:end_idx]

            size = len(sources_batch)
            _, negatives_batch = train_rand_sampler.sample(size)

            with torch.no_grad():
                pos_label = torch.ones(size, dtype=torch.float, device=device)
                neg_label = torch.zeros(size, dtype=torch.float, device=device)
            tgn = tgn.train()
            pos_prob, neg_prob = tgn.compute_edge_probabilities(sources_batch, destinations_batch, negatives_batch,
                                                            timestamps_batch, edge_idxs_batch, NUM_NEIGHBORS)
            loss += criterion(pos_prob.squeeze(), pos_label) + criterion(neg_prob.squeeze(), neg_label)
        loss /= args["backprop_every"]
        loss.backward()
        optimizer.step()
        m_loss.append(loss.item())

        if USE_MEMORY:
            tgn.memory.detach_memory()

        epoch_time = time.time() - start_epoch
        epoch_times.append(epoch_time)

        ### Validation
        tgn.set_neighbor_finder(full_ngh_finder)
        if USE_MEMORY:
            train_memory_backup = tgn.memory.backup_memory()
        val_ap, val_auc = eval_edge_prediction(model=tgn,
                                                negative_edge_sampler=val_rand_sampler,
                                                data=val_data,
                                                n_neighbors=NUM_NEIGHBORS)
        if USE_MEMORY:
            val_memory_backup = tgn.memory.backup_memory()
            tgn.memory.restore_memory(train_memory_backup)
        nn_val_ap, nn_val_auc = eval_edge_prediction(model=tgn,
                                    negative_edge_sampler=val_rand_sampler,
                                    data=new_node_val_data,
                                    n_neighbors=NUM_NEIGHBORS)
        if USE_MEMORY:
            tgn.memory.restore_memory(val_memory_backup)

        new_nodes_val_aps.append(nn_val_ap)
        val_aps.append(val_ap)
        train_losses.append(np.mean(m_loss))
            
        total_epoch_time = time.time() - start_epoch
        total_epoch_times.append(total_epoch_time)
    
    print(f"Epoch {epoch} completed in f{total_epoch_time:.2f}s Loss: {np.mean(epoch_loss):.4f}")
    print(f"val auc: {val_auc}, new node val auc: {nn_val_auc}")
    print(f"val ap: {val_ap}, new node val ap: {nn_val_ap}")

    if early_stopper.early_stop_check(val_ap):
        print('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
        print(f'Loading the best model at epoch {early_stopper.best_epoch}')
        best_model_path = get_checkpoint_path(early_stopper.best_epoch)
        tgn.load_state_dict(torch.load(best_model_path))
        print(f'Loaded the best model at epoch {early_stopper.best_epoch} for inference')
        tgn.eval()
        break
    else:
        torch.save(tgn.state_dict(), get_checkpoint_path(epoch))

    if USE_MEMORY:
        val_memory_backup = tgn.memory.backup_memory()

Starting epoch 0


AssertionError: Trying to update memory to time in the past

In [20]:
!which python
!which jupyter-notebook

/Users/jonathansneh/.pyenv/versions/3.9.7/bin/python
/Users/jonathansneh/.pyenv/versions/3.9.7/bin/jupyter-notebook
