# Filament Tools Suite (FiTSuite) - typecluster v0: Empty Template

Last modified: 6/18/23, DL

Should just be able to run all and add your use case at the end. Parameters and inputs detailed below.

Suggest making a copy of this Jupyter notebook for each dataset to be processed so the full record of processing is stored

## Setup

### *General instructions*
The following is a simple way to set up an virtual environment with venv if you don't know how to:
- Run the following commands in a new terminal, then shutdown your current jupyter notebook and server
    - `python3 -m venv FiTSuite`
    - `source FiTSuite/bin/activate.csh` or `source FiTSuite/bin/activate`
    - `python3 -m pip install ipykernel`
    - `python3 -m ipykernel install --user --name=FiTSuite`
- Re-run the command `jupyter notebook` or `python3 -m notebook` or whatever you normally do and open this notebook
    - Currrently incompatible with JupyterLab 4 due to plotly issues... use JupyterLab 3 e.g. `pip install jupyterlab==3`
- Upon reopening, choose FiTSuite as Kernel by going to `Kernel > Change kernel > FiTSuite`

When done running the Jupyter notebook and shutting down the server, you can deactivate the virtual environment with `deactivate` in the second terminal

After you have done this once, next time just run `jupyter notebook` to get back here, and change kernel to `FiTSuite` again

### *LMB instructions*
The following is a simple way to set up an virtual environment with venv if you don't know how to:
- Run the following commands in a new terminal from your LMB home directory:
    - `module avail python`
        - Should see something like "python3/3.10.7" or higher
    - `module load python3/3.10.7` 
        - or whatever the latest version of python is
    - `python3 -m venv FiTSuite`
    - `source FiTSuite/bin/activate.csh` or `source FiTSuite/bin/activate`
    - `python3 -m pip install jupyter notebook` (could also be pip3)
        - Currrently incompatible with JupyterLab 4 due to plotly issues... use JupyterLab 3 e.g. `pip install jupyterlab==3`        
    - `python3 -m ipykernel install --user --name=FiTSuite`
- Run the command `jupyter notebook` or `python3 -m notebook` or whatever you normally do and open this notebook
- Upon reopening, choose FiTSuite as Kernel by going to `Kernel > Change kernel > FiTSuite`

When done running the Jupyter notebook and shutting down the server, you can deactivate the virtual environment with `deactivate`

After you have done this once, next time just reactivate the virtual environment, run `jupyter notebook` to get back here, and make sure kernel is still set to `FiTSuite`

Also, can source from my home directory -> `source lmb/home/dli/FiTSuite/bin/activate.csh`. This is likely easiest...

In [None]:
# If kernel is selected properly, should expect to see FiTSuite in the string below
import sys
print(sys.executable)

In [None]:
# Should be >=3.10
!python3 --version

### Package Installation
FiTSuite requires a number of packages. These can be installed multiple ways, here are two:


1. Install using the `FiTSuite_requirements.txt`, e.g.
    ```
    # Run this ONLY the first time you create the venv to install all required packages (and some unnecessary ones oops)
    !python3 -m pip install -U -r FiTSuite_requirements.txt
    # Important relatively uncommon ones are starfile, seaborn, fastcluster, plotly, ipywidgets
    # NEED ipywidgets <= 8 because major overhaul that is not compatible with plotly
    ```
2. Install the missing packages explicitly, e.g. (may need to add more)
    ```
    # Install missing/uncommon packages
    !pip install fastcluster
    !pip install scikit-learn
    !pip install umap-learn
    !pip install starfile
    !pip install mrcfile
    !pip install plotly
    !pip install seaborn
    !pip install scikit-image
    !pip install ipywidgets==7.7.1 jupyterlab-widgets==1.1.1 
    # NEED ipywidgets <= 8 because major overhaul that is not compatible with plotly

    ! jupyter nbextension enable --py widgetsnbextension
    ```

In [None]:
# Import relevant packages (some are for other functions in FiTSuite)
import numpy as np
import pandas as pd
import starfile
import seaborn as sns
import collections
from collections import Counter
from matplotlib import pyplot as plt
from scipy.cluster.hierarchy import fcluster, dendrogram
import scipy
import mrcfile
import fastcluster
import matplotlib.patches as patches
import IPython.display
from skimage import filters
import matplotlib.font_manager

# Packages for interactive functionality
from ipywidgets import widgets, interactive, HBox, VBox
import plotly.express as px
import plotly.graph_objects as go

# Uncomment only if you want to either use KMeans or compute dimensionality reduction plots
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap.umap_ as umap
from sklearn.cluster import KMeans

! jupyter nbextension enable --py widgetsnbextension

### Helper Functions

- File I/O
    - read_particle_star(filename: str) -> (pd.DataFrame, collections.OrderedDict)
    - read_models_star(filename: str) -> (pd.DataFrame, collections.OrderedDict)
    - read_mrcs_file(filename: str) -> (list, np.recarray)
    - read_mrc_file(filename: str) -> (np.ndarray, np.recarray)
    - write_particle_star(all_data: collections.OrderedDict, particle_df: pd.DataFrame, filename: str, overwrite: bool = True)
    - write_models_star(classes_df: pd.DataFrame, filename: str, overwrite: bool = True)
- Add columns to individual particle_df
    - hash_particle_df(particle_df: pd.DataFrame, verbose: bool = True) -> pd.DataFrame
    - add_classnumber_to_particle_df(hashed_particle_df: pd.DataFrame, particleHashClassDict: dict, verbose: bool = True) -> pd.DataFrame
    - add_classID_to_particle_df(hashed_particle_df: pd.DataFrame, classIDdict: dict, verbose: bool = True) -> pd.DataFrame
    - add_filamentID_to_particle_df(hashed_particle_df: pd.DataFrame, filamentIDdict: dict, verbose: bool = True) -> pd.DataFrame
    - add_particleID_to_particle_df(hashed_particle_df: pd.DataFrame, particleIDdict: dict, verbose: bool = True) -> pd.DataFrame
- Manipulating and calculating counts_df s
    - compute_particle_counts_df(hashed_particle_df: pd.DataFrame, classIDdict: dict = None, filamentIDdict: dict = None, verbose: bool = True) -> (pd.DataFrame, pd.DataFrame, Counter, dict, Counter, dict)
- Plotting
    - clusterplot_counts_df(counts_df: pd.DataFrame, filepath: str = None, savefig: bool = True, dpi: int = 300, metric: str = 'cosine', vmax: int = None, cmap: str = 'inferno', standardize: int = None, row_colors: list = None, col_colors: list = None, row_linkage: np.ndarray = None, col_linkage: np.ndarray = None, figsize_x: int = 25, figsize_y: int = 25, label_fontsize: int = 15, dendrogram_ratio: float = 0.2, colors_ratio: float = 0.03, cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), panel_label: bool = False, panel_label_letter: str = "a",  panel_label_fontsize: int = 15) -> sns.matrix.ClusterGrid
    - dimensionality_reduction_plot(counts_df, mode: str = "UMAP", savefig: bool = True, filepath: str = None, figsize_x: float = 25, figsize_y: float = 25, subfigure_label_fontsize = 25, title_fontsize = 30, axis_label_fontsize = 20, alpha = 0.1, row_colors: list = None, col_colors: list = None, random_seed: int = 365)
    - plot_class_averages(class_mrcs_file: str, frame_order: list = None, classes_df: pd.DataFrame = None, savefig: bool = True, filepath: str = None, dpi: int = 300, number_of_images: int = None, images_per_row: int = 10, label_x: int = 15, label_y: int = 225, label_fontsize: int = 10, figsize_x: float = 25, color_list: list = None, panel_label: bool = False, panel_label_letter: str = "c",  panel_label_fontsize: int = 15, panel_x = 0.013333, panel_y = 0.98)
- Computing and using clusterings
    - typecluster_initial_clustering(particles_star_file: str, output_path: str = None, savefig: bool = True, verbose: bool = True, metric: str = "cosine",  vmax: int = None, cmap: str = "inferno", standardize: int = None, figsize_x: int = 25, figsize_y: int =25, label_fontsize: int = 15, dpi: int = 300, dendrogram_ratio: float = 0.2, cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), panel_label: bool = False, panel_label_letter: str = "a",  panel_label_fontsize: int = 15) -> (sns.matrix.ClusterGrid, pd.DataFrame, pd.DataFrame, collections.OrderedDict, Counter, dict, Counter, dict)
    - reorder_class_average_star_by_clustered(model_star_file: str, clustered: sns.matrix.ClusterGrid, classIDdict: dict, output_path: str = None, savestar: bool = True, overwrite: bool = True) -> pd.DataFrame
    - reorder_counts_df_by_clustered(counts_df: pd.DataFrame, clustered: sns.matrix.ClusterGrid) -> pd.DataFrame
