In [None]:
# from torch.nn.utils.parametrizations import spectral_norm
import time
import torch
import torch.nn as nn
import numpy as np
from torch.autograd.functional import jacobian
from torch_geometric.nn import GCNConv, ChebConv, SAGEConv
from torch_geometric.nn import Sequential as Graph_Sequential
from torch.nn import Parameter
from L3net import GraphConv_Bases
import torch_geometric as pyg
# print(torch.__file__)
import pdb
from timeit import default_timer as timer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# IResNet


class input_tranpose(nn.Module):
    def __init__(self, dim0=1, dim1=2):
        super().__init__()
        self.dim0 = dim0
        self.dim1 = dim1

    def forward(self, x):
        return torch.transpose(x, self.dim0, self.dim1)


class InvResBlock(nn.Module):
    '''
    First construct ResNet block as F_b=I+g, where
        g=W_i \circ \phi ... \circ W_1, where the spectral norm of W_i is strictly smaller than 1 & phi can be chosen as ReLU, ELU, tanh, etc. as contractive nonlinearities
    '''

    def __init__(self, C, model_args, version, A_):
        super().__init__()
        '''
        model_args=[c,dim]
            c: The spectral norm of the weight matrix. Too small choice seem to make things harder to train
            dim: Dimension of hidden representation
        '''
        dim, input_non_linear = model_args
        act = nn.ELU(inplace=True)
        layers = layers_append([], act, C, dim, C, version=version, A_=A_)
        self.bottleneck_block = nn.Sequential(*layers)
        self.actnorm = None
        self.C = C

    def forward(self, x):
        n = int(x.shape[1]/self.C)
        x = x.reshape(x.shape[0], n, self.C)
        Fx = self.bottleneck_block(x)
        output = Fx+x
        output = output.flatten(start_dim=1)  # So it is in R^N times nC
        # The goal is so that the jacobian has shape nC-by-nC
        return [output, Fx]

    def inverse(self, y, maxIter=100, eps=1e-7):
        x_pre = y
        for iter in range(maxIter):
            x_now = y-self.bottleneck_block(x_pre)
            diff = torch.linalg.norm(x_now-x_pre)
            if diff <= eps:
                break
            x_pre = x_now
        if self.actnorm is not None:
            x_now = self.actnorm.inverse(x_now)
        return x_now


class InvResBlock_Graph(nn.Module):
    '''
    First construct ResNet block as F_b=I+g, where
        g=W_tori \circ \phi ... \circ W_1, where the spectral norm of W_i is strictly smaller than 1 & phi can be chosen as ReLU, ELU, tanh, etc. as contractive nonlinearities
    '''

    def __init__(self, C, model_args, version='two_GCN_one_FC'):
        super().__init__()
        [dim, K] = model_args
        act = nn.ELU(inplace=True)
        # trans = input_tranpose(1, 2)
        layers = []
        layers = layers_append(layers, act, C, dim, C, K, version)
        self.bottleneck_block = Graph_Sequential(
            'x, edge_index, edge_weight', layers)
        # self.actnorm = ActNorm2D(C) # extremely slow
        self.actnorm = None
        self.C = C

    def forward(self, x, edge_index, edge_weight):
        # HERE, need to reshape x first, as it is a flattened matrix of node features (so in dimension nC rather than n-X-C)
        # We assume each row is this flattened graph, and x\in \R^{N\times nC}
        n = int(x.shape[1]/self.C)
        x = x.reshape(x.shape[0], n, self.C)
        Fx = self.bottleneck_block(x, edge_index, edge_weight)
        output = Fx+x
        output = output.flatten(start_dim=1)  # So it is in R^N times nC
        # The goal is so that the jacobian has shape nC-by-nC
        return [output, Fx]

    def inverse(self, y, edge_index, edge_weight, maxIter=100, eps=1e-7):
        # Fixed point iteration to find the inverse
        x_pre = y
        for iter in range(maxIter):
            x_now = y-self.bottleneck_block(x_pre, edge_index, edge_weight)
            diff = torch.linalg.norm(x_now-x_pre)
            if diff <= eps:
                break
            x_pre = x_now
        if self.actnorm is not None:
            x_now = self.actnorm.inverse(x_now)
        return x_now


