This code is partly adapted from: https://github.com/ZhuangDingyi/STZINB

In [1]:
import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.nn.init as init
import numpy as np
import pandas as pd
import scipy.sparse as sp
from scipy.linalg import eigvalsh
from scipy.linalg import fractional_matrix_power
from torch.utils.data import Dataset, DataLoader
import pickle as pkl
from scipy.spatial.distance import jensenshannon
import time
from scipy.stats import nbinom

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Spatial Module

In [4]:
import math

class InterGAT(nn.Module):
    """
    Neural network block that applies attention mechanism to sampled locations (only the attention).
    """
    def __init__(self, in_channels, alpha=0.2, threshold=0.0, concat = True):
        """
        :param in_channels: Number of time step.
        :param alpha: alpha for leaky Relu.
        :param threshold: threshold for graph connection
        :param concat: whether concat features
        :It should be noted that the input layer should use linear activation
        """
        super(InterGAT, self).__init__()
        self.alpha = alpha
        self.threshold = threshold
        self.concat = concat
        self.in_channels = in_channels
#         self.a = nn.Parameter(torch.zeros(size=(2*in_channels + 1, 1)))
        self.attn1 = nn.Linear(2*in_channels, 16)
        self.attn2 = nn.Linear(16, 1)
#         nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.relu = nn.ReLU()
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, input_target, input_neigh, adj):
        """input_target [batch_size, n_feat]"""
        """input_neigh [batch_size, n_neigh, n_feat]"""
        """adj [batch_size, n_neigh]"""

        B = input_neigh.size()[0]
        N = input_neigh.size()[1]
        # h_query = h[:, :, :-2]

        # a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, self.in_channels), h.repeat(1, N, 1)], dim=2).view(B, N, N, 2 * self.in_channels)
        a_input = torch.cat([input_target.unsqueeze(1).repeat(1, N, 1), input_neigh], -1) # batch_size, n_neigh, n_feat
        # a_input = a_input.unsqueeze(1).unsqueeze(1).repeat(1, 2, 24, 1, 1, 1)
        # a_input = torch.cat([a_input, tmp_emb.unsqueeze(-2).unsqueeze(-2).repeat(1, 1, 1, N, N, 1)], -1)
        e = self.leakyrelu(self.attn2(self.relu(self.attn1(a_input)))).squeeze(-1) # batch_size, n_neigh
#         e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(3))
        zero_vec = -9e15*torch.ones_like(e).to(device)
        # print("adj", adj.device, "e", e.device, "zero_vec", zero_vec.device)
        attention = torch.where(adj > self.threshold, e, zero_vec) #>threshold for attention connection
        attention = F.softmax(attention, dim=-1) # batch_size, n_neigh
        h_prime = torch.matmul(attention.unsqueeze(1), input_neigh).squeeze(1) # batch_size, 1, n_feat
        # batch_size, n_feat
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

class InterMGCN(nn.Module):
    def __init__(self, c_in, c_out, graph_conv_act_func, enable_bias=True):
        super(InterMGCN, self).__init__()
        self.c_in = c_in
        self.c_out = c_out
        self.enable_bias = enable_bias
        self.graph_conv_act_func = graph_conv_act_func
        self.relu = nn.ReLU()
        self.weight = nn.Parameter(torch.Tensor(c_in, c_out).float().to(device))
        if enable_bias == True:
            self.bias = nn.Parameter(torch.Tensor(c_out).float().to(device))
        else:
            self.register_parameter('bias', None)
        self.initialize_parameters()

    def initialize_parameters(self):
        # For Sigmoid or Tanh
        if self.graph_conv_act_func == 'sigmoid' or self.graph_conv_act_func == 'tanh':
            init.xavier_uniform_(tensor=self.weight, gain=init.calculate_gain(self.graph_conv_act_func))
        # For ReLU, LeakyReLU, or ELU
        elif self.graph_conv_act_func == 'relu' or self.graph_conv_act_func == 'leaky_relu' or self.graph_conv_act_func == 'elu':
            init.kaiming_uniform_(self.weight)
        if self.bias is not None:
            init.zeros_(self.bias)

    def forward(self, input_target, input_neigh, adj):
        """input_target [batch_size, n_feat]"""
        """input_neigh [batch_size, n_node, n_feat]"""
        """adj [batch_size, n_node]"""
        batch_size, n_vertex, c_in = input_neigh.shape # x=[batch_size, c_in, n_node] # the size of
        # x_first_target = input_target @ self.weight
        x_first_mul = input_neigh.reshape(-1, c_in) @ self.weight
        x_first_mul = input_neigh.view(batch_size, n_vertex, -1)
        x_second_mul = torch.matmul(adj.unsqueeze(1), x_first_mul).squeeze(1) # batch_size, c_out
        if self.bias is not None:
            x_gcnconv = x_second_mul + self.bias
        else:
            x_gcnconv = x_second_mul
        return self.relu(x_gcnconv)

class InterGraphConvLayer(nn.Module):
    def __init__(self, Ks, c_in, c_out, graph_conv_type, graph_conv_act_func):
        super(InterGraphConvLayer, self).__init__()
        self.Ks = Ks
        # self.c_in = c_in
        # self.c_out = c_out
        # self.align = Align(c_in, c_out)
        self.graph_conv_type = graph_conv_type
        self.graph_conv_act_func = graph_conv_act_func
        self.enable_bias = True
        if self.graph_conv_type == "gat":
            self.align = nn.Linear(c_in, c_out)
            self.gcnconv = InterGAT(c_out)
        else:
            self.gcnconv = InterMGCN(c_in, c_out, graph_conv_act_func)

    def forward(self, x_target, x_neigh, graph_conv_matrix=None):
        # x [batch_size, ..., n_feat]
        # graph_conv_matrix [batch_size, ..., n_node, n_node]
        if self.graph_conv_type == "gat":
            x_target = self.align(x_target)
            x_neigh = self.align(x_neigh)
        x_gc_with_rc = self.gcnconv(x_target, x_neigh, graph_conv_matrix)
        return x_gc_with_rc # [batch_size, c_out, n_vertex]

# Loss

In [5]:
def nb_nll_loss(y,n,p,y_mask=None, weight=True):
    """
    y: true values
    y_mask: whether missing mask is given
    """
    if y_mask is not None:
      y = y[y_mask > 0]
      n = n[y_mask > 0]
      p = p[y_mask > 0]
    nll = torch.lgamma(n) + torch.lgamma(y+1) - torch.lgamma(n+y) - n*torch.log(p) - y*torch.log(1-p)
    return torch.sum(nll)

def nb_zeroinflated_nll_loss(y,n,p,pi,y_mask=None, weight=0.):
    """
    y: true values
    y_mask: whether missing mask is given
    https://stats.idre.ucla.edu/r/dae/zinb/
    """
    if y_mask is not None:
      y = y[y_mask > 0]
      n = n[y_mask > 0]
      p = p[y_mask > 0]
      pi = pi[y_mask > 0]

    idx_yeq0 = y==0
    idx_yg0  = y>0

    n_yeq0 = n[idx_yeq0]
    p_yeq0 = p[idx_yeq0]
    pi_yeq0 = pi[idx_yeq0]
    yeq0 = y[idx_yeq0]

    n_yg0 = n[idx_yg0]
    p_yg0 = p[idx_yg0]
    pi_yg0 = pi[idx_yg0]
    yg0 = y[idx_yg0]

    #L_yeq0 = torch.log(pi_yeq0) + (1-pi_yeq0)*torch.pow(p_yeq0,n_yeq0)
    #L_yg0  = torch.log(pi_yg0) + torch.lgamma(n_yg0+yg0) - torch.lgamma(yg0+1) - torch.lgamma(n_yg0) + n_yg0*torch.log(p_yg0) + yg0*torch.log(1-p_yg0)
    L_yeq0 = torch.log(pi_yeq0) + torch.log((1-pi_yeq0)*torch.pow(p_yeq0,n_yeq0))
    L_yg0  = torch.log(1-pi_yg0) + torch.lgamma(n_yg0+yg0) - torch.lgamma(yg0+1) - torch.lgamma(n_yg0) + n_yg0*torch.log(p_yg0) + yg0*torch.log(1-p_yg0)
    #print('nll',torch.mean(L_yeq0),torch.mean(L_yg0),torch.mean(torch.log(pi_yeq0)),torch.mean(torch.log(pi_yg0)))
    if weight > 0:
        # od_weight = torch.pow(yg0, weight)
        # od_weight[yg0 < 1] = 0
    #     L_yg0 = L_yg0 * torch.exp(yg0 / weight)
    # if focal is True:
        L_yeq0 = (1-torch.exp(L_yeq0)) ** weight * L_yeq0
        L_yg0 = (1-torch.exp(L_yg0)) ** weight * L_yg0
    return -torch.sum(L_yeq0)-torch.sum(L_yg0)

# Prediction Layer

In [6]:
class FNN_NBNorm(nn.Module):
    def __init__(self, input_dim, output_dim, dropout, setting="daily_cnt", prob="ZINB"):
        super(FNN_NBNorm, self).__init__()
        self.setting = setting
        self.prob = prob
        self.n_fc =nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(128, output_dim)
          )
        self.p_fc = nn.Sequential(
          nn.Linear(input_dim, 256),
          nn.ReLU(),
          nn.Dropout(p=dropout),
          nn.Linear(256, 128),
          nn.ReLU(),
          nn.Dropout(p=dropout),
          nn.Linear(128, output_dim)
        )

    def forward(self, x):
          """ori_dist_feat [batch_size, n_neigh, n_feat]"""
          """ori_dist_adj [batch_size, n_neigh]"""
          """des_dist_feat [batch_size, n_global_sta, n_neigh, n_feat]"""
          """des_dist_adj [batch_size, n_global_sta, n_neigh]"""
          """od_feat [batch_size, n_global_sta, n_feat]"""
          """od_dist_feat [batch_size, n_global_sta, (n_neigh+1)*(n_neigh+1)-1, n_feat]"""
          """od_dist_adj [batch_size, n_global_sta, (n_neigh+1)*(n_neigh+1)-1]"""
          """x_month [batch_size]"""
          n = self.n_fc(x)
          p = self.p_fc(x)
          n = F.softplus(n) # Some parameters can be tuned here
          p = F.sigmoid(p)
          return n.squeeze(-1), p.squeeze(-1)

