In [None]:
from sklearn.neighbors import kneighbors_graph
import networkx as nx
import numpy as np
import pandas as pd
import torch
import itertools
from sklearn.datasets import make_moons, make_circles
from torch_geometric.nn import GCNConv, ChebConv, SAGEConv
import pdb
import pickle5 as pickle
# import pickle
from torch.distributions.multivariate_normal import MultivariateNormal
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import networkx as nx
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

''' 1. Non_graph Simulation Helpers '''


class graph_simulate_3node():
    def __init__(self, num_sample):
        self.num_sample = num_sample
        # Control the design of Z
        self.complex_X = False  # If Z are two-moon with rotation + raw H translation
        # Control how Z goes to X, where either "small_averaging" or "disporportional" is used
        self.small_averaging = False  # If A[i,i] += 10
        self.disporportional = False  # If A has different diagonal entries
        self.change_A = False
        self.P_square = False  # If X=P^2Z
        self.plot_X_Z = False

    def get_full_data(self):
        X_full = []
        Y_full = []
        for i, Y in enumerate(itertools.product(*[[0, 1], [0, 1], [0, 1]])):
            print(f'Y={Y}')
            if self.complex_X:
                self.get_Z_from_Y_non_symmetric(Y)
            else:
                self.get_Z_from_Y(Y)
                # X, Z = self.get_X_from_Y(Y)
            self.get_X_from_Z()
            if i == 0:
                print(self.P)
            X_full.append(self.X.float())
            Y_full.append(torch.tensor(Y).repeat(self.num_sample, 1).float())
            if self.plot_X_Z:
                fig, ax = plt.subplots(2, 1, figsize=(8, 4))
                ax[0].scatter(self.Z.flatten(start_dim=0, end_dim=1)[
                    :, 0], self.Z.flatten(start_dim=0, end_dim=1)[:, 1])
                ax[0].set_title('Z')
                ax[1].scatter(self.X.flatten(start_dim=0, end_dim=1)[
                    :, 0], self.X.flatten(start_dim=0, end_dim=1)[:, 1])
                ax[1].set_title('X')
                fig.tight_layout()
                plt.show()
                plt.close()
        torch.manual_seed(1103)
        idx = torch.randperm(self.num_sample * 8)  # Randomly shuffle
        X_full, Y_full = torch.vstack(X_full)[idx], torch.vstack(Y_full)[idx]
        self.X_full, self.Y_full = X_full, Y_full

    def get_Z_from_Y(self, Y):
        # Highly symmetric Z, can cause training issues
        mean0, mean1, cov = torch.tensor([0., 1.5]), torch.tensor(
            [0., -1.5]), (torch.eye(2) * 0.1)
        base0_dist = MultivariateNormal(mean0, cov)
        base1_dist = MultivariateNormal(mean1, cov)
        # First get Z
        Z = []
        for i, y in enumerate(Y):
            base_dist = base0_dist if y == 0 else base1_dist
            H_sample = base_dist.rsample(sample_shape=(self.num_sample,))
            offset = 4
            if i == 0:
                H_sample[:, 0] -= offset
            if i == 2:
                H_sample[:, 0] += offset
            Z.append(H_sample)
        Z = torch.hstack(Z)
        Z = Z.reshape(self.num_sample, 3, 2)
        self.Z = Z

    def get_Z_from_Y_non_symmetric(self, Y):
        # Very non-convex Z, can cause learning difficulties
        mean0, mean1, cov = torch.tensor([0., 1.5]), torch.tensor(
            [0., -1.5]), (torch.eye(2) * 0.1)
        base0_dist = MultivariateNormal(mean0, cov)
        base1_dist = MultivariateNormal(mean1, cov)
        # First get Z
        Z = []
        moon_shift_dict = {0: [0.57, 0.73], 1: [-0.57, -0.73]}
        offset_dict = {0: [-4, 1], 1: [0, 2], 2: [4, 0]}
        theta = np.radians(270)
        c, s = np.cos(theta), np.sin(theta)
        rotate_matrix = np.array(((c, -s), (s, c)))
        for i, y in enumerate(Y):
            # Highly tailored, so we design depending on i and y
            # See notability image for example
            rstate = 1103 + i * 3 + y * 10
            moon_shift, offset = moon_shift_dict[y], offset_dict[i]
            if i == 0:
                X_np, y_np, _, _ = gen_two_moon_data(
                    2 * self.num_sample, 'two_moon', random_state=rstate)
                if y == 0:
                    # Upper moon, with shift
                    Z_sample = X_np[y_np == 0]
                else:
                    Z_sample = rotate_matrix.dot(X_np[y_np == 1].T).T
                Z_sample[:, 0] += moon_shift[0] + offset[0]
                Z_sample[:, 1] += moon_shift[1] + offset[1]
                Z_sample = torch.from_numpy(Z_sample).float()
            elif i == 1:
                if y == 0:
                    base_dist = base0_dist
                    Z_sample = base_dist.rsample(
                        sample_shape=(self.num_sample,))
                    Z_sample[:, 0] += offset[0]
                    Z_sample[:, 1] += offset[1]
                else:
                    X_np, y_np, _, _ = gen_two_moon_data(
                        2 * self.num_sample, 'two_moon', random_state=rstate)
                    Z_sample = X_np[y_np == 1]
                    Z_sample[:, 0] += moon_shift[0] + offset[0]
                    Z_sample[:, 1] += moon_shift[1] + offset[1]
                    Z_sample = torch.from_numpy(Z_sample).float()
            else:
                if y == 0:
                    X_np, y_np, _, _ = gen_two_moon_data(
                        2 * self.num_sample, 'two_moon', random_state=rstate)
                    Z_sample = rotate_matrix.dot(X_np[y_np == 0].T).T
                    Z_sample[:, 0] += moon_shift[0] + offset[0]
                    Z_sample[:, 1] += moon_shift[1] + offset[1]
                    Z_sample = torch.from_numpy(Z_sample).float()
                else:
                    base_dist = base1_dist
                    Z_sample = base_dist.rsample(
                        sample_shape=(self.num_sample,))
                    Z_sample[:, 0] += offset[0]
                    Z_sample[:, 1] += offset[1]
            Z.append(Z_sample)
        Z = torch.hstack(Z)
        Z = Z.reshape(self.num_sample, 3, 2)
        self.Z = Z

    def get_X_from_Z(self):
        # Then get X from Z
        A = np.ones((3, 3))
        A[0, 2], A[2, 0] = 0, 0
        # Emphasize the connection between node 0 and 1, which is used in training.
        if self.change_A:
            A[0, 1] = 2
            A[1, 0] = 2
        if self.small_averaging:
            for i in range(3):
                A[i, i] += 10
        # Disporportional weight
        if self.disporportional:
            A[0, 0] = 2
            A[1, 1] = 3
            A[2, 2] = 5
        D_inv = np.diag(1 / np.sum(A, axis=1))
        P_mat = torch.from_numpy(D_inv.dot(A)).type(torch.float)
        if self.P_square:
            P_mat = P_mat @ P_mat
        self.P = P_mat
        X = P_mat @ self.Z
        self.X = X

    def select_Y(self, Y_rows):
        for i, Y_row in enumerate(Y_rows):
            idx_temp = (self.Y_full == Y_row).all(dim=1).to(device)
            if i == 0:
                idx = idx_temp.clone()
            else:
                idx = torch.logical_or(idx, idx_temp)
        self.X_full, self.Y_full = self.X_full[idx], self.Y_full[idx]


