# Bionemo ESM2 fine-tuning 
(https://docs.nvidia.com/bionemo-framework/2.3/user-guide/examples/bionemo-esm2/finetune/)

**Startup**\
`docker login nvcr.io`\
`username = $oauthtoken`\
`pw = api key` \
`docker pull nvcr.io/nvidia/clara/bionemo-framework:2.3`

**Starting a Shell Inside the Container**\
`docker run` \
  `--rm -it` \
  `--gpus all` \
  `--network host` \
  `--shm-size=4g` \
  `-e WANDB_API_KEY` \
  `-e NGC_CLI_API_KEY` \
  `-e NGC_CLI_ORG` \
  `-e NGC_CLI_TEAM` \
  `-e NGC_CLI_FORMAT_TYPE` \
  `-v $LOCAL_DATA_PATH:$DOCKER_DATA_PATH` \
  `-v $LOCAL_MODELS_PATH:$DOCKER_MODELS_PATH` \
  `-v $LOCAL_RESULTS_PATH:$DOCKER_RESULTS_PATH `\
  `nvcr.io/nvidia/clara/bionemo-framework:2.3` \


**Running a Model Training Script Inside the Container**\
`docker run --rm -it --gpus all` \
  `-e NGC_CLI_API_KEY` \
  `-e WANDB_API_KEY` \
  `-v $LOCAL_DATA_PATH:$DOCKER_DATA_PATH` \
  `-v $LOCAL_MODELS_PATH:$DOCKER_MODELS_PATH` \
 ` -v $LOCAL_RESULTS_PATH:$DOCKER_RESULTS_PATH` \
 ` nvcr.io/nvidia/clara/bionemo-framework:2.3` \
`  python $DOCKER_RESULTS_PATH/training.py --option1 --option2 --output=$DOCKER_RESULTS_PATH`
  
**running jupyter lab inside the container** 

`docker run --rm -d --gpus all` \
`  -p $JUPYTER_PORT:$JUPYTER_PORT `\
`  -e NGC_CLI_API_KEY` \
`  -e WANDB_API_KEY` \
`  -v $LOCAL_DATA_PATH:$DOCKER_DATA_PATH` \
  `-v $LOCAL_MODELS_PATH:$DOCKER_MODELS_PATH` \
`  -v $LOCAL_RESULTS_PATH:$DOCKER_RESULTS_PATH` \
  `nvcr.io/nvidia/clara/bionemo-framework:2.3` \
`  jupyter lab` \
    `--allow-root` \
 `   --ip=* `\
`    --port=$JUPYTER_PORT` \
   ` --no-browser` \
   ` --NotebookApp.token=''` \
    `--NotebookApp.allow_origin='*' `\
    `--ContentsManager.allow_hidden=True` \
    `--notebook-dir=$DOCKER_RESULTS_PATH`


# BIONEMO ESM2 Fine Tuning classes 

(https://docs.nvidia.com/bionemo-framework/latest/main/examples/bionemo-esm2/finetune/)

1. **Loss Reduction Class** 
2. **Fine-Tuned Model Head**
3. **Fine-tuned model** 
4. ***Fine-tuning config**
5. **Dataset**

**Sequence-level regression** 

`%%capture --no-display cell_output`

`finetune_esm2` \
    `--restore-from-checkpoint-path {pretrain_checkpoint_path}` \
    `--train-data-path {regression_data_path}` \
    `--valid-data-path {regression_data_path}` \
    `--config-class ESM2FineTuneSeqConfig` \
   ` --dataset-class InMemorySingleValueDataset` \
  `  --task-type "regression"` \
  `  --mlp-ft-dropout 0.25` \
  `  --mlp-hidden-size 256 `\
   ` --mlp-target-size 1 `\
    `--experiment-name "sequence-level-regression"` \
   ` --num-steps 50` \
   ` --num-gpus 1 `\
    `--limit-val-batches 10 `\
 `   --val-check-interval 10` \
   ` --log-every-n-steps 10` \
   ` --encoder-frozen` \
  `  --lr 1e-5` \
    `--lr-multiplier 1e2` \
    `--scale-lr-layer "regression_head"` \
  `  --result-dir {work_dir}`  \
 `   --micro-batch-size 4 `\
 `   --label-column "labels"` \
`    --num-gpus 1` \
   ` --precision "bf16-mixed"`

In [None]:
class RegressorLossReduction(BERTMLMLossWithReduction):
    def forward(
        self, batch: Dict[str, torch.Tensor], forward_out: Dict[str, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, Dict[str, torch.Tensor]]:

        regression_output = forward_out["regression_output"]
        targets = batch["labels"].to(dtype=regression_output.dtype)  # [b, 1]
        num_valid_tokens = torch.tensor(targets.numel(), dtype=torch.int, device=targets.device)
        loss_sum = ((regression_output - targets) ** 2).sum()
        loss_sum_and_ub_size = torch.cat([loss_sum.clone().detach().view(1), num_valid_tokens.view(1)])
        return loss_sum, num_valid_tokens, {"loss_sum_and_ub_size": loss_sum_and_ub_size}

class MegatronMLPHead(MegatronModule):
    def __init__(self, config: TransformerConfig):
        super().__init__(config)
        layer_sizes = [config.hidden_size, config.mlp_hidden_size, config.mlp_target_size]
        self.linear_layers = torch.nn.ModuleList(
            [torch.nn.Linear(i, o) for i, o in zip(layer_sizes[:-1], layer_sizes[1:])]
        )
        self.act = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(p=config.ft_dropout)

    def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
        for layer in self.linear_layers[:-1]:
            hidden_states = self.dropout(self.act(layer(hidden_states)))

        output = self.linear_layers[-1](hidden_states)
        return output

class ESM2FineTuneSeqModel(ESM2Model):
    def __init__(self, config, *args, post_process: bool = True, include_embeddings: bool = False, **kwargs):
        super().__init__(config, *args, post_process=post_process, include_embeddings=True, **kwargs)

        # freeze encoder parameters
        if config.encoder_frozen:
            for _, param in self.named_parameters():
                param.requires_grad = False

        if post_process:
            self.regression_head = MegatronMLPHead(config)

    def forward(self, *args, **kwargs,):
        output = super().forward(*args, **kwargs)
        ...
        output["regression_output"] = self.regression_head(output["embeddings"])
        return output

@dataclass
class ESM2FineTuneSeqConfig(
    ESM2GenericConfig[ESM2FineTuneSeqModel, RegressorLossReduction], iom.IOMixinWithGettersSetters
):
    model_cls: Type[ESM2FineTuneSeqModel] = ESM2FineTuneSeqModel
    # typical case is fine-tune the base biobert that doesn't have this head. If you are instead loading a checkpoint
    # that has this new head and want to keep using these weights, please drop this next line or set to []
    initial_ckpt_skip_keys_with_these_prefixes: List[str] = field(default_factory=lambda: ["regression_head"])
    encoder_frozen: bool = True  # freeze encoder parameters
    # MLP head layer parameters
    mlp_ft_dropout: float = 0.25  
    mlp_hidden_size: int = 256
    mlp_target_size: int = 1

    def get_loss_reduction_class(self) -> Type[BERTMLMLossWithReduction]:
        return RegressorLossReduction

class InMemorySingleValueDataset(InMemoryProteinDataset):
    def __init__(
        self,
        sequences: pd.Series,
        labels: pd.Series,
        task_type: str = "regression",
        tokenizer: tokenizer.BioNeMoESMTokenizer = tokenizer.get_tokenizer(),
        seed: int = np.random.SeedSequence().entropy,
    ):
        super().__init__(sequences, labels, task_type, tokenizer, seed)

    def transform_label(self, label: float | str) -> Tensor:
        return torch.tensor([label], dtype=torch.float)


dataset = InMemorySingleValueDataset.from_csv(data_path)
data_module = ESM2FineTuneDataModule(train_dataset=dataset,
    valid_dataset=dataset,
    micro_batch_size=4,   # size of a batch to be processed in a device
    global_batch_size=8,  # size of batch across all devices. Should be multiple of micro_batch_size
)

import os
import shutil
import warnings

import pandas as pd


warnings.filterwarnings("ignore")
warnings.simplefilter("ignore")
cleanup: bool = True
work_dir = "/workspace/bionemo2/esm2_finetune_tutorial"

if cleanup and os.path.exists(work_dir):
    shutil.rmtree(work_dir)

if not os.path.exists(work_dir):
    os.makedirs(work_dir)
    print(f"Directory '{work_dir}' created.")
else:
    print(f"Directory '{work_dir}' already exists.")

from bionemo.core.data.load import load


pretrain_checkpoint_path = load("esm2/8m:2.0")
print(pretrain_checkpoint_path)
#num_steps * global_batch_size = len(dataset) * desired_num_epochs

artificial_sequence_data = [
    "TLILGWSDKLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
    "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
    "GRFNVWLGGNESKIRQVLKAVKEIGVSPTLFAVYEKN",
    "DELTALGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
    "KLGSLLNQLAIANESLGGGTIAVMAERDKEDMELDIGKMEFDFKGTSVI",
    "LFGAIGNAISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
    "LGGLLHDIGKPVQRAGLYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
    "LYSGDHSTQGARFLRDLAENTGRAEYELLSLF",
    "ISAIHGQSAVEELVDAFVGGARISSAFPYSGDTYYLPKP",
    "SGSKASSDSQDANQCCTSCEDNAPATSYCVECSEPLCETCVEAHQRVKYTKDHTVRSTGPAKT",
]

regression_data = [(seq, len(seq) / 100.0) for seq in artificial_sequence_data]

# Create a DataFrame
df = pd.DataFrame(regression_data, columns=["sequences", "labels"])

# Save the DataFrame to a CSV file
regression_data_path = os.path.join(work_dir, "regression_data.csv")
df.to_csv(regression_data_path, index=False)

# Naity's CAFA5 finetuned esm2

(https://github.com/naity/finetune-esm)

`git clone https://github.com/naity/finetune-esm.git`\
`python finetune-esm/train.py --help`\
`pip install -r requirements.txt`\

`python finetune esm/train.py` \
  `--experiment-name esm2_t6_8M_UR50D_lora `\
 ` --dataset-loc data/cafa5/top100_train_split.parquet` \
 ` --targets-loc data/cafa5/train_bp_top100_targets.npy` \
  `--esm-model esm2_t6_8M_UR50D` \
`  --num-workers 1` \
 ` --num-devices 1` \
`  --training-mode lora `\
 ` --learning-rate 0.0001` \
 ` --num-epochs 5`

`mlflow server --host 127.0.0.1 --port 8080 --backend-store-uri ./finetune_results/mlflow`

# ESM3 academy's tailoring output heads code 

(https://esm3academy.com/customizing-esm3-for-specialized-tasks/)

In [None]:
#Adding Classification Heads
class ClassificationModel(nn.Module): 
    def __init__(self, esm_model, num_classes): 
        super(ClassificationModel, self).__init__() 
        self.esm = esm_model 
        self.fc = nn.Linear(768, num_classes) # Adjust for embedding size 

    def forward(self, tokens): 
        outputs = self.esm(tokens) 
        cls_embedding = outputs["representations"][0][:, 0, :] # CLS token 
        return self.fc(cls_embedding) 
        
    
    def forward(self, tokens): 
        outputs = self.esm(tokens) 
        cls_embedding = outputs["representations"][0][:, 0, :] # CLS token 
        return self.fc(cls_embedding)
    
#Adding Token Classification Heads

class TokenClassificationModel(nn.Module):

    def __init__(self, esm_model, num_classes):
        super(TokenClassificationModel, self).__init__()
        self.esm = esm_model
        self.fc = nn.Linear(768, num_classes)  # Residue-level labels

    def forward(self, tokens): 
        outputs = self.esm(tokens)
        residue_embeddings = outputs["representations"][0]
        return self.fc(residue_embeddings)

#Adding Regression Heads

class RegressionModel(nn.Module):

    def __init__(self, esm_model):
        super(RegressionModel, self).__init__()
        self.esm = esm_model
        self.fc = nn.Linear(768, 1)  # Single regression output

    def forward(self, tokens):
        outputs = self.esm(tokens)
        cls_embedding = outputs["representations"][0][:, 0, :]  # CLS token
        return self.fc(cls_embedding)

#Multi-task model 

class MultiTaskModel(nn.Module):

    def __init__(self, esm_model, num_classes_task1, num_classes_task2):
        super(MultiTaskModel, self).__init__()
        self.esm = esm_model
        self.fc_task1 = nn.Linear(768, num_classes_task1)
        self.fc_task2 = nn.Linear(768, num_classes_task2)

    def forward(self, tokens):
        outputs = self.esm(tokens)
        cls_embedding = outputs["representations"][0][:, 0, :]  # CLS token
        task1_output = self.fc_task1(cls_embedding)
        task2_output = self.fc_task2(cls_embedding)
        return task1_output, task2_output

#function pred*

class FunctionClassifier(nn.Module):

    def __init__(self, esm_model, num_classes):
        super(FunctionClassifier, self).__init__()
        self.esm = esm_model
        self.fc = nn.Linear(768, num_classes)  # Adjust for embedding dimension

    def forward(self, tokens):
        outputs = self.esm(tokens)
        cls_embedding = outputs["representations"][0][:, 0, :]  # CLS token
        return self.fc(cls_embedding)

class StabilityPredictor(nn.Module):

    def __init__(self, esm_model):
        super(StabilityPredictor, self).__init__()
        self.esm = esm_model
        self.fc = nn.Linear(768, 1)  # Single regression output

    def forward(self, tokens):
        outputs = self.esm(tokens)
        cls_embedding = outputs["representations"][0][:, 0, :]  # CLS token
        return self.fc(cls_embedding)

class HybridModel(nn.Module):

    def __init__(self, esm_model, cnn_model, output_dim):
        super(HybridModel, self).__init__()
        self.esm = esm_model
        self.cnn = cnn_model
        self.fc = nn.Linear(esm_model.embedding_dim + cnn_model.output_dim, output_dim)

    def forward(self, sequence_tokens, structural_data):
        sequence_embeddings = self.esm(sequence_tokens)["representations"][0][:, 0, :]
        structural_features = self.cnn(structural_data)
        combined_features = torch.cat((sequence_embeddings, structural_features), dim=1)
        return self.fc(combined_features)

class GraphProteinModel(nn.Module):

    def __init__(self, esm_model, gnn_model, output_dim):
        super(GraphProteinModel, self).__init__()
        self.esm = esm_model
        self.gnn = gnn_model
        self.fc = nn.Linear(gnn_model.output_dim, output_dim)

    def forward(self, sequence_tokens, graph_data):
        sequence_embeddings = self.esm(sequence_tokens)["representations"][0]
        graph_embeddings = self.gnn(graph_data)
        combined_features = torch.cat((sequence_embeddings, graph_embeddings), dim=1)
        return self.fc(combined_features)

class MutationalEffectModel(nn.Module):

    def __init__(self, esm_model):
        super(MutationalEffectModel, self).__init__()
        self.esm = esm_model
        self.fc = nn.Linear(768, 1)  # Predict a single functional score

    def forward(self, tokens_wt, tokens_mutant):
        embeddings_wt = self.esm(tokens_wt)["representations"][0][:, 0, :]
        embeddings_mutant = self.esm(tokens_mutant)["representations"][0][:, 0, :]
        delta_embedding = embeddings_mutant - embeddings_wt
        return self.fc(delta_embedding)

import torch
import torch.nn as nn
from esm import pretrained

# Load pre-trained ESM3

model, alphabet = pretrained.esm3_t30_150M()

batch_converter = alphabet.get_batch_converter()

# Modify the model for classification

class ClassificationModel(nn.Module):

    def __init__(self, esm_model, num_classes):
        super(ClassificationModel, self).__init__()
        self.esm = esm_model
        self.fc = nn.Linear(768, num_classes)  # Adjust for embedding dimension

    def forward(self, tokens):
        outputs = self.esm(tokens)
        cls_embedding = outputs["representations"][0][:, 0, :]  # CLS token
        return self.fc(cls_embedding)

num_classes = 5  # Example: 5 functional categories

classification_model = ClassificationModel(model, num_classes)

#Training loop

optimizer = torch.optim.Adam(classification_model.parameters(), lr=1e-5)

loss_function = nn.CrossEntropyLoss()

for epoch in range(epochs):
    for batch_labels, batch_strs, batch_tokens in dataloader:
        optimizer.zero_grad()
        predictions = classification_model(batch_tokens)
        loss = loss_function(predictions, batch_labels)
        loss.backward()
        optimizer.step()

# novozyme competition attempt to predict Tm fine tuning esm 

(https://www.kaggle.com/code/jinyuansun/eda-and-finetune-esm)

# CAFA5 protein function prediction tournament 
(https://www.kaggle.com/competitions/cafa-5-protein-function-prediction/code)