class FNN_NBNorm_ZeroInflated(nn.Module):
    def __init__(self, input_dim, output_dim, dropout, setting="daily_cnt"):
        super(FNN_NBNorm_ZeroInflated, self).__init__()
        self.setting = setting
        self.n_fc =nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(128, output_dim)
          )
        self.p_fc = nn.Sequential(
          nn.Linear(input_dim, 256),
          nn.ReLU(),
          nn.Dropout(p=dropout),
          nn.Linear(256, 128),
          nn.ReLU(),
          nn.Dropout(p=dropout),
          nn.Linear(128, output_dim)
        )
        self.pi_fc = nn.Sequential(
          nn.Linear(input_dim, 256),
          nn.ReLU(),
          nn.Dropout(p=dropout),
          nn.Linear(256, 128),
          nn.ReLU(),
          nn.Dropout(p=dropout),
          nn.Linear(128, output_dim)
        )

    def forward(self, x):
        batch_size, n_max_station, _ = x.shape
        n = self.n_fc(x)
        p = self.p_fc(x)
        pi = self.pi_fc(x)
        n = F.softplus(n) # Some parameters can be tuned here
        p = F.sigmoid(p)
        pi = F.sigmoid(pi)
        if "daily" in self.setting:
          return n.squeeze(-1), p.squeeze(-1), pi.squeeze(-1)
        return n.reshape(batch_size, n_max_station, 2, 3), p.reshape(batch_size, n_max_station, 2, 3), pi.reshape(batch_size, n_max_station, 2, 3)

class FNN_Prediction(nn.Module):
    def __init__(self, input_dim, output_dim, dropout, setting="daily_cnt"):
        super(FNN_Prediction, self).__init__()
        self.setting = setting
        self.out_fc =nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(128, output_dim)
          )

    def forward(self, x):
        batch_size, n_max_station, _ = x.shape
        y = self.out_fc(x)
        if "daily" in self.setting:
          return y.squeeze(-1)
        return y.reshape(batch_size, n_max_station, 2, 3)

# Model

In [7]:
class Graph_NBNorm_ZeroInflated(nn.Module):
    def __init__(self, dim_station, dim_interact, n_month_embed=12, dropout=0.5, setting="daily_density",
                 dist="zinb", graph_conv_type="gat", od_graph=True, node_graph=True, builtEnv=True,
                 n_gcn=8, n_od_gcn=8):
        super(Graph_NBNorm_ZeroInflated,self).__init__()
        """setting: daily_density, hour_density"""
        """dist: zinb, zip, na"""
        """graph_conv_type: gat, gcn"""
        """od_graph: true & false"""
        """node_graph: true & false"""
        self.month_embedding = nn.Embedding(12, n_month_embed)
        self.setting = setting
        self.dist = dist
        self.graph_conv_type = graph_conv_type
        self.od_graph = od_graph
        self.node_graph = node_graph
        self.builtEnv = builtEnv
        self.graph_conv_act_func = "relu"
        """feature encoding layer"""
        input_dim = dim_station*2+dim_interact+n_month_embed
        if self.node_graph:
          self.ori_gcn_dist = InterGraphConvLayer(1, dim_station, n_gcn, self.graph_conv_type, self.graph_conv_act_func)
          input_dim += n_gcn
          if self.builtEnv:
            self.ori_gcn_builtEnv = InterGraphConvLayer(1, dim_station, n_gcn, self.graph_conv_type, self.graph_conv_act_func)
            input_dim += n_gcn
          input_dim += n_gcn

        if self.od_graph:
          self.od_gcn_dist = InterGraphConvLayer(1, dim_station * 2 + dim_interact, n_od_gcn, self.graph_conv_type, self.graph_conv_act_func)
          input_dim += n_od_gcn
          if self.builtEnv:
            self.od_gcn_builtEnv = InterGraphConvLayer(1, dim_station * 2 + dim_interact, n_od_gcn, self.graph_conv_type, self.graph_conv_act_func)
            input_dim += n_od_gcn
        if self.setting == "daily_cnt":
            input_dim += 1
        self.input_dim = input_dim
        """prediction layer"""
        if "daily" in self.setting:
          output_dim = 1
        elif "hour" in self.setting:
          output_dim = 6
        self.output_dim = output_dim
        if self.dist == "zinb":
          self.pred_fc = FNN_NBNorm_ZeroInflated(self.input_dim, self.output_dim, dropout=dropout, setting=setting)
        elif self.dist == "nb":
          self.pred_fc = FNN_NBNorm(self.input_dim, self.output_dim, dropout=dropout, setting=setting)
        else:
          self.pred_fc = FNN_Prediction(self.input_dim, self.output_dim, dropout=dropout, setting=setting)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

    def process_od_features(self, ori_feat, ori_neigh_feat, des_feat, des_neigh_feat, od_feat, od_neigh_feat):
        batch_size, n_global_station, n_neigh, n_station_feat = des_neigh_feat.shape
        x_od_ori_feat = torch.cat([ori_feat.unsqueeze(1), ori_neigh_feat], 1).unsqueeze(1).repeat(1, n_global_station, 1, 1)
        x_od_des_feat = torch.cat([des_feat.unsqueeze(2), des_neigh_feat], 2)
        x_od_ori_feat = x_od_ori_feat.unsqueeze(-2).repeat(1, 1, 1, n_neigh+1, 1)
        x_od_des_feat = x_od_des_feat.unsqueeze(-3).repeat(1, 1, n_neigh+1, 1, 1)
        x_od_ori_des_feat = torch.cat([x_od_ori_feat, x_od_des_feat], -1).reshape(batch_size, n_global_station, (n_neigh+1)*(n_neigh+1), -1)
        if self.setting == "daily_density":
            x_od_target = torch.cat([x_od_ori_des_feat[:, :, 0], od_feat], -1)
        else:
            x_od_target = torch.cat([x_od_ori_des_feat[:, :, 0], od_feat[:, :, :-1]], -1)
        x_od_neigh = torch.cat([x_od_ori_des_feat[:, :, 1:], od_neigh_feat], -1)
        return x_od_target, x_od_neigh

    def forward(self, ori_feat, ori_dist_feat, ori_dist_adj, ori_builtEnv_feat, ori_builtEnv_adj,
                des_feat, des_dist_feat, des_dist_adj, des_builtEnv_feat, des_builtEnv_adj,
                od_feat, od_dist_feat, od_dist_adj, od_builtEnv_feat, od_builtEnv_adj,
                x_month, des_mask):
        """ori_dist_feat [batch_size, n_neigh, n_feat]"""
        """ori_dist_adj [batch_size, n_neigh]"""
        """des_dist_feat [batch_size, n_global_sta, n_neigh, n_feat]"""
        """des_dist_adj [batch_size, n_global_sta, n_neigh]"""
        """od_feat [batch_size, n_global_sta, n_feat]"""
        """od_dist_feat [batch_size, n_global_sta, (n_neigh+1)*(n_neigh+1)-1, n_feat]"""
        """od_dist_adj [batch_size, n_global_sta, (n_neigh+1)*(n_neigh+1)-1]"""
        """x_month [batch_size]"""
        batch_size, n_global_station, n_neigh, n_station_feat = des_dist_feat.shape
        """load od features"""
        x_od_target, x_od_dist = self.process_od_features(ori_feat, ori_dist_feat, des_feat, des_dist_feat, od_feat, od_dist_feat)
        if self.builtEnv:
          _, x_od_builtEnv = self.process_od_features(ori_feat, ori_builtEnv_feat, des_feat, des_builtEnv_feat, od_feat, od_builtEnv_feat)
        x_month = self.month_embedding(x_month)
        x_od_feat_ls = [x_od_target, x_month.unsqueeze(1).repeat(1, n_global_station, 1)]
        """origin spatial fature encoding"""
        if self.node_graph:
          x_ori_graph_dist = self.ori_gcn_dist(ori_feat, ori_dist_feat, ori_dist_adj) # batch_size, c
          x_od_feat_ls.append(x_ori_graph_dist.unsqueeze(1).repeat(1, n_global_station, 1))
          if self.builtEnv:
            x_ori_graph_builtEnv = self.ori_gcn_builtEnv(ori_feat, ori_dist_feat, ori_dist_adj) # batch_size, c
            x_od_feat_ls.append(x_ori_graph_builtEnv.unsqueeze(1).repeat(1, n_global_station, 1))
        """destination spatial feature encoding"""
        if self.node_graph:
          x_des_graph_dist = self.ori_gcn_dist(des_feat.reshape(batch_size * n_global_station, -1),
                            des_dist_feat.reshape(batch_size * n_global_station, n_neigh, -1),
                            des_dist_adj.reshape(batch_size * n_global_station, n_neigh)).reshape(\
                            batch_size, n_global_station, -1) # batch_size, c
          x_od_feat_ls.append(x_des_graph_dist)
          if self.builtEnv:
            x_des_graph_builtEnv = self.ori_gcn_builtEnv(des_feat.reshape(batch_size * n_global_station, -1),
                            des_builtEnv_feat.reshape(batch_size * n_global_station, n_neigh, -1),
                            des_builtEnv_adj.reshape(batch_size * n_global_station, n_neigh)).reshape(\
                            batch_size, n_global_station, -1) # batch_size, c
            x_od_feat_ls.append(x_des_graph_builtEnv)
        """od dist-based spatial feature encoding"""
        if self.od_graph:
          x_od_graph_dist = self.od_gcn_dist(x_od_target.reshape(batch_size * n_global_station, -1),
                           x_od_dist.reshape(batch_size * n_global_station, (n_neigh+1)*(n_neigh+1)-1, -1),
                           od_dist_adj.reshape(batch_size * n_global_station, (n_neigh+1)*(n_neigh+1)-1)).reshape(\
                           batch_size, n_global_station, -1) # batch_size, c
          x_od_feat_ls.append(x_od_graph_dist)
          if self.builtEnv:
            x_od_graph_builtEnv = self.od_gcn_builtEnv(x_od_target.reshape(batch_size * n_global_station, -1),
                           x_od_builtEnv.reshape(batch_size * n_global_station, (n_neigh+1)*(n_neigh+1)-1, -1),
                           od_builtEnv_adj.reshape(batch_size * n_global_station, (n_neigh+1)*(n_neigh+1)-1)).reshape(\
                           batch_size, n_global_station, -1) # batch_size, c
            x_od_feat_ls.append(x_od_graph_builtEnv)
        if self.setting == "daily_cnt":
            x_od_feat_ls.append(od_feat[:, :, -1:])
        """prediction"""
        x_od = torch.cat(x_od_feat_ls, -1)
        return self.pred_fc(x_od)

