
# GENBAIT Tutorial

This notebook provides a detailed walkthrough of the GENBAIT Python package for selecting optimal baits in BioID experiments using Genetic Algorithms (GA) and various clustering metrics. We will guide you through loading BioID data, running a genetic algorithm for bait selection, and evaluating the results using different NMF-derived and non NMF-derived metrics. 



## Importing required modules

In [4]:
import genbait as gb
import warnings
warnings.filterwarnings('ignore')



## Setting Input, Output Directories, and Parameters for NMF and GA

In this section, we define the input and output directories for file handling, as well as key parameters for Non-negative Matrix Factorization (NMF) and the Genetic Algorithm (GA) used in feature (bait) selection.

- **input_directory**: This variable defines the path where the input files, such as the SAINT file and other data, are stored.
- **output_directory**: This variable specifies the location where the output files will be saved after processing, including the results of the genetic algorithm and NMF computations.

- **primary_baits**: A list of primary baits that can be used to filter the dataset. If set to `None`, all baits from the dataset will be used. You can uncomment this line to provide specific bait proteins.

- **n_components**: Specifies the number of NMF components (latent features). This parameter controls how many underlying factors will be extracted from the BioID dataset. Typically, these components represent key biological factors.

- **n_baits_to_select**: Defines the number of baits (features) to be selected by the genetic algorithm. The goal of the GA is to find an optimal subset of baits that retain the biological significance of the dataset.

- **subset_range**: This is the range of the number of baits that can be selected by the genetic algorithm. It is calculated based on the `n_baits_to_select` value to ensure that the number of selected baits remains within a valid range (neither too many nor too few).
  - Function: `calculate_subset_range`
  - Input: `n_baits_to_select` (the desired number of baits to select).
  - Output: A tuple representing the valid range of baits that the GA can select.


In [5]:
input_directory = 'data/input_files/'  
output_directory = 'data/output_files/'  

# primary_baits = ['Gene1', 'Gene2', 'Gene3'] 

n_components = 20  
n_baits_to_select = 50  

subset_range = gb.calculate_subset_range(n_baits_to_select)  

## Loading and Normalizing the BioID Data

In this step, we load and preprocess the BioID data using the `load_bioid_data` function. This function reads the SAINT file and performs preprocessing steps such as calculating average control intensities, correcting for control values, and normalizing the data.

- **Function**: `load_bioid_data`
- **Parameters**:
  - `input_file_path`: The path to the SAINT file (located in the `input_directory`).
  - `output_file_directory`: The directory where the processed and normalized data will be saved. The normalized DataFrame will also be stored as a CSV file for further use.
  
- **Process**:
  - **Step 1**: The SAINT file is read and baits (proteins) are processed.
  - **Step 2**: Controls are averaged, and prey spectral counts are adjusted by subtracting these control averages.
  - **Step 3**: The data is filtered by a False Discovery Rate (FDR) threshold, pivoted to a matrix format (baits as rows, preys as columns), and normalized using MinMax scaling.

- **Output**: 
  - Returns the normalized DataFrame (`df_norm`), which can be used in subsequent analyses such as feature selection and clustering.



In [6]:
df_norm = gb.load_bioid_data(
    input_file_path=f'{input_directory}saint-latest.txt',
    output_file_directory=output_directory
)

In [7]:
df_norm

PreyGene,AAAS,AAK1,AAR2,AARS2,AASDH,AASS,AATF,ABCA3,ABCB1,ABCB10,...,ZNF830,ZNF850,ZNF91,ZRANB2,ZSCAN18,ZSCAN21,ZW10,ZWINT,ZYX,ZZZ3
Bait,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AARS2,0.0,0.000000,0.0,0.0,0.0,0.347015,0.000000,0.0,0.0,0.266667,...,0.0,0.0,0.0,0.000000,0.0,0.000000,0.0,0.0,0.000000,0.0
ACBD5,0.0,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,...,0.0,0.0,0.0,0.000000,0.0,0.000000,0.0,0.0,0.000000,0.0
ACTB,0.0,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,...,0.0,0.0,0.0,0.000000,0.0,0.000000,0.0,0.0,0.384215,0.0
ACTR1A,0.0,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,...,0.0,0.0,0.0,0.000000,0.0,0.000000,0.0,0.0,0.000000,0.0
ACTR2,0.0,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,...,0.0,0.0,0.0,0.000000,0.0,0.000000,0.0,0.0,0.000000,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
VCL,0.0,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,...,0.0,0.0,0.0,0.000000,0.0,0.000000,0.0,0.0,0.278706,0.0
VIM,0.0,0.123211,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,...,0.0,0.0,0.0,0.000000,0.0,0.000000,0.0,0.0,0.000000,0.0
ZFPL1,0.0,0.000000,0.0,0.0,0.0,0.000000,0.000000,0.0,0.0,0.000000,...,0.0,0.0,0.0,0.000000,0.0,0.000000,0.0,0.0,0.000000,0.0
ZNF330_target,0.0,0.000000,0.0,0.0,0.0,0.000000,0.381833,0.0,0.0,0.000000,...,0.0,0.0,0.0,0.210237,0.0,0.579048,0.0,0.0,0.000000,0.0