- Interactive Selection
    - typecluster_interactive_select(counts_df: pd.DataFrame, clustered: sns.matrix.ClusterGrid, vmax: int = 30, cmap: list = px.colors.sequential.Inferno, figsize_x: int = 1000, figsize_y: int =1000) -> list
    - typecluster_interactive_output(sliders: list, mode: str, output_path: str, clustered: sns.matrix.ClusterGrid, counts_df: pd.DataFrame, IDadded_particle_df: pd.DataFrame = None, all_data: collections.OrderedDict = None, classIDdict: dict = None, model_star_file: str = None, overwrite: bool = True, drophashes: bool = True) -> pd.DataFrame
- (Semi-)Automated Selection
    - typecluster_compute_kmeans_labels(counts_df, col_cluster_number: int = 10, row_cluster_number: int = 10, random_seed: int = 365) -> (list, list)
    - typecluster_dendrogram_threshold_select(clustered: sns.matrix.ClusterGrid, threshold: float)
    - typecluster_interactive_dendrogram_thresholding(clustered_df_name: str = "clustered") -> (widgets.FloatSlider, widgets.widget_box.HBox)
    - typecluster_compute_dendrogram_threshold_labels(clustered: sns.matrix.ClusterGrid, threshold: float, output_path: str, counts_df: pd.DataFrame, savefig: bool = True, vmax: int = None, cmap: str = "inferno", standardize: int = None, figsize_x: int = 25, figsize_y: int = 25, label_fontsize: int = 15, dpi: int = 300, dendrogram_ratio: float = 0.2, cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), panel_label: bool = False, panel_label_letter: str = "a",  panel_label_fontsize: int = 15) -> (list, Counter, sns.matrix.ClusterGrid)
    - convert_labels_to_colors(labels, color_dict: dict = None, plot_legend: bool = True, verbose: bool = True) -> list
    - typecluster_output_filaments_from_labels(col_labels: list, output_path: str, IDadded_particle_df: pd.DataFrame, all_data: collections.OrderedDict, particle_threshold: int = 1000, overwrite: bool = True, drophashes: bool = True, verbose: bool = True) -> pd.DataFrame

In [None]:
# Functions for file input/output
def read_particle_star(filename: str) -> (pd.DataFrame, collections.OrderedDict):
    '''
    Reads in particle star files, writes out a DataFrames with particles only and and an OrderedDict with all data
    '''
    if filename[-5:] != ".star":
        raise ValueError(f"{filename} is not in .star format")
    else:
        try:
            all_data = starfile.read(filename).copy()
            particle_df = all_data["particles"].copy()
        except:
            raise ValueError(f"Particle/data .star file {filename} does not contain particles or does not exist")
        return particle_df, all_data

def read_models_star(filename: str) -> (pd.DataFrame, collections.OrderedDict):
    '''
    Reads in models star files, writes out a DataFrames with model classes only and and an OrderedDict with all data
    '''
    if filename[-5:] != ".star":
        raise ValueError("File is not in .star format")
    else:
        try:
            all_data = starfile.read(filename).copy()
            models_df = all_data["model_classes"].copy()
        except:
            raise ValueError(f"Models .star file {filename} does not contain 2D classes or does not exist")
        return models_df, all_data
    
def read_mrcs_file(filename: str) -> (list, np.recarray):
    '''
    Reads in .mrcs files, writes out a list containing the frames and a Numpy recarray containing voxel sizes 
    '''
    if filename[-5:] != ".mrcs":
        raise ValueError(f"{filename} is not in .mrcs format")
    else:
        data = []
        try:
            with mrcfile.open(filename,permissive=True) as mrcs:
                for i, frame in enumerate(mrcs.data):
                    data.append(frame)
                voxel_size = mrcs.voxel_size
        except:
            raise ValueError(f"Unable to parse .mrcs file {filename}")
        return data, voxel_size
    
def read_mrc_file(filename: str) -> (np.ndarray, np.recarray):
    '''
    Reads in .mrc files, writes out a Numpy ndarray containing the data and a Numpy recarray containing voxel sizes 
    '''
    if filename[-4:] != ".mrc":
        raise ValueError(f"{filename} is not in .mrc format")
    else:
        try:
            with mrcfile.open(filename,permissive=True) as mrc:
                data = mrc.data
                voxel_size = mrc.voxel_size
        except:
            raise ValueError(f"Unable to parse .mrc file {filename}")
        return data, voxel_size
    
def write_particle_star(all_data: collections.OrderedDict, particle_df: pd.DataFrame, filename: str, overwrite: bool = True):
    '''
    Takes in the data_df as an OrderedDict and a DataFrame of particles and writes out a particle.star combining the two
    '''
    data_output = all_data.copy()
    data_output["particles"] = particle_df
    starfile.write(data_output, f"{filename}_particles.star", overwrite = overwrite)
    print("Particles saved as " + f"{filename}_particles.star")

def write_models_star(classes_df: pd.DataFrame, filename: str, overwrite: bool = True):
    '''
    Takes in classes as a DataFrame and writes out a classes.star
    '''
    starfile.write(classes_df, f"{filename}_classes.star", overwrite = overwrite)
    print("Class averages saved as " + f"{filename}_classes.star")
    
# Functions for adding columns to particle_df
def hash_particle_df(particle_df: pd.DataFrame, verbose: bool = True) -> pd.DataFrame:
    '''
    Adds a particle and filament hash to a copy of the input dataframe, replacing old hashes
    Optionally prints out total number of filaments and total particles
    '''
    hashed_particle_df = particle_df.copy()
    
    # Remove old hashes
    if "Hash ID" in hashed_particle_df.columns: # From old versions of FiTSuite, so keep for back compatibility
        if verbose:
            print("Dropping old 'Hash ID' column in hashed_particle_df")
        hashed_particle_df = hashed_particle_df.drop(columns = ["Hash ID"])
    if "Hash" in hashed_particle_df.columns: # From old versions of FiTSuite, so keep for back compatibility
        if verbose:
            print("Dropping old 'Hash' column in hashed_particle_df")
        hashed_particle_df = hashed_particle_df.drop(columns = ["Hash"])
    if "filamentHash" in hashed_particle_df.columns:
        if verbose:
            print("Dropping old 'filamentHash' column in hashed_particle_df")
        hashed_particle_df = hashed_particle_df.drop(columns = ["filamentHash"])
    if "particleHash" in hashed_particle_df.columns:
        if verbose:
            print("Dropping old 'particleHash' column in hashed_particle_df")
        hashed_particle_df = hashed_particle_df.drop(columns = ["particleHash"])
        
    # Add in hashes
    try:
        hashed_particle_df["filamentHash"] = hashed_particle_df["rlnHelicalTubeID"].astype(str) + hashed_particle_df["rlnMicrographName"]
        hashed_particle_df["particleHash"] = hashed_particle_df["rlnHelicalTubeID"].astype(str) + hashed_particle_df["rlnMicrographName"] + hashed_particle_df["rlnHelicalTrackLengthAngst"].astype(str)
    except:
        raise KeyError("Unable to hash particle_df as not all required fields exist")
        
    total_filaments = hashed_particle_df["filamentHash"].nunique()
    total_particles = hashed_particle_df["particleHash"].nunique()
    assert total_particles == len(hashed_particle_df.index), f"particle_df contains duplicate particles"
    if verbose:
        print(f"Total filaments: {total_filaments} and total particles: {total_particles}")
        
    return hashed_particle_df

def add_classnumber_to_particle_df(hashed_particle_df: pd.DataFrame, particleHashClassDict: dict, verbose: bool = True) -> pd.DataFrame:
    '''
    Takes in a hashed_particle_df and adds the rlnClassNumber of each particle from a dict to a copy of the DataFrame
    Replaces old rlnClassNumber if it exists
    '''
    classadded_particle_df = hashed_particle_df.copy()
    
    # Remove old labels
    if  'rlnClassNumber' in classadded_particle_df.columns:
        if verbose:
            print("Dropping old 'rlnClassNumber' column in classadded_particle_df")
        classadded_particle_df = classadded_particle_df.drop(columns = ["rlnClassNumber"])
    
    # Add in classes
    try:
        classadded_particle_df['rlnClassNumber']=classadded_particle_df['particleHash'].map(particleHashClassDict)
    except:
        raise KeyError("Unable to assign classes as not all particle assignments exist or particle_df not hashed")
     
    total_classes = classadded_particle_df['rlnClassNumber'].nunique()
    if verbose:
        print("Total classes:", total_classes)
        
    return classadded_particle_df

