In [None]:
from timeit import default_timer as timer
import matplotlib.pyplot as plt
import numpy as np
import torch
import pdb
import matplotlib.cm as cm
import os
from scipy.stats import kde
from moviepy.editor import *
from matplotlib.ticker import MaxNLocator
# from pdf2image import convert_from_path
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def visualize_generation_one_graph(self, X, Y, H_full, Y_row=None):
    '''
        self: the object which has many methods, including computing L_g
        Y_row: a specific choice of Y such that we only show inverse at this Y
        If None, plot everything
    '''
    # For two moon, after training the models
    # Basically visualize how the original density is gradually transformed to the data density
    # NOTE: due to speed in inversion, we just examine result over a subset of total data
    plt.rcParams['axes.titlesize'] = 18
    plt.rcParams['legend.fontsize'] = 13
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14
    plt.rcParams['figure.titlesize'] = 24
    which_rows = (Y == Y_row).all(
        dim=1) if Y_row is not None else torch.tensor([True]).repeat(X.shape[0])
    with torch.no_grad():
        # Somehow much have it moved to cpu
        batch_idx = np.arange(X.shape[0])[which_rows.cpu()]
        self.viz = True
        L_g_now = np.around(self.get_L_g(batch_idx, X, Y).item(), 2)
        self.viz = False
    X, Y, H_full = X[which_rows], Y[which_rows], H_full[which_rows]
    start = timer()
    X_pred = self.model.inverse(H_full, self.edge_index).cpu()
    H_full = H_full.cpu().detach()
    N = X.shape[0]
    print(f'Invert {N} samples took {timer()-start} secs')
    with torch.no_grad():
        H_pred = self.model.forward(X.flatten(
            start_dim=1), self.edge_index, logdet=False).cpu()
    H_pred = H_pred.reshape(X.shape)
    X = X.cpu()
    # # Visualize X and Inverse of H
    # num_to_plot = 1000 if '8_gaussian' in self.path else 600
    # if self.V > 1 or self.C > 2:
    #     num_to_plot = 100
    num_to_plot = N
    plt_generation_fig(self, X[:num_to_plot], X_pred[:num_to_plot],
                       Y[:num_to_plot], H_full[:num_to_plot], H_pred[:num_to_plot], L_g_now)
    if self.C == 1:
        # Graph GP, so we want to visualize the covariances
        plot_and_save_corr(self, X, H_full)
    # Also report quantitative metrics:
    if self.final_viz:
        # Record num of obs.
        X_sub, X_pred = X.flatten(
            start_dim=1), X_pred.flatten(start_dim=1)
        self.two_sample_stat[Y_row] = [N]
        for method in ['MMD', 'Energy']:
            if method == 'MMD':
                for alphas in [[0.1], [1.0], [5.0], [10.0]]:
                    ret = self.two_sample_mtd(
                        X_sub, X_pred, alphas=alphas, method=method)
                    self.two_sample_stat[Y_row].append(ret)
            else:
                ret = self.two_sample_mtd(X_sub, X_pred, method=method)
                self.two_sample_stat[Y_row].append(ret)


