In [15]:
import torch
import torch.nn as nn
from torch.nn import init
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler

import numpy as np
from numpy.random import choice

import random
import time
import argparse
import os


from scipy.stats import rankdata
from scipy.sparse import lil_matrix, find

### 构建模型

In [4]:
class Decayer(nn.Module):
    def __init__(self, device, w, decay_method='exp'):
        super(Decayer,self).__init__()
        self.decay_method = decay_method
        self.linear = nn.Linear(1,1,False).to(device)
        self.w = w

    def exponetial_decay(self, delta_t):
        return torch.exp(-self.w*delta_t)
    
    def log_decay(self, delta_t):
        return 1/torch.log(2.7183 + self.w*delta_t)
    
    def rev_decay(self, delta_t):
        return 1/(1 + self.w*delta_t)

    def forward(self,delta_t):
        if self.decay_method == 'exp':
            return self.exponetial_decay(delta_t)
        elif self.decay_method == 'log':
            return self.log_decay(delta_t)
        elif self.decay_method == 'rev':
            return self.rev_decay(delta_t)
        else:
            return self.exponetial_decay(delta_t)

In [2]:
class Combiner(nn.Module):
    def __init__(self, input_size, output_size,act, bias = True ):
        super(Combiner,self).__init__()
        self.h2o = nn.Linear(input_size,output_size,bias)
        self.l2o = nn.Linear(input_size,output_size,bias)
        if act == 'tanh':
            self.act = nn.Tanh()
        elif act == 'sigmoid':
            self.act = nn.Sigmoid()
        else:
            self.act = nn.ReLU() 

    def forward(self, head_info, tail_info):
        node_output = self.h2o(head_info) + self.l2o(tail_info)
        node_output_tanh = self.act(node_output)
        return node_output_tanh

In [3]:
class Edge_updater_nn(nn.Module):
    def __init__(self, node_input_size, output_size , act = 'tanh',relation_input_size = None, bias = True):
        super(Edge_updater_nn,self).__init__()
        self.h2o = nn.Linear(node_input_size,output_size,bias)
        self.l2o = nn.Linear(node_input_size,output_size,bias)
        if relation_input_size is not None:
            self.r2o = nn.Linear(relation_input_size,output_size,bias)
        if act == 'tanh':
            self.act = nn.Tanh()
        elif act == 'sigmoid':
            self.act = nn.Sigmoid()
        else:
            self.act = nn.ReLU() 

    def forward(self, head_node, tail_node, relation=None):

        if relation is None:
            edge_output = self.h2o(head_node) + self.l2o(tail_node)
        else:
            edge_output = self.h2o(head_node) + self.l2o(tail_node) + self.r2o(relation)
        edge_output_act = self.act(edge_output)
        return edge_output_act

In [5]:
class TLSTM(nn.Module):
    def __init__(self,input_size, hidden_size,  bias = True):
        super(TLSTM,self).__init__()
        self.i2h = nn.Linear(input_size, 4*hidden_size, bias)
        self.h2h = nn.Linear(hidden_size, 4*hidden_size, bias)
        self.c2s = nn.Sequential(nn.Linear(hidden_size, hidden_size, bias), nn.Tanh())
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()

    def forward(self,input, cell, hidden, transed_delta_t):
        cell_short = self.c2s(cell)
        cell_new = cell - cell_short + cell_short* transed_delta_t 
        gates = self.i2h(input) + self.h2h(hidden)
        ingate, forgate, cellgate, outgate = gates.chunk(4,1)
        ingate = self.sigmoid(ingate)
        forgate = self.sigmoid(forgate)
        cellgate = self.tanh(cellgate)
        outgate = self.sigmoid(outgate)
        cell_output = forgate*cell_new + ingate*cellgate
        hidden_output = outgate*self.tanh(cell_output) 
        
        return cell_output, hidden_output

In [7]:
class Attention(nn.Module):
    def __init__(self, embedding_dims):
        super(Attention,self).__init__()
        self.bilinear = nn.Bilinear(embedding_dims,embedding_dims,1)
        self.softmax = nn.Softmax(0)

    def forward(self,node1, node2):
        return self.softmax( self.bilinear(node1, node2).view(-1,1) )

