# Federated Protein Downstream Fine-tuning

<div class="alert alert-block alert-info"> <b>NOTE</b> This notebook was tested on a single A1000 GPU and is compatible with BioNeMo Framework v2.3. To leverage additional or higher-performance GPUs, you can modify the configuration files and simulation script to accommodate multiple devices and increase thread utilization respectively.</div>

The example datasets used here are made available by [Therapeutics Data Commons](https://tdcommons.ai/) through PyTDC.

This example shows three different downstream tasks for fine-tuning a BioNeMo ESM-style model on different datasets.
We separate the scripts and job configurations into three folders based on the dataset names:


1. `tap`: therapeutic antibody profiling"
2. `sabdab`: SAbDab: the structural antibody database"
3. `scl`: "subcellular location prediction"

## Setup

Ensure that you have read through the Getting Started section, can run the BioNeMo Framework docker container, and have configured the NGC Command Line Interface (CLI) within the container. It is assumed that this notebook is being executed from within the container.

<div class="alert alert-block alert-info"> <b>NOTE</b> Some of the cells below generate long text output.  We're using <pre>%%capture --no-display --no-stderr cell_output</pre> to suppress this output.  Comment or delete this line in the cells below to restore full output.</div>

### Import and install all required packages

In [1]:
# %%capture --no-display --no-stderr cell_output
! pip install fuzzywuzzy PyTDC --no-dependencies  # install tdc without dependencies to avoid version conflicts in the BioNeMo container
! pip install nvflare~=2.5
#! pip install biopython
#! pip install scikit-learn
#! pip install matplotlib
#! pip install protobuf==3.20
#! pip install huggingface-hub==0.22.0

import os
import warnings

warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")

[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/igraph-0.11.8-py3.12-linux-x86_64.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/looseversion-1.3.0-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist-packages/lightning_utilities-0.11.9-py3.12.egg is deprecated. pip 25.1 will enforce this behaviour change. A possible replacement is to use pip for package installation. Discussion can be found at https://github.com/pypa/pip/issues/12330[0m[33m
[0m[33mDEPRECATION: Loading egg at /usr/local/lib/python3.12/dist

### Task 1: Cross-endpoint multi-task fitting

#### Data: Five computational developability guidelines for therapeutic antibody profiling
See https://tdcommons.ai/single_pred_tasks/develop/#tap
- 241 Antibodies (both chains)

#### Task Description: *Regression*. 
Given the antibody's heavy chain and light chain sequence, predict its developability. The input X is a list of two sequences where the first is the heavy chain and the second light chain.

Includes five metrics measuring developability of an antibody: 
 - Complementarity-determining regions (CDR) length - Trivial (excluded)
 - patches of surface hydrophobicity (PSH)
 - patches of positive charge (PPC)
 - patches of negative charge (PNC)
 - structural Fv charge symmetry parameter (SFvCSP)

In the data preparation script, one can choose between uniform sampling of the data among clients and
heterogeneous data splits using a Dirichlet sampling strategy. 
Here, different values of alpha control the level of heterogeneity. Below, we show a Dirichlet sampling of `alpha=1`.

In [None]:
! cd /bionemo_nvflare_examples/downstream/tap && python prepare_tap_data.py

|                                Uniform sampling                                 |                                    Dirichlet sampling                                     |
|:-------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------:|
| <img src="./tap/figs/tap_uniform.svg" alt="Uniform data sampling" width="150"/> | <img src="./tap/figs/tap_alpha1.0.svg" alt="Dirichlet sampling (alpha=1.0)" width="150"/> |


**Run training (central, local, & FL)**

You can change the FL job that's going to be simulated inside the `run_sim_tap.py` script.

In [None]:
! cd /bionemo_nvflare_examples/downstream/tap && python run_sim_tap.py

### Task 2: Cross-compound task fitting

#### Data: Predicting Antibody Developability from Sequence using Machine Learning
See https://tdcommons.ai/single_pred_tasks/develop/#sabdab-chen-et-al
- 2,409 Antibodies (both chains)

#### Task Description: *Binary classification*. 
Given the antibody's heavy chain and light chain sequence, predict its developability. The input X is a list of two sequences where the first is the heavy chain and the second light chain.

In [6]:
# you may need to fix these paths to your own scripts
! cd /bionemo_nvflare_examples/downstream/sabdab && python prepare_sabdab_data.py

Found local copy...
Loading...
Done!
Sampling with alpha=1.0
Save 80 training proteins for site-1 (frac=0.041)
Save 365 training proteins for site-2 (frac=0.190)
Save 216 training proteins for site-3 (frac=0.112)
Save 578 training proteins for site-4 (frac=0.300)
Save 568 training proteins for site-5 (frac=0.295)
Save 119 training proteins for site-6 (frac=0.062)
Saved 1927 training and 482 testing proteins.
  TRAIN Pos/Neg ratio: neg=366, pos=1561: 4.265
  TRAIN Trivial accuracy: 0.810
  TEST Pos/Neg ratio: neg=116, pos=366: 3.155
  TEST Trivial accuracy: 0.759
[[       nan 0.04657534 0.02314815 0.0449827  0.04929577 0.03361345]
 [       nan        nan 0.18055556 0.17128028 0.19542254 0.18487395]
 [       nan        nan        nan 0.11591696 0.10211268 0.08403361]
 [       nan        nan        nan        nan 0.28521127 0.32773109]
 [       nan        nan        nan        nan        nan 0.28571429]
 [       nan        nan        nan        nan        nan        nan]]
Avg. overlap: 14

Again, we are using the Dirichlet sampling strategy to generate heterogeneous data distributions among clients.
Lower values of `alpha` generate higher levels of heterogeneity.

|                                            Alpha 10.0                                             |                                            Alpha 1.0                                            |
|:-------------------------------------------------------------------------------------------------:|:-----------------------------------------------------------------------------------------------:|
| <img src="./sabdab/figs/sabdab_alpha10.0.svg" alt="Dirichlet sampling (alpha=10.0)" width="150"/> | <img src="./sabdab/figs/sabdab_alpha1.0.svg" alt="Dirichlet sampling (alpha=1.0)" width="150"/> |


**Run training (central, local, & FL)**

You can change the FL job that's going to be simulated by changing the arguments of `run_sim_sabdab.py` script. The ESM2 finetuning arguments such as learning rate and others can be modified inside the script itself.

First check its arguments.

In [9]:
!cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --help

Traceback (most recent call last):
  File "/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py", line 15, in <module>
    from nvflare.job_config.script_runner import BaseScriptRunner
ModuleNotFoundError: No module named 'nvflare'


**1. Central training**

To simulate central training, we use one client, running one round of training for several steps. Note that if the `--exp_name` argument contains `"central"`, the combined training dataset is used.

In [10]:
!cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --num_clients=1 --num_rounds=1 --local_steps=300 --exp_name central

Traceback (most recent call last):
  File "/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py", line 15, in <module>
    from nvflare.job_config.script_runner import BaseScriptRunner
ModuleNotFoundError: No module named 'nvflare'


**2. Local training**

To simulate central training, we use six clients, each running one round of training for several steps.

In [11]:
!cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --num_clients=6 --num_rounds=1 --local_steps=300 --exp_name local

Traceback (most recent call last):
  File "/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py", line 15, in <module>
    from nvflare.job_config.script_runner import BaseScriptRunner
ModuleNotFoundError: No module named 'nvflare'


**3. FedAvg training**

To simulate federated training, we use six clients, running several rounds with FedAvg, each with a smaller number of local steps.

In [12]:
!cd /bionemo_nvflare_examples/downstream/sabdab && python run_sim_sabdab.py --num_clients=6 --num_rounds=30 --local_steps=10 --exp_name fedavg

Traceback (most recent call last):
  File "/bionemo_nvflare_examples/downstream/sabdab/run_sim_sabdab.py", line 15, in <module>
    from nvflare.job_config.script_runner import BaseScriptRunner
ModuleNotFoundError: No module named 'nvflare'


#### Results with heterogeneous data sampling (alpha=10.0)
| Setting | Accuracy  |
|:-------:|:---------:|
|  Local  |   0.821   |
|   FedAvg    | **0.833** |

#### Results with heterogeneous data sampling (alpha=1.0)
| Setting | Accuracy  |
|:-------:|:---------:|
|  Local  |   0.813   |
|   FedAvg    | **0.835** |

### Task 3. Subcellular location prediction with ESM2nv 650M
Follow the data download and preparation in [task_fitting.ipynb](../task_fitting/task_fitting.ipynb).

Here, we use a heterogeneous sampling with `alpha=1.0`.

<img src="./scl/figs/scl_alpha1.0.svg" alt="Dirichlet sampling (alpha=10.0)" width="300"/>


In [None]:
# for this to work run the task_fitting notebook first in ../nvflare_with_bionemo/task_fitting/task_fitting.ipynb
! cd /bionemo_nvflare_examples/downstream/scl && python run_sim_scl.py

Note, you can switch between local and FL jobs by modifying the `run_sim_scl.py` script.

#### Results with heterogeneous data sampling (alpha=10.0)
| Setting | Accuracy  |
|:-------:|:---------:|
|  Local  |   0.773   |
|   FedAvg    | **0.776** |


<img src="./scl/figs/scl_results.svg" alt="Dirichlet sampling (alpha=1.0)" width="300"/>