In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
cd /content/drive/MyDrive/mamba/

/content/drive/MyDrive/mamba


In [None]:
## !pip install  dgl -f https://data.dgl.ai/wheels/torch-2.1/cu118/repo.html
## !pip install dgl-cu121
## !pip install  dgl -f https://data.dgl.ai/wheels/torch-2.1/cu118/repo.html
!pip install  dgl -f https://data.dgl.ai/wheels/torch-2.2/cu121/repo.html

Looking in links: https://data.dgl.ai/wheels/torch-2.2/cu121/repo.html


In [None]:
!pip install causal-conv1d>=1.2.0
!pip install mamba-ssm



In [None]:
##pip install  dgl -f https://data.dgl.ai/wheels/torch-2.1/cu121/repo.html
# !pip install dgl -f https://data.dgl.ai/wheels/cu121/repo.html.

In [None]:
import os
os.environ["DGLBACKEND"] = "pytorch"
import ssl
import time
from six.moves import urllib
import numpy as np
import pandas as pd
import torch
import dgl
import copy
import argparse
import inspect
from dgl.dataloading import Sampler
from sklearn.metrics import average_precision_score, roc_auc_score

In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=50,help='epochs for training on entire dataset')
parser.add_argument("--num_negative_samples", type=int, default=1,
                        help="number of negative samplers per positive samples")
parser.add_argument("--fast_mode", action="store_true", default=False,
                        help="Fast Mode uses batch temporal sampling, history within same batch cannot be obtained")
parser.add_argument("--n_neighbors", type=int, default=12,
                        help="number of neighbors while doing embedding")
parser.add_argument("--batch_size", type=int,
                        default=1024, help="Size of each batch")
# parser.add_argument("--fast_mode", action="store_true", default=False,
#                         help="Fast Mode uses batch temporal sampling, history within same batch cannot be obtained")

parser.add_argument("--embedding_dim", type=int, default=100,
                        help="Embedding dim for link prediction")
parser.add_argument("--memory_dim", type=int, default=100,
                        help="dimension of memory")
parser.add_argument("--temporal_dim", type=int, default=100,
                        help="Temporal dimension for time encoding")
parser.add_argument("--num_heads", type=int, default=8,
                        help="Number of heads for multihead attention mechanism")
parser.add_argument("--memory_updater", type=str, default='ssm',
                        help="Recurrent unit for memory update")
parser.add_argument("--k_hop", type=int, default=1,
                        help="sampling k-hop neighborhood")
parser.add_argument("--not_use_memory", action="store_true", default=False,
                        help="Enable memory for TGN Model disable memory for TGN Model")
# parser.add_argument("--aggregator", type=str, default='last',
#                         help="Aggregation method for memory update")
args = parser.parse_args(args=[])
# args = parser.parse_args()
# args.epochs = 50
data_file = "wikipedia"

In [None]:
def preprocess(data_name):
    u_list, i_list, ts_list, label_list = [], [], [], []
    feat_l = []
    idx_list = []

    with open(data_name) as f:
        s = next(f)
        for idx, line in enumerate(f):
            e = line.strip().split(',')
            u = int(e[0])
            i = int(e[1])

            ts = float(e[2])
            label = float(e[3])  # int(e[3])

            feat = np.array([float(x) for x in e[4:]])

            u_list.append(u)
            i_list.append(i)
            ts_list.append(ts)
            label_list.append(label)
            idx_list.append(idx)

            feat_l.append(feat)
    return pd.DataFrame({'u': u_list,
                         'i': i_list,
                         'ts': ts_list,
                         'label': label_list,
                         'idx': idx_list}), np.array(feat_l)


In [None]:
def reindex(df, bipartite=True):
    new_df = df.copy()
    if bipartite:
        assert (df.u.max() - df.u.min() + 1 == len(df.u.unique()))
        assert (df.i.max() - df.i.min() + 1 == len(df.i.unique()))

        upper_u = df.u.max() + 1
        new_i = df.i + upper_u

        new_df.i = new_i
        new_df.u += 1
        new_df.i += 1
        new_df.idx += 1
    else:
        new_df.u += 1
        new_df.i += 1
        new_df.idx += 1

    return new_df

In [None]:
def run(data_name, bipartite=True):
    PATH = './data/{}.csv'.format(data_name)
    OUT_DF = './data/ml_{}.csv'.format(data_name)
    OUT_FEAT = './data/ml_{}.npy'.format(data_name)
    OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name)

    df, feat = preprocess(PATH)
    new_df = reindex(df, bipartite)

    empty = np.zeros(feat.shape[1])[np.newaxis, :]
    feat = np.vstack([empty, feat])

    max_idx = max(new_df.u.max(), new_df.i.max())
    rand_feat = np.zeros((max_idx + 1, 172))

    new_df.to_csv(OUT_DF)
    np.save(OUT_FEAT, feat)
    np.save(OUT_NODE_FEAT, rand_feat)

In [None]:
def TemporalDataset(dataset, force_reload = False):
    if force_reload or not os.path.exists('./data/{}.bin'.format(dataset)):
        if not os.path.exists('./data/{}.csv'.format(dataset)):
            if not os.path.exists('./data'):
                os.mkdir('./data')

            url = 'https://snap.stanford.edu/jodie/{}.csv'.format(dataset)
            print("Start Downloading File....")
            context = ssl._create_unverified_context()
            data = urllib.request.urlopen(url, context=context)
            with open("./data/{}.csv".format(dataset), "wb") as handle:
                handle.write(data.read())

        print("Start Process Data ...")
        run(dataset)
        raw_connection = pd.read_csv('./data/ml_{}.csv'.format(dataset))
        raw_feature = np.load('./data/ml_{}.npy'.format(dataset))
        # -1 for re-index the node
        src = raw_connection['u'].to_numpy()-1
        dst = raw_connection['i'].to_numpy()-1
        # Create directed graph
        g = dgl.graph((src, dst))
        g.edata['timestamp'] = torch.from_numpy(
            raw_connection['ts'].to_numpy())
        g.edata['label'] = torch.from_numpy(raw_connection['label'].to_numpy())
        g.edata['feats'] = torch.from_numpy(raw_feature[1:, :]).float()
        g.ndata[dgl.NID] = g.nodes()
        dgl.save_graphs('./data/{}.bin'.format(dataset), [g])
    else:
        print("Data is exist directly loaded.")
        gs, _ = dgl.load_graphs('./data/{}.bin'.format(dataset))
        g = gs[0]
    return g

In [None]:
# data =  TemporalDataset('wikipedia', force_reload=False)
data =  TemporalDataset(data_file, force_reload=True)

Start Process Data ...


In [None]:
# data = data.int()
print(data.ndata[dgl.NID])
print(data.edges())

tensor([   0,    1,    2,  ..., 9224, 9225, 9226])
(tensor([   0,    1,    1,  ..., 2399, 7479, 2399]), tensor([8227, 8228, 8228,  ..., 8722, 9147, 8722]))


In [None]:
# print(data.edges())

In [None]:
# print(data.edata['feats'])

In [None]:
# print(data.edata['timestamp'])

In [None]:
# class WikiDataset(dgl.data.DGLDataset():
    # __init__(self):