## Running the Genetic Algorithm

In this step, we use the `run_ga` function to run a genetic algorithm (GA) for selecting an optimal subset of features (baits) from the BioID dataset. The GA evolves a population of possible solutions over multiple generations to find the best subset of baits.

- **Function**: `run_ga`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `n_components`: The number of NMF components used in the analysis.
  - `subset_range`: The range of the number of selected baits (features).
  - `population_size`: The number of individuals in the GA population. In this case, it's set to 500.
  - `n_generations`: The number of generations for which the GA will run. In this case, it's set to 10.
  - `cxpb`: Probability of crossover. Default is 0.3.
  - `mutpb`: Probability of mutation. Default is 0.1.

- **Returns**:
  - `pop`: The final population of individuals after the GA has evolved over the specified generations.
  - `logbook`: A logbook that contains statistical information about the evolution process, including metrics like the average fitness, minimum fitness, and maximum fitness for each generation.
  - `hof`: The Hall of Fame, which contains the best individuals (i.e., subsets of baits) discovered during the GA run.




In [None]:
pop, logbook, hof = gb.run_ga(
    df_norm=df_norm,
    n_components=n_components,
    subset_range=subset_range,
    population_size=500,
    n_generations=100
)

## Saving GA results
Here, we save the results of the GA using the `save_ga_results` function, which saves the population, logbook, and Hall of Fame to specified file paths.

- **Parameters**:
  - `population`: The final population after the GA run.
  - `logbook`: The logbook tracking the evolution of the population.
  - `hof`: Hall of Fame containing the best individuals found by the GA.
  - `pop_file_path`, `logbook_file_path`, `hof_file_path`: Paths to save the files.





In [None]:
gb.save_ga_results(
    population=pop,
    logbook=logbook,
    hof=hof,
    pop_file_path=f"{output_directory}/popfile.pkl",
    logbook_file_path=f"{output_directory}/logbookfile.pkl",
    hof_file_path=f"{output_directory}/hoffile.pkl"
)

## Loading GA results
Here we can load previously saved population and logbook from their respective file paths

In [8]:
population, logbook, hof = gb.load_ga_results(
    pop_file_path=f"{output_directory}/popfile.pkl",
    logbook_file_path=f"{output_directory}/logbookfile.pkl",
    hof_file_path=f"{output_directory}/hoffile.pkl"
)

## Calculating GA result

In this step, we use the `calculate_ga_results` function to retrieve the best subset of baits selected by the genetic algorithm and calculate the NMF correlation between the original and selected subset. This function returns the selected baits and their corresponding NMF correlation values.

- **Function**: `calculate_ga_results`
- **Parameters**:
  - `hof`: The Hall of Fame from the genetic algorithm, which contains the best individuals (subsets of baits).
  - `df_norm`: The normalized BioID data (DataFrame).
  - `n_components`: The number of NMF components used in the analysis.
  
- **Returns**:
  - `selected_baits`: The list of baits selected by the genetic algorithm as the optimal subset.
  - `mean_nmf_correlation`: The mean correlation between the NMF basis matrix of the original dataset and the subset dataset.
  - `min_nmf_correlation`: The min correlation between the NMF basis matrix of the original dataset and the subset dataset.
  - `all_nmf_correaltions`: A list of correlation values for each NMF component.




In [9]:
selected_baits, mean_nmf_correlation, min_nmf_correlation, all_nmf_correaltions = gb.calculate_ga_results(hof, df_norm, n_components)

In [10]:
selected_baits