class InvResNet(nn.Module):
    '''
    Refer to https://github.com/jhjacobsen/invertible-resnet/blob/master/models/conv_iResNet.py, where they stack multiple blocks together
        Line 418 class conv_iResNet(nn.Module) stacks blocks together,
        Line 56 builds the block
    '''

    def __init__(self, C, output_dim=1, nblocks=5, model_args=[0.9, 64, 3], graph=False, dim_inc=False, version='two_GCN_one_FC', A_=None):
        '''
            Output_dim: for classification
            num_nodes: number of graph nodes
            dim_inc: sometimes, if X \in \R^C lies in low dimension (e.g., graph node feature), inc. dimension and flow in \R^C' maybe better in terms of avoiding pathological situations
        '''
        super().__init__()
        self.C = C
        dim, K = model_args[1], model_args[2]
        self.dim_inc = dim_inc
        # Actual dimension in which the distribution flows
        self.C_prime = 2*C if self.dim_inc else C
        self.blocks = nn.ModuleList([InvResBlock_Graph(self.C_prime, [dim, K], version) for b in range(
            nblocks)]) if graph else nn.ModuleList([InvResBlock(self.C_prime, [dim, b], version, A_) for b in range(nblocks)])
        if self.dim_inc:
            self.blocks = nn.ModuleList(
                [SimpleDimInc_block(C, self.C_prime), *self.blocks])
        self.fc = nn.Linear(self.C_prime, output_dim)
        if self.C_prime > 2:
            self.fc = nn.Linear(2, output_dim)
        self.reduce_factor = model_args[0]
        self.small_weights()

    def forward(self, x, edge_index=None, edge_weight=None, logdet=True):
        start = 0
        if self.dim_inc:
            # Inc. dimension 1st
            n = int(x.shape[1]/self.C)
            x = x.reshape(x.shape[0], n, self.C)
            x = self.blocks[0](x).flatten(
                start_dim=1)  # So it is in R^N times nC
            start += 1
        log_det = 0
        transport_cost = 0
        for j, block in enumerate(self.blocks[start:]):
            if logdet:
                # In graph example, use the fast matrix log trace, but biased
                # TODO: improve it later with the Residual Flow paper idea
                if edge_index is not None:
                    det_block = torch.log(
                        torch.abs(torch.det(batch_jacobian(block, x, edge_index, edge_weight)))).sum()
                else:
                    det_block = torch.log(
                        torch.abs(torch.det(batch_jacobian(block, x)))).sum()
                log_det = log_det+det_block
            x, Fx = block(
                x, edge_index, edge_weight) if edge_index is not None else block(x)
            transport_cost += (torch.linalg.norm(Fx.flatten(start_dim=1),
                               dim=1)**2/2).sum()
        if logdet:
            return x, log_det, transport_cost
        else:
            return x

    def inverse(self, y, edge_index=None, edge_weight=None, maxIter=50):
        with torch.no_grad():
            start = 1 if self.dim_inc else 0
            for block in reversed(self.blocks[start:]):
                y = block.inverse(
                    y, edge_index, edge_weight, maxIter) if edge_index is not None else block.inverse(y, maxIter)
            if self.dim_inc:
                y = self.blocks[0].inverse(y)
        return y

    def classification(self, H):
        '''
            Yield a linear classifier
        '''
        return self.fc(H)

    def small_weights(self):
        for name, W in self.named_parameters():
            if 'fc' not in name:
                with torch.no_grad():
                    # Of course, this is user-specified. It is just for initialization
                    # In fact, should not be too small, as it would make the transport cost too negligible
                    # And losses decay too slowly
                    # And the model more likely get non-invertible...
                    W.mul_(self.reduce_factor)
                W.requires_grad = True


# CGAN