def add_classID_to_particle_df(hashed_particle_df: pd.DataFrame, classIDdict: dict, verbose: bool = True) -> pd.DataFrame:
    '''
    Takes in a hashed_particle_df and adds the classID of each particle from a dict to a copy of the DataFrame
    Replaces old classID if it exists
    '''
    classIDadded_particle_df = hashed_particle_df.copy()
    
    # Remove old labels
    if 'classID' in classIDadded_particle_df.columns:
        if verbose:
            print("Dropping old 'classID' column in classIDadded_particle_df")
        classIDadded_particle_df = classIDadded_particle_df.drop(columns = ["classID"])
    
    # Add in classes
    try:
        classIDadded_particle_df['classID'] = classIDadded_particle_df['rlnClassNumber'].map(classIDdict)
    except:
        raise KeyError("Unable to assign class IDs as not all class assignments exist or particle_df does not class numbers")
     
    total_classIDs = classIDadded_particle_df['classID'].nunique()
    if verbose:
        print("Total class IDs:", total_classIDs)
        
    return classIDadded_particle_df

def add_filamentID_to_particle_df(hashed_particle_df: pd.DataFrame, filamentIDdict: dict, verbose: bool = True) -> pd.DataFrame:
    '''
    Takes in a hashed_particle_df and adds the filamentID of each particle from a dict to a copy of the DataFrame
    Replaces old filamentID if it exists
    '''
    filamentIDadded_particle_df = hashed_particle_df.copy()
    
    # Remove old labels
    if  'filamentID' in filamentIDadded_particle_df.columns:
        if verbose:
            print("Dropping old 'filamentID' column in filamentIDadded_particle_df")
        filamentIDadded_particle_df = filamentIDadded_particle_df.drop(columns = ["filamentID"])
    
    # Add in classes
    try:
        filamentIDadded_particle_df['filamentID']=filamentIDadded_particle_df['filamentHash'].map(filamentIDdict)
    except:
        raise KeyError("Unable to assign filament IDs as not all particle assignments exist or particle_df not hashed")
     
    total_filamentIDs = filamentIDadded_particle_df['filamentID'].nunique()
    if verbose:
        print("Total filament IDs:", total_filamentIDs)
        
    return filamentIDadded_particle_df

def add_particleID_to_particle_df(hashed_particle_df: pd.DataFrame, particleIDdict: dict, verbose: bool = True) -> pd.DataFrame:
    '''
    Takes in a hashed_particle_df and adds the particleID of each particle from a dict to a copy of the DataFrame
    Replaces old particelID if it exists
    '''
    particleIDadded_particle_df = hashed_particle_df.copy()
    
    # Remove old labels
    if  'particleID' in particleIDadded_particle_df.columns:
        if verbose:
            print("Dropping old 'particleID' column in particleIDadded_particle_df")
        particleIDadded_particle_df = particleIDadded_particle_df.drop(columns = ["particleID"])
    
    # Add in classes
    try:
        particleIDadded_particle_df['particleID']=particleIDadded_particle_df['particleHash'].map(filamentIDdict)
    except:
        raise KeyError("Unable to assign particle IDs as not all particle assignments exist or particle_df not hashed")
     
    total_particleIDs = particleIDadded_particle_df['particleID'].nunique()
    if verbose:
        print("Total particle IDs:", total_particleIDs)
        
    return particleIDadded_particle_df

# Functions for computing or manipulating counts_df s
def compute_particle_counts_df(hashed_particle_df: pd.DataFrame, classIDdict: dict = None, filamentIDdict: dict = None, 
                               verbose: bool = True) -> (pd.DataFrame, pd.DataFrame, Counter, dict, Counter, dict):
    '''
    Computes (class x filament) dataframe of counts of particles per class per filament given a hashed_particle_df
    If provided with a previously computed classIDdict or filamentIDdict can use those too
    Also outputs other useful info - counts of particles per class/filament and dicts for filament/class hash to ID lookup
    '''
    # Check that input is valid
    assert "rlnClassNumber" in hashed_particle_df.columns, "hashed_particle_df does not contain class numbers"
    assert "filamentHash" in hashed_particle_df.columns, "hashed_particle_df does not contain filament hashes"
    
    # Compute dimensions of matrix and make up dicts to calculate IDs from
    total_classes = hashed_particle_df["rlnClassNumber"].nunique()
    total_filaments = hashed_particle_df["filamentHash"].nunique()
    total_particles = len(hashed_particle_df.index)
    classcount = Counter(hashed_particle_df["rlnClassNumber"].to_list())
    if classIDdict == None:
        classIDdict = {val:idx for idx, val in enumerate(sorted(classcount.keys()))}
    filamentcount = Counter(hashed_particle_df["filamentHash"].to_list())
    if filamentIDdict == None:
        filamentIDdict = {val:idx for idx, val in enumerate(filamentcount.keys())}
    if verbose:
        print(f"Total classes = {total_classes}, total particles = {total_particles} and total filaments = {total_filaments}")
        print("Particles per class: ", classcount)
    
    # Add filament and class IDs to hashed particles
    IDadded_particle_df = add_classID_to_particle_df(hashed_particle_df, classIDdict, verbose = False)
    IDadded_particle_df = add_filamentID_to_particle_df(IDadded_particle_df, filamentIDdict, verbose = False)

    #Compute counts_df (note that pivot table approach runs much slower)
    # counts_df = pd.pivot_table(IDadded_particle_df, values='particleHash', index='classID', columns='filamentID', aggfunc=pd.Series.nunique, fill_value = 0)
    countsmatrix = np.zeros((len(classIDdict),len(filamentIDdict)))
    for x, y in zip(IDadded_particle_df["classID"],IDadded_particle_df["filamentID"]):
        countsmatrix[x, y] += 1
    
    # Making row and col labels strings facilitates plotting, row labels are original class numbers
    counts_df = pd.DataFrame(countsmatrix, index = [str(x) for x in classIDdict.keys()], columns = [str(x) for x in range(len(filamentIDdict))])
    counts_df.index.name = "rlnClassNumber"
    counts_df.columns.name = "filamentID"
    
    return counts_df, IDadded_particle_df, classcount, classIDdict, filamentcount, filamentIDdict

# Functions for plotting figures
def clusterplot_counts_df(counts_df: pd.DataFrame, filepath: str = None, savefig: bool = True, dpi: int = 300,
                          metric: str = 'cosine', vmax: int = None, cmap: str = 'inferno', standardize: int = None,
                          row_colors: list = None, col_colors: list = None, 
                          row_linkage: np.ndarray = None, col_linkage: np.ndarray = None, 
                          figsize_x: int = 25, figsize_y: int = 25, label_fontsize: int = 15, 
                          dendrogram_ratio: float = 0.2, colors_ratio: float = 0.03, cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), 
                          panel_label: bool = False, panel_label_letter: str = "a",  panel_label_fontsize: int = 15) -> sns.matrix.ClusterGrid:
    '''
    Takes in counts_df and plots clustermap using a UPGMA/average hierarchical clustering
    Can modify various parameters and optionally save figure
    Refer to https://seaborn.pydata.org/generated/seaborn.clustermap.html for most params
    cmaps that are perceptually uniform sequential are viridis, plasma, inferno, magma, cividis
    '''
    clustered = sns.clustermap(counts_df, metric = metric, vmax = vmax, cmap = cmap, standard_scale = standardize,
                               figsize =(figsize_x, figsize_y), row_linkage = row_linkage, col_linkage = col_linkage,
                               row_colors = row_colors, col_colors = col_colors, cbar_kws={'label': 'Particle Counts'},
                               dendrogram_ratio = dendrogram_ratio, colors_ratio = colors_ratio, cbar_pos=cbar_pos)
    
    # Add labels
    clustered.ax_heatmap.set_ylabel("2D Class Average Numbers", fontsize = label_fontsize)
    clustered.ax_heatmap.set_xlabel("Filament IDs", fontsize = label_fontsize)
    cbar = clustered.ax_heatmap.collections[0].colorbar
    cbar.ax.yaxis.label.set_fontsize(label_fontsize)
    
    # For figure plotting can add a panel label, if you do this almost certainly want to move the cbar_pos 
    # With a panel_label_fontsize of 15, cbar_pos = (0.1, 0.82, 0.05, 0.16) seems to work
    if panel_label:
        font = {'fontname': 'Arial', 'fontsize': panel_label_fontsize, 'fontweight': 'bold'}
        plt.text(0.02, 0.97, panel_label_letter, fontdict = font, transform=plt.gcf().transFigure)
    plt.show()
    
    if savefig:
        if filepath != None:
            clustered.savefig(f"{filepath}_clustered.png", dpi=dpi)
            print("Figure saved as "+f"{filepath}_clustered.png")
        else:
            clustered.savefig("counts_df_clustered.png", dpi=dpi)
            print("Figure saved as counts_df_clustered.png")
    
    return clustered