In [12]:
class DyGNN(nn.Module):
    def __init__(self, num_embeddings, embedding_dims, edge_output_size, device, w, is_att=False, transfer=False, nor=0, if_no_time=0, threhold=None, second_order=False, if_updated=0, drop_p=0, num_negative=5, act='tanh', if_propagation=1, decay_method='exp', weight=None, relation_size=None, bias=True):
        super(DyGNN, self).__init__()
        self.embedding_dims = embedding_dims
        self.num_embeddings = num_embeddings
        self.nor = nor
        
        # self.weight = weight.to(device)
        self.device = device
        self.transfer = transfer
        self.if_propagation = if_propagation
        self.if_no_time = if_no_time
        self.second_order = second_order
        
        # self.cuda = cuda
        self.combiner = Combiner(embedding_dims, embedding_dims, act).to(device)
        self.decay_method = decay_method
        self.if_updated = if_updated
        self.threhold = threhold
        
        print('Only propagate to relevance nodes below time interval: ', threhold)
        
        # self.tanh = nn.Tanh().to(device)
        if act == 'tanh':
            self.act = nn.Tanh().to(device)
        elif act == 'sigmoid':
            self.act = nn.Sigmoid().to(device)
        else:
            self.act = nn.ReLU().to(device) 
            
        self.decayer = Decayer(device, w, decay_method)
        self.edge_updater_head = Edge_updater_nn(embedding_dims, edge_output_size,act, relation_size).to(device)
        self.edge_updater_tail = Edge_updater_nn(embedding_dims, edge_output_size,act, relation_size).to(device)

        if if_no_time:
            self.node_updater_head = nn.LSTMCell(edge_output_size, embedding_dims, bias).to(device)
            self.node_updater_tail = nn.LSTMCell(edge_output_size, embedding_dims, bias).to(device) 
        else:
            self.node_updater_head = TLSTM(edge_output_size, embedding_dims).to(device)
            self.node_updater_tail = TLSTM(edge_output_size, embedding_dims).to(device)	

        self.tran_head_edge_head = nn.Linear(edge_output_size, embedding_dims, bias).to(device)
        self.tran_head_edge_tail = nn.Linear(edge_output_size, embedding_dims, bias).to(device)	
        self.tran_tail_edge_head = nn.Linear(edge_output_size, embedding_dims, bias).to(device)
        self.tran_tail_edge_tail = nn.Linear(edge_output_size, embedding_dims, bias).to(device)
        
        self.is_att = is_att
        if self.is_att:
            self.attention = Attention(embedding_dims).to(device)

        self.num_negative = num_negative

        self.recent_timestamp = torch.zeros((num_embeddings, 1), dtype = torch.float, requires_grad = False).to(device)

        self.interaction_timestamp = lil_matrix((num_embeddings,num_embeddings),dtype = np.float32)
        
        self.cell_head = nn.Embedding(num_embeddings, embedding_dims, weight).to(device)
        self.cell_head.weight.requires_grad = False
        self.cell_tail = nn.Embedding(num_embeddings, embedding_dims, weight).to(device)
        self.cell_tail.weight.requires_grad = False
        
        self.hidden_head = nn.Embedding(num_embeddings, embedding_dims, weight).to(device)
        self.hidden_head.weight.requires_grad = False
        self.hidden_tail = nn.Embedding(num_embeddings, embedding_dims, weight).to(device)
        self.hidden_tail.weight.requires_grad = False
        
        self.node_representations = nn.Embedding(num_embeddings, embedding_dims, weight).to(device)
        self.node_representations.weight.requires_grad = False

        if transfer:
            self.transfer2head = nn.Linear(embedding_dims, embedding_dims, False).to(device)
            self.transfer2tail = nn.Linear(embedding_dims, embedding_dims, False).to(device)
            if drop_p>=0:
                self.dropout = nn.Dropout(p=drop_p).to(device)
                
        self.cell_head_copy = nn.Embedding.from_pretrained(self.cell_head.weight.clone()).to(device)
        self.cell_tail_copy = nn.Embedding.from_pretrained(self.cell_tail.weight.clone()).to(device)
        self.hidden_head_copy = nn.Embedding.from_pretrained(self.hidden_head.weight.clone()).to(device)
        self.hidden_tail_copy = nn.Embedding.from_pretrained(self.hidden_tail.weight.clone()).to(device)
        self.node_representations_copy = nn.Embedding.from_pretrained(self.node_representations.weight.clone()).to(device)

        # if cuda:
        #     self.cell_head = self.cell_head.cuda()
        #     self.cell_tail = self.cell_tail.cuda()
        #     self.node_representations = self.node_representations.cuda()
        #     self.recent_timestamp = self.recent_timestamp.cuda()
        #     self.tran_head_edge_head.cuda()
        #     self.tran_head_edge_head.cuda()
        #     self.tran_tail_edge_head.cuda()
        #     self.tran_tail_edge_tail.cuda()
        
    def reset_time(self):
        self.recent_timestamp = torch.zeros((self.num_embeddings, 1), dtype = torch.float, requires_grad = False).to(self.device)
        self.interaction_timestamp = lil_matrix((self.num_embeddings,self.num_embeddings),dtype = np.float32)
    
    def reset_reps(self):
        self.cell_head = nn.Embedding.from_pretrained(self.cell_head_copy.weight.clone()).to(self.device)
        self.cell_tail = nn.Embedding.from_pretrained(self.cell_tail_copy.weight.clone()).to(self.device)
        self.hidden_head = nn.Embedding.from_pretrained(self.hidden_head_copy.weight.clone()).to(self.device)
        self.hidden_tail = nn.Embedding.from_pretrained(self.hidden_tail_copy.weight.clone()).to(self.device)
        self.node_representations = nn.Embedding.from_pretrained(self.node_representations_copy.weight.clone()).to(self.device)
        
    def link_pred_with_update(self,test_data):
        pass
    
    
    def forward(self, interactions):
        test_time = False
        
        all_head_nodes = set()
        all_tail_nodes = set()
        
        steps = len(interactions[:,0])

        node2timetsamp = dict()
        node2cell_head = dict()
        node2cell_tail = dict()
        node2hidden_head = dict()
        node2hidden_tail = dict()
        node2rep = dict()
        
        output_rep_head = []
        output_rep_tail = []
        tail_neg_list = []
        head_neg_list = []
        
        if test_time:
            old_time = time.time()
        for i in range(steps):
            i_condi = i%200 == 1
            if test_time and i_condi:
                time1 = time.time() 
                print('----------------------------------------------------')
                print(i,'1 step time', str(time1 - old_time) )
                old_time = time1
                
            head_index = int(interactions[i,0])
            tail_index = int(interactions[i,1])
            all_head_nodes.add(head_index)
            all_tail_nodes.add(tail_index)
            
            head_inx_lt = torch.LongTensor([head_index]).to(self.device)
            tail_inx_lt = torch.LongTensor([tail_index]).to(self.device)
            
            timestamp = interactions[i,2]
            current_t = torch.FloatTensor([timestamp]).view(-1,1).to(self.device)

            head_prev_t = self.recent_timestamp[head_index]
            tail_prev_t = self.recent_timestamp[tail_index]

            if test_time and i_condi:
                time2 = time.time()
                print('test_point2', str(time2-time1))

            if head_index in node2rep:
                head_node_rep = node2rep[head_index]
            else:
                head_node_rep = self.node_representations(head_inx_lt)

            if tail_index in node2rep:
                tail_node_rep = node2rep[tail_index]
            else:
                tail_node_rep = self.node_representations(tail_inx_lt)
                
                
            if head_index in node2hidden_head:
                head_node_cell_head = node2cell_head[head_index]
                head_node_hidden_head = node2hidden_head[head_index]
            else:
                head_node_cell_head = self.cell_head(head_inx_lt)
                head_node_hidden_head = self.hidden_head(head_inx_lt)
            if head_index in node2hidden_tail:
                head_node_hidden_tail = node2hidden_tail[head_index]
            else:
                head_node_hidden_tail = self.hidden_tail(head_inx_lt)
                
            
            if tail_index in node2hidden_tail:
                tail_node_cell_tail = node2cell_tail[tail_index]
                tail_node_hidden_tail = node2hidden_tail[tail_index]
            else:
                tail_node_cell_tail = self.cell_tail(tail_inx_lt)
                tail_node_hidden_tail = self.hidden_tail(tail_inx_lt)
              
            
            if tail_index in node2hidden_head:
                tail_node_hidden_head = node2hidden_head[tail_index]
            else:
                tail_node_hidden_head = self.hidden_head(tail_inx_lt)
                
            
            if test_time and i_condi:
                time3 = time.time()
                print('prepare rep time', str(time3-time2))

            head_delta_t = current_t - head_prev_t
            tail_delta_t = current_t - tail_prev_t

            with torch.no_grad():
                self.recent_timestamp[[head_index, tail_index]] = current_t

            transed_head_delta_t = self.decayer(head_delta_t)
            transed_tail_delta_t = self.decayer(tail_delta_t)

            edge_info_head = self.edge_updater_head(head_node_rep, tail_node_rep)
            edge_info_tail = self.edge_updater_tail(head_node_rep, tail_node_rep)
            
            
            if self.if_no_time:
                updated_head_node_hidden_head,updated_head_node_cell_head  = self.node_updater_head(edge_info_head, ( head_node_hidden_head, head_node_cell_head ))
            else:
                updated_head_node_cell_head, updated_head_node_hidden_head = self.node_updater_head(edge_info_head, head_node_cell_head, head_node_hidden_head , transed_head_delta_t)

            updated_head_node_rep = self.combiner(updated_head_node_hidden_head, head_node_hidden_tail)

            node2cell_head[head_index] = updated_head_node_cell_head
            node2hidden_head[head_index] = updated_head_node_hidden_head
            node2rep[head_index] = updated_head_node_rep
            
            
            if self.if_updated:
                output_rep_head.append(updated_head_node_rep)
            else:
                output_rep_head.append(head_node_rep)

            if self.if_no_time:
                updated_tail_node_hidden_tail, updated_tail_node_cell_tail, = self.node_updater_tail(edge_info_tail, (tail_node_hidden_tail, tail_node_cell_tail))
            else:
                updated_tail_node_cell_tail, updated_tail_node_hidden_tail = self.node_updater_tail(edge_info_tail, tail_node_cell_tail, tail_node_hidden_tail, transed_tail_delta_t)
            updated_tail_node_rep = self.combiner(tail_node_hidden_head, updated_tail_node_hidden_tail)

            node2cell_tail[tail_index] = updated_tail_node_cell_tail
            node2hidden_tail[tail_index] = updated_tail_node_hidden_tail
            node2rep[tail_index] = updated_tail_node_rep
            
            
            if self.if_updated:
                output_rep_tail.append(updated_tail_node_rep)
            else:
                output_rep_tail.append(tail_node_rep)

            if test_time and i_condi:
                time4 = time.time()
                print('update reps', str(time4-time3))


            if self.if_propagation:
                head_node_head_neighbors, head_node_tail_neighbors = self.propagation(head_index, current_t, edge_info_head, 'head', node2cell_head, node2hidden_head, node2cell_tail, node2hidden_tail, node2rep, self.threhold, self.second_order)
                tail_node_head_neighbors, tail_node_tail_neighbors = self.propagation(tail_index, current_t, edge_info_tail, 'tail', node2cell_head, node2hidden_head, node2cell_tail, node2hidden_tail, node2rep, self.threhold, self.second_order)
            else:
                head_node_head_neighbors, head_node_tail_neighbors, n_i_1, n_i_2 = self.get_neighbors(head_index,current_t, self.threhold)
                tail_node_head_neighbors, tail_node_tail_neighbors, n_i_1, n_i_2 = self.get_neighbors(tail_index, current_t, self.threhold)
                head_node_head_neighbors = set(head_node_head_neighbors)
                head_node_tail_neighbors = set(head_node_tail_neighbors)
                tail_node_head_neighbors = set(tail_node_head_neighbors)
                tail_node_tail_neighbors = set(tail_node_tail_neighbors)
                
            
            if test_time and i_condi:
                time5 = time.time()
                if self.if_propagation:
                    print('propagation time', str(time5-time4))
                else:
                    print('Get neighbors time', str(time5-time4))
            all_head_nodes = all_head_nodes | head_node_head_neighbors | tail_node_head_neighbors
            all_tail_nodes = all_tail_nodes | head_node_tail_neighbors | tail_node_tail_neighbors

            ### generate negative samples ###
            tail_candidates = all_tail_nodes - {head_index,tail_index} - head_node_tail_neighbors
            if len(tail_candidates)==0:
                tail_neg_samples = list(choice(range(self.num_embeddings), size=self.num_negative))

            else:
                tail_neg_samples = list(choice(list(tail_candidates), size = self.num_negative))

            head_candidates = all_head_nodes - {tail_index,head_index}- tail_node_head_neighbors
            if len(head_candidates) ==0:
                head_neg_samples = list(choice(range(self.num_embeddings), size=self.num_negative))
            else:
                head_neg_samples = list(choice(list(head_candidates), size = self.num_negative))
                
            
            if test_time and i_condi: 
                time6 = time.time()
                print('get negative samples time', str(time6 - time5))

            for i in tail_neg_samples:
                if i in node2rep:
                    tail_neg_list.append(node2rep[i])
                else:
                    i_lt = torch.LongTensor([i]).to(self.device)

                    tail_neg_list.append(self.node_representations(i_lt))
            for i in head_neg_samples:
                if i in node2rep:
                    head_neg_list.append(node2rep[i])
                else:
                    i_lt = torch.LongTensor([i]).to(self.device)

                    head_neg_list.append(self.node_representations(i_lt))
            if test_time and i_condi: 
                time7 = time.time()
                print('Prepare neg reps time', str(time7 - time6))
                
        ###update interaction time###
            self.interaction_timestamp[head_index, tail_index] = current_t[0,0]    
        
        ###### Prepare modifed cell, hidden and rep to write back to the memory ########
        cell_head_inx = list(node2cell_head.keys())
        output_cell_head = list(node2cell_head.values())

        cell_tail_inx = list(node2cell_tail.keys())
        output_cell_tail = list(node2cell_tail.values())


        hidden_head_inx = list(node2hidden_head.keys())
        output_hidden_head = list(node2hidden_head.values())

        hidden_tail_inx = list(node2hidden_tail.keys())
        output_hidden_tail = list(node2hidden_tail.values())

        rep_inx = list(node2rep.keys())
        output_rep = list(node2rep.values())
           
        
        output_cell_head_tensor = torch.cat([*output_cell_head]).view(-1,self.embedding_dims)
        output_hidden_head_tensor = torch.cat([*output_hidden_head]).view(-1,self.embedding_dims)
        output_rep_head_tensor = torch.cat([*output_rep_head]).view(-1,self.embedding_dims)

        output_cell_tail_tensor = torch.cat([*output_cell_tail]).view(-1,self.embedding_dims)
        output_hidden_tail_tensor = torch.cat([*output_hidden_tail]).view(-1,self.embedding_dims)
        output_rep_tail_tensor = torch.cat([*output_rep_tail]).view(-1,self.embedding_dims)

        output_rep_tensor = torch.cat([*output_rep]).view(-1,self.embedding_dims)

        tail_neg_tensors = torch.cat([*tail_neg_list]).view(-1,self.embedding_dims)
        head_neg_tensors = torch.cat([*head_neg_list]).view(-1,self.embedding_dims)

        if self.transfer:
            output_rep_head_tensor = self.dropout(self.transfer2head(output_rep_head_tensor))
            output_rep_tail_tensor = self.dropout(self.transfer2tail(output_rep_tail_tensor))

            head_neg_tensors =self.dropout(self.transfer2head(head_neg_tensors))
            tail_neg_tensors = self.dropout(self.transfer2tail(tail_neg_tensors))

        if self.nor:
            output_rep_head_tensor = nn.functional.normalize(output_rep_head_tensor)
            output_rep_tail_tensor = nn.functional.normalize(output_rep_tail_tensor)

            head_neg_tensors = nn.functional.normalize(head_neg_tensors)
            tail_neg_tensors = nn.functional.normalize(tail_neg_tensors)



        with torch.no_grad():
            self.cell_head.weight[cell_head_inx,:] = output_cell_head_tensor
            self.hidden_head.weight[hidden_head_inx,:] = output_hidden_head_tensor

            self.cell_tail.weight[cell_tail_inx,:] = output_cell_tail_tensor
            self.hidden_tail.weight[hidden_tail_inx,:] = output_hidden_tail_tensor

            self.node_representations.weight[rep_inx,:] = output_rep_tensor

        return output_rep_head_tensor, output_rep_tail_tensor, head_neg_tensors, tail_neg_tensors
    
    
    def get_rep(self, nodes, rep_type, rep_dict):
        if rep_type == 'node_rep':
            rep = self.node_representations(torch.LongTensor(nodes).to(self.device))
        elif rep_type == 'cell_head':
            rep = self.cell_head(torch.LongTensor(nodes).to(self.device))
        elif rep_type == 'cell_tail':
            rep = self.cell_tail(torch.LongTensor(nodes).to(self.device))
        elif rep_type == 'hidden_head':
            rep = self.hidden_head(torch.LongTensor(nodes).to(self.device))    
        else:
            rep = self.hidden_tail(torch.LongTensor(nodes).to(self.device))     
        for nei in nodes:
            if nei in rep_dict:
                rep[nodes.index(nei),:] = rep_dict[nei]    
        return  rep  
    
    
    def get_neighbors(self,node,current_t,threhold=None):
        row_inx, col_inx, timestamps = find(self.interaction_timestamp) 

        head_inx = list(np.where(col_inx == node)[0])
        head_neighbors = row_inx[head_inx]
        head_timestamps = timestamps[head_inx]

        tail_inx = list(np.where(row_inx == node)[0])
        tail_neighbors = col_inx[tail_inx]
        tail_timestamps = timestamps[tail_inx]
        if threhold is not None:

            head_inx_th = (current_t.item() -  head_timestamps ) <=threhold
            head_neighbors = head_neighbors[head_inx_th]
            head_timestamps = head_timestamps[head_inx_th]


            tail_inx_th = (current_t.item() - tail_timestamps) <=threhold
            tail_timestamps = tail_timestamps[tail_inx_th]
            tail_neighbors = tail_neighbors[tail_inx_th]

        return head_neighbors, tail_neighbors , head_timestamps, tail_timestamps
    
    
    def get_att_score(self,node, neighbors, node2rep):
        nei_reps = self.get_rep(neighbors, 'node_rep', node2rep)
        node_rep  = self.get_rep([node], 'node_rep', node2rep)

        node_reps = node_rep.repeat(len(neighbors),1)

        return self.attention(node_reps, nei_reps)
    
    
    def propagation(self, node, current_t, edge_info, node_type, node2cell_head, node2hidden_head, node2cell_tail, node2hidden_tail, node2rep, threhold = None, second_order=False):
        head_neighbors, tail_neighbors, head_timestamps, tail_timestamps = self.get_neighbors(node, current_t,threhold)

        head_neighbors = list(head_neighbors)
        head_timestamps = list(head_timestamps)
        if len(head_neighbors)>0:
            if node_type == 'head':
                head_nei_edge_info = self.tran_head_edge_head(edge_info)

            else: 
                head_nei_edge_info = self.tran_tail_edge_head(edge_info)

            head_delta_ts = current_t.repeat(len(head_timestamps),1) - torch.FloatTensor(head_timestamps).to(self.device).view(-1,1)
            transed_head_delta_ts = self.decayer(head_delta_ts)

            
            head_nei_cell = self.get_rep(head_neighbors, 'cell_head',node2cell_head)
            if self.if_no_time:
                tran_head_nei_edge_info = head_nei_edge_info.repeat(len(head_neighbors),1)
            else:
                tran_head_nei_edge_info = head_nei_edge_info.repeat(len(head_neighbors),1) * transed_head_delta_ts


            if self.is_att:
                att_score_head = self.get_att_score(node, head_neighbors, node2rep)
                tran_head_nei_edge_info = tran_head_nei_edge_info*att_score_head


            head_nei_cell = head_nei_cell + tran_head_nei_edge_info
            head_nei_hidden = self.act(head_nei_cell)
            head_nei_tail_hidden = self.get_rep(head_neighbors, 'hidden_tail', node2hidden_tail)
            head_nei_rep = self.combiner(head_nei_hidden, head_nei_tail_hidden)

            for i, nei in enumerate(head_neighbors):
                node2cell_head[nei] = head_nei_cell[i].view(-1,self.embedding_dims)
                node2hidden_head[nei] = head_nei_hidden[i].view(-1,self.embedding_dims)
                node2rep[nei] = head_nei_rep[i].view(-1,self.embedding_dims)

            if second_order:
                for  head_node_sec in head_neighbors:
                    self.second_propagation(head_node_sec, current_t , tran_head_nei_edge_info[0,:], 'head', node2cell_head, node2hidden_head, node2cell_tail, node2hidden_tail, node2rep, threhold)

        tail_neighbors = list(tail_neighbors)
        tail_timestamps = list(tail_timestamps)
        if len(tail_neighbors)>0:

            if node_type == 'head':
                tail_nei_edge_info = self.tran_head_edge_tail(edge_info)
            else: 
                tail_nei_edge_info = self.tran_tail_edge_tail(edge_info)

            tail_delta_ts = current_t.repeat(len(tail_timestamps),1) - torch.FloatTensor(tail_timestamps).to(self.device).view(-1,1) 
            transed_tail_delta_ts = self.decayer(tail_delta_ts)


            tail_nei_cell = self.get_rep(tail_neighbors, 'cell_tail', node2cell_tail)
            if self.if_no_time:
                tran_tail_nei_edge_info = tail_nei_edge_info.repeat(len(tail_neighbors),1)
            else:
                tran_tail_nei_edge_info = tail_nei_edge_info.repeat(len(tail_neighbors),1) * transed_tail_delta_ts

            if self.is_att:
                att_score_tail = self.get_att_score(node, tail_neighbors, node2rep)
                tran_head_nei_edge_info = tran_tail_nei_edge_info*att_score_tail

            tail_nei_cell = tail_nei_cell + tran_tail_nei_edge_info
            tail_nei_hidden = self.act(tail_nei_cell)
            tail_nei_head_hidden = self.get_rep(tail_neighbors, 'hidden_head', node2hidden_head)
            tail_nei_rep = self.combiner(tail_nei_head_hidden, tail_nei_hidden)

            for i, nei in enumerate(tail_neighbors):
                node2cell_tail[nei] = tail_nei_cell[i].view(-1,self.embedding_dims)
                node2hidden_tail[nei] = tail_nei_hidden[i].view(-1,self.embedding_dims)
                node2rep[nei]= tail_nei_rep[i].view(-1, self.embedding_dims)

            if second_order:
                for tail_node_sec in tail_neighbors:
                    self.second_propagation(tail_node_sec, current_t, tran_tail_nei_edge_info[0,:], 'tail', node2cell_head, node2hidden_head, node2cell_tail, node2hidden_tail, node2rep, threhold)

        return set(head_neighbors), set(tail_neighbors)

    
    def second_propagation(self, node, current_t, edge_info, node_type, node2cell_head, node2hidden_head, node2cell_tail, node2hidden_tail, node2rep, threhold = None):
        head_neighbors, tail_neighbors, head_timestamps, tail_timestamps = self.get_neighbors(node,current_t, threhold)

        head_neighbors = list(head_neighbors)
        head_timestamps = list(head_timestamps)
        if len(head_neighbors) > 0:
            if node_type == 'head':
                head_nei_edge_info = self.tran_head_edge_head(edge_info)
            else: 
                head_nei_edge_info = self.tran_tail_edge_head(edge_info)

            head_delta_ts = current_t.repeat(len(head_timestamps),1) - torch.FloatTensor(head_timestamps).to(self.device).view(-1,1)
            transed_head_delta_ts = self.decayer(head_delta_ts)

            head_nei_cell = self.get_rep(head_neighbors, 'cell_head',node2cell_head)
            if self.if_no_time:
                tran_head_nei_edge_info = head_nei_edge_info.repeat(len(head_neighbors),1)
            else:
                tran_head_nei_edge_info = head_nei_edge_info.repeat(len(head_neighbors),1) * transed_head_delta_ts

            if self.is_att:
                att_score_head = self.get_att_score(node, head_neighbors, node2rep)
                tran_head_nei_edge_info = tran_head_nei_edge_info*att_score_head

            head_nei_cell = head_nei_cell + tran_head_nei_edge_info
            head_nei_hidden = self.act(head_nei_cell)
            head_nei_tail_hidden = self.get_rep(head_neighbors, 'hidden_tail', node2hidden_tail)
            head_nei_rep = self.combiner(head_nei_hidden, head_nei_tail_hidden)

            for i, nei in enumerate(head_neighbors):
                node2cell_head[nei] = head_nei_cell[i].view(-1,self.embedding_dims)
                node2hidden_head[nei] = head_nei_hidden[i].view(-1,self.embedding_dims)
                node2rep[nei] = head_nei_rep[i].view(-1,self.embedding_dims)

        tail_neighbors = list(tail_neighbors)
        tail_timestamps = list(tail_timestamps)
        
        if len(tail_neighbors) > 0:
            if node_type == 'head':
                tail_nei_edge_info = self.tran_head_edge_tail(edge_info)
            else: 
                tail_nei_edge_info = self.tran_tail_edge_tail(edge_info)

            tail_delta_ts = current_t.repeat(len(tail_timestamps),1) - torch.FloatTensor(tail_timestamps).to(self.device).view(-1,1) 
            transed_tail_delta_ts = self.decayer(tail_delta_ts)

            
            tail_nei_cell = self.get_rep(tail_neighbors, 'cell_tail', node2cell_tail)
            if self.if_no_time:
                tran_tail_nei_edge_info = tail_nei_edge_info.repeat(len(tail_neighbors),1)
            else:
                tran_tail_nei_edge_info = tail_nei_edge_info.repeat(len(tail_neighbors),1) * transed_tail_delta_ts

            if self.is_att:
                att_score_tail = self.get_att_score(node, tail_neighbors, node2rep)
                tran_head_nei_edge_info = tran_tail_nei_edge_info*att_score_tail


            tail_nei_cell = tail_nei_cell + tran_tail_nei_edge_info
            tail_nei_hidden = self.act(tail_nei_cell)
            tail_nei_head_hidden = self.get_rep(tail_neighbors, 'hidden_head', node2hidden_head)
            tail_nei_rep = self.combiner(tail_nei_head_hidden, tail_nei_hidden)

            for i, nei in enumerate(tail_neighbors):
                node2cell_tail[nei] = tail_nei_cell[i].view(-1,self.embedding_dims)
                node2hidden_tail[nei] = tail_nei_hidden[i].view(-1,self.embedding_dims)
                node2rep[nei]= tail_nei_rep[i].view(-1,self.embedding_dims)
                
        return set(head_neighbors), set(tail_neighbors)

    def loss(self, interactions):
        output_rep_head_tensor, output_rep_tail_tensor, head_neg_tensors, tail_neg_tensors = self.forward(interactions)

        head_pos_tensors = output_rep_head_tensor.clone().repeat(1,self.num_negative).view(-1,self.embedding_dims)
        tail_pos_tensors = output_rep_tail_tensor.clone().repeat(1,self.num_negative).view(-1,self.embedding_dims)

        num_pp = output_rep_head_tensor.size()[0]
        labels_p = torch.FloatTensor([1]*num_pp).to(self.device)
        labels_n = torch.FloatTensor([0]*num_pp*2*self.num_negative).to(self.device)

        labels = torch.cat((labels_p,labels_n))

        scores_p = torch.bmm(output_rep_head_tensor.view(num_pp,1,self.embedding_dims),output_rep_tail_tensor.view(num_pp,self.embedding_dims,1))
        scores_n_1 = torch.bmm(head_neg_tensors.view(num_pp*self.num_negative,1,self.embedding_dims), tail_pos_tensors.view(num_pp*self.num_negative, self.embedding_dims,1))
        scores_n_2 = torch.bmm(head_pos_tensors.view(num_pp*self.num_negative,1,self.embedding_dims), tail_neg_tensors.view(num_pp*self.num_negative, self.embedding_dims,1))

        scores = torch.cat((scores_p,scores_n_1,scores_n_2)).view(num_pp*(1+2*self.num_negative))
        bce_with_logits_loss = nn.BCEWithLogitsLoss()
        loss = bce_with_logits_loss(scores,labels)

        return loss        

