# Interpreting weights of the output layer
The neural network has three layers and two sets of nodes. We neglect the nonlinear activation functions, assumes the network operates in the linear regime. The input layer contains cytokines $C_i$, $i \in \{1, 2, 3, 4, 5 \}$. The intermediate layer contains our latent space representation, nodes $n_j$ (node 1 $n_1$ and node 2 $n_2$). The output layer contains un-normalized probabilities $p(k)$ to have each ligand $k$ (N4, Q4, T4, Q7, G4, E1). 

The first weight matrix is the projection matrix $P_{ji}$, which projects the cytokine data to a 2D plane in cytokine space, and gives our latent space representation. We plot those weights here, to know the composition of each latent space variable.

Then, we look at the other set of weights, the $w_{kj}$, that give the probability distribution of the ligand identities as a function of the latent space:

$$ p(k) = \frac{ e^{\sum_j w_{kj} n_j }}{ \sum_{k'} e^{\sum_j w_{k'j} n_j}} $$

We now want to know what node 1 and node 2 values correspond to a high probability for each ligand

In [None]:
%matplotlib inline
# Some imports
import numpy as np
import scipy as sp
import seaborn as sns
import pandas as pd
import pickle, json

from matplotlib import pyplot as plt
import matplotlib as mpl
from matplotlib.colors import Normalize, LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
import matplotlib.colors as clr

# Can execute from any folder and still find the necessary files by using those paths
import os, sys
main_dir_path = os.path.abspath('../')
datapath = os.path.join("data", "trained-networks")

## Aesthetics

In [None]:
def set_rcparams_1():
    # Label sizes for Science format (figure width 2.25 inches or 4.75 inches)
    # Squeezing three subplots in a row: 4.75/3 = 1.583333
    sns.reset_orig()
    plt.rcParams["figure.figsize"] = (1.55, 1.65)
    plt.rcParams["font.size"] = 8
    plt.rcParams["axes.labelsize"] = 7
    plt.rcParams["legend.fontsize"] = 7
    plt.rcParams["xtick.labelsize"] = 6
    plt.rcParams["ytick.labelsize"] = 6
    plt.rcParams["xtick.major.pad"] = 2.  # distance to major tick label in points
    plt.rcParams["xtick.minor.pad"] = 2.
    plt.rcParams["axes.labelpad"] = 1.
    plt.rcParams["axes.linewidth"] = 0.8
    #plt.rcParams["axes.spines.top"] = False
    #plt.rcParams["axes.spines.right"] = False
    plt.rcParams['figure.dpi'] = 250 # default for me was 75
    
def set_rcparams_2():
    # Science parameters
    plt.rcParams["figure.figsize"] = [2.5, 2.]
    plt.rcParams["lines.linewidth"] = 2.
    plt.rcParams["font.size"] = 8.
    plt.rcParams["axes.labelsize"] = 8.
    plt.rcParams["legend.fontsize"] = 8.
    plt.rcParams["xtick.labelsize"] = 6.
    plt.rcParams["ytick.labelsize"] = 6.

In [None]:
def build_colors(whichc="latent"):
    if whichc == "latent":
        latent_colors = [list(clr.to_rgba(a)) for a in ["crimson", "goldenrod", "maroon"]]  # both, node 1, node 2
        latent_colors[1] = sns.set_hls_values(color=latent_colors[1], h=None, l=0.6, s=None)  # making goldenrod lighter
        latent_colors[0] = sns.set_hls_values(color=latent_colors[0], h=None, l=0.5, s=None)  # make crimson lighter
        palet = latent_colors[1:]
    elif whichc == "cytokines":
        cyt_palette = sns.cubehelix_palette(5, start=.5, rot=-.75)  # blue-green colors
        # Based on MI order, from highest to lowest (highest MI is darkest)
        cyt_color_order = ["IL-2", "IFNg", "IL-17A", "IL-6", "TNFa"][::-1]
        palet = {cyt_color_order[i]:cyt_palette[i] for i in range(len(cyt_color_order))}
    elif whichc == "theoretical_5":
        n_categories = 5
        palet = sns.color_palette("deep", n_categories)
        palet = [sns.set_hls_values(a, s=0.4, l=0.6) for a in palet]
        palet[-1] = (0, 0, 0, 1)  # Make the null peptide black
    elif whichc == "theoretical_6":
        n_categories = 6
        all_theo_antigen_colors = sns.color_palette("deep", 10)
        palet = sns.color_palette("deep", n_categories)
        #Comment this next line out if you want to revert to the old color scheme francois
        palet = [all_theo_antigen_colors[0],all_theo_antigen_colors[6]]+all_theo_antigen_colors[1:5]
        palet = [sns.set_hls_values(a, s=0.4, l=0.6) for a in palet]
        palet[-1] = (0, 0, 0, 1)  # Make the null peptide black.
    elif whichc == "peptides":
        palet = sns.color_palette(sns.color_palette(), 10)
        palet = palet[:4]+[palet[5]]
    else:
        raise ValueError("Unrecognized palette type whichc: {}".format(whichc))
    return palet