['ACBD5',
 'ACTR3',
 'AIFM1',
 'AKAP1_target',
 'ANAPC2',
 'ARF6',
 'CBX3',
 'CENPA',
 'CKAP4',
 'COIL',
 'CS',
 'CYP2C1_sigseq',
 'DCTN1',
 'DDX23',
 'DHX8',
 'EMD',
 'FLOT1',
 'HIST1H2BG',
 'HSD17B11',
 'KRAS_target',
 'KRT19',
 'LAMP1',
 'LAMP2',
 'LAMTOR1_target_peptide',
 'LRRC59_Nterm',
 'LYN',
 'MAPRE3',
 'NCBP3',
 'NDC80',
 'NOP56',
 'OCLN',
 'PCM1',
 'PDHA1',
 'PDIA4',
 'PXMP2',
 'RAB11A',
 'RAB35',
 'RAB4A',
 'RPL31',
 'RPN1',
 'RPS20',
 'RPS6',
 'SEC61B_Cterm',
 'SFXN1',
 'STX7',
 'SYNE3_Cterm',
 'SYNE3_Nterm',
 'TERF2IP',
 'TRIM36_Nterm',
 'ZFPL1']

In [11]:
mean_nmf_correlation

0.9628075922637978

In [12]:
min_nmf_correlation

0.9025194439128281

In [13]:
all_nmf_correaltions

array([0.90251944, 0.95496961, 0.97857123, 0.95774577, 0.97135343,
       0.95147357, 0.96787425, 0.9518188 , 0.96214486, 0.93408249,
       0.9544204 , 0.99657838, 0.99440013, 0.98931542, 0.95406028,
       0.96171887, 0.98038172, 0.95886556, 0.93932722, 0.99453041])

## Calculating NMF Cosine similarity

In this step, we use the `calculate_nmf_cosine_similarity` function to compute the cosine similarity between the original and selected subset baits. This similarity gives insight into how well the selected baits represent the original data.

- **Function**: `calculate_nmf_cosine_similarity`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The subset of baits selected by the genetic algorithm.
  - `n_components`: The number of NMF components used in the analysis.
  
- **Returns**:
  - `mean_nmf_cos_similarity_score`: The mean cosine similarity score between the NMF components of the original dataset and the selected subset.
  - `min_nmf_cos_similarity_score`: The min cosine similarity score between the NMF components of the original dataset and the selected subset.
  - `all_nmf_cos_similarity_scores`: A list of cosine similarity scores for each NMF component, showing the similarity for individual components.



In [14]:
mean_nmf_cos_similarity_score, min_nmf_cos_similarity_score, all_nmf_cos_similarity_scores = gb.calculate_nmf_cosine_similarity(df_norm, selected_baits, n_components)

In [15]:
mean_nmf_cos_similarity_score

0.9647851800497926

In [16]:
min_nmf_cos_similarity_score

0.9114302775417854

In [17]:
all_nmf_cos_similarity_scores

array([0.91143028, 0.95803041, 0.98029793, 0.96046514, 0.97242552,
       0.95618445, 0.96943639, 0.9526346 , 0.962459  , 0.93820777,
       0.95752685, 0.99673303, 0.99489931, 0.98997593, 0.95592155,
       0.96211448, 0.98075699, 0.96157342, 0.93981945, 0.99481111])

## Calculating NMF KL divergence


In this step, we use the `calculate_nmf_kl_divergence` function to compute the Kullback-Leibler (KL) divergence between the NMF components of the original data and the selected baits. KL divergence measures how one probability distribution diverges from a second, expected probability distribution. This is useful for understanding the differences between the original data and the subset selected by the genetic algorithm.

- **Function**: `calculate_nmf_kl_divergence`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The list of baits (features) selected by the GA.
  - `n_components`: The number of NMF components used in the analysis.

- **Returns**:
  - `mean_nmf_kl_divergence_score`: The average KL divergence score across all NMF components.
  - `max_nmf_kl_divergence_score`: The max KL divergence score across all NMF components.
  - `all_nmf_kl_divergence_scores`: The individual KL divergence scores for each component.


In [18]:
mean_nmf_kl_divergence_score, max_nmf_kl_divergence_score, all_nmf_kl_divergence_scores = gb.calculate_nmf_kl_divergence(df_norm, selected_baits, n_components)

In [19]:
mean_nmf_kl_divergence_score

1.9866790073854574

In [20]:
max_nmf_kl_divergence_score

4.277723462622161

In [21]:
all_nmf_kl_divergence_scores

array([2.6238254 , 1.67745437, 0.66506218, 0.90228065, 1.12757156,
       1.44977502, 1.93827632, 3.78036364, 3.25671706, 2.52029273,
       2.20218745, 0.52032607, 0.30324526, 0.67319525, 3.17328875,
       3.82603123, 2.05850216, 1.90574664, 4.27772346, 0.85171493])