# Data Container

In [8]:
# data_dir=data_dir, knn_dir = knn_dir
class DataInput(object):
    def __init__(self, local2global_station_dir, month_od_opendays_dir, month_od_nonzero_dir,
                 feat_dir, od_feat_dir, month_feat_dir, month_od_feat_dir,
                 manhattan_dir, month_station_opendays_dir,
                 dist_knn_dir, dist_adj_dir,
                 builtEnv_knn_dir, builtEnv_adj_dir,
                 n_train_timestep=50, k=5,
                 setting="daily_cnt"):
        self.timestep_ls = self.load_timestep_ls()
        print("done load timestep_ls...")
        self.local2global_station_dir = local2global_station_dir
        self.month_od_opendays_dir = month_od_opendays_dir
        self.month_od_nonzero_dir = month_od_nonzero_dir
        self.feat_dir = feat_dir
        self.od_feat_dir = od_feat_dir
        self.month_feat_dir = month_feat_dir
        self.month_od_feat_dir = month_od_feat_dir
        self.manhattan_dir = manhattan_dir
        self.month_station_opendays_dir = month_station_opendays_dir
        self.dist_knn_dir = dist_knn_dir
        self.dist_adj_dir = dist_adj_dir
        self.builtEnv_knn_dir = builtEnv_knn_dir
        self.builtEnv_adj_dir = builtEnv_adj_dir
        self.n_train_timestep = n_train_timestep
        self.k = k
        self.setting = setting

        self.dataset = dict()
        self.dataset["local2global_station"], self.dataset["n_station"] = self.load_local2global_station()
        print("done load local2global_station...")
        self.dataset["local_exist_mask"] = self.load_local_exist_mask()
        print("done load exist_mask...")
        self.dataset["feature"] = self.load_feature()
        print("done load feature...")
        self.dataset["month_feature"] = self.load_month_feat_ls()
        print("done load month_feature...")
        self.dataset["od_feature"] = self.load_od_feature()
        print("done load od_feature...")
        self.dataset["month_od_feature"] = self.load_month_od_feat_ls()
        print("done load month_od_feature...")
        self.dataset["month_od_opendays"] = self.load_month_od_opendays_ls()
        print("done load month_od_opendays...")
        self.dataset["month_od_nonzero"] = self.load_month_od_nonzero_ls()
        print("done load month_od_nonzero...")
        self.dataset["timestep2month"] = self.load_month()
        self.dataset["month_dist_knn"] = self.load_dist_knn()
        self.dataset["dist_adj_diag1"] = self.load_dist_adj_diag1()
        self.dataset["month_builtEnv_knn"] = self.load_builtEnv_knn_ls()
        self.dataset["builtEnv_adj_diag1"] = self.load_builtEnv_adj_ls_diag1()
        self.min_outflow, self.max_outflow = self.compute_min_max_outflow(pre_min=0, pre_max=733.85)
        print("min_daily_outflow", self.min_outflow, "max_daily_outflow", self.max_outflow)
        self.min_od_cnt, self.max_od_cnt, self.min_od_density, self.max_od_density = self.compute_min_max_od()
        print("min_od_cnt", self.min_od_cnt, "max_od_cnt", self.max_od_cnt)
        print("min_od_density", self.min_od_density, "max_od_density", self.max_od_density)

    def load_timestep_ls(self):
        timestep_ls = []
        year_ls = [2013, 2014, 2015, 2016, 2017, 2018, 2019]
        for year in year_ls:
          if year == 2013:
            month_ls = ["07", "08", "09", "10", "11", "12"]
          else:
            month_ls = ["01", "02", "03", "04", "05", "06", "07", "08", "09", "10", "11", "12"]
          for month in month_ls:
            timestep_ls.append(f"{year}{month}")
        return timestep_ls

    def load_local2global_station(self):
        with open(self.local2global_station_dir, 'rb') as handle:
          local2global_station_ls = pkl.load(handle)
        local2global_station_ls = [torch.from_numpy(local2global_station).long().to(device) for local2global_station in local2global_station_ls]
        n_station_ls = [len(local2global_station) for local2global_station in local2global_station_ls]
        return local2global_station_ls, torch.LongTensor(n_station_ls).to(device)

    def load_local_exist_mask(self):
        train_station = set()
        local_exist_mask = list()
        for i in range(self.n_train_timestep):
          train_station = train_station.union(set(self.dataset["local2global_station"][i].tolist()))
          local_exist_mask.append(0)
        train_station = torch.LongTensor(list(train_station)).to(device)
        for local2global_station in self.dataset["local2global_station"][self.n_train_timestep:]:
          exist_mask = torch.isin(local2global_station, train_station) * 1
          local_exist_mask.append(exist_mask.long().to(device))
        return local_exist_mask

    def load_feature(self):
        feat = pd.read_csv(self.feat_dir)
        feat.fillna(0, inplace=True)
        feat.drop(columns=["station size", "size found"], inplace=True)
        feat = feat.values
        feat_min, feat_max = feat.min(axis=0), feat.max(axis=0)
        feat = (feat - feat_min) / (feat_max - feat_min)
        feat = 2 * feat - 1
        print("feat shape{}".format(feat.shape))
        return torch.from_numpy(feat).float().to(device)

    def load_month_feat_ls(self):
        with open(self.month_feat_dir, 'rb') as handle:
          month_feat_ls = pkl.load(handle)
        month_feat_ls = [torch.from_numpy(month_feat).float().to(device) for month_feat in month_feat_ls]
        return month_feat_ls

    def load_od_feature(self):
        """od features (travel_dist, geo_dist, bearing)"""
        od_feat = np.load(self.od_feat_dir) # [n_station, n_station, 3]travel_dist, geo_dist, bearing
        dist = od_feat[:, :, 0]
        np.fill_diagonal(dist, 0)
        od_feat[:, :, 0] = dist
        """dist category features"""
        dist_500 = (dist <= 500) * 1
        dist_3000 = (dist > 500) * (dist <= 3000)
        dist_5000 = (dist > 3000) * (dist <= 5000)
        dist_inf = (dist > 5000) * 1
        dist_cat = np.stack([dist_500, dist_3000, dist_5000, dist_inf], -1) # n_s, n_s, 4
        """mangattan features"""
        manhattan = np.load(self.manhattan_dir)
        manhattan_binary = np.zeros((manhattan.shape[0], manhattan.shape[1], 4)) # n_s, n_s, 4
        n_station = manhattan_binary.shape[0]
        ori_idx = np.stack([np.arange(n_station)] * n_station, 1)
        des_idx = np.stack([np.arange(n_station)] * n_station, 0)
        manhattan_binary[ori_idx.reshape(-1), des_idx.reshape(-1), manhattan.reshape(-1)] = 1
        od_feat = np.concatenate([od_feat, dist_cat, manhattan_binary], -1)
        od_min, od_max = np.min(od_feat.min(axis=0), axis=0), np.max(od_feat.max(axis=0), axis=0)
        print(np.stack([od_min, od_max]))
        od_feat = (od_feat - od_min) / (od_max - od_min)
        od_feat = 2 * od_feat - 1
        print("od feat shape{}".format(od_feat.shape))
        return torch.from_numpy(od_feat).float().to(device)

    def load_month_od_feat_ls(self):
        new_month_od_feat_ls = []
        with open(self.month_od_feat_dir, 'rb') as handle:
          month_od_feat_ls = pkl.load(handle)
        month_od_feat_ls = [torch.from_numpy(month_od_feat).float().to(device) for month_od_feat in month_od_feat_ls]
        return month_od_feat_ls

    def load_month_station_opendays_ls(self):
        with open(self.month_station_opendays_dir, 'rb') as handle:
          month_station_opendays_ls = pkl.load(handle)
        month_station_opendays_ls = [torch.from_numpy(month_station_opendays).long().to(device) for \
                                     month_station_opendays in month_station_opendays_ls]
        print(len(month_station_opendays_ls))
        print(month_station_opendays_ls[0].shape)
        return month_station_opendays_ls

    def load_month_od_opendays_ls(self):
        with open(self.month_od_opendays_dir, 'rb') as handle:
          month_od_opendays_ls = pkl.load(handle) # 2, n_station, n_station
        if "daily" in self.setting:
          month_od_opendays_ls = [torch.from_numpy(np.sum(month_od_opendays, 0)).long().to(device) for month_od_opendays in month_od_opendays_ls]
        else:
          month_od_opendays_ls = [torch.from_numpy(month_od_opendays).long().to(device) for month_od_opendays in month_od_opendays_ls]
        return month_od_opendays_ls

    def load_month_od_nonzero_ls(self):
        with open(self.month_od_nonzero_dir, 'rb') as handle:
          month_od_nonzero_ls = pkl.load(handle)
        if "daily" in self.setting:
          new_month_od_nonzero_ls = []
          for i, month_od_nonzero in enumerate(month_od_nonzero_ls):
            month_od_df = pd.DataFrame(data=month_od_nonzero, columns=["weekday", "hour", "start_sta", "end_sta", "cnt"])
            month_od_df = month_od_df.groupby(["start_sta", "end_sta"])["cnt"].sum()
            month_od_df.to_csv("../data/temp.csv")
            month_od_df = pd.read_csv("../data/temp.csv")
            month_od_df.columns=["start_sta", "end_sta", "cnt"]
            new_month_od_nonzero_ls.append(torch.from_numpy(month_od_df.values).long().to(device))
        else: # We focus on 8AM-8PM
          new_month_od_nonzero_ls = []
          for i, month_od_nonzero in enumerate(month_od_nonzero_ls):
            month_od_df = pd.DataFrame(data=month_od_nonzero, columns=["weekday", "hour", "start_sta", "end_sta", "cnt"])
            # remove results from 0-6h
            month_od_df = month_od_df.loc[(month_od_df["hour"] >= 8) & (month_od_df["hour"] < 20)].copy()
            month_od_df["hour"] = month_od_df["hour"] - 8
            month_od_df[self.setting] = np.floor(month_od_df["hour"] / int(self.setting[0])).astype(int)
            month_od_df = month_od_df.groupby(["weekday", self.setting, "start_sta", "end_sta"])["cnt"].sum()
            month_od_df.to_csv("../data/temp.csv")
            month_od_df = pd.read_csv("../data/temp.csv")
            month_od_df.columns=["weekday", self.setting, "start_sta", "end_sta", "cnt"]
            new_month_od_nonzero_ls.append(torch.from_numpy(month_od_df.values).long().to(device))
        del month_od_nonzero_ls
        return new_month_od_nonzero_ls

    def load_month(self):
        time_ls = [6, 7, 8, 9, 10, 11]
        for i in range(6):
            time_ls = time_ls + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
        time_arr = np.array(time_ls, dtype=np.int32)
        print("month shape{}".format(time_arr.shape))
        return torch.from_numpy(time_arr).long().to(device)

    def load_dist_knn(self):
        dist_knn = np.load(self.dist_knn_dir)[:, :, :self.k] # n_month, n_station, k1
        local_dist_knn_ls = []
        for i, local2global_station in enumerate(self.dataset["local2global_station"]):
          local2global_station = local2global_station.detach().cpu().numpy()
          local_dist_knn = dist_knn[i, local2global_station] # n_local_station, k
          global2local_station = np.zeros(1101, dtype=np.int32)
          global2local_station[local2global_station] = np.arange(len(local2global_station))
          local_dist_knn = global2local_station[local_dist_knn.reshape(-1)].reshape(-1, self.k)
          local_dist_knn_ls.append(torch.from_numpy(local_dist_knn).long().to(device))
        del dist_knn
        return local_dist_knn_ls

    def load_dist_adj_diag1(self):
        adj = np.load(self.dist_adj_dir)
        np.fill_diagonal(adj, 1)
        return torch.from_numpy(adj).float().to(device)

    def load_builtEnv_knn_ls(self):
        builtEnv_knn = np.load(self.builtEnv_knn_dir)[:, :, :self.k] # n_month, n_station, k1
        local_builtEnv_knn_ls = []
        for i, local2global_station in enumerate(self.dataset["local2global_station"]):
          local2global_station = local2global_station.detach().cpu().numpy()
          local_builtEnv_knn = builtEnv_knn[i, local2global_station] # n_local_station, k
          global2local_station = np.zeros(1101, dtype=np.int32)
          global2local_station[local2global_station] = np.arange(len(local2global_station))
          local_builtEnv_knn = global2local_station[local_builtEnv_knn.reshape(-1)].reshape(-1, self.k)
          local_builtEnv_knn_ls.append(torch.from_numpy(local_builtEnv_knn).long().to(device))
        del builtEnv_knn
        return local_builtEnv_knn_ls

    def load_builtEnv_adj_ls_diag1(self):
        adj = np.load(self.builtEnv_adj_dir)
        local_builtEnv_adj_ls = []
        for i, local2global_station in enumerate(self.dataset["local2global_station"]):
          local2global_station = local2global_station.detach().cpu().numpy()
          local_builtEnv_adj = adj[i][local2global_station][:, local2global_station]
          np.fill_diagonal(local_builtEnv_adj, 1)
          local_builtEnv_adj_ls.append(torch.from_numpy(local_builtEnv_adj).float().to(device))
        del adj
        return local_builtEnv_adj_ls

    def compute_min_max_outflow(self, pre_min=None, pre_max=None):
        if pre_min is not None:
          return pre_min, pre_max
        min_outflow, max_outflow = 0, 0
        for month_od_nonzero, month_station_opendays in zip(self.dataset["month_od_nonzero"], self.dataset["month_station_opendays"]):
          month_od = np.zeros((month_station_opendays.shape[1], month_station_opendays.shape[1]))
          month_od[month_od_nonzero[:, 0], month_od_nonzero[:, 1], month_od_nonzero[:, 2], month_od_nonzero[:, 3]] = month_od_nonzero[:, 4]
          month_outflow = np.sum(np.sum(month_od, 1), -1)
          month_outflow = month_outflow.astype(float) / (month_station_opendays + 1e-32)
          if np.min(month_outflow) < min_outflow:
            min_outflow = np.min(month_outflow)
          if np.max(month_outflow) > max_outflow:
            max_outflow = np.max(month_outflow)
        return min_outflow, max_outflow

    def compute_min_max_od(self, pre_min=None, pre_max=None):
        if pre_min is not None:
          return pre_min, pre_max
        min_od_density, max_od_density, min_od_cnt, max_od_cnt = 0, 0, 0, 0
        for month_od_nonzero, month_od_opendays in zip(self.dataset["month_od_nonzero"][self.n_train_timestep:],
                                                       self.dataset["month_od_opendays"][self.n_train_timestep:]):
          month_od_nonzero = month_od_nonzero.detach().cpu().numpy()
          month_od_opendays = month_od_opendays.detach().cpu().numpy()

          month_od = np.zeros((month_od_opendays.shape[1], month_od_opendays.shape[1]))
          month_od[month_od_nonzero[:, 0], month_od_nonzero[:, 1]] = month_od_nonzero[:, 2]
          if np.min(month_od) < min_od_cnt:
            min_od_cnt = np.min(month_od)
          if np.max(month_od) > max_od_cnt:
            max_od_cnt = np.max(month_od)

          month_od = month_od / (month_od_opendays + 1e-32)
          if np.min(month_od) < min_od_density:
            min_od_density = np.min(month_od)
          if np.max(month_od) > max_od_density:
            max_od_density = np.max(month_od)
        return min_od_cnt, max_od_cnt, min_od_density, max_od_density

    def minmax_normalize(self, x, min, max):
        x = (x - min) / (max - min)
        x = 2 * x - 1
        return x

    def minmax_denormalize(self, x, min, max):
        x = (x + 1) / 2
        x = (max - min) * x + min
        return x