In [None]:
num_nodes = data.num_nodes()
num_edges = data.num_edges()
TRAIN_SPLIT = 0.7
VALID_SPLIT = 0.85

# set random Seed
np.random.seed(2021)
torch.manual_seed(2021)

<torch._C.Generator at 0x7908d13e7290>

In [None]:
# print(num_nodes)

In [None]:
trainval_div = int(VALID_SPLIT*num_edges)
# print(trainval_div)

In [None]:
test_split_ts = data.edata['timestamp'][trainval_div]
# print(test_split_ts)

In [None]:
test_nodes = torch.cat([data.edges()[0][trainval_div:], data.edges()[1][trainval_div:]]).unique().numpy()

In [None]:
test_new_nodes = np.random.choice(test_nodes, int(0.1*len(test_nodes)), replace=False)

In [None]:
# print(test_nodes.shape)

In [None]:
# print(test_new_nodes.shape)

In [None]:
in_subg = dgl.in_subgraph(data, test_new_nodes)
out_subg = dgl.out_subgraph(data, test_new_nodes)

In [None]:
# print(len(in_subg.nodes()))
# print(in_subg.edges())

In [None]:
# Remove edge who happen before the test set to prevent from learning the connection info
new_node_in_eid_delete = in_subg.edata[dgl.EID][in_subg.edata['timestamp'] < test_split_ts]
#gets the eids in in_sub_g that has time_stamp lower than
# print(test_split_ts)
# print(in_subg.edata['timestamp'] < test_split_ts)
# print(in_subg.edata[dgl.EID])
# print(new_node_in_eid_delete)
new_node_out_eid_delete = out_subg.edata[dgl.EID][out_subg.edata['timestamp'] < test_split_ts]
#all the inbound and outbound edges that occurred before test_split_ts
new_node_eid_delete = torch.cat([new_node_in_eid_delete, new_node_out_eid_delete]).unique()

In [None]:
graph_new_node = copy.deepcopy(data)
graph_new_node.remove_edges(new_node_eid_delete)
#Here we have removed only the edges that occur before test_split_ts and have node_id
#belonging to test_split new nodes. In this way there will be no
#edge in the train or val split these edges

In [None]:
# print(dgl.NID)

In [None]:
# Now for no new node graph, all edge id need to be removed
in_eid_delete = in_subg.edata[dgl.EID]
out_eid_delete = out_subg.edata[dgl.EID]
eid_delete = torch.cat([in_eid_delete, out_eid_delete]).unique()

graph_no_new_node = copy.deepcopy(data)
graph_no_new_node.remove_edges(eid_delete)
#Here graph_no_new_node has no edges for those selected nodes ib train, test or valid split.

In [None]:
# print(graph_no_new_node.nodes())

In [None]:
# in_nid = in_subg.edata[dgl.NID]
# print(in_subg.edata)

In [None]:
neg_sampler = dgl.dataloading.negative_sampler.Uniform(k=args.num_negative_samples)
g_sampling = None if args.fast_mode else dgl.add_reverse_edges(graph_no_new_node, copy_edata=True)
print(graph_no_new_node.num_edges())
print(g_sampling.num_edges())
new_node_g_sampling = None if args.fast_mode else dgl.add_reverse_edges(graph_new_node, copy_edata=True)
print(graph_new_node.num_edges())
print(new_node_g_sampling.num_edges())
#     if not args.fast_mode:
#         new_node_g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
#         g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()

134309
268618
139206
278412


In [None]:
# Set Train, validation, test and new node test id
train_seed = torch.arange(int(TRAIN_SPLIT*graph_no_new_node.num_edges()))
print(train_seed)
valid_seed = torch.arange(int(TRAIN_SPLIT*graph_no_new_node.num_edges()), trainval_div-new_node_eid_delete.size(0))
print(valid_seed)
test_seed = torch.arange(trainval_div-new_node_eid_delete.size(0), graph_no_new_node.num_edges())
print(test_seed)
test_new_node_seed = torch.arange(trainval_div-new_node_eid_delete.size(0), graph_new_node.num_edges())
print(test_new_node_seed)

tensor([    0,     1,     2,  ..., 94013, 94014, 94015])
tensor([ 94016,  94017,  94018,  ..., 115581, 115582, 115583])
tensor([115584, 115585, 115586,  ..., 134306, 134307, 134308])
tensor([115584, 115585, 115586,  ..., 139203, 139204, 139205])


In [None]:
# print(new_node_g_sampling.nodes())
# print(g_sampling.ndata[dgl.NID] )

In [None]:
from dgl.dataloading import BlockSampler
from functools import partial
class TemporalSampler(BlockSampler):
    """ Temporal Sampler builds computational and temporal dependency of node representations via
    temporal neighbors selection and screening.

    The sampler expects input node to have same time stamps, in the case of TGN, it should be
    either positive [src,dst] pair or negative samples. It will first take in-subgraph of seed
    nodes and then screening out edges which happen after that timestamp. Finally it will sample
    a fixed number of neighbor edges using random or topk sampling.

    Parameters
    ----------
    sampler_type : str
        sampler indication string of the final sampler.

        If 'topk' then sample topk most recent nodes

        If 'uniform' then uniform randomly sample k nodes

    k : int
        maximum number of neighors to sampler

        default 10 neighbors as paper stated

    Examples
    ----------
    Please refers to examples/pytorch/tgn/train.py

    """

    def __init__(self, sampler_type='topk', k=10):
        super(TemporalSampler, self).__init__()
        if sampler_type == 'topk':
            self.sampler = partial(
                dgl.sampling.select_topk, k=k, weight='timestamp')
        elif sampler_type == 'uniform':
            self.sampler = partial(dgl.sampling.sample_neighbors, fanout=k)
        else:
            raise DGLError(
                "Sampler string invalid please use \'topk\' or \'uniform\'")

    def sample(
        self, g, seed_nodes, exclude_eids=None, timestamp = 0,
    ):  # pylint: disable=arguments-differ
        """Sample a list of blocks from the given seed nodes."""
        result = self.sample_blocks(g, seed_nodes, timestamp, exclude_eids=exclude_eids)
        return self.assign_lazy_features(result)

    def sampler_frontier(self,
                         block_id,
                         g,
                         seed_nodes,
                         timestamp):
        full_neighbor_subgraph = dgl.in_subgraph(g, seed_nodes)
        # print("Full NSG1: ", full_neighbor_subgraph)
        # print("seed nodes: ", seed_nodes)

        # Adding self loops? but why?
        full_neighbor_subgraph = dgl.add_edges(full_neighbor_subgraph,
                                               seed_nodes, seed_nodes)
        # print("full_neighbor_subgraph edges: ", full_neighbor_subgraph.edges())
        # print("Full NSG2: ", full_neighbor_subgraph)

        #Remove edges that occurs after timetsamp
        temporal_edge_mask = (full_neighbor_subgraph.edata['timestamp'] < timestamp) + (
            full_neighbor_subgraph.edata['timestamp'] <= 0)
        # print("temporal_edge_mask: ", temporal_edge_mask)
        temporal_subgraph = dgl.edge_subgraph(
            full_neighbor_subgraph, temporal_edge_mask)

        # Map preserve ID
        temp2origin = temporal_subgraph.ndata[dgl.NID]
