In [1]:
import os
import sys
import torch
import numpy as np
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPStrategy

sys.path.append('..')
sys.path.append('FrustraSeq')
from FrustraSeq.models.FrustraSeq import FrustraSeq
from FrustraSeq.dataloader import FrustrationDataModule
from FrustraSeq.utils import run_eval_metrics

config = {
    "experiment_name": "it5_DEB_FOCAL",
    "parquet_path": "../data/frustration/v8_frustration_v2.parquet.gzip",
    "set_key": "split_test", # split_test (gonzalos prots in test) or set_old (split for previous dataset) or split0-3
    "cath_sampling_n": 10, # 100,  # None for no sampling
    "batch_size": 10, #32 for FT; 512 else
    "num_workers": 10,
    "max_seq_length": 100,
    "precision": "full",
    "pLM_model": "../data/protT5",
    "prefix_prostT5": "<AA2fold>",
    "pLM_dim": 1024, #1280 for ESM
    "no_label_token": -100,
    "finetune": False,
    "lora_r": 4,
    "lora_alpha": 1,
    "lora_modules": ["q", "k", "v", "o"], #"wi", "wo", "w1", "w2", "w3", "fc1", "fc2", "fc3"], # ["query", "key", "value", "fc1", "fc2"] for esm
    "ce_weighting": None, #[2.65750085, 0.68876299, 0.8533673], #[10.0, 2.0, 2.5], [(1/0.13)/(1/0.13), (1/0.48)/(1/0.13), (1/0.39)/(1/0.13)]
    "use_focal_loss_instead_of_ce": True,
    "notes": "",
}
architecture = {}
architecture["lr"] = 1e-4
architecture["dropout"] = 0.1
architecture["kernel_1"] = 7
architecture["padding_1"] = architecture["kernel_1"] // 2  # to keep same length
architecture["kernel_2"] = 7
architecture["padding_2"] = architecture["kernel_2"] // 2  # to keep same length
architecture["hidden_dim_0"] = 64
architecture["hidden_dim_1"] = 10
config["architecture"] = architecture

#torch.set_float32_matmul_precision("high")
trainer_precision = "bf16-mixed" #"32"

if config["finetune"]:
    find_unused = False
else:
    find_unused = True



  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_module = FrustrationDataModule(df=None,
                                    parquet_path=config["parquet_path"], 
                                    batch_size=config["batch_size"],
                                    set_key=config["set_key"],
                                    max_seq_length=config["max_seq_length"], 
                                    num_workers=config["num_workers"], # 0 
                                    persistent_workers=True, # Flase
                                    pin_memory=True, # Flase
                                    prefetch_factor=1, #!
                                    sample_size=None,
                                    cath_sampling_n=config["cath_sampling_n"])

early_stop = EarlyStopping(monitor="val_loss",
                            patience=5,
                            min_delta=0.0001,
                            mode='min',
                            verbose=True)
checkpoint = ModelCheckpoint(monitor="val_loss",
                                dirpath=f"./{config['experiment_name']}",
                                filename=f"best_val_model",
                                save_top_k=1,
                                mode='min',
                                save_weights_only=False)
logger = WandbLogger(project="FrustraSeq",
                        name=config["experiment_name"],
                        save_dir=f"./{config['experiment_name']}",
                        log_model=False,
                        offline=True,
                        )
lr_logger = LearningRateMonitor(logging_interval='step')

In [None]:
trainer = Trainer(default_root_dir=f"./{config['experiment_name']}",
                accelerator="mps",
                devices=1,
                #strategy=DDPStrategy(find_unused_parameters=find_unused),
                max_epochs=1,
                logger=logger,
                log_every_n_steps=10,
                val_check_interval=0.2,
                callbacks=[early_stop, checkpoint, lr_logger],
                precision=trainer_precision,
                gradient_clip_val=1,
                enable_progress_bar=False,
                deterministic=False,
                accumulate_grad_batches=1, # used 8 for FT, maybe less for for FT in future.
                )

ckpt_path = None
ckpt_file = f"{config['experiment_name']}/best_val_model.ckpt"

if os.path.exists(ckpt_file):
    ckpt_path = ckpt_file
    print(f"RANK {os.environ.get('RANK', -1)}: Resuming training from checkpoint: {ckpt_file}")
else:
    print(f"RANK {os.environ.get('RANK', -1)}: Starting new training run")

