# OUT OF DATE - NEEDS TO BE UPDATED

In [1]:
#@title **1. Setup**

#@markdown ### Identification
huggingface_username = "Synthyra"  #@param {type:"string"}
#@markdown ---
huggingface_token = ""            #@param {type:"string"}
#@markdown ---
wandb_api_key = ""                #@param {type:"string"}
#@markdown ---
synthyra_api_key = ""             #@param {type:"string"}
#@markdown ---
github_token = ""                 #@param {type:"string"}
#@markdown ---


github_clone_path = f"https://{github_token}@github.com/Synthyra/ProbePackageHolder.git"
# !git clone {github_clone_path}
# %cd ProbePackageHolder
# !pip install -r requirements.txt --quiet


In [None]:
#@title **2. Session/Directory Settings**

import torch
import argparse
from types import SimpleNamespace
from base_models.get_base_models import BaseModelArguments, standard_benchmark
from data.hf_data import HFDataArguments
from data.supported_datasets import supported_datasets
from embedder import EmbeddingArguments
from probes.get_probe import ProbeArguments
from probes.trainers import TrainerArguments
from main import MainProcess


main = MainProcess(argparse.Namespace(), GUI=True)

#@markdown **Paths**

#@markdown These will be created automatically if they don't exist

#@markdown **Log Directory**
log_dir = "logs"                            #@param {type:"string"}
#@markdown ---

#@markdown **Results Directory**
results_dir = "results"                    #@param {type:"string"}
#@markdown ---

#@markdown **Model Save Directory**
model_save_dir = "weights"                 #@param {type:"string"}
#@markdown ---

#@markdown **Embedding Save Directory**
embedding_save_dir = "embeddings"          #@param {type:"string"}
#@markdown ---

#@markdown **Download Directory**
#@markdown - Where embeddings are downloaded on Hugging Face
download_dir = "Synthyra/mean_pooled_embeddings"  #@param {type:"string"}
#@markdown ---


main.full_args.hf_token = huggingface_token
main.full_args.wandb_api_key = wandb_api_key
main.full_args.synthyra_api_key = synthyra_api_key
main.full_args.log_dir = log_dir
main.full_args.results_dir = results_dir
main.full_args.model_save_dir = model_save_dir
main.full_args.embedding_save_dir = embedding_save_dir
main.full_args.download_dir = download_dir
main.full_args.replay_path = None
main.logger_args = SimpleNamespace(**main.full_args.__dict__)
main.start_log_gui()

#@markdown Press play to setup the session:

In [None]:
#@title **2. Data Settings**

#@markdown **Max Sequence Length**
max_length = 2048          #@param {type:"integer"}
#@markdown ---

#@markdown **Trim Sequences**
#@markdown - If true, sequences are removed if they are longer than the maximum length
#@markdown - If false, sequences are truncated to the maximum length
trim = False               #@param {type:"boolean"}
#@markdown ---

#@markdown **Dataset Names**
#@markdown Valid options (comma-separated):

#@markdown *Multi-label classification:*

#@markdown - EC, GO-CC, GO-BP, GO-MF

#@markdown *Single-label classification:*

#@markdown - MB, DeepLoc-2, DeepLoc-10, solubility, localization, material-production, cloning-clf, number-of-folds

#@markdown *Regression:*

#@markdown - enzyme-kcat,temperature-stability, optimal-temperature, optimal-ph, fitness-prediction, stability-prediction, fluorescence-prediction

#@markdown *PPI:*

#@markdown - human-ppi, peptide-HLA-MHC-affinity

#@markdown *Tokenwise:*

#@markdown - SecondaryStructure-3, SecondaryStructure-8
dataset_names = "EC, DeepLoc-2, DeepLoc-10, enzyme-kcat"  #@param {type:"string"}
#@markdown ---

data_paths = [supported_datasets[name.strip()] for name in dataset_names.split(",") if name.strip()]

main.full_args.data_paths = data_paths
main.full_args.max_length = max_length
main.full_args.trim = trim
main.data_args = HFDataArguments(**main.full_args.__dict__)
args_dict = {k: v for k, v in main.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()}
main.logger_args = SimpleNamespace(**args_dict)
main.get_datasets()

#@markdown Press play to load datasets:

In [5]:
#@title **3. Model Selection**

#@markdown Comma-separated model names.
#@markdown If empty, defaults to `standard_benchmark`.
#@markdown Valid options (comma-separated):
#@markdown - `ESM2-8, ESM2-35, ESM2-150, ESM2-650`
#@markdown - `ESMC-300, ESMC-600`
#@markdown - `Random, Random-Transformer`
model_names = "ESMC-300"  #@param {type:"string"}
#@markdown ---

selected_models = [name.strip() for name in model_names.split(",") if name.strip()]

if not selected_models:
    selected_models = standard_benchmark

main.full_args.model_names = selected_models
main.model_args = BaseModelArguments(**main.full_args.__dict__)
args_dict = {k: v for k, v in main.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()}
main.logger_args = SimpleNamespace(**args_dict)
main._write_args()