class GP_graph():
    def __init__(self, num_sample, V):
        self.num_sample = num_sample
        self.V = V  # Must be a prime
        if isPrime(self.V) == False:
            raise ValueError('V Must be a prime for Chordal Cycle Graph')

    def gen_1d_GP_data(self, Sigma_type='ChebNet'):
        if Sigma_type == 'ChebNet':
            self.get_graph_for_Cheb()
            # So sigma = I + L, normalized and rescaled
            layer = ChebConv(self.V, self.V, K=3).to(device)
            for name, param in layer.named_parameters():
                if name != 'bias':
                    with torch.no_grad():
                        weight_val = torch.eye(self.V).to(device)
                        a, b, c = 0.5, 0.1, 0.5
                        # By design, Sigma^{-1} = b*\hat L + 2*c*\hat L^2 + a*(1-c)*I
                        # For Sigma off-diagonal enteires large, b should be small, c reasonably large, and a*(1-c) not large
                        if '0' in name:
                            weight_val *= a
                        if '1' in name:
                            weight_val *= b
                        if '2' in name:
                            weight_val *= c
                        # To make sure covariance matrix invertible
                        param.data = weight_val
            X = torch.eye(self.V, self.V).to(device)
            # Get edge_weights so correlations are higher on off-diagonal
            n_edge = self.edge_index.shape[1]
            edge_weights = torch.ones(n_edge)
            # # Perturb on-diagnal entries
            # num_e = 0
            # for e in range(n_edge):
            #     if num_e > int(self.V/2):
            #         break
            #     edge = self.edge_index[:, e]
            #     which_e = (self.edge_index.T == edge).all(dim=1).cpu()
            #     e_loc = np.arange(n_edge)[which_e][0]
            #     if edge[0] == edge[1]:
            #         num_e += 1
            #         edge_weights[e_loc] *= 10
            # # Perturb off-diagnal entries
            # exclude_idx = []
            # for e in range(n_edge):
            #     if e > int((n_edge-self.V)/4):
            #         # e.g., only pertueb half of edge weights
            #         break
            #     if e in exclude_idx:
            #         continue
            #     edge = self.edge_index[:, e]
            #     if edge[0] != edge[1]:
            #         # Smaller values make Sigma larger, because it is the inverse
            #         mult = np.log(e+1)*1e-2
            #         edge_weights[e] *= mult
            #         oppo_edge = torch.tensor([edge[1], edge[0]]).to(device)
            #         which_e = (self.edge_index.T == oppo_edge).all(dim=1).cpu()
            #         oppo_e = np.arange(n_edge)[which_e][0]
            #         edge_weights[oppo_e] *= mult
            #         exclude_idx.append(oppo_e)
            self.edge_weights = edge_weights.to(device)
            Sigma_inv = layer(X, self.edge_index,
                              edge_weight=self.edge_weights)
            Sigma = torch.inverse(Sigma_inv).to(device)
            gaid = torch.diag(1 / torch.diag(Sigma)**0.5)
            Sigma_corr = gaid @ Sigma @ gaid
            print(f'Corr matrix: {Sigma_corr}')
            Unique_Y, counts_Y = torch.unique(
                torch.round(Sigma_corr, decimals=2), return_counts=True)
            idx = torch.abs(Unique_Y) > 0.1
            print(
                f'Correlation Dist. are {Unique_Y[idx].cpu().detach().tolist()}, \n with frequency {counts_Y[idx].tolist()}')
            self.Sigma = Sigma
            fig, ax = plt.subplots(figsize=(4, 4))
            c = ax.matshow(Sigma_corr.cpu().detach().numpy())
            cbar_ax = fig.add_axes([1, 0.1, 0.1, 0.8])
            plt.colorbar(c, cax=cbar_ax)
            ax.set_title(r'Corr of $\Sigma$')
        if Sigma_type == 'Local':
            # Consider KNN graph, where I just manually change the value of correlation matrix Sigma after created
            # NOTE, the 3-node graph has issue that \Sigma^-1 not local
            self.get_graph_and_Sigma_for_local()
            self.edge_weights = torch.ones(self.edge_index.shape[1]).to(device)
            Sigma = self.Sigma
        Mu = torch.ones(self.V).to(device)
        X_dist = MultivariateNormal(Mu, Sigma)
        torch.manual_seed(1103)
        X_train = X_dist.rsample(sample_shape=(
            self.num_sample,)).cpu().detach().reshape(self.num_sample, self.V, 1)
        Y_train = torch.zeros(self.num_sample, self.V)
        self.X_train, self.Y_train = X_train, Y_train
        torch.manual_seed(111)
        X_test = X_dist.rsample(sample_shape=(
            self.num_sample,)).cpu().detach().reshape(self.num_sample, self.V, 1)
        Y_test = torch.zeros(self.num_sample, self.V)
        self.X_test, self.Y_test = X_test, Y_test

    def get_graph_for_Cheb(self):
        # Get a graph where the locality is low (e.g., not easy to just go to another edge)
        G = nx.chordal_cycle_graph(self.V)
        G.add_edges_from([(i, i) for i in range(self.V)])
        fig, ax = plt.subplots(figsize=(4, 4))
        nx.draw(G, ax=ax, with_labels=True, node_color='white')
        self.fig = fig
        # edge_index = list(G.edges) # For erdo renyi graph
        edge_index = np.unique([list(i)[:2]
                               for i in list(G.edges)], axis=0).tolist()
        m = 0
        while m < len(edge_index):
            edge = edge_index[m]
            k, j = edge
            if [j, k] not in edge_index:
                edge_index.append([j, k])
            m += 1
        for i in range(self.V):
            if [i, i] not in edge_index:
                edge_index.append([i, i])
        edge_index = torch.tensor(edge_index).T.to(device)
        self.edge_index = edge_index

    def get_graph_and_Sigma_for_local(self):
        # The KNN graph suggested by Prof. Cheng
        self.knn = False
        if self.knn:
            n, knn = self.V, 2
            np.random.seed(1103)
            T = np.sort(np.random.rand(n))
            X = np.array([[np.cos(np.pi * t), np.sin(np.pi * t)] for t in T])
            X = X + np.random.rand(X.shape[0], X.shape[1]) * 0.05
            A = kneighbors_graph(X, knn, mode='connectivity',
                                 include_self=True).toarray()
            A = A + A.T
            A[A > 0] = 1
            S = np.diag(np.ones(n)) * knn * 2 + A
            D = np.diag(1 / np.sqrt(np.sum(S, axis=1)))
            S = D @ S @ D
            S = (S + S.T) / 2
            # Convert to correlation matrix
            diag = np.sqrt(np.diag(np.diag(S)))
            gaid = np.linalg.inv(diag)
            S = gaid @ S @ gaid
            # NOTE, keep this small, as o/w S_inv would not be local
            offset = np.min(S[S > 0]) / n
            S[0, 1] += offset
            S[1, 0] += offset
            S[n - 2, n - 1] -= offset
            S[n - 1, n - 2] -= offset
            S_inv = np.linalg.inv(S)
            rows, cols = np.where(A == 1)
            edges = list(zip(rows.tolist(), cols.tolist()))
            self.edge_index = torch.tensor(
                [list(i) for i in edges]).T.to(device)
        else:
            self.edge_index = torch.tensor(
                [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]).to(device)
            rho, rho1 = 0.6, -0.4
            S = np.array([[1, rho, 0], [rho, 1, rho1], [0, rho1, 1]])
            S_inv = np.linalg.inv(S)
            print(S_inv)  # Check if "local"
            X = np.zeros((3, 2))
        fig, ax = plt.subplots(2, 2, figsize=(
            8, 8), constrained_layout=True)
        gr = nx.Graph()
        gr.add_edges_from(self.edge_index.T.tolist())
        nx.draw(gr, ax=ax[0, 0], with_labels=True, node_color='white')
        ax[0, 1].plot(X[:, 0], X[:, 1], 'o')
        ax[1, 0].matshow(S)
        ax[1, 0].set_title(r'Corr of $\Sigma$')
        c = ax[1, 1].matshow(S_inv)
        ax[1, 1].set_title(r'Corr of $\Sigma^{-1}$')
        cbar_ax = fig.add_axes([1, 0.025, 0.1, 0.4])
        plt.colorbar(c, cax=cbar_ax)
        plt.show()
        plt.close()
        fig, ax = plt.subplots(figsize=(4, 4))
        gr = nx.Graph()
        gr.add_edges_from(self.edge_index.T.tolist())
        nx.draw(gr, ax=ax, with_labels=True, node_color='white')
        self.fig = fig
        self.Sigma = torch.from_numpy(S).float().to(device)