def save_trajectory_revised(self, X, Y, H_full):
    '''
        NOTE: Here X can either be true data sample OR the base sample
        Then the def. of H_full also changes
        If X are base samples, then we must use forward mapping of blocks
    '''
    V_tmp = X.shape[1]
    savedir = f'{self.path}'
    fsize = 22
    N = X.shape[0]
    blocks = self.model.blocks if self.from_X_to_H else reversed(
        self.model.blocks)
    X_np, H_np = X.flatten(start_dim=0, end_dim=1), H_full.flatten(
        start_dim=0, end_dim=1)
    if self.C > 2:
        V_tmp = int(self.C/2)
        C_tmp = 2
        X_np = X.reshape(N, V_tmp, C_tmp).flatten(start_dim=0, end_dim=1)
        H_np = H_full.reshape(N, V_tmp, C_tmp).flatten(start_dim=0, end_dim=1)
    xmin, xmax = min(X_np[:, 0].min(), H_np[:, 0].min()).item(), max(
        X_np[:, 0].max(), H_np[:, 0].max()).item()
    ymin, ymax = min(X_np[:, 1].min(), H_np[:, 1].min()).item(), max(
        X_np[:, 1].max(), H_np[:, 1].max()).item()
    with torch.no_grad():
        t = 0
        # Gradually invert H_full through each layer to see how it matches the original density
        for block in blocks:
            if self.from_X_to_H:
                # Here H_full is actually X
                X_pred, Fx = block(
                    X.flatten(start_dim=1), self.edge_index, self.edge_weight) if self.edge_index is not None else block(X.flatten(start_dim=1))
                transport_cost = (torch.linalg.norm(Fx.flatten(start_dim=1),
                                                    dim=1)**2/2).sum().item()/N
                X_pred = X_pred.reshape(X.shape)
                self.transport_cost_XtoH_ls.append(transport_cost)
            else:
                if self.C > 2:
                    H_full = H_full.flatten(start_dim=1)
                X_pred = block.inverse(
                    H_full, self.edge_index, self.edge_weight) if self.edge_index is not None else block.inverse(H_full)
                transport_cost = (torch.linalg.norm(
                    (X_pred-H_full).flatten(start_dim=1), dim=1)**2/2).sum().item()/N
                self.transport_cost_HtoX_ls.append(transport_cost)
            if self.C > 2:
                V_tmp = int(self.C/2)
                C_tmp = 2
                X = X.reshape(N, V_tmp, C_tmp)
                X_pred = X_pred.reshape(N, V_tmp, C_tmp)
                H_full = H_full.reshape(N, V_tmp, C_tmp)
            if self.from_X_to_H:
                transport_cost_ls = self.transport_cost_XtoH_ls
            else:
                transport_cost_ls = self.transport_cost_HtoX_ls
            # Include transport cost on the top
            fig = plt.figure(figsize=(8, 11))
            spec = fig.add_gridspec(5, 2)
            # Plot transport cost
            ax = fig.add_subplot(spec[0, :])
            ax.plot(transport_cost_ls, '-o')
            # ax.set_xlabel('Block')
            if self.from_X_to_H:
                ax.set_title(
                    r'$W_2$ transport cost of $X \rightarrow H$ over blocks')
            else:
                ax.set_title(
                    r'$W_2$ transport cost of $H \rightarrow X$ over blocks')
            ax.xaxis.set_major_locator(MaxNLocator(integer=True))
            ax.set_facecolor('lightblue')
            colors = np.tile(cm.rainbow(np.linspace(0, 1, V_tmp)), (N, 1))
            if V_tmp == 1:
                # Two-moon or 8_gaussian
                if '8_gaussian' in self.path:
                    colors = np.repeat('r', N)
                    colors[(Y == 1).cpu().detach().numpy().flatten()] = 'm'
                    colors[(Y == 2).cpu().detach().numpy().flatten()] = 'y'
                    colors[(Y == 3).cpu().detach().numpy().flatten()] = 'k'
                else:
                    colors = np.repeat('black', N)
                    colors[(Y == 1).cpu().detach().numpy().flatten()] = 'blue'
            plt_dict = {0: X, 1: X_pred}
            if self.from_X_to_H:
                plt_dict[0] = H_full
            # Plot target and estimates
            for j in range(2):
                ax = fig.add_subplot(
                    spec[1:3, 0]) if j == 0 else fig.add_subplot(spec[3:, 0])
                ax.set_facecolor('lightblue')
                ax.set_xlim(xmin, xmax)
                ax.set_ylim(ymin, ymax)
                if j == 0:
                    ax.xaxis.set_visible(False)
                if self.from_X_to_H:
                    title = r'Targets $H$' if j == 0 else r'Estimates $\hat{H}$'
                else:
                    title = r'Targets $X$' if j == 0 else r'Estimates $\hat{X}$'
                XorXpred = plt_dict[j]
                XorXpred_tmp = XorXpred.flatten(
                    start_dim=0, end_dim=1).cpu().numpy()
                if self.V > 1 or (self.V == 1 and self.C > 2):
                    ax.plot(XorXpred_tmp[:, 0], XorXpred_tmp[:, 1],
                            linestyle='dashed', linewidth=0.075)
                ax.scatter(XorXpred_tmp[:, 0],
                           XorXpred_tmp[:, 1], color=colors)
                ax.set_title(title, fontsize=fsize)
            # plot the density
            # # Not including transport cost on the top
            ax = fig.add_subplot(spec[1:3, 1])
            ax.set_facecolor('lightblue')
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(ymin, ymax)
            X_pred_tmp = X_pred.flatten(start_dim=0, end_dim=1).cpu().numpy()
            # Try to get density overlaid but different colors, since I have multiple blobs
            x, y = X_pred_tmp[:, 0], X_pred_tmp[:, 1]
            xy = np.vstack([x, y])
            k = kde.gaussian_kde([x, y])(xy)
            ax.scatter(x, y, c=k, s=2)
            if self.from_X_to_H:
                title = r"Density of $\hat{H}$"
            else:
                title = r"Density of $\hat{X}$"
            ax.set_title(title, fontsize=fsize)
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            # plot the vector field
            # # Not including transport cost on the top
            ax = fig.add_subplot(spec[3:, 1])
            ax.set_facecolor('lightblue')
            ax.set_xlim(xmin, xmax)
            ax.set_ylim(ymin, ymax)
            X_pred_pre = X if self.from_X_to_H else H_full
            X_pred_pre = X_pred_pre.flatten(
                start_dim=0, end_dim=1).cpu().numpy()
            directions = X_pred_tmp - X_pred_pre
            logmag = 2 * \
                np.log(np.hypot(directions[:, 0], directions[:, 1]))
            # Smaller scale = larger arrow
            ax.quiver(
                x, y, directions[:, 0], directions[:, 1],
                np.exp(logmag), cmap="coolwarm", scale=3.5, width=0.015, pivot="mid")
            ax.set_title("Vector Field", fontsize=fsize)
            ax.yaxis.set_visible(False)
            fig.tight_layout()
            plt.savefig(os.path.join(
                savedir, f"viz-{t:05d}.jpg"))
            plt.show()
            # Update H or X as input for next plot
            if self.from_X_to_H:
                X = X_pred.clone()
            else:
                H_full = X_pred.clone()
            t += 1