#@markdown *Press play to choose models:*


In [None]:
#@title **4. Embedding Settings**
#@markdown **Batch size**
batch_size = 4                #@param {type:"integer"}
#@markdown ---

#@markdown **Number of dataloader workers**
#@markdown - We recommend 0 for small sets of sequences, but 4-8 for larger sets
num_workers = 0               #@param {type:"integer"}
#@markdown ---

#@markdown **Download embeddings from Hugging Face**
#@markdown - If there is a precomputed embedding type that's useful to you, it is probably faster to download it
#@markdown - HIGHLY recommended for CPU users
download_embeddings = False   #@param {type:"boolean"}
#@markdown ---

#@markdown **Full residue embeddings**
#@markdown - If true, embeddings are saved as a matrix of shape `(L, d)`
#@markdown - If false, embeddings are pooled to `(d,)`
matrix_embed = False          #@param {type:"boolean"}
#@markdown ---

#@markdown **Embedding Pooling Types**
#@markdown - If more than one is passed, embeddings are concatenated
#@markdown Valid options (comma-separated):
#@markdown - `mean, max, norm, median, std, var, cls, parti`
#@markdown - `parti` (pool parti) must be used on its own
embedding_pooling_types = "mean, std"  #@param {type:"string"}
#@markdown ---

#@markdown **Embedding Data Type**
#@markdown - Embeddings are cast to this data type for storage
embed_dtype = "float32"       #@param ["float32","float16","bfloat16","float8_e4m3fn","float8_e5m2"]
#@markdown ---

#@markdown **Save embeddings to SQLite**
#@markdown - If true, embeddings are saved to a SQLite database
#@markdown - They will be accessed on the fly by the trainer
#@markdown - This is HIGHLY recommended for matrix embeddings
#@markdown - If false, embeddings are saved to a .pth file but loaded all at once
sql = False                   #@param {type:"boolean"}
#@markdown ---

main.full_args.all_seqs = main.all_seqs
main.full_args.batch_size = batch_size
main.full_args.num_workers = num_workers
main.full_args.download_embeddings = download_embeddings
main.full_args.matrix_embed = matrix_embed
main.full_args.embedding_pooling_types = [p.strip() for p in embedding_pooling_types.split(",") if p.strip()]
if embed_dtype == "float32": main.embed_dtype = torch.float32
elif embed_dtype == "float16": main.embed_dtype = torch.float16
elif embed_dtype == "bfloat16": main.embed_dtype = torch.bfloat16   
elif embed_dtype == "float8_e4m3fn": main.embed_dtype = torch.float8_e4m3fn
elif embed_dtype == "float8_e5m2": main.embed_dtype = torch.float8_e5m2
else:
    print(f"Invalid embedding dtype: {embed_dtype}. Using float32.")
    main.embed_dtype = torch.float32
main.sql = sql


main.embedding_args = EmbeddingArguments(**main.full_args.__dict__)
args_dict = {k: v for k, v in main.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()}
main.logger_args = SimpleNamespace(**args_dict)
main.save_embeddings_to_disk()

#@markdown *Press play to embed sequences:*


In [7]:
#@title **5. Probe Settings**

#@markdown **Probe Type**
#@markdown - `linear`: a MLP for pooled embeddings
#@markdown - `transformer`: a transformer model for matrix embeddings
#@markdown - `retrievalnet`: custom combination of cross-attention and convolution for matrix embeddings
probe_type = "linear"     #@param ["linear", "transformer", "retrievalnet"]
#@markdown ---

#@markdown **Tokenwise**
#@markdown - If true, the objective is to predict a property of each token (matrix embeddings only)
#@markdown - If false, the objective is to predict a property of the entire sequence (pooled embeddings OR matrix embeddings)
tokenwise = False         #@param {type:"boolean"}
#@markdown ---

#@markdown **Pre-LayerNorm**
#@markdown - If true, a LayerNorm is applied as the first layer of the probe the probe
#@markdown - Typicall improves performance
pre_ln = True             #@param {type:"boolean"}
#@markdown ---

#@markdown **Number of layers**
#@markdown - Number of hidden layers in the probe
#@markdown - Linear probes have 1 input layer and 2 output layers, so 1 layer is a 4 layer MLP
#@markdown - This refers to how many transformer blocks are used in the transformer probe
#@markdown - Same for retrievalnet probes
n_layers = 1              #@param {type:"integer"}
#@markdown ---

#@markdown **Hidden dimension**
#@markdown - The hidden dimension of the model
#@markdown - 2048 - 8192 is recommended for linear probes, 384 - 1536 is recommended for transformer probes
hidden_dim = 8192         #@param {type:"integer"}
#@markdown ---

#@markdown **Dropout**
#@markdown - Dropout rate for the probe
#@markdown - 0.2 is recommended for linear, 0.1 otherwise
dropout = 0.2             #@param {type:"number"}
#@markdown ---