#         print("seed nodes: ", seed_nodes)
#         print("temp2origin (actual node ids): ", temp2origin)
#         print("temporal_subgraph edges: ", temporal_subgraph.edges())
#         print("temporal subgraph nids: ", temporal_subgraph.ndata[dgl.NID])


        # The added new edgge will be preserved hence
        root2sub_dict = dict(
            zip(temp2origin.tolist(), temporal_subgraph.nodes().tolist()))


        # temporal_subgraph.ndata["_orgID"] = g.ndata[dgl.NID][temp2origin]
        # print("temporal subgraph nids: ", temporal_subgraph.ndata[dgl.NID])
        # print("seed node before updates: ", seed_nodes)
        seed_nodes = [root2sub_dict[int(n)] for n in seed_nodes]

        # print("updated seed nodes: ", seed_nodes)
        # print("")
        final_subgraph = self.sampler(g=temporal_subgraph, nodes=seed_nodes)
        final_subgraph.remove_self_loop()
        # print("final subgraphs: ", final_subgraph)
        # print("Eids:", final_subgraph.edges())
        # print("Nids:", final_subgraph.nodes())
        # print("ndata: ", final_subgraph.ndata)
        src_nodes, dest_nodes =  final_subgraph.edges()
        src_nodes = src_nodes.unique()
        dest_nodes = dest_nodes.unique()
        # print("s: ", src_nodes)
        # print("d: ", dest_nodes)
        block = dgl.transforms.to_block(final_subgraph, seed_nodes)
        block2 = dgl.transforms.to_block(final_subgraph)

#         print("block: ", block)
#         print("block ndata: ", block.ndata)
#         print("block srcdata: ", block.srcdata)
#         print("block dstdata: ", block.dstdata)

        # print("block2: ", block2)
        # print("block2 ndata: ", block2.ndata)
        # print("block2 srcdata: ", block2.srcdata)
        # print("block2 dstdata: ", block2.dstdata)
        return src_nodes, dest_nodes, block #dgl.transforms.to_block(final_subgraph)



    def sampler_frontier_for_Batch(self,
                         block_id,
                         g,
                         seed_nodes,
                         timestamp):
        full_neighbor_subgraph = dgl.in_subgraph(g, seed_nodes)
        # print("Full NSG1: ", full_neighbor_subgraph)
        # print("seed nodes: ", seed_nodes)

        # Adding self loops? but why?
        full_neighbor_subgraph = dgl.add_edges(full_neighbor_subgraph,
                                               seed_nodes, seed_nodes)
        # print("full_neighbor_subgraph edges: ", full_neighbor_subgraph.edges())
        # print("Full NSG2: ", full_neighbor_subgraph)

        #Remove edges that occurs after timetsamp
        temporal_edge_mask = (full_neighbor_subgraph.edata['timestamp'] < timestamp) + (
            full_neighbor_subgraph.edata['timestamp'] <= 0)
        # print("temporal_edge_mask: ", temporal_edge_mask)
        temporal_subgraph = dgl.edge_subgraph(
            full_neighbor_subgraph, temporal_edge_mask)

        # Map preserve ID
        temp2origin = temporal_subgraph.ndata[dgl.NID]
#         print("seed nodes: ", seed_nodes)
#         print("temp2origin (actual node ids): ", temp2origin)
#         print("temporal_subgraph edges: ", temporal_subgraph.edges())
#         print("temporal subgraph nids: ", temporal_subgraph.ndata[dgl.NID])


        # The added new edgge will be preserved hence
        root2sub_dict = dict(
            zip(temp2origin.tolist(), temporal_subgraph.nodes().tolist()))


        temporal_subgraph.ndata["_orgID"] = g.ndata[dgl.NID][temp2origin]
        # print("temporal subgraph nids: ", temporal_subgraph.ndata[dgl.NID])
        # print("seed node before updates: ", seed_nodes)
        seed_nodes = [root2sub_dict[int(n)] for n in seed_nodes]

        # print("updated seed nodes: ", seed_nodes)
        # print("")
        final_subgraph = self.sampler(g=temporal_subgraph, nodes=seed_nodes)
        final_subgraph.remove_self_loop()
        # print("final subgraphs: ", final_subgraph)
        # print("Eids:", final_subgraph.edges())
        # print("Nids:", final_subgraph.nodes())
        # print("ndata: ", final_subgraph.ndata)
        src_nodes, dest_nodes =  final_subgraph.edges()
        src_nodes = src_nodes.unique()
        dest_nodes = dest_nodes.unique()
        # print("s: ", src_nodes)
        # print("d: ", dest_nodes)
        # block = dgl.transforms.to_block(final_subgraph, seed_nodes)
        # block2 = dgl.transforms.to_block(final_subgraph)

#         print("block: ", block)
#         print("block ndata: ", block.ndata)
#         print("block srcdata: ", block.srcdata)
#         print("block dstdata: ", block.dstdata)

        # print("block2: ", block2)
        # print("block2 ndata: ", block2.ndata)
        # print("block2 srcdata: ", block2.srcdata)
        # print("block2 dstdata: ", block2.dstdata)
        return src_nodes, dest_nodes, final_subgraph #dgl.transforms.to_block(final_subgraph)

        # Temporal Subgraph

    def sample_blocks(self,
                      g,
                      seed_nodes, timestamp,
                      exclude_eids=None):
        blocks = []
        s, d, frontier = self.sampler_frontier(0, g, seed_nodes, timestamp)
        #block = transform.to_block(frontier,seed_nodes)
        block = frontier
        # if self.return_eids:
        #     self.assign_block_eids(block, frontier)
        blocks.append(block)
        # print(blocks)
        # # g.srcnodes, g.srcdata, feature_names
        # print("block src nodes: ", block.srcnodes)
        # print("block src data: ", block.srcdata)
        # # print("node feats prefetch: ", self.prefetch_node_feats)
        return (s, d, blocks)
    def sample_blocks_for_Batch(self,
                      g,
                      seed_nodes, timestamp,
                      exclude_eids=None):
        blocks = []
        s, d, frontier = self.sampler_frontier_for_Batch(0, g, seed_nodes, timestamp)
        #block = transform.to_block(frontier,seed_nodes)
        block = frontier
        # if self.return_eids:
        #     self.assign_block_eids(block, frontier)
        blocks.append(block)
        # print(blocks)
        # # g.srcnodes, g.srcdata, feature_names
        # print("block src nodes: ", block.srcnodes)
        # print("block src data: ", block.srcdata)
        # # print("node feats prefetch: ", self.prefetch_node_feats)
        return (s, d, blocks)


In [None]:
from dgl.dataloading import Sampler
from dgl.dataloading import EdgePredictionSampler, set_edge_lazy_features, compact_graphs, Mapping, EID, NID, F, context_of, heterograph, find_exclude_eids

