### Import necessary modules

In [None]:
import numpy as np 
import pandas as pd
import matplotlib.pyplot as plt
import os
import string
import seaborn as sns
from scipy.spatial import distance
from scipy.cluster import hierarchy
import matplotlib.colors as cl
import glob
from cryodrgn import utils

### Read in data

Running the following cells produces a dataframe in which each row contains the occupancy data for one of the maps sampled from latent space, and each column corresponds to the occupancy of a given subunit across maps, normalized by the reference volume. Change the variables in the first box as necessary to reflect the path where your occupancies.csv file from the calc_occupancy.py script is located.

This notebook can also be used for datasets other than EMPIAR 10076; simply change the number of volumes as necessary, and redefine the chains dictionary. 

In [None]:
#change this to wherever your experimental map occupancy table is stored
occupancies = #FILL ME IN

#change these as necessary for different numbers of volumes or different PDB models
num_volumes = 500

chains = {}
chains['prots1'] = ['L2', 'L3', 'L4', 'L5', 'L6', 'L9', 'L11', 'L13'] + ['L' + str(i) for i in range(14,26)] + ['L27', 'L28', 'L29', 'L30', 'L32', 'L33']
chains['prots2'] = ['L34', 'L35', 'L36']
chains['RNA1'] = ['H' + str(i) for i in range(1,15)] + ['H16'] + ['H' + str(i) for i in range(18,26)] + ['H25a', 'H26', 'H27']
chains['RNA2'] = ['H28', 'H29'] + ['H' + str(i) for i in range(31,36)] + ['H35a'] + ['H' + str(i) for i in range(36,47)] + ['H26a', 'H47', 'H48', 'H49', 'H49b', 'H50', 'H51']
chains['RNA3'] = ['H52', 'H53', 'H54', 'H55', 'H49a'] + ['H' + str(i) for i in range(56,70)] + ['H' + str(i) for i in range(71,78)] 
chains['RNA4'] = ['H' + str(i) for i in range(78,102)]
chains['RNA_5S'] = ['H' + str(i) + '_5S' for i in range(1,6)]

In [None]:
df = pd.read_csv(occupancies, index_col = 0, header = [0,1]).dropna(axis = 1)
df_rename = pd.DataFrame(index = df.index)
alpha_list = string.ascii_lowercase

for col in df.columns:
    pdb_name, chain = col
    chain_ind = alpha_list.index(chain)
    chain_id = chains[pdb_name][chain_ind]
    
    df_rename[chain_id] = df[col]

### Data normalization

Normalization methods may vary on a dataset-by-dataset basis. We implement here a method that effectively rescales all the occupancies so they fall between what we consider zero occupancy and what we consider full occupancy. These designations of zero and complete occupancy are determined as percentiles of the full set of values from df_rename. For this dataset, we recommend using the 10th and 90th percentiles as zero and full occupancies, respectively, but again this may differ for other datasets

In [None]:
#change this to alter the percentile of the data to be set as zero occupancy
low_cutoff = 10

#change this to alter the  percentile of the data to be set as full occupancy
high_cutoff = 90

#change this to indicate the directory where you want the output dataframe stored
outdir = './'#FILL ME IN

In [None]:
vals = df_rename.values.flatten()
q_low = np.percentile(vals, low_cutoff)
q_high = np.percentile(vals, high_cutoff)

df_rename[df_rename < q_low] = q_low
df_rename[df_rename > q_high] = q_high

for col in df_rename.columns:
    df_rename[col] = (df_rename[col]-q_low)/(q_high-q_low)

In [None]:
#save output .txt file
df_rename.to_csv(outdir + 'norm_occs.txt', sep='\t')

### Hierarchical clustering

Hierarchical clustering can be performed either directly in python with the code below, or in an external interactive clustering program using the output dataframe .csv files. The code below returns the nodes for the rows and columns separately; these outputs are parsed in the next sections to automatically extract volume classes and structural blocks, and to write scripts to visualize the blocks and volumes in ChimeraX. Adjusting the linkage method or distance metric may change the results of the clustering. Given a single clustering result, changing the row or column threshold adjusts the threshold distance at which classes are defined. Changing the default save_file variable to a path will save the resulting figure.

In [None]:
#define clustering choices
figsize= #FILL ME IN
row_threshold= #FILL ME IN
col_threshold= #FILL ME IN
linkage_methods = 'ward'
distance_metric = 'euclidean'
cmap='Blues'
row_cmap = 'Spectral'
col_cmap = 'viridis'
save_file = None