def trajectory_to_gif(self):
    import subprocess
    savedir = f'{self.path}'
    # Smaller framerate reduces picture speed (desirable if num blocks small)
    # 10 for 40 blocks was pretty fast
    suffix = '_XtoH' if self.from_X_to_H else '_HtoX'
    out_path = os.path.join(savedir, f'trajectory_epoch{self.epoch}{suffix}')
    bashCommand = 'ffmpeg -framerate 5 -y -i {} {}'.format(os.path.join(
        savedir, 'viz-%05d.jpg'), out_path+'.mp4')
    process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
    output, error = process.communicate()
    clip = (VideoFileClip(out_path+'.mp4'))
    clip.write_gif(out_path + '.gif')


def plot_and_save_corr(self, X, H_full):
    X_pred = self.model.inverse(H_full, self.edge_index)
    H_pred = self.model.forward(
        X.flatten(start_dim=1), self.edge_index, logdet=False)
    S_X_true = get_corrcoef(X)
    S_X_est = get_corrcoef(X_pred)
    S_H_true = get_corrcoef(H_full)
    S_H_est = get_corrcoef(H_pred)
    fig, ax = plt.subplots(2, 3, figsize=(12, 8), constrained_layout=True)
    ax[0, 0].matshow(S_X_true)
    ax[0, 0].set_title(r'Correlation of $X$')
    c = ax[0, 1].matshow(S_X_est)
    ax[0, 1].set_title(r'Correlation of $\hat{X}$')
    cbar_ax = fig.add_axes([-0.15, 0.5, 0.1, 0.4])
    plt.colorbar(c, cax=cbar_ax)
    c1 = ax[0, 2].matshow(S_X_true-S_X_est)
    ax[0, 2].set_title(r'Diff. of Correlation in $X$')
    cbar_ax = fig.add_axes([1, 0.5, 0.1, 0.4])
    plt.colorbar(c1, cax=cbar_ax)
    ax[1, 0].matshow(S_H_true)
    ax[1, 0].set_title(r'Correlation of $H$')
    c = ax[1, 1].matshow(S_H_est)
    ax[1, 1].set_title(r'Correlation of $\hat{H}$')
    cbar_ax = fig.add_axes([-0.15, 0, 0.1, 0.4])
    plt.colorbar(c, cax=cbar_ax)
    c1 = ax[1, 2].matshow(S_H_true-S_H_est)
    ax[1, 2].set_title(r'Diff. of Correlation in $H$')
    cbar_ax = fig.add_axes([1, 0, 0.1, 0.4])
    plt.colorbar(c1, cax=cbar_ax)
    self.fig_corr = fig