def isPrime(n):
    for i in range(2, int(n**0.5) + 1):
        if n % i == 0:
            return False
    return True


'''2. Real data helpers, Traffic '''


class trffic_data():
    def __init__(self, d):
        '''
        Input:
            d here means how long in the past we look at each node. It is thus the in-channel dimension
        '''
        self.d = d

    def get_traffic_train_test(self, num_neighbor=3, sub=False):
        '''
            Description:
                Data are available hourly, with Yt,i = 1 (resp. 2) if the current traffic flow lies outside the upper (resp. lower) 90% quantile over the past four days of traffic flow of its nearest four neighbors based on sensor proximity.
        '''
        d = self.d
        # Traffic flow multi-class detection
        with open(f'flow_frame_train_0.7_no_drop_data.p', 'rb') as fp:
            Xtrain = pickle.load(fp)
        with open(f'flow_frame_test_0.7_no_drop_data.p', 'rb') as fp:
            Xtest = pickle.load(fp).to_numpy()
        with open(f'true_anomalies.p', 'rb') as fp:
            Yvals = pickle.load(fp).to_numpy()
        # Define edge index
        sensors = np.array(list(Xtrain.columns))
        Xtrain = Xtrain.to_numpy()
        scaler = StandardScaler()
        Xtrain = scaler.fit_transform(Xtrain)
        Xtest = scaler.fit_transform(Xtest)
        Ytrain = Yvals[:Xtrain.shape[0], :]
        Ytest = Yvals[Xtrain.shape[0]:, :]
        if sub:
            N = int(Xtrain.shape[0] / 2)  # 50% or /2 already pretty good
            N1 = int(Xtest.shape[0] / 2)
            Xtrain = Xtrain[-N:]
            Xtest = Xtest[:N1]
            Ytrain = Ytrain[-N:]
            Ytest = Ytest[:N1]
        with open(f'sensor_neighbors.p', 'rb') as fp:
            neighbor_dict = pickle.load(fp)
        # # Randomly select 15 nodes
        # np.random.seed(1103)
        # chosen_nodes = np.random.choice(len(sensors), 15, replace=False)
        # sensors = sensors[chosen_nodes]
        # Xtrain, Xtest, Ytrain, Ytest = Xtrain[:, chosen_nodes], Xtest[:,
        #                                                               chosen_nodes], Ytrain[:, chosen_nodes], Ytest[:, chosen_nodes]
        sensors_dict = {i: j for (i, j) in zip(sensors, range(len(sensors)))}
        edge_index = []
        # num_neighbor = 3
        for k, sensor in enumerate(sensors):
            neighbors = neighbor_dict[sensor]
            num_n = 0
            for p in range(len(sensors)):
                if num_n >= num_neighbor:
                    break
                if neighbors[p] in sensors_dict.keys():
                    edge_index.append([k, sensors_dict[neighbors[p]]])
                    num_n += 1
        edge_index = torch.from_numpy(np.array(edge_index).T).type(torch.long)
        # Define graphs, similarly as the solar data
        X_train = []
        X_test = []
        Y_train = []
        Y_test = []
        for t in range(d - 1, Xtrain.shape[0]):
            X_train.append(
                np.flip(Xtrain[t - d + 1:t + 1].T, 1))
            Y_train.append(Ytrain[t])
        for t in range(Xtest.shape[0]):
            if t < d - 1:
                temp = np.c_[np.flip(Xtest[:t + 1].T, 1),
                             np.flip(Xtrain[-(d - t) + 1:].T, 1)]
            else:
                temp = np.flip(Xtest[t - d + 1:t + 1].T, 1)
            X_test.append(temp)
            Y_test.append(Ytest[t])
        X_train = torch.stack([torch.from_numpy(val.copy()).type(
            torch.FloatTensor) for val in X_train])
        X_test = torch.stack([torch.from_numpy(val.copy()).type(
            torch.FloatTensor) for val in X_test])
        Y_train = torch.stack([torch.from_numpy(val.copy()).type(
            torch.FloatTensor) for val in Y_train])
        Y_test = torch.stack([torch.from_numpy(val.copy()).type(
            torch.FloatTensor) for val in Y_test])
        Y_train[Y_train == 2] = 0
        Y_test[Y_test == 2] = 0
        self.X_train, self.X_test, self.Y_train, self.Y_test, self.edge_index = X_train, X_test, Y_train, Y_test, edge_index

    def plot_traffic(self):
        plt.rcParams['axes.titlesize'] = 20
        plt.rcParams['figure.titlesize'] = 28
        fig, axs = plt.subplots(3, 1, figsize=(
            3, 7), constrained_layout=True)
        G = nx.Graph()
        G.add_edges_from(self.edge_index.cpu().detach().numpy().T)
        pos = nx.circular_layout(G)
        i = 0
        nx.draw(G, pos, ax=axs[i], with_labels=True, node_color='lightblue')
        N, N1 = self.Y_train_sub.numel(), self.Y_test_sub.numel()
        colors = np.repeat('black', N)
        colors1 = np.repeat('black', N1)
        colors[(self.Y_train_sub.flatten() == 1).cpu(
        ).detach().numpy().flatten()] = 'red'
        colors1[(self.Y_test_sub.flatten() == 1).cpu(
        ).detach().numpy().flatten()] = 'red'
        Xtrain, Xtest = self.X_train_sub.flatten(
            start_dim=0, end_dim=1), self.X_test_sub.flatten(start_dim=0, end_dim=1)
        i += 1
        axs[i].scatter(Xtrain[:, 0], Xtrain[:, 1], s=1, color=colors)
        # axs[0].plot(Xtrain[:, 0], Xtrain[:, 1],
        #             linestyle='dashed', linewidth=0.075)
        axs[i].set_title('Train X')
        i += 1
        axs[i].scatter(Xtest[:, 0], Xtest[:, 1], s=1, color=colors1)
        # axs[1].plot(Xtest[:, 0], Xtest[:, 1],
        #             linestyle='dashed', linewidth=0.075)
        axs[i].set_title('Test X')
        self.fig = fig
        plt.show()
        plt.close()

    def select_Y(self, Y_rows, train=True):
        if train:
            X, Y = self.X_train, self.Y_train
        else:
            X, Y = self.X_test, self.Y_test
        if Y_rows is None:
            if train:
                self.X_train_sub, self.Y_train_sub = X, Y
            else:
                self.X_test_sub, self.Y_test_sub = X, Y
        else:
            for i, Y_row in enumerate(Y_rows):
                idx_temp = (Y == Y_row).all(dim=1).to(device)
                if i == 0:
                    idx = idx_temp.clone()
                else:
                    idx = torch.logical_or(idx, idx_temp)
            if train:
                self.X_train_sub, self.Y_train_sub = X[idx], Y[idx]
            else:
                self.X_test_sub, self.Y_test_sub = X[idx], Y[idx]