In [None]:
def colorbar(mappable):
    """ Copied from https://joseph-long.com/writing/colorbars/  """
    ax = mappable.axes
    fig = ax.figure
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    return fig.colorbar(mappable, cax=cax)

# Input weights bar plot

In [None]:
def input_weights_plot(wfile=os.path.join(main_dir_path, datapath, "mlp_input_weights-thomasRecommendedTraining."), 
                       minmaxfile=os.path.join(main_dir_path, datapath, "min_max-thomasRecommendedTraining.pkl"), 
                       dosave=True):
    # Latent space colors
    # Colors for Nodes 1+2, Node 1, Node 2
    nodePalette = build_colors(whichc="latent")
    
    # Cytokines colors
    cyt_palette = build_colors(whichc="cytokines")
    nice_cyto_labels = {"IL-2":"IL-2", "IFNg":r"IFN-$\gamma$", "IL-17A":"IL-17A", "IL-6":"IL-6", "TNFa":"TNF"}

    input_weights = np.load(wfile).T
    # Order of the cytokines in the columns of this matrix, from left to right
    df_min, df_max = pickle.load(open(minmaxfile, "rb"))
    cyto_weights_order = df_min.index.get_level_values("Cytokine").to_list()

    fig, ax = plt.subplots()
    fig.set_size_inches(2.25, 1.75)  # A bit larger since supplementary
    ax.axhline(0, color=(0.5, 0.5, 0.5, 1), linewidth=0.5, zorder=1)
    # Make two sub-histograms
    x_increms = np.linspace(-0.25, 0.25, len(cyto_weights_order))
    bwidth = x_increms[1] - x_increms[0]
    ypositions = range(len(cyto_weights_order))
    x_locs = np.asarray([0, 1])

    # Plot the two histograms
    cyto_plot_order = ["IFNg", "IL-2", "IL-17A", "IL-6", "TNFa"]
    for i, cyt in enumerate(cyto_plot_order):
        widx = cyto_weights_order.index(cyt)
        ax.bar(x_locs+x_increms[i], width=bwidth, height=input_weights[:, widx],
                color=cyt_palette[cyt], edgecolor=nodePalette,
                label=nice_cyto_labels[cyt], zorder=3)

    # Labeling, etc.
    ax.set_ylabel("Weight to LS (a. u.)")
    #ax.legend(ncol=len(cyto_weights_order))
    ax.set_xticks(np.concatenate([x_locs[0]+x_increms, x_locs[1]+x_increms]))
    ax.set_xticklabels(cyto_plot_order*2, rotation=90)

    # Annotate LS1 and LS2
    ylims = ax.get_ylim()
    ax.set_ylim(ylims[0], ylims[1]*1.4)
    for i in range(2):
        # Rectangle, then annotate on top
        boxwidth = x_increms[-1]-x_increms[0] + bwidth
        xy_bottomleft = (x_locs[i]+x_increms[0]-bwidth/2, ylims[1]*1.1)
        xy_center = (xy_bottomleft[0]+boxwidth/2, xy_bottomleft[1]+ylims[1]*0.1)
        ax.add_artist(mpl.patches.Rectangle(xy_bottomleft,
            width=boxwidth, height=ylims[1]*0.2, color=nodePalette[i]))
        ax.annotate(r"$LS_{}$".format(i+1), xy=xy_center, xytext=xy_center,
                    ha="center", va="center", xycoords="data", color="w")

    fig.tight_layout()
    if dosave:
        fig.savefig(os.path.join(main_dir_path, "figures", "latentspace", "latentspace_weights_barplot.pdf"),
                    transparent=True, bbox_inches="tight")
    plt.show()
    plt.close()