In [None]:
# from dgl.dataloading import Sampler
class TemporalEdgePredictionSampler(Sampler):
    """Sampler class that wraps an existing sampler for node classification into another
    one for edge classification or link prediction.

    See also
    --------
    as_edge_predIt finds all the nodes that have zero in-degree and zero out-degree in all the given graphs, and eliminates them from all the graphs.iction_sampler
    """

    def __init__(
        self,
        sampler,
        exclude=None,
        reverse_eids=None,
        reverse_etypes=None,
        negative_sampler=None,
        prefetch_labels=None,
    ):
        super().__init__()
        # Check if the sampler's sample method has an optional third argument.
        argspec = inspect.getfullargspec(sampler.sample)
        if len(argspec.args) < 4:  # ['self', 'g', 'indices', 'exclude_eids']
            raise TypeError(
                "This sampler does not support edge or link prediction; please add an"
                "optional third argument for edge IDs to exclude in its sample() method."
            )
        self.reverse_eids = reverse_eids
        self.reverse_etypes = reverse_etypes
        self.exclude = exclude
        self.sampler = sampler
        self.negative_sampler = negative_sampler
        self.prefetch_labels = prefetch_labels or []
        self.output_device = sampler.output_device

    def _build_neg_graph(self, g, seed_edges):
        neg_srcdst = self.negative_sampler(g, seed_edges)
        if not isinstance(neg_srcdst, Mapping):
            assert len(g.canonical_etypes) == 1, (
                "graph has multiple or no edge types; "
                "please return a dict in negative sampler."
            )
            neg_srcdst = {g.canonical_etypes[0]: neg_srcdst}

        dtype = F.dtype(list(neg_srcdst.values())[0][0])
        ctx = context_of(seed_edges) if seed_edges is not None else g.device
        neg_edges = {
            etype: neg_srcdst.get(
                etype,
                (
                    F.copy_to(F.tensor([], dtype), ctx=ctx),
                    F.copy_to(F.tensor([], dtype), ctx=ctx),
                ),
            )
            for etype in g.canonical_etypes
        }
        neg_pair_graph = heterograph(
            neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
        )
        return neg_pair_graph

    def assign_lazy_features(self, result):
        """Assign lazy features for prefetching."""
        pair_graph = result[1]
        set_edge_lazy_features(pair_graph, self.prefetch_labels)
        # In-place updates
        return result

    def sample(self, g, seed_edges):  # pylint: disable=arguments-differ
        """Samples a list of blocks, as well as a subgraph containing the sampled
        edges from the original graph.

        If :attr:`negative_sampler` is given, also returns another graph containing the
        negative pairs as edges.
        """
        # print("seed_edges: ", seed_edges)
        # print("time stamp of the seed edges: ", g.edata['timestamp'][seed_edges])
        timestamps = g.edata['timestamp'][seed_edges]
        if isinstance(seed_edges, Mapping):
            seed_edges = {
                g.to_canonical_etype(k): v for k, v in seed_edges.items()
            }
        exclude = self.exclude
        pair_graph = g.edge_subgraph(
            seed_edges, relabel_nodes=False, output_device=self.output_device
        )
        # print("pair graph edges: ",pair_graph.edges())

        eids = pair_graph.edata[EID]
        # print("Pair graph: ", pair_graph)

        if self.negative_sampler is not None:
            neg_graph = self._build_neg_graph(g, seed_edges)
            pair_graph, neg_graph = compact_graphs([pair_graph, neg_graph])
        else:
            pair_graph = compact_graphs(pair_graph)

        pair_graph.edata[EID] = eids
        seed_nodes = pair_graph.ndata[NID]

#         print("Pair graph: ", pair_graph)
#         print("Nodes: ", pair_graph.nodes())
#         print("Ndata: ", pair_graph.ndata)
#         print("Edges: ", pair_graph.edges())
#         print("Edata:", pair_graph.edata)

#         print("Neg graph: ", neg_graph)
#         print("Nodes: ", neg_graph.nodes())
#         print("Ndata: ", neg_graph.ndata)
#         print("Edges: ", neg_graph.edges())
#         print("Edata:", neg_graph.edata)


        exclude_eids = find_exclude_eids(
            g,
            seed_edges,
            exclude,
            self.reverse_eids,
            self.reverse_etypes,
            self.output_device,
        )



        #time_stamp of the first seed edge
        # print("seed edges: ", see)
        # timestamp =
        input_nodes, _, blocks = self.sampler.sample(g, seed_nodes, exclude_eids, timestamp = timestamps[0])

        if self.negative_sampler is None:
            return self.assign_lazy_features((input_nodes, pair_graph, blocks))
        else:
            return self.assign_lazy_features(
                (input_nodes, pair_graph, neg_graph, blocks)
            )



In [None]:
# from dgl.dataloading import Sampler
class BatchedTemporalEdgePredictionSampler(Sampler):
    """Sampler class that wraps an existing sampler for node classification into another
    one for edge classification or link prediction.

    See also
    --------
    as_edge_predIt finds all the nodes that have zero in-degree and zero out-degree in all the given graphs, and eliminates them from all the graphs.iction_sampler
    """

    def __init__(
        self,
        sampler,
        exclude=None,
        reverse_eids=None,
        reverse_etypes=None,
        negative_sampler=None,
        prefetch_labels=None,
    ):
        super().__init__()
        # Check if the sampler's sample method has an optional third argument.
        argspec = inspect.getfullargspec(sampler.sample)
        if len(argspec.args) < 4:  # ['self', 'g', 'indices', 'exclude_eids']
            raise TypeError(
                "This sampler does not support edge or link prediction; please add an"
                "optional third argument for edge IDs to exclude in its sample() method."
            )
        self.reverse_eids = reverse_eids
        self.reverse_etypes = reverse_etypes
        self.exclude = exclude
        self.sampler = sampler
        self.negative_sampler = negative_sampler
        self.prefetch_labels = prefetch_labels or []
        self.output_device = sampler.output_device

    def _build_neg_graph(self, g, seed_edges):
        neg_srcdst = self.negative_sampler(g, seed_edges)
        if not isinstance(neg_srcdst, Mapping):
            assert len(g.canonical_etypes) == 1, (
                "graph has multiple or no edge types; "
                "please return a dict in negative sampler."
            )
            neg_srcdst = {g.canonical_etypes[0]: neg_srcdst}

        dtype = F.dtype(list(neg_srcdst.values())[0][0])
        ctx = context_of(seed_edges) if seed_edges is not None else g.device
        neg_edges = {
            etype: neg_srcdst.get(
                etype,
                (
                    F.copy_to(F.tensor([], dtype), ctx=ctx),
                    F.copy_to(F.tensor([], dtype), ctx=ctx),
                ),
            )
            for etype in g.canonical_etypes
        }
        neg_pair_graph = heterograph(
            neg_edges, {ntype: g.num_nodes(ntype) for ntype in g.ntypes}
        )
        return neg_pair_graph

    def assign_lazy_features(self, result):
        """Assign lazy features for prefetching."""
        pair_graph = result[1]
        set_edge_lazy_features(pair_graph, self.prefetch_labels)
        # In-place updates
        return result

    def sample(self, g, seed_edges):  # pylint: disable=arguments-differ
        """Samples a list of blocks, as well as a subgraph containing the sampled
        edges from the original graph.

        If :attr:`negative_sampler` is given, also returns another graph containing the
        negative pairs as edges.
        """
        # print("seed_edges: ", seed_edges)
        # print("g edges:", g.edges())
        # print("g ndata: ", g.ndata)
        # print("time stamp of the seed edges: ", g.edata['timestamp'][seed_edges])
        timestamps = g.edata['timestamp'][seed_edges]
        if isinstance(seed_edges, Mapping):
            seed_edges = {
                g.to_canonical_etype(k): v for k, v in seed_edges.items()
            }
        exclude = self.exclude
        pair_graph = g.edge_subgraph(
            seed_edges, relabel_nodes=False, output_device=self.output_device
        )
        # print(seed_edges)
        # print("pair graph ndata: ",pair_graph.ndata)
        # print("edges: ", pair_graph.edges())

        eids = pair_graph.edata[EID]
        # print("Pair graph: ", pair_graph)

        if self.negative_sampler is not None:
            neg_graph = self._build_neg_graph(g, seed_edges)
            # print("neg_graph ndata: ", neg_graph.ndata)
            # print("neg_graph edges: ", neg_graph.edges())
            pair_graph, neg_graph = compact_graphs([pair_graph, neg_graph])
            # print("compact neg_graph ndata: ", neg_graph.ndata)

        else:
            pair_graph = compact_graphs(pair_graph)

        # print("compact pair graph ndata: ",pair_graph.ndata)
        # print("edges: ", pair_graph.edges())
        pair_graph.edata[EID] = eids
        seed_nodes = pair_graph.ndata[NID]