def dimensionality_reduction_plot(counts_df, mode: str = "UMAP", savefig: bool = True, filepath: str = None, 
                                figsize_x: float = 25, figsize_y: float = 25, subfigure_label_fontsize = 25, 
                                title_fontsize = 30, axis_label_fontsize = 20, alpha = 0.1,
                                row_colors: list = None, col_colors: list = None, random_seed: int = 365):
    '''
    Performs dimensionality reduction on filaments or classes and plots/optionally saves figure
    Can choose from PCA/t-SNE/UMAP, and can also color by type by providing color labels
    Generally, no real reason to use this over the hierarchical clustering plot but it is available.
    '''
    # Choose right mode
    if mode == "PCA":
        optimiser = PCA(n_components=2, random_state = random_seed)
        # Seems weird to have a seed for PCA, but it just depends on the solver used
    elif mode == "t-SNE":
        optimiser = TSNE(n_components=2, random_state = random_seed)
    elif mode == "UMAP":
        optimiser = umap.UMAP(random_state = random_seed)
    else:
        raise ValueError("Mode not defined")
        
    # Compute dimensionality reduction
    countsmatrix = np.array(counts_df)
    ClassReduced = optimiser.fit_transform(countsmatrix)
    ClassReducedrownormed = optimiser.fit_transform(countsmatrix/np.sum(countsmatrix, axis=1, keepdims=True))
    FilsReduced = optimiser.fit_transform(np.transpose(countsmatrix))
    FilsReducedcolnormed = optimiser.fit_transform(np.transpose(countsmatrix/np.sum(countsmatrix, axis=0, keepdims=True)))

    # Plot figure
    fig, axs = plt.subplots(2, 2, figsize = (figsize_x, figsize_y))
    if row_colors == None:
        axs[0,0].scatter(ClassReduced[:, 0], ClassReduced[:, 1], alpha = alpha, s = 100)
        axs[0,1].scatter(ClassReducedrownormed[:, 0], ClassReducedrownormed[:, 1], alpha = alpha, s = 100)
    else:
        axs[0,0].scatter(ClassReduced[:, 0], ClassReduced[:, 1], alpha = alpha, s = 100, c = row_colors)
        axs[0,1].scatter(ClassReducedrownormed[:, 0], ClassReducedrownormed[:, 1], alpha = alpha, s = 100, c = row_colors)
    if col_colors == None:
        axs[1,0].scatter(FilsReduced[:, 0], FilsReduced[:, 1], alpha = alpha)
        axs[1,1].scatter(FilsReducedcolnormed[:, 0], FilsReducedcolnormed[:, 1], alpha = alpha)
    else:
        axs[1,0].scatter(FilsReduced[:, 0], FilsReduced[:, 1], alpha = alpha, c = col_colors)
        axs[1,1].scatter(FilsReducedcolnormed[:, 0], FilsReducedcolnormed[:, 1], alpha = alpha, c = col_colors)
    
    # Add labels, adjust layout
    axs[0,0].set_title("Class Averages", fontname = 'Arial', fontsize = subfigure_label_fontsize)
    axs[0,1].set_title("Normalized Class Averages", fontname = 'Arial', fontsize = subfigure_label_fontsize)
    axs[1,0].set_title("Unique Filaments", fontname = 'Arial', fontsize = subfigure_label_fontsize)
    axs[1,1].set_title("Normalized Unique Filaments", fontname = 'Arial', fontsize = subfigure_label_fontsize)
    for axis in [axs[0,0], axs[0,1], axs[1,0], axs[1,1]]:
        axis.set_xlabel(f"{mode} dimension 1", fontname = 'Arial', fontsize = axis_label_fontsize)
        axis.set_ylabel(f"{mode} dimension 2", fontname = 'Arial', fontsize = axis_label_fontsize)
    fig.suptitle(f"Dimensionality Reduction by {mode}", fontname = 'Arial', fontsize=title_fontsize, y = 0.99)
    plt.subplots_adjust(wspace=0.1, hspace=0.1, top=0.95)
    # plt.tight_layout()
    plt.show()
    
    if savefig:
        if filepath != None:
            fig.savefig(f"{filepath}_dimensionalityreduction_{mode}.png")
            print(f"Figure saved as {filepath}_dimensionalityreduction_{mode}.png")
        else:
            fig.savefig(f"counts_df_dimensionalityreduction_{mode}.png")
            print(f"Figure saved as counts_df_dimensionalityreduction_{mode}.png")
            
def plot_class_averages(class_mrcs_file: str, frame_order: list = None, classes_df: pd.DataFrame = None,
                        savefig: bool = True, filepath: str = None, dpi: int = 300,
                        number_of_images: int = None, images_per_row: int = 10, 
                        label_x: int = 15, label_y: int = 225,
                        label_fontsize: int = 10, figsize_x: float = 25, color_list: list = None,
                        panel_label: bool = False, panel_label_letter: str = "c",  panel_label_fontsize: int = 15,
                        panel_x = 0.013333, panel_y = 0.98):
    '''
    Plots specified number of class averages from a mrcs file or the maximum possible
    Can plot them in a specified order from a list/classes_df (frame_order prioritized over classes_df) with colors
    
    Params:
        class_mrcs_file = path to mrcs.star file as string (required)
        frame_order = list containing frames (indexed from 0) desired
        classes_df = DataFrame, containing the frames in their desired order in the 'rlnReferenceImage' column
        savefig = True or False, boolean flag
        filepath = path name that will be appended to the front of anything written out
        dpi = int, default 300
        number_of_images = int, number of images desired, will plot less if there aren't enough
        images_per_row = int, default 10
        label_fontsize = int, default 10
        figsize_x = float, default 25, width of figure, will be autoscaled
        color_list = list of colors already in the desired frame order
        label_x = 15
        label_y = 225
        panel_label: Controls whether to add panel label for figure making
        panel_label_letter: str, default "b",  
        panel_label_fontsize: int, default 15
        panel_x = 0.013333
        panel_y = 0.98
    '''
    frame_data, voxel_size = read_mrcs_file(class_mrcs_file)
    
    # Set up params
    if frame_order != None:
        frame_order_list = frame_order
    elif classes_df is not None:
        frame_order_list = [int(framename.split("@")[0])-1 for framename in classes_df['rlnReferenceImage']]
    else:
        frame_order_list = list(range(0, len(frame_data)))
        
    if number_of_images == None:
        n = len(frame_order_list)
    else:
        n = min(number_of_images, len(frame_order_list)) 
        
    # Plot figure
    if panel_label:
        fig, axes = plt.subplots(int(np.ceil(n/images_per_row)), images_per_row, figsize = (figsize_x, (int(np.ceil(n/images_per_row))+0.1)*figsize_x/images_per_row))
    else:
        fig, axes = plt.subplots(int(np.ceil(n/images_per_row)), images_per_row, figsize = (figsize_x, int(np.ceil(n/images_per_row))*figsize_x/images_per_row))
    axes = axes.flatten()
    width, height = frame_data[0].shape
    for i, ax in enumerate(axes):
        if i < n:
            try:
                ax.imshow(frame_data[frame_order_list[i]], cmap="gray", origin="lower")
                # sns.heatmap(frame_data[frame_order_list[i]], ax = ax, cmap = 'gray', cbar = False, square = True)
                ax.text(label_x, label_y, str(frame_order_list[i]+1), fontname = 'Arial', fontfamily = 'sans-serif', fontsize = label_fontsize, c = 'white') 
                ax.axis('off')
                if color_list is not None:
                    ax.add_patch(patches.Rectangle((0, 0), width, height, fill = False, lw = 4, ec = color_list[i]))
            except:
                ax.axis('off')
        else:
            ax.axis('off')
    if panel_label:
        font = {'fontname': 'Arial', 'fontsize': panel_label_fontsize, 'fontweight': 'bold'}
        fig.suptitle(panel_label_letter, x = panel_x, y = panel_y, fontname = 'Arial', fontsize = panel_label_fontsize, fontweight = 'bold') #fontdict = font)
        plt.tight_layout(w_pad = 0.1, h_pad = 0.1, pad = 0.4)
    else:
        plt.tight_layout(w_pad = 0.5) #, h_pad=0.5)  
    plt.show()
    
    if savefig:
        if filepath != None:
            fig.savefig(f"{filepath}_classaverages.png", dpi=dpi)
            print("Figure saved as "+f"{filepath}_classaverages.png")
        else:
            fig.savefig("classaverages.png", dpi=dpi)
            print("Figure saved as classaverages.png")
            