model = FrustraSeq(config=config)

trainer.fit(
    model,
    datamodule=data_module,
    ckpt_path=ckpt_path
)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


RANK -1: Starting new training run
Using focal loss instead of cross-entropy loss for classification head. Overrides ce_weighting if set.
RANK -1: Model initialized.




Loaded 5620 samples from ../data/frustration/v8_frustration_v2.parquet.gzip
Created train/val/test masks
Initialized res_idx_mask and frst_vals tensors
Populated res_idx_mask and frst_vals tensors
Created train dataset
Created val dataset
Created test dataset
Train/Val/Test split: 4990/300/330 samples


/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /Users/janleusch/Documents/phd/pLMtrainer/FrustraSeq/notebooks/it5_DEB_FOCAL exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.


RANK -1: lora params: 0, head params: 8


/Users/janleusch/anaconda3/envs/biotrainer/lib/python3.12/site-packages/pytorch_lightning/utilities/model_summary/model_summary.py:231: Precision bf16-mixed is not supported by the model summary.  Estimated model size in MB will not be accurate. Using 32 bits instead.

  | Name        | Type           | Params | Mode 
-------------------------------------------------------
0 | encoder     | T5EncoderModel | 1.2 B  | eval 
1 | CNN         | Sequential     | 3.2 M  | train
2 | reg_head    | Sequential     | 11     | train
3 | cls_head    | Sequential     | 33     | train
4 | mse_loss_fn | MSELoss        | 0      | train
5 | ce_loss_fn  | FocalLoss      | 0      | train
-------------------------------------------------------
1.2 B     Trainable params
0         Non-trainable params
1.2 B     Total params
4,845.538 Total estimated model params size (MB)
16        Modules in train mode
439       Modules in eval mode


Starting validation 2025-12-22 17:57:52




Val batch with no valid residues - skipping
Ending validation 2025-12-22 17:57:59
Starting training epoch 0 at 2025-12-22 17:57:59




Starting validation 2025-12-22 18:00:20
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping


Metric val_loss improved. New best score: 0.890


Ending validation 2025-12-22 18:01:09
Starting validation 2025-12-22 18:04:03
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping


Metric val_loss improved by 0.320 >= min_delta = 0.0001. New best score: 0.570


Ending validation 2025-12-22 18:04:49
Starting validation 2025-12-22 18:07:25
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping


Metric val_loss improved by 0.060 >= min_delta = 0.0001. New best score: 0.510


Ending validation 2025-12-22 18:08:10
Starting validation 2025-12-22 18:10:44
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping
Val batch with no valid residues - skipping


Metric val_loss improved by 0.018 >= min_delta = 0.0001. New best score: 0.492


Ending validation 2025-12-22 18:11:28



Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
socket.send() raised exception.
socket.send() raised exception.


Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x320ca9550>> (for post_run_cell), with arguments args (<ExecutionResult object at 311f1e120, execution_count=3 error_before_exec=None error_in_exec=1 info=<ExecutionInfo object at 311f1efc0, raw_cell="trainer = Trainer(default_root_dir=f"./{config['ex.." transformed_cell="trainer = Trainer(default_root_dir=f"./{config['ex.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/janleusch/Documents/phd/pLMtrainer/FrustraSeq/notebooks/train_notebook.ipynb#W4sZmlsZQ%3D%3D> result=None>,),kwargs {}:

In [None]:
test_trainer = Trainer(accelerator="gpu", # gpu
                        devices=1, # only use one gpu for inference
                        max_epochs=2, 
                        logger=logger,
                        log_every_n_steps=10,
                        val_check_interval=0.2, 
                        precision=trainer_precision, #!, config["inference_precision"],
                        gradient_clip_val=1,
                        enable_progress_bar=False,
                        )
#tune on val set
data_module.test_dataloader = data_module.val_dataloader
model = FrustraSeq.load_from_checkpoint(checkpoint_path=f"{config['experiment_name']}/best_val_model.ckpt",
                                        config=config)
test_trainer.test(model, datamodule=data_module)
model.save_preds_dict(set="val")

metrics = run_eval_metrics(np.load(f"./{config['experiment_name']}/val_preds.npz"), return_cls_report_dict=False)
print(metrics["cls_report"])
print(metrics["pearson_r"])
print(metrics["mean_absolute_error"])