## Calculating NMF GO Jaccard Index

In this step, we use the `calculate_nmf_go_jaccard` function to compute the Jaccard Index between the Gene Ontology (GO) Cellular Component (CC) terms associated with the original data and the selected baits. The Jaccard Index is used to measure the similarity between two sets, specifically the overlap between GO terms of the original and selected baits.

- **Function**: `calculate_nmf_go_jaccard`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The list of baits (features) selected by the genetic algorithm.
  - `n_components`: The number of NMF components used in the analysis.

- **Returns**:
  - `mean_nmf_go_jaccard_index`: The mean Jaccard index score across all NMF components.
  - `min_nmf_go_jaccard_index`: The min Jaccard index score across all NMF components.
  - `all_nmf_go_jaccard_indices`: The individual Jaccard index scores for each component.



In [22]:
mean_nmf_go_jaccard_index, min_nmf_go_jaccard_index, all_nmf_go_jaccard_indices = gb.calculate_nmf_go_jaccard(df_norm, selected_baits, n_components=n_components)


In [23]:
mean_nmf_go_jaccard_index

0.7278483823159123

In [24]:
min_nmf_go_jaccard_index

0.4074074074074074

In [25]:
all_nmf_go_jaccard_indices

{6: 0.8297872340425532,
 4: 0.9558823529411765,
 2: 0.8529411764705882,
 9: 0.7209302325581395,
 5: 0.7586206896551724,
 0: 0.436046511627907,
 8: 0.7205882352941176,
 3: 0.9642857142857143,
 7: 0.4074074074074074,
 1: 0.6111111111111112,
 12: 0.7727272727272727,
 10: 0.5535714285714286,
 11: 0.6461538461538462,
 15: 0.7023809523809523,
 13: 0.8153846153846154,
 17: 0.6439393939393939,
 16: 0.8,
 19: 0.75,
 14: 0.7540983606557377,
 18: 0.8611111111111112}

## Calculating NMF ARI (Adjusted Rand Index)

In this step, we use the `calculate_nmf_ari` function to compute the Adjusted Rand Index (ARI) between the clustering of the original data and the selected baits. ARI measures the similarity of the clustering assignments between two different sets, where a higher ARI indicates better alignment between the clusters.

- **Function**: `calculate_nmf_ari`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The list of baits (features) selected by the genetic algorithm.
  - `n_components`: The number of NMF components used in the analysis.

- **Returns**:
  - `nmf_ari_score`: The ARI score indicating how well the selected baits' clustering matches the clustering of the original data.



In [26]:
nmf_ari_score = gb.calculate_nmf_ari(df_norm, selected_baits, n_components)

In [27]:
nmf_ari_score

0.707807499705288

## Calculating min NMF purity score

In this step, we use the `calculate_min_nmf_purity` function to evaluate how well the NMF components derived from a subset of selected baits preserve the original structure of the full dataset. This is done by measuring, for each NMF component, the fraction of preys that are consistently assigned to the same component in both the full and subset data. The minimum of these preservation values across components is returned as a conservative indicator of component stability.

- **Function**: `calculate_min_nmf_purity`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The list of baits (features) selected by the genetic algorithm.
  - `n_components`: The number of NMF components used in the analysis.

- **Returns**:
  - `min_purity`: The ARI score indicating how well the selected baits' clustering matches the clustering of the original data.

In [28]:
min_nmf_purity_score = gb.calculate_min_nmf_purity(df_norm, selected_baits, n_components)

In [29]:
min_nmf_purity_score

0.32235294117647056

## Calculating Remaining Preys Percentage and Count

In this step, we use the `calculate_remaining_preys_percentage` function to compute the percentage and count of preys (proteins) that remain in the selected subset of data after applying the genetic algorithm. This helps assess the coverage of preys in the selected baits.

- **Function**: `calculate_remaining_preys_percentage`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The list of baits (features) selected by the genetic algorithm.
  - `n_components`: The number of NMF components used in the analysis.

- **Returns**:
  - `remaining_preys_percentage`: The percentage of preys that remain in the selected baits.
  - `remaining_preys_count`: The total number of preys remaining in the selected subset.



In [30]:
remaining_preys_percentage, remaining_preys_count = gb.calculate_remaining_preys_percentage(df_norm, selected_baits, n_components)

In [31]:
remaining_preys_percentage

0.7990506329113924

In [32]:
remaining_preys_count

3535

## Calculating GO Retrieval Percentage

