In [1]:
import dgl
from dgl.data.utils import load_graphs

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random

from model.model import HTGNN, NodePredictor
from utils.pytorchtools import EarlyStopping
from utils.data import load_COVID_data

dgl.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
x = np.array([])

In [5]:
y = np.array([1,2,3])
z = np.array([4,5,6])

In [11]:
np.concatenate((y,z))

array([1, 2, 3, 4, 5, 6])

: 

In [7]:
a = [1,2,3,4,5,6,7,8,9,10,11,12]


In [11]:
# a[::3]
a[1::3]

[2, 5, 8, 11]

In [12]:
a[2::3]

[3, 6, 9, 12]

In [3]:
x

array([[1, 2, 3],
       [4, 5, 6],
       [7, 8, 9]])

In [6]:
np.mean(x,axis=0)

array([4., 5., 6.])

In [20]:
batchs = 10000 // 64 
start = random.choice(list(range(batchs-6)))

In [21]:
start

30

In [22]:
range(start, start + 3)

range(30, 33)

In [12]:
x = torch.randn(2, 3)
y = torch.randn(2, 3)
z = torch.cat((x,y))
z

tensor([[-0.8567,  1.1006, -1.0712],
        [ 0.1227, -0.5663,  0.3731],
        [-0.8920, -1.5091,  0.3704],
        [ 1.4565,  0.9398,  0.7748]])

In [13]:
z.dim()

2

In [14]:
z = z.T.unsqueeze(0)

In [17]:
z.shape

torch.Size([1, 3, 4])

: 

In [2]:
device = torch.device('cuda')
glist, _ = load_graphs('data/covid_graphs.bin')
time_window = 7

train_feats, train_labels, val_feats, val_labels, test_feats, test_labels = load_COVID_data(glist, time_window)


In [9]:
train_feats[0].nodes('state')

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50],
       dtype=torch.int32)

In [10]:
sg, inverse_indices = dgl.khop_in_subgraph(train_feats[0], {'state': 2}, k=2, store_ids=True)

In [11]:
sg

Graph(num_nodes={'county': 216, 'state': 13},
      num_edges={('county', 'affiliate_r_t0', 'state'): 216, ('county', 'affiliate_r_t1', 'state'): 216, ('county', 'affiliate_r_t2', 'state'): 216, ('county', 'affiliate_r_t3', 'state'): 216, ('county', 'affiliate_r_t4', 'state'): 216, ('county', 'affiliate_r_t5', 'state'): 216, ('county', 'affiliate_r_t6', 'state'): 216, ('county', 'nearby_county_t0', 'county'): 1412, ('county', 'nearby_county_t1', 'county'): 1412, ('county', 'nearby_county_t2', 'county'): 1412, ('county', 'nearby_county_t3', 'county'): 1412, ('county', 'nearby_county_t4', 'county'): 1412, ('county', 'nearby_county_t5', 'county'): 1412, ('county', 'nearby_county_t6', 'county'): 1412, ('state', 'affiliate_t0', 'county'): 216, ('state', 'affiliate_t1', 'county'): 216, ('state', 'affiliate_t2', 'county'): 216, ('state', 'affiliate_t3', 'county'): 216, ('state', 'affiliate_t4', 'county'): 216, ('state', 'affiliate_t5', 'county'): 216, ('state', 'affiliate_t6', 'county'): 216,

In [6]:
sg, inverse_indices = dgl.khop_in_subgraph(train_feats[0], {'state': [1,2]}, k=2, store_ids=True)

In [7]:
sg

Graph(num_nodes={'county': 245, 'state': 14},
      num_edges={('county', 'affiliate_r_t0', 'state'): 245, ('county', 'affiliate_r_t1', 'state'): 245, ('county', 'affiliate_r_t2', 'state'): 245, ('county', 'affiliate_r_t3', 'state'): 245, ('county', 'affiliate_r_t4', 'state'): 245, ('county', 'affiliate_r_t5', 'state'): 245, ('county', 'affiliate_r_t6', 'state'): 245, ('county', 'nearby_county_t0', 'county'): 1543, ('county', 'nearby_county_t1', 'county'): 1543, ('county', 'nearby_county_t2', 'county'): 1543, ('county', 'nearby_county_t3', 'county'): 1543, ('county', 'nearby_county_t4', 'county'): 1543, ('county', 'nearby_county_t5', 'county'): 1543, ('county', 'nearby_county_t6', 'county'): 1543, ('state', 'affiliate_t0', 'county'): 245, ('state', 'affiliate_t1', 'county'): 245, ('state', 'affiliate_t2', 'county'): 245, ('state', 'affiliate_t3', 'county'): 245, ('state', 'affiliate_t4', 'county'): 245, ('state', 'affiliate_t5', 'county'): 245, ('state', 'affiliate_t6', 'county'): 245,

In [14]:
dic = {}
z = 0
# k = 0 
for stype, etype, dtype in sg.canonical_etypes:
    k = sg[stype, etype, dtype].number_of_edges()
    # k = sg[stype, etype, dtype].number_of_edges()
    dic[etype] = (z,z + k)
    z += k


In [15]:
dic

{'affiliate_r_t0': (0, 29),
 'affiliate_r_t1': (29, 58),
 'affiliate_r_t2': (58, 87),
 'affiliate_r_t3': (87, 116),
 'affiliate_r_t4': (116, 145),
 'affiliate_r_t5': (145, 174),
 'affiliate_r_t6': (174, 203),
 'nearby_county_t0': (203, 334),
 'nearby_county_t1': (334, 465),
 'nearby_county_t2': (465, 596),
 'nearby_county_t3': (596, 727),
 'nearby_county_t4': (727, 858),
 'nearby_county_t5': (858, 989),
 'nearby_county_t6': (989, 1120),
 'affiliate_t0': (1120, 1149),
 'affiliate_t1': (1149, 1178),
 'affiliate_t2': (1178, 1207),
 'affiliate_t3': (1207, 1236),
 'affiliate_t4': (1236, 1265),
 'affiliate_t5': (1265, 1294),
 'affiliate_t6': (1294, 1323),
 'nearby_state_t0': (1323, 1324),
 'nearby_state_t1': (1324, 1325),
 'nearby_state_t2': (1325, 1326),
 'nearby_state_t3': (1326, 1327),
 'nearby_state_t4': (1327, 1328),
 'nearby_state_t5': (1328, 1329),
 'nearby_state_t6': (1329, 1330)}

In [17]:
print('\n'*3)







In [24]:
train_feats[0].etypes

['affiliate_r_t0',
 'affiliate_r_t1',
 'affiliate_r_t2',
 'affiliate_r_t3',
 'affiliate_r_t4',
 'affiliate_r_t5',
 'affiliate_r_t6',
 'nearby_county_t0',
 'nearby_county_t1',
 'nearby_county_t2',
 'nearby_county_t3',
 'nearby_county_t4',
 'nearby_county_t5',
 'nearby_county_t6',
 'affiliate_t0',
 'affiliate_t1',
 'affiliate_t2',
 'affiliate_t3',
 'affiliate_t4',
 'affiliate_t5',
 'affiliate_t6',
 'nearby_state_t0',
 'nearby_state_t1',
 'nearby_state_t2',
 'nearby_state_t3',
 'nearby_state_t4',
 'nearby_state_t5',
 'nearby_state_t6']

In [25]:
c = {type:0 for type in sg.etypes}

In [23]:
sg.num_nodes(ntype='state')

1

In [5]:
train_feats[0]

Graph(num_nodes={'county': 3223, 'state': 51},
      num_edges={('county', 'affiliate_r_t0', 'state'): 3141, ('county', 'affiliate_r_t1', 'state'): 3141, ('county', 'affiliate_r_t2', 'state'): 3141, ('county', 'affiliate_r_t3', 'state'): 3141, ('county', 'affiliate_r_t4', 'state'): 3141, ('county', 'affiliate_r_t5', 'state'): 3141, ('county', 'affiliate_r_t6', 'state'): 3141, ('county', 'nearby_county_t0', 'county'): 22176, ('county', 'nearby_county_t1', 'county'): 22176, ('county', 'nearby_county_t2', 'county'): 22176, ('county', 'nearby_county_t3', 'county'): 22176, ('county', 'nearby_county_t4', 'county'): 22176, ('county', 'nearby_county_t5', 'county'): 22176, ('county', 'nearby_county_t6', 'county'): 22176, ('state', 'affiliate_t0', 'county'): 3141, ('state', 'affiliate_t1', 'county'): 3141, ('state', 'affiliate_t2', 'county'): 3141, ('state', 'affiliate_t3', 'county'): 3141, ('state', 'affiliate_t4', 'county'): 3141, ('state', 'affiliate_t5', 'county'): 3141, ('state', 'affiliate

In [7]:
time_window = 7
device = 'cuda:0'
graph_atom = test_feats[0]
htgnn = HTGNN(graph=graph_atom, n_inp=1, n_hid=8, n_layers=2, n_heads=1, time_window=time_window, norm=False, device=device)
predictor = NodePredictor(n_inp=8, n_classes=1)
model = nn.Sequential(htgnn, predictor).to(device)
model.load_state_dict(torch.load('/home/jiazhengli/xdgnn/HTGNN/output/COVID19/checkpoint_HTGNN_0.pt'))

<All keys matched successfully>

In [11]:
h = model[0](sg.to(device),'state')

In [27]:
g = dgl.heterograph({
    ('user', 'plays', 'game'): (torch.tensor([0, 1, 1, 2]),
                                torch.tensor([0, 0, 1, 1])),
    ('developer', 'develops', 'game'): (torch.tensor([0, 1]),
                                        torch.tensor([0, 1]))
    })
g = dgl.remove_edges(g, torch.tensor([0, 1]), 'plays',store_ids=True)
g.edges('all', etype='plays')

(tensor([1, 2]), tensor([1, 1]), tensor([0, 1]))

In [29]:
g.num_edges()

4

In [28]:
g.edata[dgl.EID]

{('developer', 'develops', 'game'): tensor([0, 1]),
 ('user', 'plays', 'game'): tensor([2, 3])}