class DataGenerator(object):
    def __init__(self, data_class):
        self.data_class = data_class

    def get_data_loader(self, batch_size: int):
        data_loader = dict()
        for mode in ['train', 'test']:
            samples = PrepareSample(device=device, n_station_ls=self.data_class.dataset["n_station"], n_train_timestep=self.data_class.n_train_timestep,
                                    mode=mode)
            if mode == 'train':
                data_loader['train'] = DataLoader(dataset=PrepareDataset(device=device, inputs=samples.inputs, output=samples.output), \
                                               batch_size=batch_size, shuffle=True)
                data_loader['valid'] = DataLoader(dataset=PrepareDataset(device=device, inputs=samples.val_inputs, output=samples.val_output), \
                                               batch_size=batch_size, shuffle=False)
            else:
                data_loader['test'] = DataLoader(dataset=PrepareDataset(device=device, inputs=samples.inputs, output=samples.output), \
                                               batch_size=batch_size, shuffle=False)
        return data_loader

class PrepareSample(object):
    def __init__(self, device: str, n_station_ls, n_train_timestep, mode):
          self.device = device
          self.n_station_ls = n_station_ls
          self.n_train_timestep = n_train_timestep
          self.mode = mode
          self.inputs, self.output, self.val_inputs, self.val_output = None, None, None, None
          self.prepare_xy()

    def prepare_xy(self):
        """origin feat"""
        # 2, 18, n_station
        timestep, local_station = [], []
        if self.mode == "train":
          min_snapshot, max_snapshot = 0, self.n_train_timestep
        else:
          min_snapshot, max_snapshot = self.n_train_timestep, len(self.n_station_ls)
        for i in range(min_snapshot, max_snapshot):
          timestep.append(np.array([i] * self.n_station_ls[i].item()))
          local_station.append(np.arange(self.n_station_ls[i].item()))
        timestep = np.concatenate(timestep, 0)
        local_station = np.concatenate(local_station, 0)
        if self.mode == "train":
          val_idx = np.load("../data/val_idx.npy")
          train_idx = np.ones_like(timestep)
          train_idx[val_idx] = 0
          train_idx = np.where(train_idx == 1)[0]
          print("val_idx", len(val_idx), "train_idx", len(train_idx))
          for mode in ["train", "valid"]:
              idx = val_idx if mode == 'valid' else train_idx
              x_timestep = timestep[idx]
              x_local_station = local_station[idx]
              x, y = dict(), dict()
              x['timestep'] = torch.from_numpy(x_timestep).long().to(device)
              x['local_station'] = torch.from_numpy(x_local_station).long().to(device)
              print(f"==================={mode}======================")
              print('timestep', x['timestep'].shape)
              # print('day', x['day'].shape)
              print('local_station', x['local_station'].shape)
              if mode == "train":
                  self.inputs, self.output = x, y
              else:
                  self.val_inputs, self.val_output = x, y
        else:
          x, y = dict(), dict()
          x['timestep'] = torch.from_numpy(timestep).long().to(device)
          x['local_station'] = torch.from_numpy(local_station).long().to(device)
          print(f"===================test======================")
          print('timestep', x['timestep'].shape)
          print('local_station', x['local_station'].shape)
          self.inputs, self.output = x, y