In [None]:
def plot_occupancy(data, linkage_method='ward', distance_metric='euclidean', figsize=(10,10), cmap='Blues', row_threshold=-1, col_threshold=-1, row_map = 'Spectral', col_map = 'viridis', save_file = None):
    
    
    col_linkage = hierarchy.linkage(distance.pdist(data.T, metric=distance_metric), method=linkage_methods)
    row_linkage = hierarchy.linkage(distance.pdist(data, metric=distance_metric), method=linkage_methods)
    row_total = np.max(hierarchy.fcluster(row_linkage, t = row_threshold, criterion = 'distance'))
    col_total = np.max(hierarchy.fcluster(col_linkage, t = col_threshold, criterion = 'distance'))
    
    fig, axes = plt.subplots(2, 2, gridspec_kw={'width_ratios': [1, 4], 'height_ratios': [1,4]}, figsize=figsize)
    
    row_map = plt.matplotlib.cm.get_cmap(row_map)
    col_map = plt.matplotlib.cm.get_cmap(col_map)
    row_colors = [cl.rgb2hex(row_map(i/row_total)) for i in range(0, row_total)]
    col_colors = [cl.rgb2hex(col_map(i/col_total)) for i in range(0, row_total)]
    hierarchy.set_link_color_palette(col_colors)
    
    col_nodes = hierarchy.dendrogram(col_linkage, color_threshold=col_threshold, ax = axes[0][1], get_leaves=True, labels=data.columns.tolist(), above_threshold_color = 'black')
    axes[0][1].axhline(col_threshold, ls='dashed', color='grey')
    axes[0][1].spines['top'].set_visible(False)
    axes[0][1].spines['right'].set_visible(False)
    axes[0][1].spines['bottom'].set_visible(False)
    
    hierarchy.set_link_color_palette(row_colors)
    row_nodes = hierarchy.dendrogram(row_linkage, color_threshold=row_threshold, ax = axes[1][0], orientation='left', get_leaves=True, labels=data.index.tolist(), above_threshold_color = 'black')
    axes[1][0].axvline(row_threshold, ls='dashed', color='grey')
    axes[1][0].spines['top'].set_visible(False)
    axes[1][0].spines['right'].set_visible(False)
    axes[1][0].spines['left'].set_visible(False)
    
    data_ordered=data[[data.columns[i] for i in col_nodes['leaves']]]
    
    data_ordered = data_ordered.reindex(data.index[row_nodes['leaves']])
    heatmap = axes[1][1].pcolor(data_ordered, cmap=cmap)
    axes[1][1].set_xticks([])
    axes[1][1].set_yticks([])
    
    fig.colorbar(heatmap, ax=axes[0][0], fraction=0.75, label='occupancy')
    axes[0][0].set_xticks([])
    axes[0][0].set_yticks([])
    axes[0][0].axis('off')
    
    plt.tight_layout()
    if save_file:
        fig.savefig(save_file)
    
    return (row_nodes, col_nodes)

In [None]:
row_nodes, col_nodes = plot_occupancy(df_rename, row_threshold=row_threshold, col_threshold=col_threshold, figsize=figsize, row_map = row_cmap, col_map = col_cmap, linkage_method = linkage_methods, distance_metric = distance_metric, save_file = save_file)

### Extract classes from clustering

Both the volume and subunit classes defined in the clustering above can be extracted for manual inspection and for mapping into latent space and finding centroid volumes. Here we create two dictionaries with a common set of keys for each of row_nodes and col_nodes. The keys are the class/structural block IDs, and the values of the colors_dict and groups_dict dictionaries indicate the color (as a hexadecimal string) of each cluster in the above dendrogram, and the rows or columns that belong to that class, respectively.   

In [None]:
def extract_groups(nodes):
    colors, color_inds = np.unique(nodes['leaves_color_list'], return_index = True)
    color_groupings = [np.where(np.array(nodes['leaves_color_list']) == i)[0] for i in colors]
    ind_groupings = [np.array(nodes['ivl'])[i] for i in color_groupings]
    
    groups_dict = {}
    colors_dict = {}
    for i in range(0, len(colors)):
        col = nodes['leaves_color_list'][np.sort(color_inds)[::-1][i]]
        colors_dict[i] = col
        ind = np.where(colors == col)[0][0]
        groups_dict[i] = ind_groupings[ind]
    return colors_dict, groups_dict