def get_corrcoef(input):
    if len(input.shape) < 2:
        raise ValueError('Inpit must be at least 2 dimensional.')
    if len(input.shape) > 2:
        return torch.corrcoef(input[:, :, 0].T).cpu().detach().numpy()
    else:
        return torch.corrcoef(input.T).cpu().detach().numpy()


def plt_generation_fig(self, X, X_pred, Y, H_full, H_pred, L_g_now):
    plt_dict = {0: X, 1: X_pred, 2: H_full, 3: H_pred}
    V_tmp = X.shape[1]
    N = X.shape[0]
    if self.C > 2:
        # NOTE: this is because FC treated graph example in R^V-x-C as a vector in \R^V-by-C, so that we need reshaping for visualization
        V_tmp = int(self.C/2)
        C_tmp = 2
        for key in plt_dict.keys():
            plt_dict[key] = plt_dict[key].reshape(N, V_tmp, C_tmp)
    if self.final_viz and self.plot_sub:
        title_dict = {
            0: r'$X|Y$', 1: r'$\hat{X}|Y=F^{-1}(H|Y)$'}
        fig, axs = plt.subplots(1, 2, figsize=(2 * 4, 4))
    else:
        title_dict = {
            0: r'$X|Y$', 1: r'$\hat{X}|Y=F^{-1}(H|Y)$', 2: r'$H|Y$', 3: r'$\hat{H}|Y=F(X|Y)$'}
        fig, axs = plt.subplots(1, 4, figsize=(4 * 4, 4))
    # Plot X and X_pred=F^-1(H)
    which_cmap = cm.coolwarm
    if 'solar' in self.path or 'traffic' in self.path:
        markersize = 20
        lwidth = 0.025
        X = plt_dict[0]
        vars = torch.var(X, dim=[0, 2]).cpu().detach()
        vars, idx = torch.sort(vars, descending=True)
        # All between 0 and 1
        vars = ((vars-vars.min())/(vars.max()-vars.min()))
        cutoff = 0.7
        vars[vars > cutoff] = vars[vars > cutoff]**2  # Make them lighter
        vars = torch.flip(vars, dims=(0,)).numpy()
    else:
        lwidth = 0.075
        vars = np.linspace(0, 1, V_tmp)
    print(f'1st Var to Last Var, lightest to darkest: {vars}')
    colors = np.tile(which_cmap(vars), (X.shape[0], 1))
    if V_tmp == 1:
        # Two-moon or 8_gaussian
        if '8_gaussian' in self.path:
            colors = np.repeat('r', N)
            colors[(Y == 1).cpu().detach().numpy().flatten()] = 'm'
            colors[(Y == 2).cpu().detach().numpy().flatten()] = 'y'
            colors[(Y == 3).cpu().detach().numpy().flatten()] = 'k'
        else:
            colors = np.repeat('black', N)
            colors[(Y == 1).cpu().detach().numpy().flatten()] = 'blue'
    for j in range(len(title_dict)):
        ax, ax1 = axs[j], axs[0]
        if j > 1:
            ax2 = axs[2]
        XorH = plt_dict[j]
        XorXpred_tmp = XorH.flatten(start_dim=0, end_dim=1).numpy()
        if self.C == 1:
            XorXpred_tmp = np.c_[XorXpred_tmp, np.zeros(XorXpred_tmp.shape)]
        if self.V > 1 or (self.V == 1 and self.C > 2):
            ax.plot(XorXpred_tmp[:, 0], XorXpred_tmp[:, 1],
                    linestyle='dashed', linewidth=lwidth)
        if 'solar' in self.path or 'traffic' in self.path:
            ax.scatter(XorXpred_tmp[:, 0],
                       XorXpred_tmp[:, 1], color=colors, s=markersize)
        else:
            ax.scatter(XorXpred_tmp[:, 0],
                       XorXpred_tmp[:, 1], color=colors)
        ax.set_title(title_dict[j])
        if j < 2:
            X_tmp = plt_dict[0].flatten(start_dim=0, end_dim=1).numpy()
            ax.set_xlim(left=X_tmp[:, 0].min(), right=X_tmp[:, 0].max())
            ax.set_ylim(bottom=X_tmp[:, 1].min(), top=X_tmp[:, 1].max())
        if j == 1:
            ax.get_shared_x_axes().join(ax1, ax)
            ax.get_shared_y_axes().join(ax1, ax)
        if j == 3:
            ax.get_shared_x_axes().join(ax2, ax)
            ax.get_shared_y_axes().join(ax2, ax)
            ax.set_title(
                f'{title_dict[j]}, L_g is {L_g_now}')
    fig.tight_layout()
    self.fig_gen = fig
    plt.show()