In [16]:
class Temporal_Dataset(Dataset):
    def __init__(self, file_name, starting = 0,skip_rows=0, div =3600):
        self.data = np.loadtxt(fname=file_name, skiprows=skip_rows)[:,[0,1,3]]
        self.time = self.data[:,2]
        self.trans_time = (self.time - self.time[0])/div
        self.data[:,2] = self.trans_time
        self.data[:, [0,1]] = self.data[:,[0,1]] - starting

    def __len__(self):
        return self.time.shape[0]

    def __getitem__(self,idx):
        sample = self.data[idx,:]
        return sample

### 参数设置

In [17]:
def get_args():
    parser = argparse.ArgumentParser(description = 'Show description')
    
    parser.add_argument('-data', '--dataset', type = str, help = 'which dataset to run', default = 'uci')
    parser.add_argument('-b', '--batch_size', type= int, help = 'batch_size', default = 200)
    parser.add_argument('-l', '--learning_rate', type = float, help = 'learning_rate', default = 0.001)
    parser.add_argument('-nn', '--num_negative', type = int, help = 'num_negative', default = 5)
    parser.add_argument('-tr', '--train_ratio', type = float, help = 'train_ratio', default = 0.8)
    parser.add_argument('-vr', '--valid_ratio', type = float, help = 'valid_ratio', default = 0.01)
    parser.add_argument('-act', '--act', type = str, help = 'act function', default = 'tanh')
    parser.add_argument('-trans', '--transfer', type = int, help = 'transfer to head, tail representations', default = 1)
    parser.add_argument('-dp' , '--drop_p', type = float, help = 'dropout_rate', default = 0)
    parser.add_argument('-ip', '--if_propagation', type = int, help = 'if_propagation', default=1)
    parser.add_argument('-ia', '--is_att', type = int, help = 'use attention or not', default=1)
    parser.add_argument('-w', '--w', type = float, help = 'w for decayer', default = 2)
    parser.add_argument('-s', '--seed', type = int, help = 'random seed', default = 0)
    parser.add_argument('-rp', '--reset_rep', type = int, help = 'whether reset rep', default = 1)
    parser.add_argument('-dc', '--decay_method', type = str, help = 'decay_method', default = 'log')
    parser.add_argument('-nor', '--nor', type = int , help = 'normalize or not', default = 0)
    parser.add_argument('-iu', '--if_updated', type = int, help = 'use updated representation in loss', default = 0)
    parser.add_argument('-wd', '--weight_decay', type = float, help = 'weight decay', default = 0.001)
    parser.add_argument('-nt', '--if_no_time', type = int, help = 'if no time interval information', default = 0)
    parser.add_argument('-th', '--threhold', type = float, help = 'the threhold to filter the neighbors, if None, do not filter', default = None)
    parser.add_argument('-2hop', '--second_order', type = int, help = 'whether to use 2-hop prop', default = 0)
    
    args = parser.parse_args(args=[])
    return args