In [None]:
vol_colors, vol_classes = extract_groups(row_nodes)

In [None]:
subunit_colors, subunit_blocks = extract_groups(col_nodes)

In [None]:
utils.save_pkl(vol_classes, 'vol_classes.pkl')

### Visualize volume classes in ChimeraX

For each volume class defined above, this section writes out a .py script that can be opened in ChimeraX. Each script will open all the volumes from the given class so they can be manually compared. Note that the full path of the directory containing the 500 sampled maps must be provided in the vol_dir variable. 

In [None]:
vol_dir = #FILL ME IN
out_dir = #FILL ME IN

In [None]:
def check_dir(dirname, make = True):
    if make:
        if not os.path.exists(dirname):
            os.mkdir(dirname)
    if not dirname.endswith('/'):
        dirname = dirname + '/'
    return dirname
        
def write_vol_classes(groups_dict, voldir, outdir):
    voldir = check_dir(voldir, make = False)
    outdir = check_dir(outdir)
            
    for i in groups_dict.keys():
        outfile = outdir + 'class' + str(i) + '.py'
        with open(outfile, 'w') as f:
            f.write('from chimerax.core.commands import run\n')
            for j in groups_dict[i]:
                f.write('run(session, "open {}vol_{:03d}.mrc")\n'.format(voldir, j))
    return

In [None]:
write_vol_classes(vol_classes, vol_dir, out_dir)

### Visualize structural blocks in ChimeraX

We also write out a single script to open the aligned PDB files and color each chain according to the structural block assigned in clustering. This script can then be opened with ChimeraX. Note that the full path of the directory containing the aligned files must be provided in the aligned_dir variable, and that the only PDB files in that directory should be the aligned files you want to color. 

In [None]:
aligned_dir = #FILL ME IN 
out_file = #FILL ME IN

In [None]:
def write_subunit_blocks(groups_dict, colors_dict, aligneddir, outfile):
    aligneddir = check_dir(aligneddir, make = False)
    pdb_list = glob.glob(aligneddir + '*.pdb')
    alpha_list = list(string.ascii_uppercase)
    command_list = []
    for i in groups_dict.keys():
        command = 'run(session, "color '
        for j in groups_dict[i]:
            pdb_file = [k for k in chains.keys() if j in chains[k]][0]
            pdb_ind = pdb_list.index(aligneddir + pdb_file + '.pdb') + 1
            chain = alpha_list[chains[pdb_file].index(j)]
            command = command + f'#{pdb_ind}/{chain} '
        command = command + f'{colors_dict[i]}")\n'
        command_list.append(command)    
    
    with open(outfile, 'w') as f:
        f.write('from chimerax.core.commands import  run \n')
        for pdb in pdb_list:
            f.write(f'run(session, "open {pdb}")\n')
        for com in command_list:
            f.write(com)
    return

In [None]:
write_subunit_blocks(subunit_blocks, subunit_colors, aligned_dir, out_file)

### Plot individual subunit occupancy distributions

To view the occupancy distributions of a set of subunits, users can provide a list of subunits (or a whole subunit block from the subunit_blocks dictionary) to the plot_distribution function. To log-scale the y-axis of the figure, set log_plot = True. To overlay a dashed line indicating a particular subunit occupancy, provide a dictionary whose keys are the subunits being plotted and whose values are the desired values, e.g. threshold_dict['H68'] = 0.5

In [None]:
subunits = #FILL ME IN
log_plot = False

In [None]:
def plot_distribution(data, subs, color = 'steelblue', log = False):
    nrows = int(np.ceil(len(subs)/3))
    figsize = (8, 2*nrows)
    fig, ax = plt.subplots(nrows, 3, figsize = figsize, sharex = True, sharey = True)
    ax = ax.flatten()
    
    bins = np.linspace(0, 1, 25)
    for i,sub in enumerate(subs):
        ax[i].hist(data[sub], bins = bins, color = color, log = log)
        ax[i].set_title(sub)
        if thresholds:
            upper_lim = y_hist.max()
            ax[i].plot([thresholds[sub], thresholds[sub]], [0, upper_lim], '--k')
            
    if len(subs)%3 > 0:
        fig.delaxes(ax[i+1])
        if len(subs)%3 == 1:
            fig.delaxes(ax[i+2])

    fig.text(0.5, 0, 'occupancy')
    fig.text(0, 0.5, 'counts', rotation = 90)
    plt.tight_layout()
    
    if outfile:
        plt.savefig(outfile, dpi = 300)
    return