#         print("Pair graph: ", pair_graph)
#         print("Nodes: ", pair_graph.nodes())
#         print("Ndata: ", pair_graph.ndata)
#         print("Edges: ", pair_graph.edges())
#         print("Edata:", pair_graph.edata)

#         print("Neg graph: ", neg_graph)
#         print("Nodes: ", neg_graph.nodes())
#         print("Ndata: ", neg_graph.ndata)
#         print("Edges: ", neg_graph.edges())
#         print("Edata:", neg_graph.edata)


        batch_graphs = []
        nodes_id = []
        timestamps = []
        exclude_eids = find_exclude_eids(
            g,
            seed_edges,
            exclude,
            self.reverse_eids,
            self.reverse_etypes,
            self.output_device,
        )
        for i, edge in enumerate(zip(g.edges()[0][seed_edges], g.edges()[1][seed_edges])):
            # ts = pair_graph.edata['timestamp'][i]
            ts = pair_graph.edata['timestamp'][i]
            timestamps.append(ts)
            _, _, subgs = self.sampler.sample_blocks_for_Batch(g,list(edge),timestamp=ts)
            subg = subgs[0]
#             print("subg: ",  subg)
#             print("subg.edata: ", subg.edata)
#             print("subg.ndata: ", subg.ndata)
#             print("subg.srcdata: ", subg.srcdata)
#             print(subg.num_nodes(ntype="_N"))
#             print(subg.nodes(ntype="_N"))
#             print("subg.dstdata: ", subg.dstdata)

            # {'_ID': {'_N': tensor([0, 1])}, '_orgID': {'_N': tensor([1, 2])}})
            # t_dict = {}
            # t_dict["_N"] = ts.repeat(subg.num_nodes(ntype="_N"))
            subg.ndata['timestamp'] = ts.repeat(subg.num_nodes())
            # print("subg: ",  subg)
            # print("subg.ndata: ", subg.ndata)
            # print("subg.edata: ", subg.edata)
            nodes_id.append(subg.srcdata[dgl.NID])
            batch_graphs.append(subg)

        timestamps = torch.tensor(timestamps).repeat_interleave(self.negative_sampler.k)
        # for i, neg_edge in enumerate(zip(neg_srcdst_raw[0].tolist(), neg_srcdst_raw[1].tolist())):
        neg_list = []

        src, dst = neg_graph.edges()
        for i, src_node in enumerate(src):
            neg_list.append([neg_graph.ndata[dgl.NID][src_node], neg_graph.ndata[dgl.NID][dst[i]]])
        for i, neg_edge in enumerate(neg_list):
            ts = timestamps[i]
            # print("neg_edge: ", neg_edge)
            _, _, subgs = self.sampler.sample_blocks_for_Batch(g,
                                                    neg_edge,
                                                    timestamp=ts)
            subg = subgs[0]
            # print("subg ndata: ", subg.ndata)
            # t_dict = {}
            # t_dict["_N"] = ts.repeat(subg.num_nodes(ntype="_N"))
            # subg.ndata['timestamp'] = t_dict#ts.repeat(subg.num_nodes(ntype="_N"))
            subg.ndata['timestamp'] = ts.repeat(subg.num_nodes())
            # subg.ndata['timestamp'] = ts.repeat(subg.num_nodes())
            batch_graphs.append(subg)
        blocks = [dgl.batch(batch_graphs)]
        input_nodes = torch.cat(nodes_id)
        # return input_nodes, pair_graph, neg_pair_graph, blocks


        #time_stamp of the first seed edge
        # print("seed edges: ", see)
        # timestamp =
        # input_nodes, _, blocks = self.sampler.sample(g, seed_nodes, exclude_eids, timestamp = timestamps[0])

        if self.negative_sampler is None:
            return self.assign_lazy_features((input_nodes, pair_graph, blocks))
        else:
            return self.assign_lazy_features((input_nodes, pair_graph, neg_graph, blocks))



In [None]:

class NeighborSampler(Sampler):
    def __init__(self, fanouts):
        super().__init__()
        self.fanouts = fanouts

    # NOTE: There is an additional third argument. For homogeneous graphs,
    #   it is an 1-D tensor of integer IDs. For heterogeneous graphs, it
    #   is a dictionary of ID tensors. We usually set its default value to be None.
    def sample(self, g, seed_nodes, exclude_eids=None):
        output_nodes = seed_nodes
        subgs = []
        for fanout in reversed(self.fanouts):
            # Sample a fixed number of neighbors of the current seed nodes.
            sg = g.sample_neighbors(seed_nodes, fanout, exclude_edges=exclude_eids)
            # Convert this subgraph to a message flow graph.
            sg = dgl.to_block(sg, seed_nodes)
            seed_nodes = sg.srcdata[NID]
            subgs.insert(0, sg)
        input_nodes = seed_nodes
        return input_nodes, output_nodes, subgs

In [None]:
def _get_device():
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
    device = torch.device(device)
    if device.type == 'cuda' and device.index is None:
        device = torch.device('cuda', torch.cuda.current_device())
    return device

In [None]:
device = _get_device()
print(device)

cuda:0


In [None]:
temporal_sampler = TemporalSampler(k=args.n_neighbors)
neg_sampler = dgl.dataloading.negative_sampler.Uniform(k=args.num_negative_samples)
temporal_edge_sampler = BatchedTemporalEdgePredictionSampler(temporal_sampler,  negative_sampler=neg_sampler)
# sampler = NeighborSampler(args.n_neighbors)
# edge_collator = TemporalEdgeCollator

In [None]:
src, dst = graph_no_new_node.edges()
print(src[0:8], dst[0:8])
print(graph_no_new_node.edata['timestamp'][0:8])