def plt_generation_fig_competitor(self, X, X_pred, Y, H=None, H_pred=None):
    plt.rcParams['axes.titlesize'] = 18
    plt.rcParams['legend.fontsize'] = 13
    plt.rcParams['xtick.labelsize'] = 14
    plt.rcParams['ytick.labelsize'] = 14
    plt.rcParams['figure.titlesize'] = 24
    X = X.cpu()
    N = X.shape[0]
    # num_to_plot = 1000 if '8_gaussian' in self.path else 600
    # if self.V > 1 or self.C > 2:
    #     num_to_plot = 100
    num_to_plot = N
    X, X_pred, Y = X[:num_to_plot], X_pred[:num_to_plot], Y[:num_to_plot]
    N = X.shape[0]
    if self.final_viz:
        plt_dict = {0: X_pred}
        title_dict = {0: r'$\hat{X}|Y=G^{-1}(H, Y)$'}
        fig, ax = plt.subplots(figsize=(4, 4))
    else:
        H, H_pred = H[:num_to_plot], H_pred[:num_to_plot]
        plt_dict = {0: X, 1: X_pred, 2: H, 3: H_pred}
        title_dict = {
            0: r'$X|Y$', 1: r'$\hat{X}|Y=G^{-1}(H, Y)$', 2: r'$H$', 3: r'$\hat{H}=G(X, Y)$'}
        fig, axs = plt.subplots(1, 4, figsize=(4 * 4, 4))
    V_tmp = X.shape[1]
    # Plot X and X_pred=F^-1(H)
    which_cmap = cm.coolwarm
    if 'solar' in self.path or 'traffic' in self.path:
        markersize = 20
        lwidth = 0.025
        vars = torch.var(X, dim=[0, 2]).cpu().detach()
        vars, idx = torch.sort(vars, descending=True)
        # All between 0 and 1
        vars = ((vars-vars.min())/(vars.max()-vars.min()))
        cutoff = 0.7
        vars[vars > cutoff] = vars[vars > cutoff]**2  # Make them lighter
        vars = torch.flip(vars, dims=(0,)).numpy()
    else:
        lwidth = 0.075
        vars = np.linspace(0, 1, V_tmp)
    print(f'1st Var to Last Var, lightest to darkest: {vars}')
    colors = np.tile(which_cmap(vars), (X.shape[0], 1))
    if V_tmp == 1:
        # Two-moon or 8_gaussian
        if '8_gaussian' in self.path:
            colors = np.repeat('r', N)
            colors[(Y == 1).cpu().detach().numpy().flatten()] = 'm'
            colors[(Y == 2).cpu().detach().numpy().flatten()] = 'y'
            colors[(Y == 3).cpu().detach().numpy().flatten()] = 'k'
        else:
            colors = np.repeat('black', N)
            colors[(Y == 1).cpu().detach().numpy().flatten()] = 'blue'
    for j in range(len(title_dict)):
        if len(title_dict) > 1:
            ax, ax1 = axs[j], axs[0]
        if j > 1:
            ax2 = axs[2]
        XorH = plt_dict[j]
        XorXpred_tmp = XorH.flatten(start_dim=0, end_dim=1).numpy()
        if self.C == 1:
            XorXpred_tmp = np.c_[XorXpred_tmp, np.zeros(XorXpred_tmp.shape)]
        if self.V > 1 or (self.V == 1 and self.C > 2):
            ax.plot(XorXpred_tmp[:, 0], XorXpred_tmp[:, 1],
                    linestyle='dashed', linewidth=lwidth)
        if 'solar' in self.path or 'traffic' in self.path:
            ax.scatter(XorXpred_tmp[:, 0],
                       XorXpred_tmp[:, 1], color=colors, s=markersize)
        else:
            ax.scatter(XorXpred_tmp[:, 0],
                       XorXpred_tmp[:, 1], color=colors)
        ax.set_title(title_dict[j])
        if j < 2:
            X_tmp = X.flatten(start_dim=0, end_dim=1).numpy()
            ax.set_xlim(left=X_tmp[:, 0].min(), right=X_tmp[:, 0].max())
            ax.set_ylim(bottom=X_tmp[:, 1].min(), top=X_tmp[:, 1].max())
        if j == 1:
            ax.get_shared_x_axes().join(ax1, ax)
            ax.get_shared_y_axes().join(ax1, ax)
        if j == 3:
            ax.get_shared_x_axes().join(ax2, ax)
            ax.get_shared_y_axes().join(ax2, ax)
    fig.tight_layout()
    self.fig_gen = fig
    plt.show()