In [None]:
plot_distribution(df_rename, subunits, log = False)

### Examine subunit-subunit correlations

Plotting subunit-subunit correlation can be informative for determining if there is positive or negative cooperativity between any two given subunits. Users can here provide a list of subunits to compare, where each subunits[i] is a list of two subunits to be plotted against each other. Overlaid dashed lines can again be implemented by using supplying a thresholds dictionary (as described above). 

In [None]:
subunits = #FILL ME IN

In [None]:
def subunit_corr(data, subs, color = 'steelblue', thresholds = None, outfile = None):
    
    nrows = int(np.ceil(len(subs)/3))
    figsize = (8, 2*nrows)
    fig, ax = plt.subplots(nrows, 3, figsize = figsize, sharex = True, sharey = True)
    ax = ax.flatten()
    
    for i,j in enumerate(subs):
        ax[i].scatter(data[j[0]], data[j[1]], color = color, s = 10, alpha = 0.1) 
        ax[i].set_xlim(-0.05, 1.05)
        ax[i].set_ylim(-0.05, 1.05)
        ax[i].set_xlabel(j[0])
        ax[i].set_ylabel(j[1])
        if thresholds:
            ax[i].plot([thresholds[j[0]], thresholds[j[0]]], [0, 1], '--k')
            ax[i].plot([0, 1], [thresholds[j[1]], thresholds[j[1]]], '--k')
    
    if len(subs)%3 > 0:
        fig.delaxes(ax[i+1])
        if len(subs)%3 == 1:
            fig.delaxes(ax[i+2])
    
    plt.tight_layout()
    if outfile:
        plt.savefig(outfile, dpi = 300)
    return

In [None]:
subunit_corr(df_rename, subunits)

### Recluster subsets of the dataframe

It may be useful in some cases to extract and recluster some subset of the volumes that have high (or low) occupancy for a given subunit or subunits. That can be done here by providing a list of subunits to filter by (e.g. ['H68', 'H79']), a list of limits by which to filter each subunit (e.g. [0.5, 0.5]), and then a list of directions for the filtration provided as 'greater' or 'lesser', based on whether you want to retain volumes greater than or less than the limit. The resulting filtered dataframe is then reclustered and the resulting volume classes and structural blocks can be exported to ChimeraX as before.

In [None]:
#define  the subset of the dataframe 
subunits = #FILL ME IN 
limits = #FILL ME IN
direction = #FILL ME IN 

#define clustering choices
sub_figsize= #FILL ME IN
sub_row_threshold= #FILL ME IN
sub_col_threshold= #FILL ME IN
linkage_methods = 'ward'
distance_metric = 'euclidean'
cmap='Blues'
row_cmap = 'Spectral'
col_cmap = 'viridis'
sub_save_file = #FILL ME IN

In [None]:
if direction[0] == 'greater':
    sub_df = df_rename[df_rename[subunits[0]] > limits[0]]
else:
    sub_df = df_rename[df_rename[subunits[0]] <= limits[0]]
for i,sub in enumerate(subunits):
    if i > 0:
        assert(direction[i] == 'greater' or direction[i] == 'lesser')
        if direction[i] == 'greater':
            sub_df = sub_df[sub_df[sub] > limits[i]]
        else:
            sub_df = sub_df[sub_df[sub] <= limits[i]]

assert(len(sub_df) > 0)

In [None]:
sub_row_nodes, sub_col_nodes = plot_occupancy(sub_df, row_threshold=sub_row_threshold, col_threshold=sub_col_threshold, figsize=sub_figsize, row_map = row_cmap, col_map = col_cmap, linkage_method = linkage_methods, distance_metric = distance_metric, save_file = sub_save_file)

In [None]:
sub_vol_colors, sub_vol_classes = extract_groups(sub_row_nodes)
sub_subunit_colors, sub_subunit_blocks = extract_groups(sub_col_nodes)

In [None]:
sub_out_dir = #FILL ME IN
sub_out_file = #FILL ME IN

write_vol_classes(sub_vol_classes, vol_dir, sub_out_dir)
write_subunit_blocks(sub_subunit_blocks, sub_subunit_colors, aligned_dir, sub_out_file)