''' 2. Real-data helpers, Solar '''


class solar_data():
    def __init__(self, num_obs_per_day, city):
        self.num_obs_per_day = num_obs_per_day
        self.city = city
        self.V = 10 if self.city == 'CA' else 0
        self.C = 2

    def get_solar(self):
        graph_connect = {'CA': False, 'LA': True}
        DHI_2017 = get_DHI(self.city, '2017', self.V, self.num_obs_per_day)
        DHI_2018 = get_DHI(self.city, '2018', self.V, self.num_obs_per_day)
        DHI_full = np.r_[DHI_2017, DHI_2018]
        T = DHI_full.shape[0]
        N = int(T * 3 / 4)
        DHI_train, DHI_test = DHI_full[:N], DHI_full[N:]
        # Anomalies have same frequency as data
        get_anomaly(DHI_full, self.city, N)
        train_anom = np.loadtxt(
            f'{self.city}_anomalies_train.csv', delimiter=',')
        test_anom = np.loadtxt(
            f'{self.city}_anomalies_test.csv', delimiter=',')
        X_train, X_test, Y_train, Y_test = get_solar_train_test(
            DHI_train, DHI_test, train_anom, test_anom, d=self.C)
        # Each has dimension (N-by-V-by-C), where the response is only N-by-V
        X_train = torch.stack([torch.from_numpy(val.copy()).type(
            torch.FloatTensor) for val in X_train])
        X_test = torch.stack([torch.from_numpy(val.copy()).type(
            torch.FloatTensor) for val in X_test])
        Y_train = torch.stack([torch.from_numpy(val.copy()).type(
            torch.FloatTensor) for val in Y_train])
        Y_test = torch.stack([torch.from_numpy(val.copy()).type(
            torch.FloatTensor) for val in Y_test])
        count_train = np.unique(Y_train.numpy(), return_counts=True)[1]
        count_test = np.unique(Y_test.numpy(), return_counts=True)[1]
        print(
            f'#1/#0 in training data is {count_train[1]/count_train[0]}')
        print(f'#1/#0 in test data is {count_test[1]/count_test[0]}')
        fully_connected = graph_connect[self.city]
        edge_index = get_edge_list(
            Y_train, self.V, fully_connected)
        self.X_train, self.Y_train, self.X_test, self.Y_test = X_train, Y_train, X_test, Y_test
        self.edge_index = edge_index

    def plot_solar(self):
        plt.rcParams['axes.titlesize'] = 20
        plt.rcParams['figure.titlesize'] = 28
        fig, axs = plt.subplots(3, 1, figsize=(
            3, 7), constrained_layout=True)
        G = nx.Graph()
        G.add_edges_from(self.edge_index.cpu().detach().numpy().T)
        pos = nx.circular_layout(G)
        i = 0
        nx.draw(G, pos, ax=axs[i], with_labels=True, node_color='lightblue')
        N, N1 = self.Y_train.numel(), self.Y_test.numel()
        colors = np.repeat('black', N)
        colors1 = np.repeat('black', N1)
        colors[(self.Y_train.flatten() == 1).cpu(
        ).detach().numpy().flatten()] = 'red'
        colors1[(self.Y_test.flatten() == 1).cpu(
        ).detach().numpy().flatten()] = 'red'
        Xtrain, Xtest = self.X_train.flatten(
            start_dim=0, end_dim=1), self.X_test.flatten(start_dim=0, end_dim=1)
        i += 1
        axs[i].scatter(Xtrain[:, 0], Xtrain[:, 1], s=1, color=colors)
        # axs[0].plot(Xtrain[:, 0], Xtrain[:, 1],
        #             linestyle='dashed', linewidth=0.075)
        axs[i].set_title('Train X')
        i += 1
        axs[i].scatter(Xtest[:, 0], Xtest[:, 1], s=1, color=colors1)
        # axs[1].plot(Xtest[:, 0], Xtest[:, 1],
        #             linestyle='dashed', linewidth=0.075)
        axs[i].set_title('Test X')
        self.fig = fig
        plt.show()
        plt.close()