# Functions for generating clustering and using clustering
def typecluster_initial_clustering(particles_star_file: str, output_path: str = None, savefig: bool = True, verbose: bool = True,
                                   metric: str = "cosine",  vmax: int = None, cmap: str = "inferno", standardize: int = None, 
                                   figsize_x: int = 25, figsize_y: int =25, label_fontsize: int = 15, dpi: int = 300,
                                   dendrogram_ratio: float = 0.2, cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), 
                                   panel_label: bool = False, panel_label_letter: str = "a",  panel_label_fontsize: int = 15
                                   ) -> (sns.matrix.ClusterGrid, pd.DataFrame, pd.DataFrame, collections.OrderedDict, Counter, dict, Counter, dict):
    '''
    Computes 2D class x unique filament particle count matrix
    Plots and saves initial hierarchical clustering figure with distograms and performs other useful calculations
    If saved, hierachical clustering will be saved as {output_path}_clustered.png or "counts_df_clustered.png"
    
    General params:
        particles_star_file = path to particle/data.star file as string, needs to have rlnClassNumbers assigned
                              e.g. from Select or Class2D jobs (required!)
        savefig = True or False, boolean flag
        output_path = path name that will be appended to the front of anything written out
        verbose = True or False, boolean flag for extra text output
    
    Clustering Figure params (all optional):
        metric = distance metric as string, default is "cosine", "jaccard" and "euclidean" also can work well
        savefig = bool, default True, controls whether figures are saved
        vmax = int, maximum value for colormap in hierarchical clustering e.g. 15
        cmap = color map as string, default "inferno", ideally perceptually uniform sequential cmaps
        standardize = int, normalize counts matrix by row (0) or column (1)
        figsize_x, figsize_y = ints, default 25, figure dimensions to be saved
        label_fontsize: int, default 15, 
        dpi: int, default 300,
        dendrogram_ratio: float controlling ratio of plot that is dendrogram 
        cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), 
        panel_label: Controls whether to add panel label for figure making
        panel_label_letter: str, default "a",  
        panel_label_fontsize: int, default 15
    '''
    # Read in relevant data and calculate class x filament particle count matrix
    particle_df, all_data = read_particle_star(particles_star_file)
    hashed_particle_df = hash_particle_df(particle_df, verbose = False)
    counts_df, IDadded_particle_df, classcount, classIDdict, filamentcount, filamentIDdict = compute_particle_counts_df(hashed_particle_df, verbose = verbose) 
    
    # Compute and plot hierarchical clustering
    clustered = clusterplot_counts_df(counts_df, filepath = output_path, savefig = savefig, dpi = dpi,
                                      metric = metric, vmax = vmax, cmap = cmap, standardize = standardize,
                                      figsize_x = figsize_x, figsize_y = figsize_y, label_fontsize = label_fontsize, 
                                      dendrogram_ratio = dendrogram_ratio, cbar_pos = cbar_pos, 
                                      panel_label = panel_label, panel_label_letter = panel_label_letter,  panel_label_fontsize = panel_label_fontsize)
    
    return clustered, counts_df, IDadded_particle_df, all_data, classcount, classIDdict, filamentcount, filamentIDdict

def reorder_class_average_star_by_clustered(model_star_file: str, clustered: sns.matrix.ClusterGrid, classIDdict: dict,
                                            output_path: str = None, savestar: bool = True, overwrite: bool = True
                                           ) -> pd.DataFrame:
    '''
    Reorders class averages in a model_classes_df and appends a rlnClassNumber and clusteredOrder column
    Writes out if desired, and can be visualized by plot_class_averages
    
    Params:
        model_star_file = path to model.star file as string, required if classes are to be written out
        clustered = sns.matrix.ClusterGrid, output e.g. from typecluster_initial_clustering
        classIDdict = dict, output e.g. from typecluster_initial_clustering
        output_path = path name that will be appended to the front of anything written out
        savestar = bool, default True, controls whether reordered star file will be written out
        overwrite = bool, default True, controls whether it is okay to overwrite existing star
    '''
    # Read in order
    ordered_row_indices = clustered.dendrogram_row.reordered_ind
    classes_to_keep_dict = {classID:ordered_row_indices.index(classIDdict[classID]) for classID in classIDdict.keys()}

    # Read in and add columns to model_star
    model_classes_df, all_model_data = read_models_star(model_star_file)
    reordered_classes_df = model_classes_df.copy()
    reordered_classes_df['rlnClassNumber'] = list(np.arange(1,len(reordered_classes_df.index)+1))
    
    # Filter 2D class files with new order and write out
    reordered_classes_df = reordered_classes_df[reordered_classes_df['rlnClassNumber'].isin(set(classes_to_keep_dict.keys()))]
    reordered_classes_df['clusteredOrder'] = reordered_classes_df['rlnClassNumber'].map(classes_to_keep_dict)
    reordered_classes_df = reordered_classes_df.sort_values(by=['clusteredOrder'])
    
    if savestar:
        if output_path == None:
            output_path = "clustered"
        write_models_star(reordered_classes_df, filename = f"{output_path}_reordered", overwrite = overwrite)
    
    return reordered_classes_df

def reorder_counts_df_by_clustered(counts_df: pd.DataFrame, clustered: sns.matrix.ClusterGrid) -> pd.DataFrame:
    '''
    Takes in a counts_df and a clustermap and rearranges counts_df to match order
    '''
    new_cols = counts_df.columns[clustered.dendrogram_col.reordered_ind]
    new_ind = counts_df.index[clustered.dendrogram_row.reordered_ind]
    reordered_counts_df = counts_df.loc[new_ind, new_cols]
    return reordered_counts_df

# Functions for interactive selection, visualization and output
def typecluster_interactive_select(counts_df: pd.DataFrame, clustered: sns.matrix.ClusterGrid, vmax: int = 30, 
                                   cmap: list = px.colors.sequential.Inferno, figsize_x: int = 1000, 
                                   figsize_y: int =1000) -> list:
    '''
    Replots hierarchically ordered counts_df using plotly, allowing for interactive selection of params
    Parameters are then passed to an typecluster_interactive_output function
    
    Params:
        counts_df = DataFrame containing counts, e.g. from typecluster_initial_clustering
        clustered = Clustermap, e.g. from typecluster_initial_clustering
        vmax = int, maximum value for colormap in hierarchical clustering e.g. 15
        cmap = color map as list from plotly, default px.colors.sequential.Inferno
        figsize_x, figsize_y = ints, default 1000, figure dimensions to be saved as pixels
    '''
    # Compute new df that is reordered by the clustering
    clustered_df = reorder_counts_df_by_clustered(counts_df, clustered)
    clustered_df.index = [clustered_df.index[x]+"_"+str(x) for x in range(len(clustered_df.index))]
    clustered_df.columns = [clustered_df.columns[x]+"_"+str(x) for x in range(len(clustered_df.columns))]
    
    # Set up sliders
    ynum, xnum = clustered_df.shape
    x1 = widgets.IntSlider(value=1, min=0.0, max=xnum-1, step=1.0, description='x-start:', continuous_update=False)
    y1 = widgets.IntSlider(value=1, min=0.0, max=ynum-1, step=1.0, description='y-start:', continuous_update=False)
    x2 = widgets.IntSlider(value=2, min=0.0, max=xnum-1, step=1.0, description='x-end:', continuous_update=False)
    y2 = widgets.IntSlider(value=2, min=0.0, max=ynum-1, step=1.0, description='y-end:', continuous_update=False)
    container = widgets.HBox(children=[x1, y1, x2, y2])

    # Set up figure
    fig = px.imshow(clustered_df, aspect = 'auto', labels=dict(y="2D Class Number (name + coordinate #)", x="Filament Hash ID (name + coordinate #)", color="Counts"),
                   color_continuous_scale=cmap, range_color = [0,vmax])
    fig.add_selection(x0=1, y0=1, x1=2, y1=2, line=dict(color="White",width=1,dash="dash"))
    fig.update_layout(dragmode = "zoom")
    fig.update_layout(height = figsize_x, width = figsize_y)
    interactive_fig = go.FigureWidget(fig)

    # Set up interaction
    def response(change):
        with interactive_fig.batch_update():
            interactive_fig.layout.selections[0]['x0'] = x1.value
            interactive_fig.layout.selections[0]['x1'] = x2.value
            interactive_fig.layout.selections[0]['y0'] = y1.value
            interactive_fig.layout.selections[0]['y1'] = y2.value

    x1.observe(response, names="value")
    x2.observe(response, names="value")
    y1.observe(response, names="value")
    y2.observe(response, names="value")
    
    # Show figure
    sliders = [x1, x2, y1, y2]
    output = widgets.VBox([container, interactive_fig])
    display(output)
    
    return sliders
    