In [None]:
set_rcparams_1()
input_weights_plot(
    wfile=os.path.join(main_dir_path, datapath, "mlp_input_weights-thomasRecommendedTraining.npy"), 
    minmaxfile=os.path.join(main_dir_path, datapath, "min_max-thomasRecommendedTraining.pkl"), 
    dosave=False  # change to True if want to save plot
)

# Output weights line plot

In [None]:
def output_weights_plot(wfile=os.path.join(main_dir_path, datapath, "mlp_output_weights-thomasRecommendedTraining.npy"),
                        potencyfile=os.path.join(main_dir_path, "data", "potencies_df_2021.json"),
                        idealec50file=os.path.join(main_dir_path, "results", "chancap", "antigen_classes_log10ec50s_HighMI_13.json"),
                        dosave=False, n_categories=6):

    # Latent space labels
    lslabels = [r"$LS_1$", r"$LS_2$"]# Colors for Nodes 1+2, Node 1, Node 2
    nodePalette = build_colors(whichc="latent")
    #For theoretical peptides
    theoretical_antigen_colors = build_colors(whichc="theoretical_{}".format(n_categories))
    #For actual peptides
    antigen_colors = build_colors(whichc="peptides")

    # To recover weights in the expected order starting with N4, take the mlp weights in reverse order
    # because for some reason, in our code to train neural networks we decided to use the opposite order. 
    outputWeights = np.load(wfile)[:, ::-1]

    # Prepare EC50 and weights in a dataframe
    log10ec50_table = np.log10(pd.read_json(potencyfile)).mean(axis=1)
    peptides = ['N4','Q4','T4','V4','G4','E1']
    log10ec50s = log10ec50_table[peptides]
    df = pd.DataFrame(outputWeights,
                columns=pd.MultiIndex.from_tuples(zip(peptides, log10ec50s)),
                index=['1','2'])
    df.columns.names = ['Antigen', r"$\log_{10} \mathrm{EC_{50}}$"]
    df.index.name = 'Latent Space'
    newDf = (df.stack(['Antigen','$\log_{10} \mathrm{EC_{50}}$'])
                .unstack('Latent Space')
                .sort_values(by='$\log_{10} \mathrm{EC_{50}}$'))
    #newDf.iloc[:,1] = newDf.iloc[:,1]*-1
    plottingDf = newDf.stack().to_frame('Weight').reset_index()
    #Linear interpolation to find the weights of theoretical peptides
    log10ec50s = json.load(open(idealec50file,'r'))
    idealPeptides = [10**x for x in log10ec50s]
    newWeightMatrix = []
    longEC50name = '$\log_{10} \mathrm{EC_{50}}$'
    for i,idealPeptide in enumerate(idealPeptides):
        logIdealPeptide = np.log10(idealPeptide)
        if i == 0:
            y_interpolation = newDf.query("Antigen == 'N4'").values.tolist()[0]
        elif i == 4:
            y_interpolation = newDf.query("Antigen == 'E1'").values.tolist()[0]
        else:
            if i == 1:
                lowPeptide,highPeptide = 'T4','Q4'
            elif i == 2:
                lowPeptide,highPeptide = 'V4','T4'
            elif i == 3:
                lowPeptide,highPeptide = 'G4','V4'
            else:
                pass
            #y_interpolation = newDf.query("Antigen == [@lowPeptide,@highPeptide]").sum().values
            y_low = newDf.query("Antigen == @lowPeptide").values
            y_high = newDf.query("Antigen == @highPeptide").values
            lowPeptideVal = plottingDf.query("Antigen == @lowPeptide")[longEC50name].iloc[0]
            highPeptideVal = plottingDf.query("Antigen == @highPeptide")[longEC50name].iloc[0]
            x_interpolation = (logIdealPeptide - lowPeptideVal)/(highPeptideVal-lowPeptideVal)
            #y_interpolation *= x_interpolation
            y_interpolation = x_interpolation * (y_high - y_low) + y_low
            y_interpolation = np.squeeze(y_interpolation).tolist()
        newWeightMatrix.append(y_interpolation)

    interpolated_weights_df = pd.DataFrame(np.asarray(newWeightMatrix),
                        index=[str(x) for x in range(len(idealPeptides))],columns=newDf.columns)
    interpolated_weights_df.index.name = 'Antigen'
    temp = newDf.copy()
    #Plot weights of actual peptides + theoretical peptides
    plottingDf['EC50'] = [10**x for x in plottingDf['$\log_{10} \mathrm{EC_{50}}$']]

    # Plotting
    fig, axis = plt.subplots()
    fig.set_size_inches(2.25, 1.75)
    markers = ["o", "s"]
    for i in range(2):
        idx = plottingDf["Latent Space"] == str(i+1)
        axis.plot(plottingDf['EC50'].loc[idx], plottingDf["Weight"].loc[idx], lw=2.5,
                marker=markers[i], color=nodePalette[i], label=lslabels[i], ms=5, zorder=10+i)

    for i,idealPeptide in enumerate(idealPeptides):
        axis.axvline(x=idealPeptide,color=theoretical_antigen_colors[i],linestyle='--', lw=2., zorder=i)

    axis.legend(handlelength=2, loc="upper left")
    axis.set_xscale("log")
    axis.set_ylabel("Output layer weight (a. u.)")
    locmin = mpl.ticker.LogLocator(base=10.0,
            subs=np.linspace(0.1,0.9,num=9,endpoint=True).tolist(),numticks=50)
    axis.xaxis.set_minor_locator(locmin)
    axis.xaxis.set_minor_formatter(mpl.ticker.NullFormatter())
    xticks = [10**5,10**4,10**3,10**2,10**1,10**0]
    xticklabels = ['10$^{'+str(int(np.log10(x)))+'}$' for x in xticks]
    axis.set_xticks(xticks)
    axis.set_xticklabels(xticklabels)
    axis.invert_xaxis()
    axis.set_xlabel('Antigen EC$_{50}$ (#)')
    fig.tight_layout()
    if dosave:
        fig.savefig(os.path.join(main_dir_path, "figures", "latentspace", "supp_panel_nodeInterpretations-weights.pdf"),
            transparent=True, bbox_inches='tight')
    plt.show()
    plt.close()