In this step, we use the `calculate_go_retrieval_percentage` function to calculate the Gene Ontology (GO) retrieval percentage for the selected baits. GO retrieval assesses the functional enrichment of the selected baits by evaluating how many of the preys are associated with known GO terms (specifically Cellular Component, or GO:CC) based on the provided GAF (Gene Annotation Format) file.

- **Function**: `calculate_go_retrieval_percentage`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The list of baits (features) selected by the genetic algorithm.
  - `gaf_path`: The file path to the GAF file, which contains GO annotations for various genes/proteins.

- **Returns**:
  - `go_retrieval_percentage`: The percentage of preys in the selected baits that are associated with known GO:CC terms.



In [33]:
go_retrieval_percentage = gb.calculate_go_retrieval_percentage(df_norm, selected_baits, gaf_path=f'{input_directory}goa_human.gaf')

In [34]:
go_retrieval_percentage

94.70046082949308

## Calculating Leiden ARI (Adjusted Rand Index)

In this step, we use the `calculate_leiden_ari` function to calculate the Adjusted Rand Index (ARI) for clustering the selected baits using the Leiden algorithm. The Leiden algorithm is a community detection algorithm commonly used for graph-based clustering. By calculating the ARI, we can compare how well the clustering from the selected baits corresponds to the clustering of the entire dataset across different resolutions. ARI measures the similarity between two clusterings, comparing the clustering of the full dataset with the clustering of the selected subset of baits.

- **Function**: `calculate_leiden_ari`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The list of baits (features) selected by the genetic algorithm.
  - `resolutions`: List of resolutions for the Leiden clustering algorithm. Higher resolutions lead to more clusters.
  - `seed`: (Optional) Random seed for reproducibility (default is 4).

- **Returns**:
  - `leiden_results`: A dictionary where the keys are resolution values, and the values are the ARI scores between the clustering of the original data and the selected baits.



In [35]:
leiden_results = gb.calculate_leiden_ari(df_norm, selected_baits, resolutions=[0.5, 1.0, 1.5])

In [36]:
for resolution, ari in leiden_results.items():
    print(f"Resolution: {resolution}, ARI: {ari}")

Resolution: 0.5, ARI: 0.6153919222091556
Resolution: 1.0, ARI: 0.6239641760362682
Resolution: 1.5, ARI: 0.6047543804319405


## Calculating GMM ARI (Adjusted Rand Index)

In this step, we use the `calculate_gmm_ari` function to calculate the Adjusted Rand Index (ARI) for Gaussian Mixture Model (GMM) clustering of the selected baits. The Gaussian Mixture Model (GMM) is used to cluster the data into a specified number of clusters. The ARI score measures the similarity between the clustering of the full dataset and the selected baits subset.

- **Function**: `calculate_gmm_ari`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The list of baits (features) selected by the genetic algorithm.
  - `cluster_numbers`: List of cluster numbers to fit the Gaussian Mixture Model (GMM). Each number represents how many clusters are formed.
  - `seed`: (Optional) Random seed for reproducibility (default is 4).

- **Returns**:
  - `gmm_results`: A dictionary where the keys are cluster numbers, and the values are the ARI scores comparing the clustering of the full dataset and the selected baits subset for each cluster number.


In [37]:
gmm_results = gb.calculate_gmm_ari(df_norm, selected_baits, cluster_numbers=[15, 20, 25, 30])

In [38]:
for cluster_number, ari in gmm_results.items():
    print(f"Cluster Number: {cluster_number}, ARI: {ari}")

Cluster Number: 15, ARI: 0.3513392100332473
Cluster Number: 20, ARI: 0.4330585428372074
Cluster Number: 25, ARI: 0.4792579710872932
Cluster Number: 30, ARI: 0.3882589070019658


## Calculating GMM Mean Correlation

In this step, we use the `calculate_gmm_mean_correlation` function to compute the mean correlation values between the Gaussian Mixture Model (GMM) clustering of the full dataset and the selected baits subset. Gaussian Mixture Model (GMM) soft clustering is performed on both the full dataset and the selected baits subset for each specified cluster number. The function computes the mean correlation between the probability distributions of the original and subset clusters, indicating how well the selected baits maintain the clustering structure of the full dataset.

- **Function**: `calculate_gmm_mean_correlation`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `selected_baits`: The list of baits (features) selected by the genetic algorithm.
  - `cluster_numbers`: List of cluster numbers to fit the Gaussian Mixture Model (GMM). Each number represents how many clusters are formed.
  - `seed`: (Optional) Random seed for reproducibility (default is 4).