def get_anomaly(raw_data, city, N):
    # NOTE: just run once is enough.
    T, V = raw_data.shape
    anomalies = np.zeros((T, V))
    # window = 15  # Used to be 30
    for l in range(V):
        for t in range(1, T):
            # past_window = np.arange(max(0, t-2*window), t, 2, dtype=int)
            # # past_at_l = raw_data[past_window, l]
            # past_at_l = raw_data[past_window]
            # Q1, Q3 = np.percentile(past_at_l, 25), np.percentile(past_at_l, 75)
            # IQR = Q3-Q1
            # lower_end1, lower_end2 = np.percentile(
            #     past_at_l, 5), np.percentile(past_at_l, 10)
            rate_inc = (raw_data[t, l] - raw_data[t - 1, l]
                        ) / raw_data[t - 1, l]
            # if raw_data[t, l] < Q1-IQR:
            # NOTE: use this a little arbitrary rule, as o/w too hard to do.
            if rate_inc > 1 or rate_inc < -0.5 or ((raw_data[t, l] < 40) and (raw_data[t - 1, l] < 35)):
                anomalies[t, l] = 1
    np.savetxt(f'{city}_anomalies_train.csv', anomalies[:N], delimiter=',')
    np.savetxt(f'{city}_anomalies_test.csv', anomalies[N:], delimiter=',')