def typecluster_interactive_output(sliders: list, mode: str, output_path: str, clustered: sns.matrix.ClusterGrid, 
                                   counts_df: pd.DataFrame, IDadded_particle_df: pd.DataFrame = None, all_data: collections.OrderedDict = None,
                                   classIDdict: dict = None, model_star_file: str = None, overwrite: bool = True,
                                   drophashes: bool = True) -> pd.DataFrame:
    '''
    
    Params:
        sliders: list of 4 sliders from typecluster_interactive_select
        mode = can be "Filaments", "Classes", or "Particles" depending on what is to be written out
        output_path = path name that will be appended to the front of anything written out (e.g. /TypeClassifer/group1)
        counts_df = DataFrame containing counts, e.g. from typecluster_initial_clustering
        clustered = Clustermap, e.g. from typecluster_initial_clustering
        IDadded_particle_df = particle_df with filament and class IDs, e.g. from 
        all_data = original read in of star file
        model_star_file = path for model.star file, if mode = "Classes"
        overwrite = boolean that controls whether it is okay to overwrite existing star files
    '''
    # Extract values from sliders
    [x1, x2, y1, y2] = sliders
    class_start = min(y1.value, y2.value)
    class_end = max(y1.value, y2.value)
    hash_start = min(x1.value, x2.value)
    hash_end = max(x1.value, x2.value)
    ynum, xnum = counts_df.shape
    
    # Find indices corresponding to these regions
    if class_start != class_end:
        if class_end != ynum-1:
            selected_2D_class_inds = clustered.dendrogram_row.reordered_ind[class_start:class_end+1]
        else:
            selected_2D_class_inds = clustered.dendrogram_row.reordered_ind[class_start:]        
    else:
        selected_2D_class_inds = clustered.dendrogram_row.reordered_ind[class_start]
        
    if hash_start != hash_end:
        if hash_end != xnum-1:
            selected_hash_inds = clustered.dendrogram_col.reordered_ind[hash_start:hash_end+1]
        else:
            selected_hash_inds = clustered.dendrogram_col.reordered_ind[hash_start:]        
    else:
        selected_hash_inds = clustered.dendrogram_col.reordered_ind[hash_start]
    
    # Write df that only contains selected particles from clustered matrix
    all_selected_df = counts_df.loc[counts_df.index[selected_2D_class_inds], counts_df.columns[selected_hash_inds]]
    selected_classes = set([int(x) for x in all_selected_df.index])
    selected_filaments = set([int(x) for x in all_selected_df.columns])
    
    # Write out correct type of file
    if mode == "Classes": # Writing out all selected 2D classes
        assert (model_star_file is not None and classIDdict is not None), "Not all required inputs provided"
        filtered_df = reorder_class_average_star_by_clustered(model_star_file, clustered, classIDdict, savestar = False)
        filtered_df = filtered_df[filtered_df['rlnClassNumber'].isin(selected_classes)]
        write_models_star(filtered_df, f"{output_path}_selected", overwrite = overwrite)
    elif mode == "Filaments": # Writing out all particles from selected filaments
        assert (IDadded_particle_df is not None and all_data is not None), "Not all required inputs provided"
        filtered_df = IDadded_particle_df.copy()
        filtered_df = filtered_df[filtered_df["filamentID"].isin(selected_filaments)]
        if drophashes:
            try:
                columns_to_drop = list(set(["filamentID", "particleID", "filamentHash", "particleHash", "classID"]).intersection(filtered_df.columns))
                filtered_df = filtered_df.drop(columns = columns_to_drop)
            except:
                print("Unable to drop hashes")
        write_particle_star(all_data, filtered_df, f"{output_path}_selected_filament", overwrite = overwrite)
    elif mode == "Particles":
        assert (IDadded_particle_df is not None and all_data is not None), "Not all required inputs provided"
        filtered_df = IDadded_particle_df.copy()
        filtered_df = filtered_df[filtered_df["filamentID"].isin(selected_filaments)&filtered_df["rlnClassNumber"].isin(selected_classes)]
        if drophashes:
            try:
                columns_to_drop = list(set(["filamentID", "particleID", "filamentHash", "particleHash", "classID"]).intersection(filtered_df.columns))
                filtered_df = filtered_df.drop(columns = columns_to_drop)
            except:
                print("Unable to drop hashes")
        write_particle_star(all_data, filtered_df, f"{output_path}_selected", overwrite = overwrite)
    else:
        raise ValueError("Current mode is undefined. Please choose from 'Classes', 'Filaments', or 'Particles'")
    
    return filtered_df

# Functions for automated cluster selection, visualization, and output
def typecluster_compute_kmeans_labels(counts_df, col_cluster_number: int = 10, row_cluster_number: int = 10, random_seed: int = 365) -> (list, list): 
    '''
    Compute k-means labels for class averages (row_clusters) and unique filaments (col_clusters)
    Returns a list for each containing the cluster number of each
    '''
    countsmatrix = np.array(counts_df)
    # Compute k-means on rows without normalization, but could in theory do /np.sum(countsmatrix, axis=0, keepdims=True))
    row_labels = KMeans(n_clusters=row_cluster_number, random_state=random_seed, n_init = 'auto').fit_predict(countsmatrix) 
    # Column k-means done with normalization
    countsmatrix_colnormalized = countsmatrix/np.sum(countsmatrix, axis=0, keepdims=True)
    col_labels = KMeans(n_clusters=col_cluster_number, random_state=random_seed, n_init = 'auto').fit_predict(np.transpose(countsmatrix_colnormalized))
    return row_labels, col_labels

def typecluster_dendrogram_threshold_select(clustered: sns.matrix.ClusterGrid, threshold: float):
    '''
    Can try out various cosine distance thresholds here and visualize on the distogram
    '''
    fig, ax = plt.subplots(figsize = (40, 25))
    dendrogram(clustered.dendrogram_col.linkage, ax = ax, color_threshold = threshold, no_labels = True)
    
def typecluster_interactive_dendrogram_thresholding(clustered_df_name: str = "clustered") -> (widgets.FloatSlider, widgets.widget_box.HBox):
    '''
    SUPER hacky way to do interactive threshold selection on dendrograms
    Wish it was done in a much better way but wasted too much time trying to fix a widgets implementation
    Notably, second output needs to be hbox
    '''
    # Set up widgets
    threshold_slider = widgets.FloatSlider(min=0, max=1, step=0.001, description='Threshold:', readout_format='.3f')
    run_button = widgets.Button(description='Run')
    
    # Set up response
    def run_button_clicked(button):
        command1 = "IPython.display.clear_output()" # So we don't keep a bunch of plots
        exec(command1)
        command2 = "display(hbox)" # Have to regenerate sliders, and thus need to return it as well
        exec(command2)
        command3 = f"typecluster_dendrogram_threshold_select({clustered_df_name}, {threshold_slider.value})"
        exec(command3)
    
    # Make widget
    run_button.on_click(run_button_clicked)
    hbox = widgets.HBox([threshold_slider, run_button])
    display(hbox)
    return threshold_slider, hbox