tensor([0, 1, 1, 2, 1, 2, 1, 4]) tensor([8227, 8228, 8228, 8229, 8228, 8229, 8228, 8231])
tensor([  0.,  36.,  77., 131., 150., 153., 217., 218.], dtype=torch.float64)


In [None]:
# # sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
# dataloader = dgl.dataloading.DataLoader(
#     graph_no_new_node, train_seed, temporal_edge_sampler,
#     batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=0, device=device)#collate_fn = edge_collator

In [None]:
# sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
train_dataloader = dgl.dataloading.DataLoader(
    graph_no_new_node, train_seed, temporal_edge_sampler,
    batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=0, device=device)#collate_fn = edge_collator

# sampler = dgl.dataloading.MultiLayerNeighborSampler([15, 10, 5])
valid_dataloader = dgl.dataloading.DataLoader(
    graph_no_new_node, valid_seed, temporal_edge_sampler,
    batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=0, device=device)#collate_fn = edge_collator

test_new_node_dataloader = dgl.dataloading.DataLoader(
    graph_new_node,test_new_node_seed, temporal_edge_sampler,
    batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=0, device=device)#collate_fn = edge_collator

test_dataloader = dgl.dataloading.DataLoader(
    graph_no_new_node, test_seed, temporal_edge_sampler,
    batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=0, device=device)#collate_fn = edge_collator


In [None]:
# with dataloader.enable_cpu_affinity():
# input_nodes, pair_graph, neg_graph, blocks = example_minibatch = next(
#     iter(dataloader)
# )
# with dataloader.enable_cpu_affinity():
#     input_nodes, pair_graph, neg_graph, blocks = example_minibatch = next(
#         iter(dataloader)
#     )
# # print("mini_batch: ", example_minibatch)
# print("input_nodes: ", input_nodes)
# print("pair graph: ", pair_graph)
# print("Pair Graph NID: ", pair_graph.ndata[dgl.NID])
# print("Pair Graph EID: ", pair_graph.edata[dgl.EID])
# # print(neg_graph.ndata[dgl.NID])
# # print(neg_graph.edata[dgl.EID])
# print(blocks[0].srcdata[dgl.NID])
# print(blocks[0].dstdata[dgl.NID])
# print(blocks[0].ndata[dgl.NID])

In [None]:
# for _, positive_pair_g, negative_pair_g, blocks in dataloader:
#     block = blocks[0]
#     print(block)
#     print(block.edata)
#     print(block.ndata)
#     print(block.edges())


In [None]:
edge_dim = data.edata['feats'].shape[1]
num_node = data.num_nodes()

In [None]:
from modules import MemoryModule, MemoryOperation, MsgLinkPredictor, TemporalTransformerConv, TimeEncode
import torch.nn as nn

In [None]:
class TGN(nn.Module):
    def __init__(self,
                 edge_feat_dim,
                 memory_dim,
                 temporal_dim,
                 embedding_dim,
                 num_heads,
                 num_nodes,
                 n_neighbors=10,
                 memory_updater_type='gru',
                 mem_device = None,
                 layers=1):
        super(TGN, self).__init__()
        self.memory_dim = memory_dim
        self.edge_feat_dim = edge_feat_dim
        self.temporal_dim = temporal_dim
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.n_neighbors = n_neighbors
        self.memory_updater_type = memory_updater_type
        self.num_nodes = num_nodes
        self.layers = layers

        self.temporal_encoder = TimeEncode(self.temporal_dim)

        self.memory = MemoryModule(self.num_nodes,
                                   self.memory_dim, mem_device=mem_device)

        self.memory_ops = MemoryOperation(self.memory_updater_type,
                                          self.memory,
                                          self.edge_feat_dim,
                                          self.temporal_encoder)

        self.embedding_attn = TemporalTransformerConv(self.edge_feat_dim,
                                                      self.memory_dim,
                                                      self.temporal_encoder,
                                                      self.embedding_dim,
                                                      self.num_heads,
                                                      layers=self.layers,
                                                      allow_zero_in_degree=True)

        self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)

    def embed(self, postive_graph, negative_graph, blocks):
        emb_graph = blocks[0]
        # nids = emb_graph.ndata["_orgID"]["_N"]

        # print("emb_graph ndata: ", emb_graph.ndata)
        # print("postive graph ndata: ", postive_graph.ndata)
        # print("negative graph ndata: ", negative_graph.ndata)
        # print("_ID of nodes: ", emb_graph.ndata[dgl.NID])
        # print("_orgID of nodes: ", nids)
        emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :]
        # emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :]
        # emb_memory = self.memory.memory[nids, :]
        emb_t = emb_graph.ndata['timestamp']
        embedding = self.embedding_attn(emb_graph, emb_memory, emb_t)
        # emb2pred = dict(
        #     zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist()))
        emb2pred = dict(
            zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist()))
        # Since postive graph and negative graph has same is mapping
        # feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]]
        feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]]
        feat = embedding[feat_id]
        pred_pos, pred_neg = self.msg_linkpredictor(
            feat, postive_graph, negative_graph)
        return pred_pos, pred_neg

    def update_memory(self, subg):
        new_g = self.memory_ops(subg)
        self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata['memory'])
        self.memory.set_last_update_t(
            new_g.ndata[dgl.NID], new_g.ndata['timestamp'])

    # Some memory operation wrappers
    def detach_memory(self):
        self.memory.detach_memory()

    def reset_memory(self):
        self.memory.reset_memory()

    def store_memory(self):
        memory_checkpoint = {}
        memory_checkpoint['memory'] = copy.deepcopy(self.memory.memory)
        memory_checkpoint['last_t'] = copy.deepcopy(self.memory.last_update_t)
        return memory_checkpoint

    def restore_memory(self, memory_checkpoint):
        self.memory.memory = memory_checkpoint['memory']
        self.memory.last_update_time = memory_checkpoint['last_t']

In [None]:
# from torch.cuda.amp import autocast, GradScaler
# scaler = GradScaler()

In [None]:
def train(model, dataloader, sampler, criterion, optimizer, args):
    model.train()
    total_loss = 0
    batch_cnt = 0
    last_t = time.time()
    data_loader_stime = time.time()
    dl_tt = 0
    tr_tt = 0
    for _, positive_pair_g, negative_pair_g, blocks in dataloader:
        # data_loader_etime = time.time()
        dt = time.time()-data_loader_stime
        # print("Data loader Batch: ", batch_cnt, "Time: ", dt, " dl_total: ", dl_tt, " tr_total: ", tr_tt)
        dl_tt = dl_tt + dt
        optimizer.zero_grad()
        # with autocast():
        pred_pos, pred_neg = model.embed(
              positive_pair_g, negative_pair_g, blocks)
        loss = criterion(pred_pos, torch.ones_like(pred_pos))
        loss += criterion(pred_neg, torch.zeros_like(pred_neg))
        lval = float(loss)*args.batch_size

        total_loss += lval
        retain_graph = True if batch_cnt == 0 and not args.fast_mode else False
        # scaler.scale(loss).backward(retain_graph=retain_graph)
        # scaler.step(optimizer)
        # scaler.update()
        loss.backward(retain_graph=retain_graph)
        optimizer.step()
        model.detach_memory()
        if not args.not_use_memory:
          # with autocast():
          model.update_memory(positive_pair_g)
            # if args.fast_mode:
            #     sampler.attach_last_update(model.memory.last_update_t)

        dt = time.time()-last_t
        tr_tt  = tr_tt + dt
        print("Data loader Batch: ", batch_cnt, "Time: ", dt, " dl_total: ", dl_tt, " tr_total: ", tr_tt, " loss: ", lval)
        # print("Batch: ", batch_cnt, "Time: ", dt)

        last_t = time.time()
        batch_cnt += 1
        data_loader_stime = time.time()
    return total_loss, dl_tt, tr_tt