In [36]:
def link_prediction(data, reps):
    head_list = list(data[:,0])
    tail_list = list(data[:,1])
    head_reps = reps[head_list,:]
    tail_reps = reps[tail_list,:]

In [37]:
def get_loss(data, head_reps, tail_reps,device):
    head_list = list(data[:,0])
    tail_list = list(data[:,1])

    head_tensors = head_reps(torch.LongTensor(head_list).to(device))
    tail_tensors = tail_reps(torch.LongTensor(tail_list).to(device))
    scores = torch.bmm(head_tensors.view(len(head_list),1,head_tensors.size()[1]),tail_tensors.view(len(head_list),head_tensors.size()[1],1)).view(len(head_list))
    labels = torch.FloatTensor([1]*len(head_list)).to(device)
    bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
    loss = bce_with_logits_loss(scores,labels)
    
    return loss

In [38]:
def rank(node, true_candidate, node2candidate, node_reps, candidate_reps, device, pri = False):
    node_tensor = node_reps(torch.LongTensor([node]).to(device)).view(-1,1)
    candidates = list(node2candidate[node])

    candidates.append(true_candidate)

    length = len(candidates)

    candidate_tensors = candidate_reps(torch.LongTensor(candidates).to(device))

    scores = torch.mm(candidate_tensors, node_tensor)
    negative_scores_numpy = -scores.view(1,-1).to('cpu').numpy()
    rank = rankdata(negative_scores_numpy)[-1]

    if pri:
        print(node , true_candidate)
        print(scores.view(-1))
        print(rank, 'out of',length)

    return rank, length

