In [None]:
import seaborn as sns
from string import ascii_lowercase
import matplotlib
#import svgutils.transform as st
from matplotlib import gridspec, pyplot as plt


matplotlib.rcParams["legend.labelspacing"] = 0.2
matplotlib.rcParams["legend.fontsize"] = 8
matplotlib.rcParams["xtick.major.pad"] = 1.0
matplotlib.rcParams["ytick.major.pad"] = 1.0
matplotlib.rcParams["xtick.minor.pad"] = 0.9
matplotlib.rcParams["ytick.minor.pad"] = 0.9
matplotlib.rcParams["legend.handletextpad"] = 0.5
matplotlib.rcParams["legend.handlelength"] = 0.5
matplotlib.rcParams["legend.framealpha"] = 0.5
matplotlib.rcParams["legend.markerscale"] = 0.7
matplotlib.rcParams["legend.borderpad"] = 0.35


def getSetup(figsize, gridd, multz=None, empts=None):
    """ Establish figure set-up with subplots. """
    sns.set(style="whitegrid", font_scale=0.7, color_codes=True, palette="colorblind", rc={"grid.linestyle": "dotted", "axes.linewidth": 0.6})

    # create empty list if empts isn't specified
    if empts is None:
        empts = []

    if multz is None:
        multz = dict()

    # Setup plotting space and grid
    f = plt.figure(figsize=figsize, constrained_layout=True)
    gs1 = gridspec.GridSpec(*gridd, figure=f)

    # Get list of axis objects
    x = 0
    ax = list()
    while x < gridd[0] * gridd[1]:
        if x not in empts and x not in multz.keys():  # If this is just a normal subplot
            ax.append(f.add_subplot(gs1[x]))
        elif x in multz.keys():  # If this is a subplot that spans grid elements
            ax.append(f.add_subplot(gs1[x: x + multz[x] + 1]))
            x += multz[x]
        x += 1

    return (ax, f)


def subplotLabel(axs):
    """ Place subplot labels on figure. """
    for ii, ax in enumerate(axs):
        ax.text(-0.2, 1.2, ascii_lowercase[ii], transform=ax.transAxes, fontsize=16, fontweight="bold", va="top")

In [None]:
def new_plot_latent(axis, autoencoder, data, dims=[0,1], exp_data=None):
    assert len(dims) == 2
    for i, x in enumerate(data):
        z = autoencoder.encoder(x.to(device))
        z = z.to('cpu').detach().numpy()
        #hacky way to label data
        if i < 99:
            axis.scatter(z[dims[0]], z[dims[1]], c='k', label = ('sin' if i == 1 else None))
        elif i < 199:
            axis.scatter(z[dims[0]], z[dims[1]], c='b', label = ('pol' if i == 101 else None))
        else:
            axis.scatter(z[dims[0]], z[dims[1]], c='g', label = ('rand' if i == 201 else None))

    if exp_data != None:
        for i, x in enumerate(exp_data):
            z = autoencoder.encoder(x.to(device))
            z = z.to('cpu').detach().numpy()
            axis.scatter(z[dims[0]], z[dims[1]], c='r', label = ('Experimental' if i == 1 else None))

    axis.set_xlabel('Latent Dim {}'.format(dims[0]))
    axis.set_ylabel('Latent Dim {}'.format(dims[1]))
    axis.set_title('Kymograph position in Latent Space')
    axis.legend

In [None]:
ax, f = getSetup((8, 5), (2, 2))

for i, ax in enumerate(ax):
    new_plot_latent(ax, VAE, dataset, dims=[i,i+1])

In [None]:
ax, f = getSetup((8, 5), (2, 2))

for i, ax in enumerate(ax):
    new_plot_latent(ax, VAE, dataset, dims=[i,i+2])