def typecluster_compute_dendrogram_threshold_labels(clustered: sns.matrix.ClusterGrid, threshold: float, 
                                                    output_path: str, counts_df: pd.DataFrame, savefig: bool = True,
                                                    vmax: int = None, cmap: str = "inferno", standardize: int = None,
                                                    figsize_x: int = 25, figsize_y: int = 25, label_fontsize: int = 15, 
                                                    dpi: int = 300, dendrogram_ratio: float = 0.2, cbar_pos: tuple = (0.02, 0.81, 0.05, 0.17), 
                                                    panel_label: bool = False, panel_label_letter: str = "a",  panel_label_fontsize: int = 15
                                                   ) -> (list, Counter, sns.matrix.ClusterGrid):
    '''
    Uses fcluster to find filament clusters under a minimum distance_threshold (nominally cosine distance)
    Outputs cluster labels for filaments and a counter of filaments/cluster
    Also plots labeled clusters on clustermap
    
    Params:
        clustered = Clustermap, e.g. from typecluster_initial_clustering
        threshold = float that is the maximum average cosine distance for a cluster
                    probably either a float threshold or threshold_slider.value
        counts_df = DataFrame containing counts, e.g. from typecluster_initial_clustering
        output_path = path name that will be appended to the front of anything written out (e.g. /TypeClassifer/group1)
        
    Otherwise, most params the same as `typecluster_initial_clustering` for visualization purposes
    '''
    col_clustered_labels = fcluster(clustered.dendrogram_col.linkage, t=threshold, criterion='distance')
    
    # Compute figure
    col_color_labels = convert_labels_to_colors(col_clustered_labels)
    output_pathused = output_path+f"_{threshold}threshold"
    labeled_clustered = clusterplot_counts_df(counts_df, filepath = output_pathused, savefig = savefig, dpi = dpi,
                                              row_linkage = clustered.dendrogram_row.linkage, col_linkage = clustered.dendrogram_col.linkage,
                                              col_colors = col_color_labels, vmax = vmax, cmap = cmap, standardize = standardize,
                                              figsize_x = figsize_x, figsize_y = figsize_y, label_fontsize = label_fontsize, 
                                              dendrogram_ratio = dendrogram_ratio, cbar_pos = cbar_pos, 
                                              panel_label = panel_label, panel_label_letter = panel_label_letter,  panel_label_fontsize = panel_label_fontsize)

    return col_clustered_labels, labeled_clustered

def convert_labels_to_colors(labels, color_dict: dict = None, plot_legend: bool = True, verbose: bool = True) -> list:
    '''
    Converts a list of labels (ints) to a list of colors either through a specificed color_dict or by uniform sampling
    '''
    counted_labels = Counter(labels)
    if verbose:
        print(f"Number of filaments/classes per label: {counted_labels}")
    unique_labels = len(counted_labels.keys())
    
    if color_dict != None:
        assert len(color_dict) >= unique_labels, "Not enough colors specified in provided color_dict"
        try:
            colors = [color_dict[label] for label in labels]
        except:
            raise KeyError("Unable to convert labels to colors using provided color_dict")
    else:
        colors = sns.color_palette("husl", unique_labels)
        color_dict = {label_id: colors[i] for i, label_id in enumerate(counted_labels.keys())}
        colors = [color_dict[label] for label in labels]
    
    if plot_legend:
        # Create subplots for each color
        num_rows = int(np.ceil(unique_labels/20))
        fig, axs = plt.subplots(num_rows, 20, figsize=(20, num_rows))

        # Flatten the axs array if it has multiple dimensions
        if isinstance(axs, np.ndarray):
            axs = axs.flatten()

        # Iterate over the color list and display each color in a subplot
        for i, label in enumerate(sorted(counted_labels.keys())):
            ax = axs[i]
            ax.set_facecolor(color_dict[label])
            ax.set_xticks([])
            ax.set_yticks([])
            ax.set_title(str(label))
        
        # Remove empty subplots if necessary
        if unique_labels < len(axs):
            for j in range(unique_labels, len(axs)):
                axs[j].axis('off')
        
        # Adjust spacing between subplots
        plt.subplots_adjust(hspace=0.5, wspace=0.1)
        plt.show()
    
    return colors

def typecluster_output_filaments_from_labels(col_labels: list, output_path: str, IDadded_particle_df: pd.DataFrame,
                                             all_data: collections.OrderedDict, particle_threshold: int = 1000, 
                                             overwrite: bool = True, drophashes: bool = True, verbose: bool = True
                                            ) -> pd.DataFrame:
    '''
    Writes out .star files containing particles from clusters with particle #s exceeding a threshold
    Clusters that fail this threshold are merged and output as one joint merged particle .star file
    Filament clusters specified by col_labels as a list, which come from k-means or distogram thresholding
    
    Params:
        col_labels = List of labels, probably from kmeans or distogram thresholding
        output_path = path name that will be appended to the front of anything written out (e.g. /TypeClassifer/job009)
        IDadded_particle_df = DataFrame with filamentIDs, probably from typecluster_initial_clustering
        all_data = reading in from read_particle_star, template for output
        particle_threshold = minimum number of particles in a filament cluster to output it standalone (float or int)
        overwrite = boolean that controls whether it is okay to overwrite existing star files
        drophashes = boolean that controls whether to drop hashes when writing out star files
        verbose = boolean that controls verbosity of text output
    '''
    
    # Compute some stats
    counted_labels = Counter(col_labels)
    num_filaments = IDadded_particle_df['filamentID'].nunique()
    num_particles = len(IDadded_particle_df.index)
    if verbose:
        print(f"Number of filaments/classes per label: {counted_labels}")
    
    # Map particles to label
    labeled_particle_df = IDadded_particle_df.copy()
    filamentID_label_dict = {i: label for i, label in enumerate(col_clustered_labels)}
    if 'clusterLabel' in labeled_particle_df.columns:
        labeled_particle_df = labeled_particle_df.drop(columns = ["clusterLabel"])
    labeled_particle_df['clusterLabel'] = labeled_particle_df['filamentID'].map(filamentID_label_dict)
    
    # Iterate through clusters and extract particles
    failed_cluster_labels = []
    for cluster_label in sorted(counted_labels.keys()):
        filtered_df = labeled_particle_df.copy()
        filtered_df = filtered_df[filtered_df['clusterLabel']==cluster_label]
        filtered_particle_count = filtered_df.shape[0]
        
        # Write out if passing particle threshold, else, group with other clusters that failed
        if filtered_particle_count >= particle_threshold:
            if drophashes:
                try: # We keep cluster labels but could also drop in the future, it helps keep dimensions consistent
                    columns_to_drop = list(set(["filamentID", "particleID", "filamentHash", "particleHash", "classID"]).intersection(filtered_df.columns))
                    filtered_df = filtered_df.drop(columns = columns_to_drop)
                except:
                    print("Unable to drop hashes")
            if verbose:
                print(f"{filtered_particle_count} particles from filament cluster {cluster_label}")
            write_particle_star(all_data, filtered_df, f"{output_path}_cluster{cluster_label}", overwrite = overwrite)
        else:
            failed_cluster_labels.append(cluster_label)
 
    # Deal with all of the failed clusters, if any
    num_failed_filaments = 0
    num_failed_particles = 0
    if failed_cluster_labels != []:
        failed_clusters_df = labeled_particle_df.copy()
        failed_clusters_df = failed_clusters_df[failed_clusters_df['clusterLabel'].isin(set(failed_cluster_labels))]
        num_failed_filaments = failed_clusters_df['filamentID'].nunique()
        num_failed_particles = len(failed_clusters_df.index)
        if verbose:
            print(f"Clusters {failed_cluster_labels} did not pass particle threshold of {particle_threshold}, so {num_failed_particles} particles will be merged together.")
        write_particle_star(all_data, failed_clusters_df, f"{output_path}_clustersfailed", overwrite = overwrite)
        
    # Output summary statistics and output relevant df
    if verbose:
        print(f"{num_failed_particles} particles out of {num_particles} total particles in merged clusters = {num_failed_particles/num_particles*100}%.")
        print(f"{num_failed_filaments} filaments out of {num_filaments} total filaments in merged clusters = {num_failed_filaments/num_filaments*100}%.")
    return labeled_particle_df

## Suggestions

Works best with manually extracted filaments instead of auto-picked filaments as auto-picked filaments are more likely wrong/short. As such, information about filament structure is then lost. 