In [None]:
output_weights_plot(wfile=os.path.join(main_dir_path, datapath, "mlp_output_weights-thomasRecommendedTraining.npy"),
                        potencyfile=os.path.join(main_dir_path, "data", "potencies_df_2021.json"),
                        idealec50file=os.path.join(main_dir_path, "results", "chancap", "antigen_classes_log10ec50s_HighMI_13.json"),
                        dosave=False, n_categories=6)

# Latent space domains defined by output weights
Two representations: the actual (W1, W2) vectors in latent space, and the different domains defined in latent space by those vectors and the logistic function. 

### Define useful functions

In [None]:
set_rcparams_2()

In [None]:
def compute_probs(wqj, bq, n1, n2, beta=1):
    """ Return an array of the probability of each peptide 
    at each (n1, n2) space specified, for the weights matrix w. 
    
    Args:
        wqj (np.2darray): weight matrix w_{qj}, size (nb_peptides, nb_latent_dimensions)
        bq (np.1darray): offsets, size (nb_peptides,)
        n1 (np.ndarray): grid of node 1 values at which to evaluate the probability
        n2 (np.ndarray): grid of node 2 values. Should have the same shape as n1. 
        beta (float): inverse temperature, i.e. sharpness of the boundaries
    
    Returns:
        pq (np.ndarray): an array of probabilities values at each point of the grid. The
            first dimension is index q, the peptide to which the prob values correspond. 
            The other dimensions are those of the n1, n2 grid. 
    """
    # Check dimensions
    if wqj.ndim != 2:
        raise ValueError("The weight matrix wg should be 2d, now it is: \n {}".format(wqj))
    elif wqj.shape[1] != 2:
        raise ValueError("wg must have two columns, one per dimension. Now shape = {}".format(wqj.shape))
    if bq.shape[0] != wqj.shape[0]:
        raise ValueError("bq has shape {}, it should have length of wqj's first axis, {}".format(
            bq.shape[0], wqj.shape[0]))
    if n1.shape != n2.shape:
        raise ValueError("n1 and n2 dimensions don't match. ")
    
    # Prepare the output array
    pq = np.zeros([wqj.shape[0], *n1.shape])
    # Prepare the array to dot with each row of wg
    n12 = np.stack([n1, n2], axis=0)
    # Compute the unnormalized probabilities
    # Dot product of a row of wqj with each element in the grid, so sum along axis 0 of n12. 
    # Using np.dot sums along second-to-last axis, i.e. axes before are a stack of matrices of the proper shape
    # Roll the axes twice to put the first axis (on which we sum) to the second-to-last position 
    # while preserving the others in place. 
    # For a 2d array it does not change the array so it's irrelevant. 
    n12 = np.moveaxis(n12, 0, -1)
    n12 = np.moveaxis(n12, -1, -2)
 
    # Applying the softmax function, exp(sum_j w_{kj} n_j + b_k)
    for q in range(wqj.shape[0]):
        pq[q] = np.exp(beta*np.dot(wqj[q], n12) + bq[q])
        
    # normalize by the sum over all peptides
    pq = pq / np.sum(pq, axis=0, keepdims=True)
    
    return pq