def losses_and_error_plt_real_data_on_graph(loss_g_ls_train, loss_g_ls_test, loss_c_ls_train, loss_c_ls_test, classify_error_ls_train, classify_error_ls_test):
    plt.rcParams['axes.titlesize'] = 18
    plt.rcParams['font.size'] = 18
    plt.rcParams['figure.titlesize'] = 22
    plt.rcParams['legend.fontsize'] = 18
    # Quick plot
    if np.min(classify_error_ls_train) == 1:
        fig, ax = plt.subplots(figsize=(4, 4))
        ax.plot(loss_g_ls_train,  label=r'Training', color='black')
        ax.plot(loss_g_ls_test, label=r'Test', color='blue')
        ax.set_title('Generative Loss')
        ax.legend()
    else:
        fig, ax = plt.subplots(3, 1, figsize=(
            4, 8), constrained_layout=True)
        ax[0].plot(loss_g_ls_train,  label=r'Training', color='black')
        ax[0].plot(loss_g_ls_test, label=r'Test', color='blue')
        ax[1].plot(loss_c_ls_train,
                   label=r'Training', color='black')
        ax[1].plot(loss_c_ls_test,  label=r'Test', color='blue')
        ax[2].plot(classify_error_ls_train,
                   label=r'Training', color='black')
        ax[2].plot(classify_error_ls_test,  label=r'Test', color='blue')
        ax[0].set_title('Generative Loss')
        ax[1].set_title(r'$\mu*$Classification Loss')
        ax[2].set_title('Classification Error')
        for ax_now in ax:
            ax_now.legend()
    plt.show()
    plt.close()
    return fig

##########