def get_DHI(city, year, V, num_obs_per_day=2):
    # average_num_obs: how many observations we average over. Default is 24 so it is 12H
    full_data = pd.read_csv(f'{city}_{year}.csv')
    full_data = full_data['DHI'].to_numpy()
    days = 365
    mult = days * 48
    freq = int(24 / (num_obs_per_day - 1))
    T = days * num_obs_per_day
    X_array = np.zeros((T, V))
    for loc in range(V):
        loc_data = full_data[loc * mult:(loc + 1) * mult]
        for d in range(T):
            X_array[d, loc] = np.mean(loc_data[d * freq:d * freq + 24])
    return X_array


def get_solar_train_test(train_DHI, test_DHI, train_anom, test_anom, d=2):
    # d is the dimension of the input signal, which intuitively is the memory depth
    # The training data starts at index d, where each row is X=\omega^-d_t=[\omega_t-1,...,\omega_t-d] \in R^{K-by-d}
    X_train = []
    X_test = []
    Y_train = []
    Y_test = []
    N, N1 = train_DHI.shape[0], test_DHI.shape[0]
    scaler = StandardScaler()
    train_DHI = scaler.fit_transform(train_DHI)
    test_DHI = scaler.fit_transform(test_DHI)
    for t in range(d - 1, N):
        X_train.append(np.flip(train_DHI[t - d + 1:t + 1], 0).T)
        Y_train.append(train_anom[t])
    for t in range(N1):
        # Use raw DHI, including today
        if t < d - 1:
            temp = temp = np.r_[
                np.flip(test_DHI[:t + 1], 0), np.flip(train_DHI[-(d - t) + 1:], 0)]
        else:
            temp = np.flip(test_DHI[t - d + 1:t + 1], 0)
        X_test.append(temp.T)
        Y_test.append(test_anom[t])
    return [X_train, X_test, Y_train, Y_test]