In [None]:
out_weights_matrix = np.load(os.path.join(main_dir_path, datapath, "mlp_output_weights-thomasRecommendedTraining.npy"))
out_weights_matrix = out_weights_matrix.T[::-1]
print(out_weights_matrix)
offsets = np.load(os.path.join(main_dir_path, datapath, "mlp_intercepts_output-thomasRecommendedTraining.npy"))
offsets = offsets[::-1]  # Also need to reverse here to have N4 first
print(offsets)

In [None]:
# Plotting the actual vectors, to get a first sense of where those vectors are
peptides = ["N4", "Q4", "T4", "V4", "G4", "E1"]
pep_cmap = sns.color_palette(n_colors=len(peptides))

fig, ax = plt.subplots()
for i in range(len(peptides)):
    ax.plot(out_weights_matrix[i, 0], out_weights_matrix[i, 1], marker="o", ms=0.1, color=pep_cmap[i])
    ax.annotate(peptides[i], xytext=out_weights_matrix[i], xy=(0, 0), xycoords="data",
                arrowprops=dict(color=pep_cmap[i], linewidth=2.5,arrowstyle="<-"))
ax.set(xlabel="Node 1 (a. u.)", ylabel="Node 2 (a. u.)")
ax.set_xticks([])
ax.set_yticks([])
for i in ["top", "right"]:
    ax.spines[i].set_visible(False)
ax.set_aspect("equal")
plt.show()
plt.close()

## Probability domain of each antigen in latent space

In [None]:
# It gives what it should, so we can use it with the actual weights and compute the probabilities
# for all peptides at each point in latent space. 
npts = 81
interv = 20
dx = interv * 2 / (npts-1)
node1, node2 = np.meshgrid(np.linspace(-interv, interv, npts), np.linspace(-interv, interv, npts)[::-1])
node1 *= 1.5
node1 += interv/4
node2 -= interv/4
# node 2 increases upwards

# Compute probabilities on the grid
probs = compute_probs(out_weights_matrix, offsets, node1,  node2, beta=0.5)

# Single plot, different colors for peptides. 
fig, ax = plt.subplots()
fig.set_size_inches(2.25, 1.75)
pep_cmap = sns.color_palette(n_colors=probs.shape[0])

for q in range(probs.shape[0]):
    probq = probs[q]
    rgba_matrix = np.zeros([probq.shape[0], probq.shape[1], 4])
    rgba_matrix[:, :, :3] = np.asarray(pep_cmap[q])[None, None, :]
    rgba_matrix[:, :, 3] = probq
    ax.imshow(rgba_matrix, aspect="equal", extent=(-interv, interv, -interv, interv))