class PrepareDataset(Dataset):
    def __init__(self, device: str, inputs: dict, output: dict):
        self.device = device
        self.inputs, self.output = inputs, output
        print("timestep", torch.is_tensor(self.inputs['timestep']))
        print("local_station", torch.is_tensor(self.inputs['local_station']))

    def __len__(self):
        return self.inputs['timestep'].shape[0]

    def __getitem__(self, item):
        return self.inputs['timestep'][item], self.inputs['local_station'][item]

# Model Trainer

In [17]:
import random

class ModelTrainer(object):
    def __init__(self, hour_prob_model_name: str, hour_prob_model: nn.Module, hour_prob_epochs: int,
              optimizer, lr:float, wd:float, dataset, data_class, alpha=0.01):
        self.hour_prob_model_name = hour_prob_model_name
        self.hour_prob_model = hour_prob_model
        self.hour_prob_optimizer = optimizer(params=self.hour_prob_model.parameters(), lr=lr, weight_decay=wd)
        self.alpha = alpha
        self.mse_loss = nn.MSELoss(reduction="sum")
        self.hour_prob_epochs = hour_prob_epochs
        self.data_class = data_class
        self.dataset = dataset # n_station, n_feat1
        self.setting = self.hour_prob_model.setting
        self.dist = self.hour_prob_model.dist

    def prepare_graph_daily_density_batch_data(self, timestep_idx, local_station_idx, dim_station_feat=44, dim_od_feat=14):
        batch_size, n_max_station = timestep_idx.shape[0], torch.max(self.dataset["n_station"][timestep_idx])
        n_neigh = self.data_class.k
        """ori_feature"""
        batch_ori_feature = torch.zeros((batch_size, dim_station_feat)).float().to(device)
        batch_ori_dist_feature = torch.zeros((batch_size, n_neigh, dim_station_feat)).float().to(device)
        batch_ori_dist_adj = torch.zeros((batch_size, n_neigh)).float().to(device)
        batch_ori_builtEnv_feature = torch.zeros((batch_size, n_neigh, dim_station_feat)).float().to(device)
        batch_ori_builtEnv_adj = torch.zeros((batch_size, n_neigh)).float().to(device)
        """des_feature"""
        batch_des_feature = torch.zeros((batch_size, n_max_station, dim_station_feat)).float().to(device)
        batch_des_dist_feature = torch.zeros((batch_size, n_max_station, n_neigh, dim_station_feat)).float().to(device)
        batch_des_dist_adj = torch.zeros((batch_size, n_max_station, n_neigh)).float().to(device)
        batch_des_builtEnv_feature = torch.zeros((batch_size, n_max_station, n_neigh, dim_station_feat)).float().to(device)
        batch_des_builtEnv_adj = torch.zeros((batch_size, n_max_station, n_neigh)).float().to(device)
        """od_feature"""
        batch_od_feature = torch.zeros((batch_size, n_max_station, dim_od_feat-1)).float().to(device)
        batch_od_opendays = torch.zeros((batch_size, n_max_station)).long().to(device)
        """od_dist_feature"""
        batch_od_dist_feature = torch.zeros((batch_size, n_max_station, (n_neigh+1)*(n_neigh+1)-1, dim_od_feat-1)).float().to(device)
        # batch_od_dist_opendays = torch.zeros((batch_size, n_max_station, (n_neigh+1)*(n_neigh+1)-1)).float().to(device)
        batch_od_dist_adj = torch.zeros((batch_size, n_max_station, (n_neigh+1)*(n_neigh+1)-1)).to(device)
        """od_builtEnv_feature"""
        batch_od_builtEnv_feature = torch.zeros((batch_size, n_max_station, (n_neigh+1)*(n_neigh+1)-1, dim_od_feat-1)).float().to(device)
        # batch_od_builtEnv_opendays = torch.zeros((batch_size, n_max_station, (n_neigh+1)*(n_neigh+1)-1)).float().to(device)
        batch_od_builtEnv_adj = torch.zeros((batch_size, n_max_station, (n_neigh+1)*(n_neigh+1)-1)).to(device)
        # batch_prob_des_mask = torch.zeros((batch_size, 16, n_max_station)).long().to(device)
        batch_des_mask = torch.zeros((batch_size, n_max_station)).long().to(device)
        """outflow_inflow_feature"""
        batch_outflow = torch.zeros(batch_size).float().to(device)
        batch_inflow = torch.zeros((batch_size, n_max_station)).float().to(device)
        batch_od = torch.zeros((batch_size, n_max_station)).float().to(device)
        for i, (timestep, local_station) in enumerate(zip(timestep_idx.tolist(), local_station_idx.tolist())):
          global_station = self.dataset["local2global_station"][timestep]
          """ori_feature"""
          x_ori_feature = self.dataset["feature"][global_station[local_station]] # n_feat
          x_ori_month_feature = self.dataset["month_feature"][timestep][local_station] # n_month_feat
          batch_ori_feature[i] = torch.cat([x_ori_feature, x_ori_month_feature], -1) # batch_size, n_feat
          """ori_dist_neigh"""
          ori_dist_knn = self.dataset["month_dist_knn"][timestep][local_station]
          x_ori_dist_feature = self.dataset["feature"][global_station[ori_dist_knn]]
          x_ori_dist_month_feature = self.dataset["month_feature"][timestep][ori_dist_knn]
          batch_ori_dist_feature[i] = torch.cat([x_ori_dist_feature, x_ori_dist_month_feature], -1)
          batch_ori_dist_adj[i] = self.dataset["dist_adj_diag1"][global_station[local_station], global_station[ori_dist_knn]]
          """ori_builtEnv_neigh"""
          ori_builtEnv_knn = self.dataset["month_builtEnv_knn"][timestep][local_station]
          x_ori_builtEnv_feature = self.dataset["feature"][global_station[ori_builtEnv_knn]]
          x_ori_builtEnv_month_feature = self.dataset["month_feature"][timestep][ori_builtEnv_knn]
          batch_ori_builtEnv_feature[i] = torch.cat([x_ori_builtEnv_feature, x_ori_builtEnv_month_feature], -1)
          batch_ori_builtEnv_adj[i] = self.dataset["builtEnv_adj_diag1"][timestep][local_station, ori_dist_knn]
          """des_feature"""
          x_des_feature = self.dataset["feature"][global_station] # n_local_station, n_feat
          x_des_month_feature = self.dataset["month_feature"][timestep] # n_local_station, n_month_feat
          batch_des_feature[i, :len(global_station)] = torch.cat([x_des_feature, x_des_month_feature], -1)
          """des_dist_neigh"""
          des_dist_knn = self.dataset["month_dist_knn"][timestep] # n_local_station, k
          x_des_dist_feature = self.dataset["feature"][global_station[des_dist_knn].reshape(-1)].reshape(len(global_station), n_neigh, -1)
          x_des_dist_month_feature = self.dataset["month_feature"][timestep][des_dist_knn.reshape(-1)].reshape(len(global_station), n_neigh, -1)
          batch_des_dist_feature[i, :len(global_station)] = torch.cat([x_des_dist_feature, x_des_dist_month_feature], -1)
          batch_des_dist_adj[i, :len(global_station)] = self.dataset["dist_adj_diag1"][global_station.unsqueeze(-1).repeat(1, n_neigh).reshape(-1), \
                                        global_station[des_dist_knn.reshape(-1)]].reshape(len(global_station), n_neigh)
          """des_builtEnv_neigh"""
          station_arange_idx = torch.arange(len(global_station)).to(device)
          des_builtEnv_knn = self.dataset["month_builtEnv_knn"][timestep] # n_local_station, k
          x_des_builtEnv_feature = self.dataset["feature"][global_station[des_builtEnv_knn].reshape(-1)].reshape(len(global_station), n_neigh, -1)
          x_des_builtEnv_month_feature = self.dataset["month_feature"][timestep][des_builtEnv_knn.reshape(-1)].reshape(len(global_station), n_neigh, -1)
          batch_des_builtEnv_feature[i, :len(global_station)] = torch.cat([x_des_builtEnv_feature, x_des_builtEnv_month_feature], -1)
          batch_des_builtEnv_adj[i, :len(global_station)] = self.dataset["builtEnv_adj_diag1"][timestep][station_arange_idx.unsqueeze(-1).repeat(1, n_neigh).reshape(-1), \
                                        des_builtEnv_knn.reshape(-1)].reshape(len(global_station), n_neigh)
          """od_dist_neigh"""
          ori_target_neigh_knn = torch.LongTensor([local_station] + ori_dist_knn.tolist()).to(device) # n_neigh + 1
          ori_target_neigh_knn = ori_target_neigh_knn.unsqueeze(0).repeat(len(global_station), 1) # n_station, n_neigh + 1
          des_target_neigh_knn = torch.cat([station_arange_idx.unsqueeze(-1), des_dist_knn], -1) # n_station, n_neigh + 1
          ori_target_neigh_knn = ori_target_neigh_knn.unsqueeze(-1).repeat(1, 1, n_neigh+1)
          des_target_neigh_knn = des_target_neigh_knn.unsqueeze(-2).repeat(1, n_neigh+1, 1)
          x_od_feature = self.dataset["od_feature"][global_station[ori_target_neigh_knn.reshape(-1)], \
                                    global_station[des_target_neigh_knn.reshape(-1)]]
          x_od_feature = x_od_feature.reshape(len(global_station), (n_neigh+1)*(n_neigh+1), -1) # n_local_station, n_od_feat
          x_od_month_feature = self.dataset["month_od_feature"][timestep][ori_target_neigh_knn.reshape(-1), des_target_neigh_knn.reshape(-1)] # n_local_station, n_od_month_feat
          x_od_month_feature = x_od_month_feature.reshape(len(global_station), (n_neigh+1)*(n_neigh+1), -1)
          x_od_feature = torch.cat([x_od_feature, x_od_month_feature], -1)
          batch_od_feature[i, :len(global_station)] = x_od_feature[:, 0]
          batch_od_dist_feature[i, :len(global_station)] = x_od_feature[:, 1:]
          ori_target_neigh_adj = self.dataset["dist_adj_diag1"][global_station[local_station], global_station[ori_target_neigh_knn.reshape(-1)]]
          des_target_neigh_adj = self.dataset["dist_adj_diag1"][global_station[station_arange_idx.unsqueeze(-1).unsqueeze(-1).repeat(1, n_neigh+1,n_neigh+1).reshape(-1)],\
                                      global_station[des_target_neigh_knn.reshape(-1)]]
          od_target_neigh_adj = (ori_target_neigh_adj + des_target_neigh_adj) / 2
          batch_od_dist_adj[i, :len(global_station)] = od_target_neigh_adj.reshape(len(global_station), (n_neigh+1)*(n_neigh+1))[:, 1:]
          """od_builtEnv_neigh"""
          ori_target_neigh_knn = torch.LongTensor([local_station] + ori_builtEnv_knn.tolist()).to(device) # n_neigh + 1
          ori_target_neigh_knn = ori_target_neigh_knn.unsqueeze(0).repeat(len(global_station), 1) # n_station, n_neigh + 1
          des_target_neigh_knn = torch.cat([station_arange_idx.unsqueeze(-1), des_builtEnv_knn], -1) # n_station, n_neigh + 1
          ori_target_neigh_knn = ori_target_neigh_knn.unsqueeze(-1).repeat(1, 1, n_neigh+1)
          des_target_neigh_knn = des_target_neigh_knn.unsqueeze(-2).repeat(1, n_neigh+1, 1)
          x_od_feature = self.dataset["od_feature"][global_station[ori_target_neigh_knn.reshape(-1)], \
                                    global_station[des_target_neigh_knn.reshape(-1)]]
          x_od_feature = x_od_feature.reshape(len(global_station), (n_neigh+1)*(n_neigh+1), -1) # n_local_station, n_od_feat
          x_od_month_feature = self.dataset["month_od_feature"][timestep][ori_target_neigh_knn.reshape(-1), des_target_neigh_knn.reshape(-1)] # n_local_station, n_od_month_feat
          x_od_month_feature = x_od_month_feature.reshape(len(global_station), (n_neigh+1)*(n_neigh+1), -1)
          x_od_feature = torch.cat([x_od_feature, x_od_month_feature], -1)
          batch_od_builtEnv_feature[i, :len(global_station)] = x_od_feature[:, 1:]
          ori_target_neigh_adj = self.dataset["builtEnv_adj_diag1"][timestep][local_station, ori_target_neigh_knn.reshape(-1)]
          des_target_neigh_adj = self.dataset["builtEnv_adj_diag1"][timestep][station_arange_idx.unsqueeze(-1).unsqueeze(-1).repeat(1, n_neigh+1,n_neigh+1).reshape(-1),\
                                      des_target_neigh_knn.reshape(-1)]
          od_target_neigh_adj = (ori_target_neigh_adj + des_target_neigh_adj) / 2
          batch_od_builtEnv_adj[i, :len(global_station)] = od_target_neigh_adj.reshape(len(global_station), (n_neigh+1)*(n_neigh+1))[:, 1:]

          """y"""
          y_od_nonzero = self.dataset["month_od_nonzero"][timestep]
          y_od = torch.zeros((len(global_station), len(global_station))).long().to(device)
          y_od[y_od_nonzero[:, 0], y_od_nonzero[:, 1]] = y_od_nonzero[:, 2]
          y_od_opendays = self.dataset["month_od_opendays"][timestep]
          batch_outflow[i] = torch.sum(y_od[local_station])
          batch_inflow[i, :len(global_station)] = torch.sum(y_od, 0)
          batch_od[i, :len(global_station)] = y_od[local_station]
          batch_od_opendays[i, :len(global_station)] = y_od_opendays[local_station]
          batch_des_mask[i, :len(global_station)] = 1
        batch_month = self.dataset["timestep2month"][timestep_idx]
        return batch_ori_feature, batch_ori_dist_feature, batch_ori_dist_adj, batch_ori_builtEnv_feature, batch_ori_builtEnv_adj, \
            batch_des_feature, batch_des_dist_feature, batch_des_dist_adj, batch_des_builtEnv_feature, batch_des_builtEnv_adj, \
            batch_od_feature, batch_od_dist_feature, batch_od_dist_adj, batch_od_builtEnv_feature, batch_od_builtEnv_adj, \
            batch_od, batch_od_opendays, batch_outflow, batch_inflow, \
            batch_des_mask, batch_month

    def prepare_ori_evaluation_mask(self, timestep_idx, day_idx, local_station_idx):
        exist_mask = [self.dataset["local_exist_mask"][timestep][local_station] for timestep, local_station in \
                      zip(timestep_idx.tolist(), local_station_idx.tolist())]
        return torch.LongTensor(exist_mask).to(device)

    def prepare_od_evaluation_mask(self, timestep_idx, local_station_idx):
        batch_size, n_max_station = timestep_idx.shape[0], torch.max(self.dataset["n_station"][timestep_idx])
        batch_exist_exist_mask = -torch.ones((batch_size, n_max_station)).long().to(device)
        batch_exist_add_mask = -torch.ones((batch_size, n_max_station)).long().to(device)
        batch_add_add_mask = -torch.ones((batch_size, n_max_station)).long().to(device)
        for i, (timestep, local_station) in enumerate(zip(timestep_idx.tolist(), local_station_idx.tolist())):
          local_exist_mask = self.dataset["local_exist_mask"][timestep]
          exist_ori = local_exist_mask[local_station]
          batch_exist_exist_mask[i, :len(local_exist_mask)] = exist_ori * local_exist_mask
          batch_exist_add_mask[i, :len(local_exist_mask)] = ((exist_ori * (1 - local_exist_mask) + (1 - exist_ori) * local_exist_mask) > 0) * 1
          batch_add_add_mask[i, :len(local_exist_mask)] = (1 - exist_ori) * (1 - local_exist_mask)
        return batch_exist_exist_mask, batch_exist_add_mask, batch_add_add_mask


    def train_hour_od(self, data_processor:dict, modes:list, model_dir:str, early_stopper=10, n_sample=100, mask_zero_rate=0.5,
                        weight=0, reg=0):
        checkpoint = {'epoch':0, 'state_dict':self.hour_prob_model.state_dict()}
        val_loss = np.inf
        outflow_loss = np.inf
        inflow_loss = np.inf
        start_time = datetime.datetime.now()
        best_result = {"All": {"CPC": 0.}}
        for epoch in range(1, self.hour_prob_epochs+1):
            running_loss = {mode:0.0 for mode in modes}
            outflow_running_loss = {mode:0.0 for mode in modes}
            inflow_running_loss = {mode:0.0 for mode in modes}
            for mode in modes:
                if mode == 'train':
                    self.hour_prob_model.train()
                else:
                    self.hour_prob_model.eval()
                step = 0
                curr_sample = 0
                for timestep, local_station in data_processor[mode]:
                    ori_feat, ori_dist_feat, ori_dist_adj, ori_builtEnv_feat, ori_builtEnv_adj, \
                          des_feat, des_dist_feat, des_dist_adj, des_builtEnv_feat, des_builtEnv_adj, \
                          od_feat, od_dist_feat, od_dist_adj, od_builtEnv_feat, od_builtEnv_adj, \
                          y_od, y_od_opendays, y_outflow, y_inflow, des_mask, x_month = \
                          self.prepare_graph_daily_density_batch_data(timestep, local_station)
                    des_mask = (y_od_opendays > 0) * 1
                    if self.setting == "daily_density":
                        y_od = y_od / (y_od_opendays + 1e-32)
                    else:
                        od_feat = torch.cat([od_feat, (y_od_opendays / 31.).unsqueeze(-1)], -1)
                    with torch.set_grad_enabled(mode = mode=='train'):
                        if self.dist =="zinb":
                            n_train, p_train, pi_train = \
                                self.hour_prob_model(ori_feat, ori_dist_feat, ori_dist_adj, ori_builtEnv_feat, ori_builtEnv_adj, \
                                            des_feat, des_dist_feat, des_dist_adj, des_builtEnv_feat, des_builtEnv_adj, \
                                            od_feat, od_dist_feat, od_dist_adj, od_builtEnv_feat, od_builtEnv_adj, \
                                            x_month, des_mask)
                            loss = nb_zeroinflated_nll_loss(y_od, n_train, p_train, pi_train, des_mask, weight=weight)
                        elif self.dist == "nb":
                            n_train, p_train = \
                                self.hour_prob_model(ori_feat, ori_dist_feat, ori_dist_adj, ori_builtEnv_feat, ori_builtEnv_adj, \
                                            des_feat, des_dist_feat, des_dist_adj, des_builtEnv_feat, des_builtEnv_adj, \
                                            od_feat, od_dist_feat, od_dist_adj, od_builtEnv_feat, od_builtEnv_adj, \
                                            x_month, des_mask)
                            loss = nb_nll_loss(y_od, n_train, p_train, des_mask)
                        else:
                            y_pred = self.hour_prob_model(ori_feat, ori_dist_feat, ori_dist_adj, ori_builtEnv_feat, ori_builtEnv_adj, \
                                            des_feat, des_dist_feat, des_dist_adj, des_builtEnv_feat, des_builtEnv_adj, \
                                            od_feat, od_dist_feat, od_dist_adj, od_builtEnv_feat, od_builtEnv_adj, \
                                            x_month, des_mask)
                            loss = self.mse_loss(y_pred[des_mask > 0], y_od[des_mask > 0])
                        if mode == 'train':
                            self.hour_prob_optimizer.zero_grad()
                            loss.backward()
                            self.hour_prob_optimizer.step()
                    running_loss[mode] += loss
                    step += y_od.shape[0]
                    curr_sample += 1
                    if curr_sample == n_sample: break
                if mode == 'valid':
                    if running_loss[mode] / step <= val_loss:
                        print(f'Epoch {epoch}, Val_loss drops from {val_loss:.5} to {running_loss[mode] / step:.5}. '
                              f'Update model checkpoint..')
                        val_loss = running_loss[mode] / step
                        checkpoint.update(epoch=epoch, state_dict=self.hour_prob_model.state_dict())
                        torch.save(checkpoint, model_dir + f'/{self.hour_prob_model_name}_best_model.pkl')
                        early_stopper = 10
                    else:
                        print(f'Epoch {epoch}, Val_loss does not improve from {val_loss:.5}.')
                        early_stopper -= 1
                        if early_stopper == 0:
                            print(f'Early stopping at epoch {epoch}..')
                            return
                    saved_checkpoint_od = torch.load(model_dir + f'/{self.hour_prob_model_name}_best_model.pkl')
                    self.hour_prob_model.load_state_dict(saved_checkpoint_od['state_dict'])
            # if epoch % 1 == 0:
            #     result = self.test_hour_od(epoch=epoch, data_processor=data_processor, modes=['test'], model_dir=model_dir)
        print('training', datetime.datetime.now() - start_time)
        return

    def test_batch_od(self, ground_truth, prediction, n_test=None, p_test=None, pi_test=None):
        """exist_exist"""
        SE = self.batch_SE(prediction, ground_truth)
        AE = self.batch_AE(prediction, ground_truth)
        batch_CPC_up, batch_CPC_bottom = self.batch_CPC(prediction, ground_truth)
        CPC_up = batch_CPC_up
        CPC_bottom = batch_CPC_bottom
        if self.dist=="zinb":
            MPIW, PICP = self.batch_zinb_MPIW(n_test, p_test, pi_test, ground_truth)
        elif self.dist=="nb":
            MPIW, PICP = self.batch_nb_MPIW(n_test, p_test, ground_truth)
        else:
            MPIW, PICP = 0, 0
        n_sample = len(ground_truth)
        return [SE, AE, CPC_up, CPC_bottom, n_sample, MPIW, PICP]

    def test_hour_od(self, epoch, data_processor:dict, modes:list, model_dir:str):
        saved_checkpoint_od = torch.load(model_dir + f'/{self.hour_prob_model_name}_best_model.pkl', map_location=torch.device('cpu'))
        self.hour_prob_model.load_state_dict(saved_checkpoint_od['state_dict'])
        self.hour_prob_model.eval()
        running_loss = {mode: 0.0 for mode in modes}
        start_time = datetime.datetime.now()

        for mode in modes:
            metric_dict = {}
            flow_settings = ["all", "od=0", "od>0", "od>=1", "od>=5"]
            od_settings = ["ee", "ea", "aa"]
            metric_names = ["SE", "AE", "CPC_up", "CPC_bottom", "n_sample", "MPIW", "PICP"]
            for flow_setting in flow_settings:
              metric_dict[flow_setting] = dict()
              for od_setting in od_settings:
                metric_dict[flow_setting][od_setting] = dict()
                for metric_name in metric_names:
                  metric_dict[flow_setting][od_setting][metric_name] = 0
            for timestep, local_station in data_processor[mode]:
                ori_feat, ori_dist_feat, ori_dist_adj, ori_builtEnv_feat, ori_builtEnv_adj, \
                      des_feat, des_dist_feat, des_dist_adj, des_builtEnv_feat, des_builtEnv_adj, \
                      od_feat, od_dist_feat, od_dist_adj, od_builtEnv_feat, od_builtEnv_adj, \
                      y_true, y_od_opendays, y_outflow, y_inflow, des_mask, x_month = \
                      self.prepare_graph_daily_density_batch_data(timestep, local_station)
                des_mask = (y_od_opendays > 0) * 1
                if self.setting == "daily_density":
                    y_true = y_true / (y_od_opendays + 1e-32)
                else:
                    od_feat = torch.cat([od_feat, (y_od_opendays / 31.).unsqueeze(-1)], -1)

                if self.dist =="zinb":
                    n_test, p_test, pi_test = \
                        self.hour_prob_model(ori_feat, ori_dist_feat, ori_dist_adj, ori_builtEnv_feat, ori_builtEnv_adj, \
                                    des_feat, des_dist_feat, des_dist_adj, des_builtEnv_feat, des_builtEnv_adj, \
                                    od_feat, od_dist_feat, od_dist_adj, od_builtEnv_feat, od_builtEnv_adj, \
                                    x_month, des_mask)
                    mean_pred = (1 - pi_test)*(n_test/p_test - n_test)
                elif self.dist =="nb":
                    n_test, p_test = \
                        self.hour_prob_model(ori_feat, ori_dist_feat, ori_dist_adj, ori_builtEnv_feat, ori_builtEnv_adj, \
                                    des_feat, des_dist_feat, des_dist_adj, des_builtEnv_feat, des_builtEnv_adj, \
                                    od_feat, od_dist_feat, od_dist_adj, od_builtEnv_feat, od_builtEnv_adj, \
                                    x_month, des_mask)
                    mean_pred = n_test/p_test
                else:
                    y_pred = self.hour_prob_model(ori_feat, ori_dist_feat, ori_dist_adj, ori_builtEnv_feat, ori_builtEnv_adj, \
                                    des_feat, des_dist_feat, des_dist_adj, des_builtEnv_feat, des_builtEnv_adj, \
                                    od_feat, od_dist_feat, od_dist_adj, od_builtEnv_feat, od_builtEnv_adj, \
                                    x_month, des_mask)
                    y_pred[y_pred < 0] = 0
                    mean_pred = y_pred
                exist_exist_mask, exist_add_mask, add_add_mask = self.prepare_od_evaluation_mask(timestep, local_station)
                for od_setting in od_settings:
                  if od_setting == "ee":
                    tmp_mask = exist_exist_mask * des_mask
                  elif od_setting == "ea":
                    tmp_mask = exist_add_mask * des_mask
                  elif od_setting == "aa":
                    tmp_mask = add_add_mask * des_mask
                  else:
                    tmp_mask = des_mask
                  for flow_setting in flow_settings:
                    if flow_setting == "od=0":
                      mask = tmp_mask * (y_true==0)
                    elif flow_setting == "od>0":
                      mask = tmp_mask * (y_true>0)
                    elif flow_setting == "od>=1":
                      mask = tmp_mask * (y_true>=1)
                    elif flow_setting == "od>=5":
                      mask = tmp_mask * (y_true>=5)
                    else:
                      mask = tmp_mask

                    if self.dist == "zinb":
                      metrics = self.test_batch_od(y_true[mask > 0].detach().cpu().numpy().flatten(),
                                      mean_pred[mask > 0].detach().cpu().numpy().flatten(),
                                      n_test[mask > 0].detach().cpu().numpy().flatten(),
                                      p_test[mask > 0].detach().cpu().numpy().flatten(),
                                      pi_test[mask > 0].detach().cpu().numpy().flatten())
                    elif self.dist == "nb":
                      metrics = self.test_batch_od(y_true[mask > 0].detach().cpu().numpy().flatten(),
                                      mean_pred[mask > 0].detach().cpu().numpy().flatten(),
                                      n_test[mask > 0].detach().cpu().numpy().flatten(),
                                      p_test[mask > 0].detach().cpu().numpy().flatten())
                    else:
                      metrics = self.test_batch_od(y_true[mask > 0].detach().cpu().numpy().flatten(),
                                      mean_pred[mask > 0].detach().cpu().numpy().flatten())

                    for metric_name, metric in zip(metric_names, metrics):
                      metric_dict[flow_setting][od_setting][metric_name] += metric

            """overall performance"""
            for flow_setting in flow_settings:
              for od_setting in od_settings:
                metric_dict[flow_setting][od_setting]["RMSE"] = np.sqrt(metric_dict[flow_setting][od_setting]["SE"] / (metric_dict[flow_setting][od_setting]["n_sample"]+1e-32))
                metric_dict[flow_setting][od_setting]["MAE"] = metric_dict[flow_setting][od_setting]["AE"] / (metric_dict[flow_setting][od_setting]["n_sample"]+1e-32)
                metric_dict[flow_setting][od_setting]["CPC"] = 2. * metric_dict[flow_setting][od_setting]["CPC_up"] / (metric_dict[flow_setting][od_setting]["CPC_bottom"]+1e-32)
                metric_dict[flow_setting][od_setting]["MPIW"] = metric_dict[flow_setting][od_setting]["MPIW"] / (metric_dict[flow_setting][od_setting]["n_sample"]+1e-32)
                metric_dict[flow_setting][od_setting]["PICP"] = metric_dict[flow_setting][od_setting]["PICP"] / (metric_dict[flow_setting][od_setting]["n_sample"]+1e-32)
        print('test', datetime.datetime.now() - start_time)
        return metric_dict
    @staticmethod
    def RMSE(y_pred:np.array, y_true:np.array):
        return np.sqrt(np.mean(np.square(y_pred - y_true)))
    @staticmethod
    def batch_SE(y_pred:np.array, y_true:np.array):
        return np.sum(np.square(y_pred - y_true))
    @staticmethod
    def MAE(y_pred:np.array, y_true:np.array):
        return np.mean(np.abs(y_pred - y_true))
    @staticmethod
    def batch_AE(y_pred:np.array, y_true:np.array):
        return np.sum(np.abs(y_pred - y_true))
    @staticmethod
    def MAPE(y_pred:np.array, y_true:np.array, epsilon=1e-0):   # zero division
        return np.mean(np.abs(y_pred - y_true) / (y_true + epsilon))
    @staticmethod
    def batch_APE(y_pred:np.array, y_true:np.array, epsilon=1e-0):   # zero division
        return np.sum(np.abs(y_pred - y_true) / (y_true + epsilon))
    @staticmethod
    def CPC(y_pred:np.array, y_true:np.array, numerator_only=False):
        if numerator_only:
            tot = 1.0
        else:
            tot = (np.sum(y_pred) + np.sum(y_true))
        if tot > 0:
            return 2.0 * np.sum(np.minimum(y_pred, y_true)) / tot
        else:
            return 0.0
    @staticmethod
    def batch_CPC(y_pred:np.array, y_true:np.array, numerator_only=False):
        return np.sum(np.minimum(y_pred, y_true)), (np.sum(y_pred) + np.sum(y_true))
        # if numerator_only:
        #     tot = 1.0
        # else:
        #     tot = (np.sum(y_pred) + np.sum(y_true))
        # if tot > 0:
        #     return 2.0 * np.sum(np.minimum(y_pred, y_true)) / tot
        # else:
        #     return 0.0
    @staticmethod
    def NRMSE(y_pred:np.array, y_true:np.array):
        return np.sqrt(np.mean(np.square(y_pred - y_true))) / np.mean(y_true)
    @staticmethod
    def NMAE(y_pred:np.array, y_true:np.array):
        return np.mean(np.abs(y_pred - y_true)) / np.mean(y_true)
    @staticmethod
    def PCC(y_pred:np.array, y_true:np.array):
        return np.corrcoef(y_pred.flatten(), y_true.flatten())[0,1]
    @staticmethod
    def JSD(y_pred:np.array, y_true:np.array, mask: np.array):
        """y_pred: [..., n_des_station]"""
        """y_true: [..., n_des_station]"""
        y_pred = y_pred.reshape(-1, y_pred.shape[-1])
        y_true = y_true.reshape(-1, y_true.shape[-1])
        mask = mask.reshape(-1, mask.shape[-1])
        js_div_ls = []
        for i in range(y_pred.shape[0]):
            js_div = jensenshannon(y_pred[i][mask[i] > 0], y_true[i][mask[i] > 0])
            js_div_ls.append(js_div)
        return np.mean(np.array(js_div_ls))

    @staticmethod
    def batch_JSD(y_pred:np.array, y_true:np.array, mask: np.array):
        """y_pred: [..., n_des_station]"""
        """y_true: [..., n_des_station]"""
        y_pred = y_pred.reshape(-1, y_pred.shape[-1])
        y_true = y_true.reshape(-1, y_true.shape[-1])
        mask = mask.reshape(-1, mask.shape[-1])
        js_div_ls = []
        for i in range(y_pred.shape[0]):
            js_div = jensenshannon(y_pred[i][mask[i] > 0], y_true[i][mask[i] > 0])
            js_div_ls.append(js_div)
        return np.sum(np.array(js_div_ls))

    def batch_zinb_MPIW(self, ns, ps, pis, targets, lower_percentile=0.1, upper_percentile=0.9):
        if len(ns) == 0: return 0,0
        max_range = self.data_class.max_od_cnt if self.setting == "daily_cnt" else self.data_class.max_od_density
        x_values = np.arange(0, max_range.round())[:, None]  # Adjust according to the expected range of outcomes
        pmf_nb = nbinom.pmf(x_values, ns, ps) # 50, 21854
        pmf_zinb = pmf_nb.copy()
        pmf_zinb = pmf_zinb * (1 - pis[None, :])
        pmf_zinb[0] = pis + pmf_zinb[0]
        cdf_zinb = np.cumsum(pmf_zinb, 0)
        lower_bound = np.sum(cdf_zinb < lower_percentile, 0)
        upper_bound = np.sum(cdf_zinb < upper_percentile, 0)
        mpiw = upper_bound - lower_bound
        picp = (targets >= lower_bound) * (targets <= upper_bound)
        return np.sum(mpiw), np.sum(picp)

    def batch_nb_MPIW(self, ns, ps, targets, lower_percentile=0.1, upper_percentile=0.9):
        if len(ns) == 0: return 0,0
        max_range = self.data_class.max_od_cnt if self.setting == "daily_cnt" else self.data_class.max_od_density
        x_values = np.arange(0, max_range.round())[:, None]  # Adjust according to the expected range of outcomes
        lower_bound, upper_bound = nbinom.interval(upper_percentile-lower_percentile, ns, ps)
        mpiw = upper_bound - lower_bound
        picp = (targets >= lower_bound) * (targets <= upper_bound)
        return np.sum(mpiw), np.sum(picp)