def get_edge_list(Y_train, n=10, fully_connected=True):
    if fully_connected:
        edge_index = torch.from_numpy(
            np.array([[a, b] for a in range(n) for b in range(n)]).T).type(torch.long)
    else:
        # Infer edge connection in a nearest neighbor fashion, by including connection among node k and all nodes whose training labels are the most similar to k (e.g., in terms of equality). The reason is that this likely indicates influence.
        # Always include itself
        Y_temp = np.array(Y_train)
        edge_index = []
        num_include = 3  # four nodes, including itself
        for k in range(n):
            same_num = np.array([np.sum(Y_temp[:, k] == Y_temp[:, j])
                                 for j in range(Y_temp.shape[1])])
            include_ones = same_num.argsort()[-num_include:][::-1]
            for j in include_ones:
                edge_index.append([k, j])
        # Also, to ensure the connection is symmetric, I included all edges where there was a directed edge before
        print(f'{len(edge_index)} directed edges initially')
        m = 0
        while m < len(edge_index):
            edge = edge_index[m]
            k, j = edge
            if [j, k] in edge_index:
                # print(f'{[j, k]}' in edge')
                m += 1
                continue
            else:
                # print(f'{[j, k]} added b/c {[k, j]}' in graph')
                edge_index.append([j, k])
                m += 1
        print(f'{len(edge_index)} undirected edges after insertion')
        edge_index = torch.from_numpy(np.array(edge_index).T)
    return edge_index