In [39]:
def get_previous_links(data):
    previous_links = set()
    for i in range(len(data)):
        head, tail, time = data[i]
        previous_links.add((int(head), int(tail)))
    return previous_links 

In [40]:
def get_node2candidate(train_data, all_nodes, pri = False):
    head_node2candidate = dict()
    tail_node2candidate = dict()

    pri = True
    if pri:
        start_time = time.time()
        print('Start to build node2candidate')

    for i in range(len(train_data)):

        head, tail, not_in_use = train_data[i]
        head = int(head)
        tail = int(tail)
        if head not in head_node2candidate:
            head_node2candidate[head] = all_nodes

        if tail not in tail_node2candidate:
            tail_node2candidate[tail] = all_nodes

    if pri: 
        end_time = time.time()
        print('node2candidate built in' , str(end_time-start_time))
        
    return head_node2candidate, tail_node2candidate

In [41]:
def get_ranks(test_data,head_reps, tail_reps, device, head_node2candidate, tail_node2candidate, pri=False, previous_links = None, bo = False):
    head_ranks = []
    tail_ranks = []
    head_lengths = []
    tail_lengths = []

    for interactioin in test_data:
        head_node, tail_node , time = interactioin
        head_node = int(head_node)
        tail_node = int(tail_node)
        if pri:
            print('--------------', head_node, tail_node, '---------------')

        if bo:
            if previous_links is not None: 
                if head_node in head_node2candidate and tail_node in tail_node2candidate and tail_node in head_node2candidate and head_node in tail_node2candidate and (head_node, tail_node) not in previous_links:
                    head_rank, head_length = rank(head_node, tail_node, head_node2candidate, head_reps, tail_reps, device,pri)
                    head_ranks.append(head_rank)
                    head_lengths.append(head_length)

                    tail_rank, tail_length = rank(tail_node, head_node, tail_node2candidate, tail_reps, head_reps, device)
                    tail_ranks.append(tail_rank)
                    tail_lengths.append(tail_length)
            else:
                if head_node in head_node2candidate and tail_node in tail_node2candidate and tail_node in head_node2candidate and head_node in tail_node2candidate:
                    head_rank, head_length = rank(head_node, tail_node, head_node2candidate, head_reps, tail_reps, device,pri)
                    head_ranks.append(head_rank)
                    head_lengths.append(head_length)

                    tail_rank, tail_length = rank(tail_node, head_node, tail_node2candidate, tail_reps, head_reps, device)
                    tail_ranks.append(tail_rank)
                    tail_lengths.append(tail_length)
        else:
            if previous_links is not None: 
                if head_node in head_node2candidate and tail_node in tail_node2candidate and (head_node, tail_node) not in previous_links:
                    head_rank, head_length = rank(head_node, tail_node, head_node2candidate, head_reps, tail_reps, device,pri)
                    head_ranks.append(head_rank)
                    head_lengths.append(head_length)

                    tail_rank, tail_length = rank(tail_node, head_node, tail_node2candidate, tail_reps, head_reps, device)
                    tail_ranks.append(tail_rank)
                    tail_lengths.append(tail_length)
            else:
                if head_node in head_node2candidate and tail_node in tail_node2candidate:
                    head_rank, head_length = rank(head_node, tail_node, head_node2candidate, head_reps, tail_reps, device,pri)
                    head_ranks.append(head_rank)
                    head_lengths.append(head_length)

                    tail_rank, tail_length = rank(tail_node, head_node, tail_node2candidate, tail_reps, head_reps, device)
                    tail_ranks.append(tail_rank)
                    tail_lengths.append(tail_length)

    return head_ranks, tail_ranks, head_lengths, tail_lengths