It also works better if particles are not randomly split up by a selection tasks, so all particles in any given filament are retained. Can run with a lot of particles as long as number of filaments and number of clusters are relatively low (<50,000 for fast performance on mac) as run time scales as (# of filaments choose 2). 

Can repeat the selection and output functions as many times as needed, just copy and paste the relevant commands.

Suggest using automated selection and output on the first high-level 2D classification job, and iterating with extraction/2D classification until the data looks homogenous. Can use interactive selection/output throughout to visualize or output specific filament types/data that you care about. 

For LMB users: It is typically easiest to mount directly onto /cephfs or /cephfs2 and read/write into there.

# For you to run!

Should be all set to start, can always call `help(function_of_your_choice)` on any function to have a more detailed listing of the parameter options

In [None]:
# Key Parameters:
particle_star = "../ExampleData/SL251-2_job007_run_it045_data.star" # Path to data.star file from Class2D job
model_star = "../ExampleData/SL251-2_job007_run_it045_model.star" # Path to model.star file from Class2D job
class_mrcs = "../ExampleData/SL251-2_job007_run_it045_classes.mrcs" # Path to classes.mrcs file from Class2D job
output = "../ExampleOutput/SL251-2_job007_run_it045" # Desired output path that will be appended to all outputs
savefig = True # Controls whether or not to save the clustered heatmap figure
savestar = True # Controls whether or not to save the reordered class averages 
saveclassaverageplot = True # Controls whether or not to save the class average figures
overwrite = False # Controls whether it is okay to overwrite existing .star files with the same name

clustered, counts_df, IDadded_particle_df, all_data, classcount, classIDdict, filamentcount, filamentIDdict = typecluster_initial_clustering(particle_star, output_path = output, savefig = savefig, vmax = 30)
reordered_classes_df = reorder_class_average_star_by_clustered(model_star, clustered, classIDdict, output_path = output, savestar = savestar, overwrite = overwrite)
plot_class_averages(class_mrcs, classes_df = reordered_classes_df, savefig = saveclassaverageplot, filepath = output)

In [None]:
'''
CELL 1 OF 2 FOR INTERACTIVE SELECTION/OUTPUT, COPY AND PASTE AS NEEDED

Note: indexes from 0 and we don't actually need x-end > x-start or y-end > y-start
The start and end coordinate x-y pairs are connected though
Selections include the boxes where the dotted outlines appear
Can save an image of your selection with the camera-looking "Download plot as a png" button
'''

sliders = typecluster_interactive_select(counts_df, clustered)
# We can keep track of our selected group here: group1, x: 860-976 y: 20-28

In [None]:
'''
CELL 2 OF 2 FOR INTERACTIVE SELECTION/OUTPUT, COPY AND PASTE AS NEEDED
'''

# Key Parameters:
selection_output = output + "_group1" # Path to write out selection to 
saveclassaverageplot = True # Controls whether or not to save the class average figures
overwrite = False # Controls whether it is okay to overwrite existing .star files with the same name
# These functions also require the parameters and data structures from the initial clustering cell

# Comment out any of the lines below that you do not want to output!
print([slider.value for slider in sliders]) # Just to remember what we selected in the previous cell
typecluster_interactive_output(sliders, "Filaments", selection_output, clustered, counts_df, IDadded_particle_df = IDadded_particle_df, all_data = all_data, model_star_file = model_star, classIDdict = classIDdict, overwrite = overwrite)
typecluster_interactive_output(sliders, "Particles", selection_output, clustered, counts_df, IDadded_particle_df = IDadded_particle_df, all_data = all_data, model_star_file = model_star, classIDdict = classIDdict, overwrite = overwrite)
selected_interactive_class_averages_df = typecluster_interactive_output(sliders, "Classes", selection_output, clustered, counts_df, model_star_file = model_star, classIDdict = classIDdict, overwrite = overwrite)
plot_class_averages(class_mrcs, classes_df = selected_interactive_class_averages_df, savefig = saveclassaverageplot, filepath = selection_output)
selected_interactive_class_averages_df 

In [None]:
'''
CELL 1 OF 3 FOR AUTOMATED SELECTION/OUTPUT - DENDROGRAM THRESHOLDING

Don't mess with the naming of the variables for the output of this function!
Choose a threshold to try, then hit run. Can repeat until it looks good
'''

threshold_slider, hbox = typecluster_interactive_dendrogram_thresholding("clustered") 

# Looks like 0.84 is a reasonable threshold for this data for the first round of reclassification

In [None]:
'''
CELL 2 OF 3 FOR AUTOMATED SELECTION/OUTPUT - DENDROGRAM THRESHOLDING
'''

# Params:
output = output
savefig = True # Controls whether or not to save the colored clustered heatmap figure

col_clustered_labels, labeled_clustered = typecluster_compute_dendrogram_threshold_labels(clustered, threshold_slider.value, output, counts_df, savefig = savefig, vmax = 30)

In [None]:
'''
CELL 2.5 OF 3 FOR AUTOMATED SELECTION/OUTPUT - DENDROGRAM THRESHOLDING (OPTIONAL)

Dimensionality reduction plot, run after cell 2 of 3, if ever
'''

# Params:
savefig = True
output = output
mode = "UMAP" # choose from "UMAP", "PCA", and "t-SNE"

col_clustered_colors = convert_labels_to_colors(col_clustered_labels)
dimensionality_reduction_plot(counts_df, mode = mode, savefig = savefig, filepath = output, col_colors = col_clustered_colors)

In [None]:
'''
CELL 3 OF 3 FOR AUTOMATED SELECTION/OUTPUT - DENDROGRAM THRESHOLDING
'''

# Params:
output = output
particle_threshold = 1000 # Minimum number of particles for a cluster, clusters failing threshold will be merged
overwrite = False # Controls whether it is okay to overwrite existing .star files with the same name

labeled_particle_df = typecluster_output_filaments_from_labels(col_clustered_labels, output, IDadded_particle_df, all_data, verbose = True, particle_threshold = particle_threshold, overwrite = overwrite)
labeled_particle_df

In [None]:
'''
CELL 0 OF 2 FOR AUTOMATED SELECTION/OUTPUT - K-MEANS (OPTIONAL)

Can use this cell if an initial clustering has not been performed
'''
particle_df, all_data = read_particle_star(particle_star)
hashed_particle_df = hash_particle_df(particle_df)
counts_df, IDadded_particle_df, classcount, classIDdict, filamentcount, filamentIDdict = compute_particle_counts_df(hashed_particle_df)
counts_df

In [None]:
'''
CELL 1 OF 2 FOR AUTOMATED SELECTION/OUTPUT - K-MEANS
'''

# Params:
class_average_cluster_number = 5
filament_cluster_number = 10 # Number of clusters to use for k-means on filaments
clustering_computed = True # Set to True if clustering previously computed, else do set to false, controls plot
kmeansoutput = output + "kmeans"
savefig = True # Controls whether or not to save the colored clustered heatmap figure

row_kmeans_labels, col_kmeans_labels = typecluster_compute_kmeans_labels(counts_df, col_cluster_number = filament_cluster_number, row_cluster_number = class_average_cluster_number)
if clustering_computed: 
    row_colors = convert_labels_to_colors(row_kmeans_labels)
    col_colors = convert_labels_to_colors(col_kmeans_labels)
    kmeans_clusterplot = clusterplot_counts_df(counts_df, savefig = savefig, filepath = kmeansoutput, vmax = 30, figsize_x = 20, figsize_y = 20, 
                                  label_fontsize=20, cbar_pos = (0.1, 0.81, 0.05, 0.17), panel_label = True, 
                                  panel_label_fontsize = 25, row_colors = row_colors, col_colors = col_colors,
                                  row_linkage = clustered.dendrogram_row.linkage, col_linkage = clustered.dendrogram_col.linkage)

In [None]:
'''
CELL 1.5 OF 2 FOR AUTOMATED SELECTION/OUTPUT - K-MEANS (OPTIONAL)

Optional, but makes sense to plot dimensionality reduction plot if clustering not computed
Run after cell 1 of 2
'''

# Params:
savefig = True
kmeansoutput = output + "kmeans"
mode = "UMAP" # choose from "UMAP", "PCA", and "t-SNE"

row_colors = convert_labels_to_colors(row_kmeans_labels)
col_colors = convert_labels_to_colors(col_kmeans_labels)
dimensionality_reduction_plot(counts_df, mode = mode, savefig = savefig, filepath = kmeansoutput, row_colors = row_colors, col_colors = col_colors)

In [None]:
'''
CELL 2 OF 2 FOR AUTOMATED SELECTION/OUTPUT - K-MEANS
'''

# Params:
kmeansoutput = output + "kmeans" #  Path for output
particle_threshold = 1000 # Minimum number of particles for a cluster, clusters failing threshold will be merged
overwrite = False # Controls whether it is okay to overwrite existing .star files with the same name

labeled_kmeans_particle_df = typecluster_output_filaments_from_labels(col_clustered_labels, kmeansoutput, IDadded_particle_df, all_data, verbose = True, particle_threshold = particle_threshold, overwrite = overwrite)
labeled_kmeans_particle_df