ax.set_xlabel(r"$LS_1$")
#ax.set_xticks(range(0, npts, 2))
ax.set_xticks([])
ax.set_xticklabels([])
ax.set_yticks([])
ax.set_ylabel(r"$LS_2$")
ax.tick_params(which="both", length=2., width=0.8)
#ax.set_yticks(range(len(lig_labels)))
ax.set_yticklabels([])

# Add a custom legend
nn_antigens = ["N4", "Q4", "T4", "V4", "G4", "E1"]
handles = [mpl.patches.Patch(color=pep_cmap[i], label=nn_antigens[i]) for i in range(len(nn_antigens))]
leg = ax.legend(handles=handles, labels=nn_antigens, loc="upper left", bbox_to_anchor=(1, 1), 
                 framealpha=0.7, ncol=1, labelspacing=0.4, handletextpad=0.3, columnspacing=1.2, 
                 handlelength=1.5, frameon=False, title="Antigen")
fig.tight_layout()

# Uncomment to save the figure
#fig.savefig(os.path.join(main_dir_path, "figures", "latentspace", "output_layer_zones_per_peptide.pdf"), 
#    bbox_inches="tight", bbox_extra_artists=(leg,), transparent=True)
plt.show()
plt.close()

# Latent space domains for theoretical antigen classes

In [None]:
# Load weights again to avoid potential mix ups
outputWeights = np.load(os.path.join(main_dir_path, datapath, "mlp_output_weights-thomasRecommendedTraining.npy"))
outputWeights = outputWeights[:, ::-1]

log10ec50_table = np.log10(pd.read_json(os.path.join(main_dir_path, "data", "potencies_df_2021.json"))).mean(axis=1)
peptides = ['N4', 'Q4', 'T4', 'V4', 'G4', 'E1']
log10ec50s = log10ec50_table[peptides]
df = pd.DataFrame(outputWeights,columns=pd.MultiIndex.from_tuples(zip(peptides, log10ec50s)), index=['1','2'])
df.columns.names = ['Antigen', r"$\log_{10} \mathrm{EC_{50}}$"]
df.index.name = 'Latent Space'

log10ec50s = json.load(open(os.path.join(main_dir_path, "results", "chancap", "antigen_classes_log10ec50s_HighMI_13.json"),'r'))
idealPeptides = [10**x for x in log10ec50s]
longEC50name = '$\log_{10} \mathrm{EC_{50}}$'


newDf = df.stack(['Antigen', longEC50name]).unstack('Latent Space').sort_values(by=longEC50name)
fittingDf = newDf.stack().to_frame('Weight')
fittingDf.index.names = ['EC50' if x == longEC50name else x for x in fittingDf.index.names]

## WEIGHT 1
#Latent space 1 has a hill fit
fittingDf1 = fittingDf.query("`Latent Space` == '1'").reset_index()
#Move lowest weight value to 0 (to avoid having to add a yshift parameter)
#fittingDf1['Weight'] = [x + abs(min(fittingDf1['Weight'])) for x in fittingDf1['Weight']]

ec50s = fittingDf1['EC50'].values
idsort = np.argsort(ec50s)
ec50s = ec50s[idsort]
#yvalues for fit
weights1 = fittingDf1['Weight'].values
weights1 = weights1[idsort]


### WEIGHT 2
fittingDf2 = fittingDf.query("`Latent Space` == '2'").reset_index()

#xvalues for fit; move lowest ec50 (strongest) to zero
ec50s = abs(fittingDf2['EC50'].values - max(fittingDf2['EC50'].values))
idsort = np.argsort(ec50s)
ec50s = ec50s[idsort]
#yvalues for fit
weights2 = fittingDf2['Weight'].values

In [None]:
# Linear interpolation for weights seems to work better? Not really. 
w1_interp = sp.interpolate.interp1d(ec50s, weights1, kind="linear")
w2_interp = sp.interpolate.interp1d(ec50s, weights2, kind="linear")
weights_matrix = np.asarray([[w1_interp(np.log10(ip)) for ip in idealPeptides], 
                             [w2_interp(np.log10(ip)) for ip in idealPeptides]]).T
print(np.log10(idealPeptides))
print(weights_matrix.T)

print(np.log10(idealPeptides))
print(weights_matrix.T)