### 模型训练

In [64]:
def train(args, data, num_nodes, model_save_dir):
    batch_size = args.batch_size
    learning_rate = args.learning_rate
    num_negative = args.num_negative
    act = args.act
    transfer = args.transfer
    drop_p = args.drop_p
    if_propagation = args.if_propagation
    w = args.w
    is_att = args.is_att
    seed = args.seed
    reset_rep = args.reset_rep
    decay_method = args.decay_method
    nor = args.nor
    if_updated = args.if_updated
    weight_decay = args.weight_decay
    if_no_time = args.if_no_time
    threhold = args.threhold
    second_order = args.second_order
    num_iter = 4

    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

    train_ratio = args.train_ratio
    valid_ratio = args.valid_ratio
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_data = data[0:int(len(data)*train_ratio)]
    validation_data = data[int(len(data)*train_ratio):int(len(data)*(train_ratio+valid_ratio))]
    test_data = data[int(len(data)*(train_ratio + valid_ratio)):len(data)]
    print('Data length: ', len(data))
    print('Train length: ', len(train_data))
    sampler = SequentialSampler(train_data)
    data_loader = DataLoader(train_data, batch_size, sampler = sampler)

    all_nodes = set(range(num_nodes))
    print('num_nodes',len(all_nodes))
    head_node2candidate, tail_node2candidate = get_node2candidate(train_data, all_nodes)



    model_save_dir = model_save_dir  + 'nt_' +str(if_no_time)+ '_wd_' + str(weight_decay) + '_up_' + str(if_updated) +'_w_' + str(w) +'_b_' + str(batch_size) + '_l_' + str(learning_rate) + '_tr_' + str(train_ratio) + '_nn_' +str(num_negative)+'_' + act + '_trans_' +str(transfer) + '_dr_p_' + str(drop_p) + '_prop_' + str(if_propagation) + '_att_' +str(is_att) + '_rp_' + str(reset_rep) + '_dcm_' + decay_method + '_nor_' + str(nor)
    if threhold is not None:
        model_save_dir = model_save_dir + '_th_' + str(threhold)

    if second_order:
        model_save_dir = model_save_dir + '_2hop'
        
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)

    dyGnn = DyGNN(num_nodes,64,64,device, w,is_att ,transfer,nor,if_no_time, threhold,second_order, if_updated,drop_p, num_negative, act, if_propagation, decay_method )
    dyGnn.train()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,dyGnn.parameters()),lr = learning_rate, weight_decay=weight_decay)


    old_head_rank = num_nodes/2
    old_tail_rank = num_nodes/2

    for epoch in range(num_iter):
        print('epoch: ', epoch)
        print('Resetting time...')
        dyGnn.reset_time()
        print('Time reset')
        if reset_rep:

            dyGnn.reset_reps()
            print('reps reset')

        x = int(5000/batch_size)
        y = int(10000/batch_size)


        for i, interactions in enumerate(data_loader):

            # Compute and print loss.
            loss = dyGnn.loss(interactions)
            if i%x==0:
                #dyGnn.reset_reps()
                print(i,' train_loss: ', loss.item())

                if transfer:
                    head_reps = nn.Embedding.from_pretrained(dyGnn.transfer2head(dyGnn.node_representations.weight))
                    tail_reps = nn.Embedding.from_pretrained(dyGnn.transfer2tail(dyGnn.node_representations.weight))
                else:
                    head_reps = dyGnn.node_representations
                    tail_reps = dyGnn.node_representations

                head_reps = nn.Embedding.from_pretrained(nn.functional.normalize(head_reps.weight))
                tail_reps = nn.Embedding.from_pretrained(nn.functional.normalize(tail_reps.weight))


            if i%y==-1:

                if transfer:
                    head_reps = nn.Embedding.from_pretrained(dyGnn.transfer2head(dyGnn.node_representations.weight))
                    tail_reps = nn.Embedding.from_pretrained(dyGnn.transfer2tail(dyGnn.node_representations.weight))
                else:
                    head_reps = dyGnn.node_representations
                    tail_reps = dyGnn.node_representations

                head_reps = nn.Embedding.from_pretrained(nn.functional.normalize(head_reps.weight))
                tail_reps = nn.Embedding.from_pretrained(nn.functional.normalize(tail_reps.weight))

                head_ranks, tail_ranks, not_in_use, not_in_use2= get_ranks(validation_data,head_reps, tail_reps, device, head_node2candidate, tail_node2candidate)
                head_ranks_numpy = np.asarray(head_ranks)
                tail_ranks_numpy = np.asarray(tail_ranks)
                print('head_rank mean: ', np.mean(head_ranks_numpy),' ; ', 'head_rank var: ', np.var(head_ranks_numpy))
                print('tail_rank mean: ', np.mean(tail_ranks_numpy),' ; ', 'tail_rank var: ', np.var(tail_ranks_numpy))


            optimizer.zero_grad()

            loss.backward()
            optimizer.step()


        if transfer:
            head_reps = nn.Embedding.from_pretrained(dyGnn.transfer2head(dyGnn.node_representations.weight))
            tail_reps = nn.Embedding.from_pretrained(dyGnn.transfer2tail(dyGnn.node_representations.weight))
        else:
            head_reps = dyGnn.node_representations
            tail_reps = dyGnn.node_representations
        head_reps = nn.Embedding.from_pretrained(nn.functional.normalize(head_reps.weight))
        tail_reps = nn.Embedding.from_pretrained(nn.functional.normalize(tail_reps.weight))

        valid_loss = get_loss(validation_data, head_reps, tail_reps, device)
        head_ranks, tail_ranks, head_lengths, tail_lengths = get_ranks(validation_data, head_reps, tail_reps, device, head_node2candidate, tail_node2candidate)
        head_ranks_numpy = np.asarray(head_ranks)
        tail_ranks_numpy = np.asarray(tail_ranks)
        head_lengths_numpy = np.asarray(head_lengths)
        tail_lengths_numpy = np.asarray(tail_lengths)

        mean_head_rank = np.mean(head_ranks_numpy)
        mean_tail_rank = np.mean(tail_ranks_numpy)


        print('head_length mean: ', np.mean(head_lengths_numpy), ';', 'num_test: ', head_lengths_numpy.shape[0])
        print('tail_lengths mean: ', np.mean(tail_lengths_numpy), ';', 'num_test: ', tail_lengths_numpy.shape[0])
        print('head_rank mean: ', mean_head_rank,' ; ', 'head_rank var: ', np.var(head_ranks_numpy))
        print('tail_rank mean: ', mean_tail_rank,' ; ', 'tail_rank var: ', np.var(tail_ranks_numpy))
        print('reverse head_rank mean: ', np.mean(1/head_ranks_numpy))
        print('reverse tail_rank mean: ', np.mean(1/tail_ranks_numpy))
        print('head_rank HITS 100: ', (head_ranks_numpy<=100).sum())
        print('tail_rank_HITS 100: ', (tail_ranks_numpy<=100).sum())
        print('head_rank HITS 50: ', (head_ranks_numpy<=50).sum())
        print('tail_rank_HITS 50: ', (tail_ranks_numpy<=50).sum())
        print('head_rank HITS 20: ', (head_ranks_numpy<=20).sum())
        print('tail_rank_HITS 20: ', (tail_ranks_numpy<=20).sum())

        if mean_head_rank < old_head_rank or mean_tail_rank < old_tail_rank:
            model_save_path = model_save_dir + '/' + 'model_after_epoch_' + str(epoch) + '.pt'
            torch.save(dyGnn.state_dict(), model_save_path)
            print('model saved in: ', model_save_path)

            with open(model_save_dir + '/' + '0valid_results.txt','a') as f:
                f.write('epoch: ' + str(epoch) + '\n')
                f.write('head_rank mean: ' + str(mean_head_rank) + ' ; ' +  'head_rank var: ' + str(np.var(head_ranks_numpy)) + '\n')
                f.write('tail_rank mean: ' + str(mean_tail_rank) + ' ; ' +  'tail_rank var: ' + str(np.var(tail_ranks_numpy)) + '\n')
                f.write('head_rank HITS 100: ' + str ( (head_ranks_numpy<=100).sum()) + '\n')
                f.write('tail_rank_HITS 100: ' + str ( (tail_ranks_numpy<=100).sum()) + '\n')
                f.write('head_rank HITS 50: ' + str( (head_ranks_numpy<=50).sum()) + '\n')
                f.write('tail_rank_HITS 50: ' + str( (tail_ranks_numpy<=50).sum()) + '\n')
                f.write('head_rank HITS 20: ' + str( (head_ranks_numpy<=20).sum()) + '\n')
                f.write('tail_rank_HITS 20: ' + str( (tail_ranks_numpy<=20).sum()) + '\n')
                f.write('============================================================================\n')
            old_head_rank = mean_head_rank + 200
            old_tail_rank = mean_tail_rank + 200

### 主函数

In [65]:
args = get_args()
model_save_dir = r"C:\Users\sss\Desktop\DyGNN-main\saved_models/"
if args.dataset == "uci":
    data = Temporal_Dataset(r'C:\Users\sss\Desktop\DyGNN-main\Dataset/UCI_email_1899_59835/opsahl-ucsocial/out.opsahl-ucsocial', 1, 2)
    # print(dir(data))
    num_nodes = 1899
    model_save_dir = model_save_dir + 'UCI/'
    print('Train on UCI_message dataset')
    train(args, data, num_nodes, model_save_dir)   
else:
    print('Please choose a dataset to run')

Train on UCI_message dataset
Data length:  59835
Train length:  47868
num_nodes 1899
Start to build node2candidate
node2candidate built in 0.11668658256530762
Only propagate to relevance nodes below time interval:  None
epoch:  0
Resetting time...
Time reset
reps reset
0  train_loss:  0.7923987507820129
25  train_loss:  0.30442488193511963


KeyboardInterrupt: 