- **Returns**:
  - `gmm_corr_results`: A dictionary where the keys are cluster numbers and the values are the mean correlation values between the soft clustering of the full dataset and the selected baits subset for each cluster number.


  


In [39]:
gmm_corr_results = gb.calculate_gmm_mean_correlation(df_norm, selected_baits, cluster_numbers=[15, 20, 25, 30])

In [40]:
for cluster_number, mean_corr in gmm_corr_results.items():
    print(f"Cluster Number: {cluster_number}, Mean Correlation: {mean_corr}")

Cluster Number: 15, Mean Correlation: 0.3461583652175666
Cluster Number: 20, Mean Correlation: 0.3670970901887084
Cluster Number: 25, Mean Correlation: 0.39059547502128106
Cluster Number: 30, Mean Correlation: 0.339023289173151


## Other methods example

In this step, we use the `run_lasso` and `run_nn` functions for bait selection using Lasso and Neural network, repectively.
- **Functions**: `run_lasso`, `run_nn`
- **Parameters**:
  - `df_norm`: The normalized BioID data (DataFrame).
  - `k`: Number of desired baits.
  - `seed`: (Optional) Random seed for reproducibility (default is 4).
  
  
- **Returns**:
  - `baits`: List of selected bait names (row indices from df_norm).


In [41]:
lasso_baits = gb.run_lasso(df_norm, 50)

In [44]:
lasso_baits

['SYNE3_Nterm',
 'COIL',
 'TERF2IP',
 'AKAP1_target',
 'RAB11A',
 'PCM1',
 'RPS20',
 'CYP2C1_sigseq',
 'SFXN1',
 'DDX23',
 'HIST1H2BG',
 'DCTN1',
 'DHX8',
 'FBL',
 'SV40_NLS',
 'MAPRE3',
 'NDC80',
 'SEC61B_Cterm',
 'CBX3',
 'PDHA1',
 'LAMTOR1_target_peptide',
 'RPL31',
 'B3GAT1',
 'LAMP2',
 'ACTB',
 'CD3EAP',
 'ASF1A',
 'AIFM1',
 'COX4I1',
 'METTL7A',
 'HNRNPA1',
 'LRRC59_Cterm',
 'COX8A_target',
 'LAMP1',
 'RPS24',
 'RPS6',
 'ANAPC2',
 'NIFK',
 'LAMTOR1',
 'KRAS_target',
 'PEX3',
 'CENPA',
 'GOLGA2',
 'PARP1',
 'EBAG9',
 'SEC61B_Nterm',
 'AKAP1',
 'RPN1',
 'MAPRE1',
 'KRAS_Q61H']

In [45]:
nn_baits = gb.run_nn(df_norm, 50)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name | Type   | Params | Mode 
----------------------------------------
0 | fc1  | Linear | 12.4 K | train
1 | relu | ReLU   | 0      | train
2 | fc2  | Linear | 1.3 K  | train
----------------------------------------
13.7 K    Trainable params
0         Non-trainable params
13.7 K    Total params
0.055     Total estimated model params size (MB)
3         Modules in train mode
0         Modules in eval mode


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
PermutationExplainer explainer: 3540it [00:42, 65.93it/s]                          


In [46]:
nn_baits

['CYP2C1_sigseq',
 'COIL',
 'TERF2IP',
 'PDHA1',
 'SYNE3_Nterm',
 'RAB11A',
 'COX8A_target',
 'RHOB',
 'LAMTOR1_target_peptide',
 'KRAS_Q61H',
 'KRAS_target',
 'AARS2',
 'HIST1H2BG',
 'CS',
 'AKAP1_target',
 'DDX23',
 'SV40_NLS',
 'CD3EAP',
 'LAMP2',
 'PARP1',
 'CBX3',
 'LAMTOR1',
 'STX7',
 'RAB9A',
 'LRRC59_Cterm',
 'SEC61B_Cterm',
 'METTL7A',
 'ZNF330_target',
 'RPL31',
 'GJD3',
 'TRAP1',
 'RPN1',
 'CEP135',
 'PCM1',
 'RAB35',
 'LYN',
 'RAB5A',
 'ASF1A',
 'ATP2A1',
 'ARF6',
 'LAMP3',
 'EBAG9',
 'C11orf52',
 'ACTB',
 'SEC61B_Nterm',
 'RAB4A',
 'B3GAT1',
 'LAMP1',
 'ELOVL5',
 'CENPA']