In [10]:
import numpy as np
from tqdm import trange
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
plt.rcParams['figure.figsize'] = (8, 6)  # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'

In [11]:
RC = np.genfromtxt("RC.csv", delimiter = ",")
print(RC)
RC.shape

[[ 0.006231  0.001632  0.       ... -0.015785  0.        0.      ]
 [ 0.000892  0.005495  0.009537 ... -0.026889 -0.014335  0.      ]
 [ 0.       -0.024729  0.       ...  0.        0.       -0.001739]
 ...
 [-0.022169  0.       -0.037553 ...  0.        0.       -0.0047  ]
 [ 0.        0.        0.       ...  0.        0.       -0.014941]
 [ 0.        0.        0.010073 ...  0.        0.        0.      ]]


(97, 400)

In [12]:
def plot_colortable(colors, text_on=True):
    # ref: https://matplotlib.org/stable/gallery/color/named_colors.html#sphx-glr-gallery-color-named-colors-py
    nunit = len(colors)
    side_length = int(np.sqrt(nunit))
    swatch_width = cell_width = cell_height = 32
    # set figs
    ncols = nrows = side_length
    width = cell_width * ncols
    height = cell_height * nrows
    dpi = 72
    fig, ax = plt.subplots(figsize=(width / dpi, height / dpi), dpi=dpi)
    # set ax axis
    ax.set_xlim(0, cell_width * ncols)
    ax.set_ylim(cell_height * (nrows), 0)
    ax.yaxis.set_visible(False)
    ax.xaxis.set_visible(False)
    ax.set_axis_off()

    for unit_idx, nunit in enumerate(range(nunit)):
        row = unit_idx // ncols
        col = unit_idx % ncols
        y = row * cell_height

        swatch_start_x = cell_width * col
        text_pos_x = cell_width * col  # + swatch_width

        if text_on:
            ax.text(text_pos_x + cell_width / 2, y + cell_height / 2, unit_idx, fontsize=6,
                    horizontalalignment='center', verticalalignment='center')

        ax.add_patch(
            Rectangle(xy=(swatch_start_x, y), width=swatch_width, height=swatch_width, facecolor=colors[unit_idx],
                      edgecolor='0.7')
            )

    return fig

def rescale(vec, qt):
    qtmin = np.quantile(vec, qt, axis=1)[:, np.newaxis]
    qtmax = np.quantile(vec, 1 - qt, axis=1)[:, np.newaxis]
    return np.minimum(np.maximum((vec - qtmin) / (qtmax - qtmin), 0), 1)

def get_colors(Vt, alpha=0.5):
    _, n = Vt.shape
    colors = []
    for i in range(n):
        colors.append((*Vt[:, i], alpha))
    return colors

def plot_PCA(Phi, filename=''):
    U, S, Vt = np.linalg.svd(Phi.T, full_matrices=False)   # Phi: 97 * 400
    principal_score = U @ np.diag(S)[:, :3]
    principal_scoreT = rescale(principal_score.T, 0.05)
    colors = get_colors(principal_scoreT, alpha=0.8)
    fig = plot_colortable(colors, text_on=False)
    if len(filename) > 0:
        fig.savefig(filename, bbox_inches='tight')
    plt.close()

plot_PCA(RC, "RC.pdf")