'''3. Simulation non-graph helper and other helpers'''


def gen_two_moon_data(N, data_name, random_state=1103):
    if data_name == 'two_moon':
        X_np, y_np = make_moons(noise=0.05,
                                n_samples=N, random_state=random_state)
        X_np = StandardScaler().fit_transform(X_np)
        X = torch.from_numpy(X_np).float().to(device)
        y = torch.from_numpy(y_np).float().to(device)
    elif data_name == 'two_circles':
        X_np, y_np = make_circles(
            n_samples=N, noise=0.05, random_state=1103, factor=0.6)
        X = torch.from_numpy(X_np).float().to(device)
        y = torch.from_numpy(y_np).float().to(device)
    else:
        raise ValueError('Not considered yet')
    return [X_np, y_np, X, y]


def draw_graph(edge_index, edge_index_est, graph_type, overleaf_path):
    G = nx.Graph()
    G.add_edges_from(edge_index.cpu().detach().numpy().T)
    pos = nx.circular_layout(G)
    fig_network, ax1 = plt.subplots(1, 2, figsize=(8, 4))
    nx.draw(G, pos, with_labels=True, node_color='lightblue', ax=ax1[0])
    # ax1[0].set_title('True Graph')
    G = nx.Graph()
    G.add_edges_from(edge_index_est.cpu().detach().numpy().T)
    pos = nx.circular_layout(G)
    nx.draw(G, pos, with_labels=True, node_color='lightblue', ax=ax1[1])
    # ax1[1].set_title('Estimated Graph')
    fig_network.savefig(f'{overleaf_path}simulated_graph_{graph_type}.pdf',
                        dpi=200, bbox_inches='tight', pad_inches=0)


verts = [
    (-2.4142, 1.),
    (-1., 2.4142),
    (1.,  2.4142),
    (2.4142,  1.),
    (2.4142, -1.),
    (1., -2.4142),
    (-1., -2.4142),
    (-2.4142, -1.)
]
label_maps = {
    'all':  [0, 1, 2, 3, 4, 5, 6, 7],
    'some': [0, 0, 1, 1, 2, 2, 3, 3],
    'none': [0, 0, 0, 0, 0, 0, 0, 0],
}


def gen_8_gaussian_data(labels, N, random_state=0):
    # N denotes number of obs for all 8 Gaussian
    np.random.seed(random_state)
    mapping = label_maps[labels]

    pos = np.random.normal(size=(N, 2), scale=0.2)
    labels = np.zeros((N, 8))
    n = N//8

    for i, v in enumerate(verts):
        pos[i*n:(i+1)*n, :] += v
        labels[i*n:(i+1)*n, mapping[i]] = 1.

    shuffling = np.random.permutation(N)
    pos = torch.tensor(pos[shuffling], dtype=torch.float)
    labels = torch.tensor(labels[shuffling], dtype=torch.float)

    return pos, labels


##########