def test_val(model, dataloader, sampler, criterion, args):
    model.eval()
    batch_size = args.batch_size
    total_loss = 0
    aps, aucs = [], []
    batch_cnt = 0
    with torch.no_grad():
        for _, postive_pair_g, negative_pair_g, blocks in dataloader:
            pred_pos, pred_neg = model.embed(
                postive_pair_g, negative_pair_g, blocks)
            loss = criterion(pred_pos, torch.ones_like(pred_pos))
            loss += criterion(pred_neg, torch.zeros_like(pred_neg))
            total_loss += float(loss)*batch_size
            y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
            y_true = torch.cat(
                [torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
            if not args.not_use_memory:
                model.update_memory(postive_pair_g)
            if args.fast_mode:
                sampler.attach_last_update(model.memory.last_update_t)
            is_bad = torch.logical_or(torch.abs(y_pred) > 10, torch.isnan(y_pred))
            y_pred[is_bad] = 0
            aps.append(average_precision_score(y_true, y_pred))
            aucs.append(roc_auc_score(y_true, y_pred))
            batch_cnt += 1
            # print("Batch: ", batch_cnt, "AP: ", aps[-1], "AUC: ", aucs[-1])

    return float(torch.tensor(aps).mean()), float(torch.tensor(aucs).mean())



In [None]:
edge_dim = data.edata['feats'].shape[1]
num_node = data.num_nodes()
model = TGN(edge_feat_dim=edge_dim,
                memory_dim=args.memory_dim,
                temporal_dim=args.temporal_dim,
                embedding_dim=args.embedding_dim,
                num_heads=args.num_heads,
                num_nodes=num_node,
                n_neighbors=args.n_neighbors,
                memory_updater_type=args.memory_updater,
                mem_device = device,
                layers=args.k_hop)

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
f = open("logging.txt", 'w')
# if args.fast_mode:
#     sampler.reset()
# model.to(device)
model = model.to(device=device)


In [None]:
print(model.memory.memory)
# next(model.memory.memory.parameters()).is_cuda


Parameter containing:
tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')


In [None]:
sampler = temporal_edge_sampler
print("At the beginning")
i = -1
# nn_test_ap, nn_test_auc = test_val(
#                 model, test_new_node_dataloader, sampler, criterion, args)
# test_ap, test_auc = test_val(
#                 model, test_dataloader, sampler, criterion, args)
# print("Epoch: {}; Test AP: {:.3f} AUC: {:.3f}\n".format(i, test_ap, test_auc))
# print("Epoch: {}; Test New Node AP: {:.3f} AUC: {:.3f}\n".format(
#                 i, nn_test_ap, nn_test_auc))

try:
        for i in range(args.epochs):
            train_loss, dl_tt, tr_tt = train(model, train_dataloader, sampler,
                               criterion, optimizer, args)
            # print("total")
            val_ap, val_auc = test_val(
                model, valid_dataloader, sampler, criterion, args)
            memory_checkpoint = model.store_memory()
            # print(memory_checkpoint)
            # if args.fast_mode:
            #     new_node_sampler.sync(sampler)
            test_ap, test_auc = test_val(
                model, test_dataloader, sampler, criterion, args)
            model.restore_memory(memory_checkpoint)
            # print("after restoring: ", model.memory.memory)
            # if args.fast_mode:
            #     sample_nn = new_node_sampler
            # else:
            sample_nn = sampler
            nn_test_ap, nn_test_auc = test_val(
                model, test_new_node_dataloader, sample_nn, criterion, args)
            log_content = []
            log_content.append("Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format(
                i, train_loss, val_ap, val_auc))
            log_content.append(
                "Epoch: {}; Test AP: {:.3f} AUC: {:.3f}\n".format(i, test_ap, test_auc))
            log_content.append("Epoch: {}; Test New Node AP: {:.3f} AUC: {:.3f}\n".format(
                i, nn_test_ap, nn_test_auc))
            log_content.append("total time: dataloading: {} and overall: {}\n".format(dl_tt, tr_tt))
            f.writelines(log_content)
            # print("before reset: ", model.memory.memory)
            model.reset_memory()
            # print("after reset: ", model.memory.memory)
            if i < args.epochs-1 and args.fast_mode:
                sampler.reset()
            print(log_content[0], log_content[1], log_content[2])
except KeyboardInterrupt:
        traceback.print_exc()
        error_content = "Training Interreputed!"
        f.writelines(error_content)
        f.close()
print("========Training is Done========")

At the beginning
Data loader Batch:  0 Time:  6.376658916473389  dl_total:  6.228303670883179  tr_total:  6.376658916473389  loss:  1687.0816650390625


  if p.grad is not None:
  if p.grad is not None:


Data loader Batch:  1 Time:  6.330328941345215  dl_total:  12.382943153381348  tr_total:  12.706987857818604  loss:  1507.561279296875
Data loader Batch:  2 Time:  6.643822908401489  dl_total:  18.822046041488647  tr_total:  19.350810766220093  loss:  1442.23486328125
Data loader Batch:  3 Time:  6.173504590988159  dl_total:  24.84015965461731  tr_total:  25.524315357208252  loss:  1487.2637939453125
Data loader Batch:  4 Time:  6.240732431411743  dl_total:  30.92544412612915  tr_total:  31.765047788619995  loss:  1664.20263671875
Data loader Batch:  5 Time:  6.165156841278076  dl_total:  36.93287110328674  tr_total:  37.93020462989807  loss:  1529.678955078125
Data loader Batch:  6 Time:  6.2133190631866455  dl_total:  42.98992562294006  tr_total:  44.14352369308472  loss:  1628.1861572265625
Data loader Batch:  7 Time:  6.214906215667725  dl_total:  49.04952335357666  tr_total:  50.35842990875244  loss:  1540.064208984375
Data loader Batch:  8 Time:  6.281335115432739  dl_total:  55.

  if p.grad is not None:
  if p.grad is not None:


Data loader Batch:  0 Time:  6.260049343109131  dl_total:  6.116761207580566  tr_total:  6.260049343109131  loss:  1970.0810546875
Data loader Batch:  1 Time:  5.9008800983428955  dl_total:  11.868306636810303  tr_total:  12.160929441452026  loss:  1502.07177734375
Data loader Batch:  2 Time:  6.278243064880371  dl_total:  17.99577498435974  tr_total:  18.439172506332397  loss:  1483.22314453125
Data loader Batch:  3 Time:  6.203287124633789  dl_total:  24.04429602622986  tr_total:  24.642459630966187  loss:  1456.50927734375
Data loader Batch:  4 Time:  6.299283742904663  dl_total:  30.151549339294434  tr_total:  30.94174337387085  loss:  1526.59228515625
Data loader Batch:  5 Time:  6.225048542022705  dl_total:  36.18826103210449  tr_total:  37.166791915893555  loss:  1460.021240234375
Data loader Batch:  6 Time:  6.250863552093506  dl_total:  42.275803327560425  tr_total:  43.41765546798706  loss:  1550.279541015625
Data loader Batch:  7 Time:  6.1586503982543945  dl_total:  48.2713

  if p.grad is not None:
  if p.grad is not None:


Data loader Batch:  0 Time:  6.244021654129028  dl_total:  6.102843523025513  tr_total:  6.244021654129028  loss:  1758.8250732421875
Data loader Batch:  1 Time:  6.282165765762329  dl_total:  12.049316644668579  tr_total:  12.526187419891357  loss:  1442.7701416015625
Data loader Batch:  2 Time:  6.1607019901275635  dl_total:  17.88352918624878  tr_total:  18.68688941001892  loss:  1438.3907470703125
Data loader Batch:  3 Time:  5.917953252792358  dl_total:  23.6495680809021  tr_total:  24.60484266281128  loss:  1456.6630859375
Data loader Batch:  4 Time:  6.145230054855347  dl_total:  29.643171787261963  tr_total:  30.750072717666626  loss:  1493.844970703125
Data loader Batch:  5 Time:  6.154524564743042  dl_total:  35.64248251914978  tr_total:  36.90459728240967  loss:  1418.40234375
Data loader Batch:  6 Time:  6.175535678863525  dl_total:  41.66331696510315  tr_total:  43.08013296127319  loss:  1433.293701171875
Data loader Batch:  7 Time:  6.113873481750488  dl_total:  47.612486

  if p.grad is not None:
  if p.grad is not None:


Data loader Batch:  0 Time:  6.301778554916382  dl_total:  6.1577980518341064  tr_total:  6.301778554916382  loss:  1450.47412109375
Data loader Batch:  1 Time:  6.182621002197266  dl_total:  12.192085027694702  tr_total:  12.484399557113647  loss:  1451.831298828125
Data loader Batch:  2 Time:  6.225046873092651  dl_total:  18.268280267715454  tr_total:  18.7094464302063  loss:  1354.2354736328125
Data loader Batch:  3 Time:  6.182330846786499  dl_total:  24.287883281707764  tr_total:  24.891777276992798  loss:  1364.15869140625
Data loader Batch:  4 Time:  6.012247562408447  dl_total:  30.14881205558777  tr_total:  30.904024839401245  loss:  1436.343017578125
Data loader Batch:  5 Time:  6.219927787780762  dl_total:  36.207263469696045  tr_total:  37.12395262718201  loss:  1426.390869140625
Data loader Batch:  6 Time:  6.252297639846802  dl_total:  42.294633626937866  tr_total:  43.37625026702881  loss:  1431.716796875
Data loader Batch:  7 Time:  6.247730255126953  dl_total:  48.380

  if p.grad is not None:
  if p.grad is not None:


Data loader Batch:  0 Time:  6.0008721351623535  dl_total:  5.856108903884888  tr_total:  6.0008721351623535  loss:  1562.5166015625
Data loader Batch:  1 Time:  6.208112478256226  dl_total:  11.903465986251831  tr_total:  12.208984613418579  loss:  1244.8040771484375
Data loader Batch:  2 Time:  6.148840427398682  dl_total:  17.889243602752686  tr_total:  18.35782504081726  loss:  1235.9228515625
Data loader Batch:  3 Time:  6.197173357009888  dl_total:  23.93748140335083  tr_total:  24.55499839782715  loss:  1368.166748046875
Data loader Batch:  4 Time:  6.074987888336182  dl_total:  29.85518765449524  tr_total:  30.62998628616333  loss:  1424.8466796875
Data loader Batch:  5 Time:  6.201061964035034  dl_total:  35.90401554107666  tr_total:  36.831048250198364  loss:  1318.6396484375
Data loader Batch:  6 Time:  6.198243141174316  dl_total:  41.87521934509277  tr_total:  43.02929139137268  loss:  1376.8466796875
Data loader Batch:  7 Time:  6.268340587615967  dl_total:  47.9869740009

  if p.grad is not None:
  if p.grad is not None:


Data loader Batch:  0 Time:  6.308718919754028  dl_total:  6.139158248901367  tr_total:  6.308718919754028  loss:  1450.4423828125
Data loader Batch:  1 Time:  6.259567499160767  dl_total:  12.242077589035034  tr_total:  12.568286418914795  loss:  1340.667724609375
Data loader Batch:  2 Time:  5.983293533325195  dl_total:  18.07577085494995  tr_total:  18.55157995223999  loss:  1263.19775390625
Data loader Batch:  3 Time:  6.212126731872559  dl_total:  24.13049602508545  tr_total:  24.76370668411255  loss:  1290.6575927734375
Data loader Batch:  4 Time:  6.152943849563599  dl_total:  30.126096725463867  tr_total:  30.916650533676147  loss:  1298.635009765625
Data loader Batch:  5 Time:  6.135547876358032  dl_total:  36.11189675331116  tr_total:  37.05219841003418  loss:  1277.7816162109375
Data loader Batch:  6 Time:  6.127482652664185  dl_total:  42.0844612121582  tr_total:  43.179681062698364  loss:  1281.990966796875
Data loader Batch:  7 Time:  6.2112531661987305  dl_total:  48.139

  if p.grad is not None:
  if p.grad is not None:


Data loader Batch:  0 Time:  6.2715044021606445  dl_total:  6.119182109832764  tr_total:  6.2715044021606445  loss:  1403.3604736328125
Data loader Batch:  1 Time:  6.158106803894043  dl_total:  12.118695735931396  tr_total:  12.429611206054688  loss:  1244.132080078125
Data loader Batch:  2 Time:  6.174557447433472  dl_total:  18.141486167907715  tr_total:  18.60416865348816  loss:  1171.12646484375
Data loader Batch:  3 Time:  6.113211631774902  dl_total:  24.08805227279663  tr_total:  24.71738028526306  loss:  1331.2384033203125
Data loader Batch:  4 Time:  6.213238000869751  dl_total:  30.151872873306274  tr_total:  30.930618286132812  loss:  1277.6956787109375
Data loader Batch:  5 Time:  6.226851940155029  dl_total:  36.21554446220398  tr_total:  37.15747022628784  loss:  1300.624267578125
Data loader Batch:  6 Time:  6.243377208709717  dl_total:  42.30465602874756  tr_total:  43.40084743499756  loss:  1261.029296875
Data loader Batch:  7 Time:  6.127832412719727  dl_total:  48.1

NameError: name 'traceback' is not defined

In [None]:
p, q = graph_no_new_node.edges()
print(p[:8])
print(q[:8])
print(graph_no_new_node.ndata['_ID'])