#@markdown **Classifier dimension**
#@markdown - The dimension of the classifier layer (transformer, retrievalnet probes only)
classifier_dim = 4096     #@param {type:"integer"}
#@markdown ---

#@markdown **Classifier Dropout**
#@markdown - Dropout rate for the classifier layer
classifier_dropout = 0.2  #@param {type:"number"}
#@markdown ---

#@markdown **Number of heads**
#@markdown - Number of attention heads in models with attention
#@markdown - between `hidden_dim // 128` and `hidden_dim // 32` is recommended
n_heads = 4               #@param {type:"integer"}
#@markdown ---

#@markdown **Rotary Embeddings**
#@markdown - If true, rotary embeddings are used with attention layers
rotary = True             #@param {type:"boolean"}
#@markdown ---

#@markdown **Probe Pooling Types**
#@markdown - If more than one is passed, embeddings are concatenated
#@markdown Valid options (comma-separated):
#@markdown - `mean, max, norm, median, std, var, cls`
#@markdown - Is how the transformer or retrievalnet embeddings are pooled for sequence-wise tasks
probe_pooling_types_str = "mean, cls"  #@param {type:"string"}

probe_pooling_types = [p.strip() for p in probe_pooling_types_str.split(",") if p.strip()]

main.full_args.probe_type = probe_type
main.full_args.tokenwise = tokenwise
main.full_args.pre_ln = pre_ln
main.full_args.n_layers = n_layers
main.full_args.hidden_dim = hidden_dim
main.full_args.dropout = dropout
main.full_args.classifier_dim = classifier_dim
main.full_args.classifier_dropout = classifier_dropout
main.full_args.n_heads = n_heads
main.full_args.rotary = rotary
main.full_args.probe_pooling_types = probe_pooling_types

main.probe_args = ProbeArguments(**main.full_args.__dict__)
args_dict = {k: v for k, v in main.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()}
main.logger_args = SimpleNamespace(**args_dict)
main._write_args()

#@markdown ---
#@markdown Press play to configure the probe:


In [8]:
#@title **6. Training Settings**

#@markdown **Use LoRA**
#@markdown - If true, LoRA on the base model
use_lora = False               #@param {type:"boolean"}
#@markdown ---

#@markdown **Hybrid Probe**
#@markdown - If true, the probe is trained on frozen embeddings
#@markdown - Then, the base model is finetuned alongside the probe
hybrid_probe = False           #@param {type:"boolean"}
#@markdown ---

#@markdown **Full Finetuning**
#@markdown - If true, the base model is finetuned for the task
full_finetuning = False        #@param {type:"boolean"}
#@markdown ---

#@markdown **Number of epochs**
num_epochs = 200               #@param {type:"integer"}
#@markdown ---

#@markdown **Trainer Batch Size**
#@markdown - The batch size for probe training
#@markdown - We recommend between 32 and 256 with some combination of this and gradient accumulation steps
trainer_batch_size = 64        #@param {type:"integer"}
#@markdown ---

#@markdown **Gradient Accumulation Steps**
gradient_accumulation_steps = 1  #@param {type:"integer"}
#@markdown ---

#@markdown **Learning Rate**
lr = 0.0001                    #@param {type:"number"}
#@markdown ---

#@markdown **Weight Decay**
#@markdown - If you are having issues with overfitting, try increasing this
weight_decay = 0.0             #@param {type:"number"}
#@markdown ---

#@markdown **Early Stopping Patience**
#@markdown - We recommend keep the epcohs high and using this to gage convergence
patience = 10                   #@param {type:"integer"}
#@markdown ---

main.full_args.use_lora = use_lora
main.full_args.hybrid_probe = hybrid_probe
main.full_args.full_finetuning = full_finetuning
main.full_args.num_epochs = num_epochs
main.full_args.trainer_batch_size = trainer_batch_size
main.full_args.gradient_accumulation_steps = gradient_accumulation_steps
main.full_args.lr = lr
main.full_args.weight_decay = weight_decay
main.full_args.patience = patience

main.trainer_args = TrainerArguments(**main.full_args.__dict__)
args_dict = {k: v for k, v in main.full_args.__dict__.items() if k != 'all_seqs' and 'token' not in k.lower() and 'api' not in k.lower()}
main.logger_args = SimpleNamespace(**args_dict)
main._write_args()

#@markdown ---
#@markdown Press play to run the trainer:
main.run_nn_probe()

In [None]:
#@title **7. Log Replay**

#@markdown **Replay Path**
#@markdown - Replay everything from a log by passing the path to the log file
replay_path = ""  #@param {type:"string"}
#@markdown ---

from logger import LogReplayer
replayer = LogReplayer(replay_path)
replay_args = replayer.parse_log()
replay_args.replay_path = replay_path

for key, value in replay_args.__dict__.items():
    if key in main.full_args.__dict__:
        main.full_args[key] = value

replayer.run_replay(main)

#@markdown ---
#@markdown Press to replay logs:
