In [None]:
from layer import TemporalLayer, SpatialLayer
import torch.nn.functional as F
import torch
import torch.nn as nn
from utils import get_graph_for_day, get_label
import pandas as pd 
import numpy as np

In [None]:
node_states = []
adjacencies = []
relation_es = []

for day in range(6, 8):
    feat, adjacency_matrix, point_e, relation_e, entity_index = get_graph_for_day(day, 10)
    node_states.append(torch.tensor(feat).float())
    adjacencies.append(torch.tensor(adjacency_matrix))
    relation_es.append(torch.tensor(relation_e))

In [None]:
train_mask, test_mask, train_label, test_label = get_label()

In [None]:
train_mask = torch.tensor(train_mask)
test_mask = torch.tensor(test_mask)
train_label = torch.tensor(train_label)
test_label = torch.tensor(test_label)

In [None]:
class HTGT(torch.nn.Module):
    def __init__(self, entity_index):
        super(HTGT, self).__init__()
        self.temporal_layer = TemporalLayer(entity_index, 128, 32, F.relu)
        self.spatial_layer = SpatialLayer(32, 32, 15 , 5, nn.ELU())
        self.mlp = nn.Linear(32, 2)
        self.buffer = []
        
    def forward(self, node_states, adjacencies, point_enc, relation_encs, mask):
        for index in range(len(node_states)):
            node_state = node_states[index]
            adjacency = adjacencies[index]
            relation_enc = relation_encs[index]
            if index == 0:
                this_buffer = torch.zeros((len(node_state), 1, 32)).type(torch.FloatTensor)
            else:
                this_buffer = self.buffer
            emb = self.temporal_layer(node_state, this_buffer)
            emb = self.spatial_layer(emb, adjacency, point_enc, relation_enc)
            if index == 0:
                self.buffer = torch.unsqueeze(emb, 1)
            else:
                self.buffer = torch.cat([self.buffer, torch.unsqueeze(emb, 1)], 1)
        self.buffer = []
        emb = emb[mask]
        logits = self.mlp(emb.float())
        return logits
    
    def predict(self, node_states, adjacencies, point_enc, relation_encs, mask):
        #Apply softmax to output. 
        pred = F.softmax(self.forward(node_states, adjacencies, point_enc, relation_encs, mask))
        ans = []
        #Pick the class with maximum weight
        for t in pred:
            if t[0]>t[1]:
                ans.append(0)
            else:
                ans.append(1)
        return torch.tensor(ans)

In [None]:
model = HTGT(entity_index)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

In [None]:
from sklearn.metrics import accuracy_score

for e in range(5000):
    y_pred = model(node_states, adjacencies, point_e, relation_es, train_mask)
    loss = criterion(y_pred,  train_label)
    print(loss)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print(accuracy_score(model.predict(node_states, adjacencies, point_e, relation_es, test_mask), test_label))

In [None]:
adjacencies[0][train_mask]