class CGAN_net(nn.Module):
    '''
        Note, this is very similar to our IResBlock, but just we no longer concatenate multiple blocks
    '''

    def __init__(self, C, dim, Y_dim=2, nblocks=10, classify=False, graph=True, version='two_L3_two_FC', A_=None):
        super().__init__()
        act = nn.ELU(inplace=True)
        full_layers = []
        trans = input_tranpose(1, 2)
        if nblocks > 1:
            for i in range(nblocks-1):
                if i == 0:
                    full_layers += layers_append([], act,
                                                 C, dim, dim, version=version, A_=A_)
                else:
                    full_layers += layers_append([], act,
                                                 dim, dim, dim, version=version, A_=A_)
                full_layers.append(trans)
                full_layers.append(pyg.nn.BatchNorm(dim))
                full_layers.append(trans)
            full_layers += layers_append([], act, dim,
                                         dim, C-Y_dim, version=version, A_=A_)
        else:
            full_layers = layers_append([], act, C,
                                        dim, C-Y_dim, version=version, A_=A_)
        self.graph = graph
        if self.graph:
            # e.g., ChebNet
            self.bottleneck_block = Graph_Sequential(
                'x, edge_index, edge_weight', full_layers)
        else:
            self.bottleneck_block = nn.Sequential(*full_layers)
        # self.actnorm = ActNorm2D(C) # extremely slow
        self.actnorm = None
        self.classify = classify
        if self.classify:
            last_layer = layers_append([], act,
                                       C-Y_dim, 32, 1, version='three_FC')
            self.D_output = nn.Sequential(*last_layer)

    def forward(self, x, edge_index, edge_weight):
        if self.graph:
            output = self.bottleneck_block(x, edge_index, edge_weight)
        else:
            output = self.bottleneck_block(x)
        if self.classify:
            # For the min-max GAN
            offset = 1e-4
            output = torch.nn.Sigmoid()(self.D_output(output))
            offset_vec = torch.zeros(output.size()).to(device)
            if (output < offset).sum() > 0:
                offset_vec[output < offset] = (
                    offset - output[output < offset]).clone().detach()
            if (output > 1-offset).sum() > 0:
                offset_vec[output > 1
                           - offset] = -(output[output > 1-offset]-(1-offset)).clone().detach()
            return output+offset_vec
            # # For Wasserstain GAN
            # return self.D_output(output)
        else:
            return output

    def set_requires_grad(self, TorF):
        for param in self.parameters():
            param.requires_grad = TorF

# # RNVP
#
#
# class R_NVP(nn.Module):
#     def __init__(self, d, c=0, hidden=64, version='three_FC'):
#         super().__init__()
#         self.d, self.c = d, c
#         # act = nn.ELU(inplace=True)
#         act = nn.LeakyReLU(inplace=True)
#         k = int((d+c)/2)
#         self.k = k
#         layers1 = layers_append([], act, d+c-k, hidden,
#                                 k, version=version)
#         layers2 = layers_append([], act, k, hidden,
#                                 d+c-k, version=version)
#         self.s1 = nn.Sequential(*layers1)
#         self.t1 = nn.Sequential(*layers1.copy())
#         self.s2 = nn.Sequential(*layers2)
#         self.t2 = nn.Sequential(*layers2.copy())
#
#     def forward(self, X, Y):
#         # Y denotes the condition variable
#         if self.c > 0:
#             # 1st block, takes in conditional variable
#             input = torch.cat([X, Y], dim=-1)
#         else:
#             input = X
#         x1, x2 = input[:, :self.k], input[:, self.k:]
#         s1_out = self.s1(x2)
#         v1 = x1*torch.exp(s1_out)+self.t1(x2)
#         s2_out = self.s2(v1)
#         v2 = x2*torch.exp(s2_out)+self.t2(v1)
#         v_tot = torch.cat([v1, v2], dim=-1)
#         log_jacob = s1_out.sum(-1)+s2_out.sum(-1)
#         return v_tot, log_jacob
#
#     def inverse(self, Z, Y):
#         if self.c1 > 0:
#             # Last block, takes in conditional variable
#             input = torch.cat([Z, Y], dim=-1)
#         else:
#             input = Z
#         z1, z2 = input[:, :self.k], input[:, self.k:]
#         x2_hat = (z2 - self.t2(z1)) * torch.exp(-self.s2(z1))
#         x1_hat = (z1 - self.t1(x2_hat)) * torch.exp(-self.s1(x2_hat))
#         return torch.cat([x1_hat, x2_hat], -1)
#
#
# class stacked_NVP(nn.Module):
#     # NOTE: this is very easy to be non-invertible...
#     def __init__(self, d, c, hidden, num_b, version='one_Cheb_two_FC'):
#         super().__init__()
#         self.d = d
#         bijectors = [R_NVP(d, c, hidden=hidden, version=version)]
#         for _ in range(num_b-1):
#             bijectors += [R_NVP(d+c, 0, hidden=hidden, version=version)]
#         self.bijectors = nn.ModuleList(bijectors)
#         self.small_weight(factor=0.15)
#
#     def forward(self, X, Y):
#         log_jacobs = []
#         for bijector in self.bijectors:
#             X, log_jac = bijector(X, Y)
#             log_jacobs.append(log_jac)
#         return X[:, :self.d], sum(log_jacobs)
#
#     def inverse(self, Z, Y):
#         self.bijectors[-1].c1 = 1  # place holder
#         for k in range(len(self.bijectors)-1):
#             self.bijectors[k].c1 = 0
#         with torch.no_grad():
#             for bijector in reversed(self.bijectors):
#                 Z = bijector.inverse(Z, Y)
#         return Z[:, :self.d]
#
#     def small_weight(self, factor):
#         # Scale parameter, as o/w the torch.exp() can make things explode
#         for W in self.parameters():
#             with torch.no_grad():
#                 W.mul_(factor)
#             W.requires_grad = True