# Load Data

In [10]:
"""load demand data"""
local2global_station_dir = "../data/feature/local2global_station.pkl"
month_od_opendays_dir = "../data/feature/od_month_opendays.pkl"
month_station_opendays_dir = "../data/feature/station_month_opendays.pkl"
month_od_nonzero_dir = "../data/feature/od_month_nonzero.pkl"
feat_dir = "../data/feature/BSS_feat_summary.csv"
od_feat_dir = "../data/feature/od_interact_feat.npy"
month_feat_dir = "../data/feature/BSS_month_feat_summary.pkl"
month_od_feat_dir = "../data/feature/od_month_interact_feat.pkl"
manhattan_dir = "../data/feature/od_manhattan.npy"
dist_knn_dir = "../data/adj/month_geo_knn_0m.npy"
dist_adj_dir = "../data/adj/Bike_geo_none.npy"
builtEnv_knn_dir = "../data/adj/month_builtEnv_knn_0m.npy"
builtEnv_adj_dir = "../data/adj/Bike_builtEnv_month.npy"

In [11]:
import os
import random
def setup_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [None]:
setting = "daily_density"
dist = "zinb"
od_graph = True
node_graph = False
graph_conv_type = "gat"
builtEnv = False
learn_rate, weight_decay = 2e-3, 1e-6
dropout = 0.5
n_time = 8
n_feat = 44
n_interact = 13
batch_size = 32
optimizer = optim.Adam
k=9
data_input = DataInput(local2global_station_dir=local2global_station_dir,
            month_od_opendays_dir=month_od_opendays_dir,
            month_od_nonzero_dir=month_od_nonzero_dir,
            feat_dir=feat_dir,
            od_feat_dir=od_feat_dir,
            month_feat_dir=month_feat_dir,
            month_od_feat_dir=month_od_feat_dir,
            manhattan_dir=manhattan_dir,
            month_station_opendays_dir=month_station_opendays_dir,
            dist_knn_dir=dist_knn_dir,
            dist_adj_dir=dist_adj_dir,
            builtEnv_knn_dir=builtEnv_knn_dir,
            builtEnv_adj_dir=builtEnv_adj_dir,
            n_train_timestep=50,
            k=k,
            setting=setting)
data_generator = DataGenerator(data_input)
data_processor = data_generator.get_data_loader(batch_size=batch_size)

# Train and Test Model

In [None]:
result_dict = dict()
for ep in range(5):
    setup_seed(ep)
    hour_prob_model_name = f"zinb_gnn_k{k}_ep{ep}"
    hour_prob_model = Graph_NBNorm_ZeroInflated(dim_station=n_feat, dim_interact=n_interact, dropout=dropout, setting=setting,
                          dist=dist, od_graph=od_graph, node_graph=node_graph,
                          graph_conv_type=graph_conv_type, builtEnv=builtEnv).to(device)
    trainer = ModelTrainer(hour_prob_model_name=hour_prob_model_name, hour_prob_model=hour_prob_model, hour_prob_epochs=5,
                optimizer=optimizer, lr=learn_rate, wd=weight_decay,
                dataset=data_input.dataset, data_class=data_input)
    model_dir = "../model"
    st_time = time.time()
    trainer.train_hour_od(data_processor=data_processor, modes=["train", "valid"], model_dir=model_dir,
                        n_sample=100, weight=0.)
    result = trainer.test_hour_od(epoch=ep, data_processor=data_processor, modes=["test"], model_dir=model_dir)