In [None]:
# Interpolation for intercepts
offsets_peps = np.load(os.path.join(main_dir_path, datapath, "mlp_intercepts_output-thomasRecommendedTraining.npy"))
offsets_peps = offsets_peps[::-1]  # Need to invert order here too
interp_offsets = sp.interpolate.interp1d(ec50s, offsets_peps, kind="linear")
offsets_theory = interp_offsets(log10ec50s)
offsets_theory = np.zeros(len(log10ec50s))

In [None]:
# Single plot, different colors for peptides.
# You can show the actual weight vector arrows by changing the following option to True
add_arrows = False

npts = 81
interv = 20
dx = interv * 2 / (npts-1)
node1, node2 = np.meshgrid(np.linspace(-interv, interv, npts), np.linspace(-interv, interv, npts)[::-1])
node1 *= 1.5

if not add_arrows:
    orig_offsets = np.asarray((interv/4, -interv/4))
else:
    orig_offsets = np.zeros(2)
node1 += orig_offsets[0]
node2 += orig_offsets[1]

probs = compute_probs(weights_matrix, offsets_theory, node1,  node2, beta=0.5)
theoretical_antigen_colors = build_colors(whichc="theoretical_6")

fig, ax = plt.subplots()
fig.set_size_inches(2.25, 1.75)
pep_cmap = theoretical_antigen_colors
pep_cmap[-1] = (0, 0, 0)

for q in range(probs.shape[0]-1, -1, -1):
    probq = probs[q]
    rgba_matrix = np.zeros([probq.shape[0], probq.shape[1], 4])
    rgba_matrix[:, :, :3] = np.asarray(pep_cmap[q])[None, None, :]
    rgba_matrix[:, :, 3] = probq
    ax.imshow(rgba_matrix, aspect="equal", 
              extent=(-interv, interv, -interv, interv))

    # Also show the weight vector itself
    if add_arrows:
        scl = 1.5
        ax.plot(weights_matrix[q, 0]*scl, weights_matrix[q, 1]*scl, marker="o", ms=0.1, color=(1, 1, 1, 0))
        ax.annotate("", xytext=weights_matrix[q]*scl, xy=(0, 0), xycoords="data",
                arrowprops=dict(edgecolor="w", facecolor=pep_cmap[q], linewidth=2.5,arrowstyle="<|-"))

ax.set_xlabel(r"$LS_1$ (a. u.)")
#ax.set_xticks(range(0, npts, 2))
ax.set_xticklabels([])
ax.set_ylabel(r"$LS_2$ (a. u.)")
#ax.tick_params(which="both", length=2., width=0.8)
ax.set_yticks([])
ax.set_xticks([])
ax.set_yticklabels([])

# Add a custom legend
decadesPeptides = [int(np.log10(a)) for a in idealPeptides]
mantissasPeptides = [int(round(idealPeptides[i]/10**decadesPeptides[i], 0)) for i in range(len(idealPeptides))]
for i in range(len(mantissasPeptides)):
    if mantissasPeptides[i] == 10:
        mantissasPeptides[i] = 1
        decadesPeptides[i] += 1
labelsPeptides = [r"${}\times 10^{}$".format(mantissasPeptides[i], decadesPeptides[i]) for i in range(len(idealPeptides))]
handles = [mpl.patches.Patch(color=pep_cmap[i], label=labelsPeptides[i]) for i in range(len(idealPeptides))]
leg = ax.legend(handles=handles, labels=labelsPeptides, loc="upper left", framealpha=0., 
          fontsize=7, labelspacing=0.2, handletextpad=0.3, columnspacing=1.2, 
          handlelength=1.5, bbox_to_anchor=(1, 1.07), title="Theoretical\nAntigen\n" + r"EC${}_{50}$ (#)", 
          title_fontsize=8
          )
plt.setp(leg.get_title(), multialignment='center')

fig.tight_layout()

# Uncomment to save the figure
#if add_arrows:
#    figname = "output_layer_zones_per_theoretical_antigen_witharrows_HighMI_13.pdf"
#else:
#    figname = "output_layer_zones_per_theoretical_antigen_noarrows_HighMI_13.pdf"
#fig.savefig(os.path.join(main_dir_path, "figures", "latentspace", figname), bbox_inches="tight", transparent=True)
plt.show()
plt.close()