# class R_NVP(nn.Module):
#     def __init__(self, d, k, c, hidden, version='three_FC'):
#         super().__init__()
#         self.d, self.k = d, k
#         act = nn.ELU(inplace=True)
#         layers1 = layers_append([], act, d-k+c, hidden,
#                                 k, version=version)
#         layers2 = layers_append([], act, k+c, hidden,
#                                 d-k, version=version)
#         self.s1 = nn.Sequential(*layers1)
#         self.t1 = nn.Sequential(*layers1.copy())
#         self.s2 = nn.Sequential(*layers2)
#         self.t2 = nn.Sequential(*layers2.copy())

#     def forward(self, X, Y):
#         # Y denotes the condition variable
#         x1, x2 = X[:, :self.k], X[:, self.k:]
#         x2_long = torch.cat([x2, Y], dim=-1)
#         s1_out, t1_out = self.s1(x2_long), self.t1(x2_long)
#         v1 = x1*torch.exp(s1_out)+t1_out
#         s2_out, t2_out = self.s2(
#             torch.cat([v1, Y], dim=-1)), self.t2(torch.cat([v1, Y], dim=-1))
#         v2 = x2*torch.exp(s2_out)+t2_out
#         v_tot = torch.cat([v1, v2], dim=-1)
#         log_jacob = s1_out.sum(-1)+s2_out.sum(-1)
#         return v_tot, log_jacob

#     def inverse(self, Z, Y):
#         z1, z2 = Z[:, :self.k], Z[:, self.k:]
#         z1_long = torch.cat([z1, Y], dim=-1)
#         x2_hat = (z2 - self.t2(z1_long)) * \
#             torch.exp(-self.s2(z1_long))
#         x1_hat = (z1 - self.t1(torch.cat([x2_hat, Y], dim=-1))) * \
#             torch.exp(-self.s1(torch.cat([x2_hat, Y],
#                       dim=-1)))
#         return torch.cat([x1_hat, x2_hat], -1)


# class stacked_NVP(nn.Module):
#     def __init__(self, d, k, c, hidden, num_b, version='one_Cheb_two_FC'):
#         super().__init__()
#         self.bijectors = nn.ModuleList([
#             R_NVP(d, k, c, hidden=hidden, version=version) for _ in range(num_b)
#         ])
#         self.small_weight(factor=0.5)

#     def forward(self, X, Y):
#         log_jacobs = []
#         for bijector in self.bijectors:
#             X, log_jac = bijector(X, Y)
#             log_jacobs.append(log_jac)
#         return X, sum(log_jacobs)

#     def inverse(self, z, Y):
#         with torch.no_grad():
#             for bijector in reversed(self.bijectors):
#                 z = bijector.inverse(z, Y)
#         return z

#     def small_weight(self, factor):
#         # Scale parameter, as o/w the torch.exp() can make things explode
#         for W in self.parameters():
#             with torch.no_grad():
#                 W.mul_(factor)
#             W.requires_grad = True

# Append net:


def layers_append(layers, act, C, dim, C1, K=3, version='one_Cheb_two_FC', A_=None):
    if version == 'one_GCN_one_FC':
        layers.append((GCNConv(C, dim), 'x, edge_index, edge_weight -> x'))
        layers.append(act)
        layers.append(torch.nn.Linear(dim, C1))
    if version == 'one_GCN_two_FC':
        layers.append((GCNConv(C, dim), 'x, edge_index, edge_weight -> x'))
        layers.append(act)
        layers.append(torch.nn.Linear(dim, dim))
        layers.append(act)
        layers.append(torch.nn.Linear(dim, C1))
    if version == 'two_GCN_one_FC':
        layers.append((GCNConv(C, dim), 'x, edge_index, edge_weight -> x'))
        # layers.append(pyg.nn.BatchNorm(dim)) # Some issues existed, IDK why
        layers.append(act)
        layers.append(
            (GCNConv(dim, dim), 'x, edge_index, edge_weight -> x'))
        layers.append(act)
        layers.append(torch.nn.Linear(dim, C1))
    if version == 'one_Cheb_two_FC':
        layers.append(
            (ChebConv(C, dim, K=K), 'x, edge_index, edge_weight -> x'))
        layers.append(act)
        layers.append(torch.nn.Linear(dim, dim))
        # # NOTE, Batchnorm makes invertibility somehow not hold, but transpose etc. works
        # layers.append(trans)
        # layers.append(pyg.nn.BatchNorm(dim))
        # layers.append(trans)
        layers.append(act)
        layers.append(torch.nn.Linear(dim, C1))
    if version == 'one_Cheb_three_FC':
        layers.append(
            (ChebConv(C, dim, K=K), 'x, edge_index, edge_weight -> x'))
        layers.append(act)
        layers.append(torch.nn.Linear(dim, dim))
        layers.append(act)
        layers.append(torch.nn.Linear(dim, dim))
        layers.append(act)
        layers.append(torch.nn.Linear(dim, C1))
    if version == 'two_Cheb_two_FC':
        layers.append(
            (ChebConv(C, dim, K=K), 'x, edge_index, edge_weight -> x'))
        layers.append(act)
        layers.append(
            (ChebConv(dim, dim, K=K), 'x, edge_index, edge_weight -> x'))
        layers.append(act)
        layers.append(torch.nn.Linear(dim, dim))
        # # NOTE, Batchnorm makes invertibility somehow not hold, but transpose etc. works
        # layers.append(trans)
        # layers.append(pyg.nn.BatchNorm(dim))
        # layers.append(trans)
        layers.append(act)
        layers.append(torch.nn.Linear(dim, C1))
    if version == 'one_Cheb':
        layers.append(
            (ChebConv(C, dim, K=K), 'x, edge_index, edge_weight -> x'))
    if version == 'three_FC':
        layers.append(nn.Linear(C, dim))
        layers.append(act)
        layers.append(nn.Linear(dim, dim))
        layers.append(act)
        layers.append(nn.Linear(dim, C1))
    if version == 'four_FC':
        layers.append(nn.Linear(C, dim))
        layers.append(act)
        layers.append(nn.Linear(dim, dim))
        layers.append(act)
        layers.append(nn.Linear(dim, dim))
        layers.append(act)
        layers.append(nn.Linear(dim, C1))
    # For L3net:
    trans = input_tranpose(1, 2)
    order_list = [1]
    if A_ is not None:
        if A_.shape[0] == 3:
            # The simulation example
            order_list = [0, 1, 2]
        if A_.shape[0] == 10:
            # Solar data
            order_list = [1, 2]  # Two bases, with 1 & 2 hop neighbors
        if A_.shape[0] == 20 or A_.shape[0] == 15:
            # Traffic data
            # order_list = [1, 2, 2]  # Two bases, with 1 & 2 hop neighbors
            order_list = [0, 1, 2]
    if version == 'one_L3_two_FC':
        layers.append(trans)
        layers.append(GraphConv_Bases(C, dim, A_, order_list=order_list))
        layers.append(trans)
        layers.append(act)
        layers.append(nn.Linear(dim, dim))
        layers.append(act)
        layers.append(nn.Linear(dim, C1))
    if version == 'one_L3_three_FC':
        layers.append(trans)
        layers.append(GraphConv_Bases(C, dim, A_, order_list=order_list))
        layers.append(trans)
        layers.append(act)
        layers.append(nn.Linear(dim, dim))
        layers.append(act)
        layers.append(nn.Linear(dim, dim))
        layers.append(act)
        layers.append(nn.Linear(dim, C1))
    if version == 'two_L3_two_FC':
        layers.append(trans)
        layers.append(GraphConv_Bases(C, dim, A_, order_list=order_list))
        layers.append(GraphConv_Bases(dim, dim, A_, order_list=order_list))
        layers.append(trans)
        layers.append(act)
        layers.append(nn.Linear(dim, dim))
        layers.append(act)
        layers.append(nn.Linear(dim, C1))
    if version == 'one_L3':
        layers.append(trans)
        order_list = [0, 1, 2]
        layers.append(GraphConv_Bases(C, dim, A_, order_list=order_list))
        layers.append(trans)
    return layers
# Small nets


class SmallGenNet(nn.Module):
    '''
        Yield the conditional mean of the base distribution using one-hot encoded response Y
    '''

    def __init__(self, Y_dim, C):
        super().__init__()
        self.fc = nn.Linear(Y_dim, C)

    def forward(self, Y):
        return self.fc(Y)


# Helpers
def batch_jacobian(func, x, edge_index=None, edge_weight=None):
    # Basically apply the jacobian function on each sample in the batch
    # x in shape (Batch, Length)
    def _func_sum(x):
        if edge_index is not None:
            return func(x, edge_index, edge_weight)[0].sum(dim=0)
        else:
            return func(x)[0].sum(dim=0)
    return jacobian(_func_sum, x, create_graph=True).permute(1, 0, 2)


#######
